// Copyright 2016 The Mellium Contributors. // Use of this source code is governed by the BSD 2-clause // license that can be found in the LICENSE file. package xmpp_test import ( "bytes" "context" "crypto/tls" "encoding/xml" "fmt" "io" "strings" "testing" "mellium.im/xmpp" "mellium.im/xmpp/internal/ns" "mellium.im/xmpp/internal/xmpptest" ) // There is no room for variation on the starttls feature negotiation, so step // through the list process token for token. func TestStartTLSList(t *testing.T) { stls := xmpp.StartTLS(nil) var b bytes.Buffer e := xml.NewEncoder(&b) start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}} r, err := stls.List(context.Background(), e, start) switch { case err != nil: t.Fatal(err) case !r: t.Error("Expected StartTLS listing to be required") } if err = e.Flush(); err != nil { t.Fatal(err) } d := xml.NewDecoder(&b) tok, err := d.Token() if err != nil { t.Fatal(err) } se := tok.(xml.StartElement) switch { case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}: t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name) case len(se.Attr) != 1: t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr) } tok, err = d.Token() if err != nil { t.Fatal(err) } reqStart := tok.(xml.StartElement) switch { case reqStart.Name != xml.Name{Space: ns.StartTLS, Local: "required"}: t.Errorf("Expected required start element but got %+v", se) case len(reqStart.Attr) > 0: t.Errorf("Expected starttls required to have no attributes but got %d", len(reqStart.Attr)) } tok, err = d.Token() if err != nil { t.Fatal(err) } ee := tok.(xml.EndElement) switch { case reqStart.Name != xml.Name{Space: ns.StartTLS, Local: "required"}: t.Errorf("Expected required end element but got %+v", ee) } tok, err = d.Token() if err != nil { t.Fatal(err) } ee = tok.(xml.EndElement) switch { case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}: t.Errorf("Expected starttls end element but got %+v", ee) } } func TestStartTLSParse(t *testing.T) { stls := xmpp.StartTLS(nil) for i, test := range [...]struct { msg string req bool err bool }{ 0: {``, false, false}, 1: {``, false, false}, 2: {``, true, false}, 3: {``, true, false}, 4: {``, false, true}, 5: {``, false, true}, } { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { d := xml.NewDecoder(bytes.NewBufferString(test.msg)) tok, _ := d.Token() se := tok.(xml.StartElement) req, _, err := stls.Parse(context.Background(), d, &se) switch { case test.err && (err == nil): t.Error("Expected starttls.Parse to error") case !test.err && (err != nil): t.Error(err) case req != test.req: t.Errorf("STARTTLS required was wrong; expected %v but got %v", test.req, req) } }) } } type nopRWC struct { io.Reader io.Writer } func (nopRWC) Close() error { return nil } func TestNegotiateServer(t *testing.T) { stls := xmpp.StartTLS(&tls.Config{}) var b bytes.Buffer c := xmpptest.NewSession(xmpp.Received, nopRWC{&b, &b}) _, rw, err := stls.Negotiate(context.Background(), c, nil) switch { case err != nil: t.Fatal(err) case rw == nil: t.Fatal("Expected a new ReadWriter when negotiating STARTTLS as a server") } // The server should send a proceed element. proceed := struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"` }{} d := xml.NewDecoder(&b) if err = d.Decode(&proceed); err != nil { t.Error(err) } } func TestNegotiateClient(t *testing.T) { for i, test := range [...]struct { responses []string err bool rw bool state xmpp.SessionState }{ 0: {[]string{``}, true, false, xmpp.Secure}, 1: {[]string{``}, false, true, xmpp.Secure}, 2: {[]string{``}, true, false, 0}, 3: {[]string{``}, false, false, 0}, 4: {[]string{``}, true, false, 0}, 5: {[]string{``}, true, false, 0}, 6: {[]string{``}, true, false, 0}, 7: {[]string{`chardata not start element`}, true, false, 0}, } { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { stls := xmpp.StartTLS(&tls.Config{}) r := strings.NewReader(strings.Join(test.responses, "\n")) var b bytes.Buffer c := xmpptest.NewSession(0, nopRWC{r, &b}) mask, rw, err := stls.Negotiate(context.Background(), c, nil) switch { case test.err && err == nil: t.Error("Expected an error from starttls client negotiation") return case !test.err && err != nil: t.Error(err) return case test.err && err != nil: return case b.String() != ``: t.Errorf("Expected client to send starttls element but got `%s`", b.String()) case test.state != mask: t.Errorf("Expected session state mask %v but got %v", test.state, mask) case test.rw && rw == nil: t.Error("Expected a new ReadWriter when negotiating STARTTLS as a client") case !test.rw && rw != nil: t.Error("Did not expect a new ReadWriter when negotiating STARTTLS as a client") } }) } }