~samwhited/xmpp

0e217cd1ebe1902b1853370d00822f282ae3064e — Sam Whited 1 year, 10 days ago 3ca2802
xmpp: make STARTTLS always required

TLS (or at the time, SSL) may have been an optional feature in the past,
but it's not anymore. These days it's far more likely that a server will
always want to require TLS in some form, so giving the user the ability
to turn it off just means we're giving users who won't understand the
consequences of their actions a knob to twiddle. In the very rare case
that a user actually *does* need STARTTLS to be an optional stream
feature, I don't think it's something we should support. For this rare
use case, they'll have to take the maintenance burden on themselves by
copy/pasting the StartTLS feature code and tweaking it for their needs.

Fixes #50

Signed-off-by: Sam Whited <sam@samwhited.com>
M CHANGELOG.md => CHANGELOG.md +5 -0
@@ 5,6 5,11 @@ All notable changes to this project will be documented in this file.

## Unreleased

### Breaking

- xmpp: remove option to make STARTTLS feature optional


### Added

- xmpp: `ConnectionState` method

M echobot_example_test.go => echobot_example_test.go +1 -1
@@ 35,7 35,7 @@ func Example_echobot() {
	s, err := xmpp.DialClientSession(
		context.TODO(), j,
		xmpp.BindResource(),
		xmpp.StartTLS(true, &tls.Config{
		xmpp.StartTLS(&tls.Config{
			ServerName: j.Domain().String(),
		}),
		xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),

M examples/echobot/echo.go => examples/echobot/echo.go +1 -1
@@ 42,7 42,7 @@ func echo(ctx context.Context, addr, pass string, xmlIn, xmlOut io.Writer, logge
		Lang: "en",
		Features: []xmpp.StreamFeature{
			xmpp.BindResource(),
			xmpp.StartTLS(true, &tls.Config{
			xmpp.StartTLS(&tls.Config{
				ServerName: j.Domain().String(),
			}),
			xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),

M examples/im/main.go => examples/im/main.go +1 -1
@@ 129,7 129,7 @@ func main() {
	session, err := xmpp.DialClientSession(
		dialCtx, parsedAddr,
		xmpp.BindResource(),
		xmpp.StartTLS(true, &tls.Config{
		xmpp.StartTLS(&tls.Config{
			ServerName: parsedAddr.Domain().String(),
		}),
		xmpp.SASL("", pass, sasl.ScramSha256Plus, sasl.ScramSha1Plus, sasl.ScramSha256, sasl.ScramSha1, sasl.Plain),

M examples/msgrepl/main.go => examples/msgrepl/main.go +1 -1
@@ 68,7 68,7 @@ func main() {
	session, err := xmpp.DialClientSession(
		dialCtx, parsedAddr,
		xmpp.BindResource(),
		xmpp.StartTLS(true, &tls.Config{
		xmpp.StartTLS(&tls.Config{
			ServerName: parsedAddr.Domain().String(),
		}),
		xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),

M session_test.go => session_test.go +1 -1
@@ 99,7 99,7 @@ var negotiateTests = [...]negotiateTestCase{
	0: {negotiator: errNegotiator, err: errTestNegotiate},
	1: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: []xmpp.StreamFeature{xmpp.StartTLS(true, nil)},
			Features: []xmpp.StreamFeature{xmpp.StartTLS(nil)},
		}),
		in:  `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out: `<?xml version="1.0" encoding="UTF-8"?><stream:stream to='' from='' version='1.0' xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams'><starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`,

M starttls.go => starttls.go +9 -11
@@ 19,24 19,22 @@ import (
// StartTLS returns a new stream feature that can be used for negotiating TLS.
// If cfg is nil, a default configuration is used that uses the domainpart of
// the sessions local address as the ServerName.
func StartTLS(required bool, cfg *tls.Config) StreamFeature {
func StartTLS(cfg *tls.Config) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Local: "starttls", Space: ns.StartTLS},
		Prohibited: Secure,
		List: func(ctx context.Context, e xmlstream.TokenWriter, start xml.StartElement) (req bool, err error) {
			if err = e.EncodeToken(start); err != nil {
				return required, err
				return true, err
			}
			if required {
				startRequired := xml.StartElement{Name: xml.Name{Space: "", Local: "required"}}
				if err = e.EncodeToken(startRequired); err != nil {
					return required, err
				}
				if err = e.EncodeToken(startRequired.End()); err != nil {
					return required, err
				}
			startRequired := xml.StartElement{Name: xml.Name{Space: "", Local: "required"}}
			if err = e.EncodeToken(startRequired); err != nil {
				return true, err
			}
			if err = e.EncodeToken(startRequired.End()); err != nil {
				return true, err
			}
			return required, e.EncodeToken(start.End())
			return true, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, r xml.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {

M starttls_test.go => starttls_test.go +57 -67
@@ 22,76 22,66 @@ import (
// There is no room for variation on the starttls feature negotiation, so step
// through the list process token for token.
func TestStartTLSList(t *testing.T) {
	for _, req := range []bool{true, false} {
		name := "optional"
		if req {
			name = "required"
		}
		t.Run(name, func(t *testing.T) {
			stls := xmpp.StartTLS(req, nil)
			var b bytes.Buffer
			e := xml.NewEncoder(&b)
			start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
			r, err := stls.List(context.Background(), e, start)
			switch {
			case err != nil:
				t.Fatal(err)
			case r != req:
				t.Errorf("Expected StartTLS listing required to be %v but got %v", req, r)
			}
			if err = e.Flush(); err != nil {
				t.Fatal(err)
			}
	stls := xmpp.StartTLS(nil)
	var b bytes.Buffer
	e := xml.NewEncoder(&b)
	start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
	r, err := stls.List(context.Background(), e, start)
	switch {
	case err != nil:
		t.Fatal(err)
	case !r:
		t.Error("Expected StartTLS listing to be required")
	}
	if err = e.Flush(); err != nil {
		t.Fatal(err)
	}

			d := xml.NewDecoder(&b)
			tok, err := d.Token()
			if err != nil {
				t.Fatal(err)
			}
			se := tok.(xml.StartElement)
			switch {
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
				t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name)
			case len(se.Attr) != 1:
				t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr)
			}
			if req {
				tok, err = d.Token()
				if err != nil {
					t.Fatal(err)
				}
				se := tok.(xml.StartElement)
				switch {
				case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
					t.Errorf("Expected required start element but got %+v", se)
				case len(se.Attr) > 0:
					t.Errorf("Expected starttls required to have no attributes but got %d", len(se.Attr))
				}
				tok, err = d.Token()
				if err != nil {
					t.Fatal(err)
				}
				ee := tok.(xml.EndElement)
				switch {
				case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
					t.Errorf("Expected required end element but got %+v", ee)
				}
			}
			tok, err = d.Token()
			if err != nil {
				t.Fatal(err)
			}
			ee := tok.(xml.EndElement)
			switch {
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
				t.Errorf("Expected starttls end element but got %+v", ee)
			}
		})
	d := xml.NewDecoder(&b)
	tok, err := d.Token()
	if err != nil {
		t.Fatal(err)
	}
	se := tok.(xml.StartElement)
	switch {
	case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
		t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name)
	case len(se.Attr) != 1:
		t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr)
	}
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	reqStart := tok.(xml.StartElement)
	switch {
	case reqStart.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
		t.Errorf("Expected required start element but got %+v", se)
	case len(reqStart.Attr) > 0:
		t.Errorf("Expected starttls required to have no attributes but got %d", len(reqStart.Attr))
	}
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	ee := tok.(xml.EndElement)
	switch {
	case reqStart.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
		t.Errorf("Expected required end element but got %+v", ee)
	}
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	ee = tok.(xml.EndElement)
	switch {
	case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
		t.Errorf("Expected starttls end element but got %+v", ee)
	}
}

func TestStartTLSParse(t *testing.T) {
	stls := xmpp.StartTLS(true, nil)
	stls := xmpp.StartTLS(nil)
	for i, test := range [...]struct {
		msg string
		req bool


@@ 131,7 121,7 @@ func (nopRWC) Close() error {
}

func TestNegotiateServer(t *testing.T) {
	stls := xmpp.StartTLS(true, &tls.Config{})
	stls := xmpp.StartTLS(&tls.Config{})
	var b bytes.Buffer
	c := xmpptest.NewSession(xmpp.Received, nopRWC{&b, &b})
	_, rw, err := stls.Negotiate(context.Background(), c, nil)


@@ 169,7 159,7 @@ func TestNegotiateClient(t *testing.T) {
		7: {[]string{`chardata not start element`}, true, false, 0},
	} {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			stls := xmpp.StartTLS(true, &tls.Config{})
			stls := xmpp.StartTLS(&tls.Config{})
			r := strings.NewReader(strings.Join(test.responses, "\n"))
			var b bytes.Buffer
			c := xmpptest.NewSession(0, nopRWC{r, &b})