~samwhited/xmpp

256a70e9908261dcd59420a78cc332d84df06734 — Sam Whited 4 years ago af5b489
xmpp: improve starttls test output
1 files changed, 106 insertions(+), 95 deletions(-)

M starttls_test.go
M starttls_test.go => starttls_test.go +106 -95
@@ 9,6 9,7 @@ import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"fmt"
	"io"
	"strings"
	"testing"


@@ 20,90 21,98 @@ import (
// through the list process token for token.
func TestStartTLSList(t *testing.T) {
	for _, req := range []bool{true, false} {
		stls := StartTLS(req, nil)
		var b bytes.Buffer
		e := xml.NewEncoder(&b)
		start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
		r, err := stls.List(context.Background(), e, start)
		switch {
		case err != nil:
			t.Fatal(err)
		case r != req:
			t.Errorf("Expected StartTLS listing required to be %v but got %v", req, r)
		}
		if err = e.Flush(); err != nil {
			t.Fatal(err)
		name := "optional"
		if req {
			name = "required"
		}
		t.Run(name, func(t *testing.T) {
			stls := StartTLS(req, nil)
			var b bytes.Buffer
			e := xml.NewEncoder(&b)
			start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
			r, err := stls.List(context.Background(), e, start)
			switch {
			case err != nil:
				t.Fatal(err)
			case r != req:
				t.Errorf("Expected StartTLS listing required to be %v but got %v", req, r)
			}
			if err = e.Flush(); err != nil {
				t.Fatal(err)
			}

		d := xml.NewDecoder(&b)
		tok, err := d.Token()
		if err != nil {
			t.Fatal(err)
		}
		se := tok.(xml.StartElement)
		switch {
		case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
			t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name)
		case len(se.Attr) != 1:
			t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr)
		}
		if req {
			tok, err = d.Token()
			d := xml.NewDecoder(&b)
			tok, err := d.Token()
			if err != nil {
				t.Fatal(err)
			}
			se := tok.(xml.StartElement)
			switch {
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
				t.Errorf("Expected required start element but got %+v", se)
			case len(se.Attr) > 0:
				t.Errorf("Expected starttls required to have no attributes but got %d", len(se.Attr))
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
				t.Errorf("Expected starttls to start with %+v token but got %+v", ns.StartTLS, se.Name)
			case len(se.Attr) != 1:
				t.Errorf("Expected starttls start element to have 1 attribute (xmlns), but got %+v", se.Attr)
			}
			if req {
				tok, err = d.Token()
				if err != nil {
					t.Fatal(err)
				}
				se := tok.(xml.StartElement)
				switch {
				case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
					t.Errorf("Expected required start element but got %+v", se)
				case len(se.Attr) > 0:
					t.Errorf("Expected starttls required to have no attributes but got %d", len(se.Attr))
				}
				tok, err = d.Token()
				ee := tok.(xml.EndElement)
				switch {
				case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
					t.Errorf("Expected required end element but got %+v", ee)
				}
			}
			tok, err = d.Token()
			if err != nil {
				t.Fatal(err)
			}
			ee := tok.(xml.EndElement)
			switch {
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "required"}:
				t.Errorf("Expected required end element but got %+v", ee)
			case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
				t.Errorf("Expected starttls end element but got %+v", ee)
			}
		}
		tok, err = d.Token()
		if err != nil {
			t.Fatal(err)
		}
		ee := tok.(xml.EndElement)
		switch {
		case se.Name != xml.Name{Space: ns.StartTLS, Local: "starttls"}:
			t.Errorf("Expected starttls end element but got %+v", ee)
		}
		})
	}
}

func TestStartTLSParse(t *testing.T) {
	stls := StartTLS(true, nil)
	for _, test := range []struct {
	for i, test := range [...]struct {
		msg string
		req bool
		err bool
	}{
		{`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>`, false, false},
		{`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"></starttls>`, false, false},
		{`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls>`, true, false},
		{`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required></required></starttls>`, true, false},
		{`<endtls xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>`, false, true},
		{`<starttls xmlns="badurn"/>`, false, true},
		0: {`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>`, false, false},
		1: {`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"></starttls>`, false, false},
		2: {`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls>`, true, false},
		3: {`<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required></required></starttls>`, true, false},
		4: {`<endtls xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>`, false, true},
		5: {`<starttls xmlns="badurn"/>`, false, true},
	} {
		d := xml.NewDecoder(bytes.NewBufferString(test.msg))
		tok, _ := d.Token()
		se := tok.(xml.StartElement)
		req, _, err := stls.Parse(context.Background(), d, &se)
		switch {
		case test.err && (err == nil):
			t.Error("Expected starttls.Parse to error")
		case !test.err && (err != nil):
			t.Error(err)
		case req != test.req:
			t.Errorf("STARTTLS required was wrong; expected %v but got %v", test.req, req)
		}
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			d := xml.NewDecoder(bytes.NewBufferString(test.msg))
			tok, _ := d.Token()
			se := tok.(xml.StartElement)
			req, _, err := stls.Parse(context.Background(), d, &se)
			switch {
			case test.err && (err == nil):
				t.Error("Expected starttls.Parse to error")
			case !test.err && (err != nil):
				t.Error(err)
			case req != test.req:
				t.Errorf("STARTTLS required was wrong; expected %v but got %v", test.req, req)
			}
		})
	}
}



@@ 139,44 148,46 @@ func TestNegotiateServer(t *testing.T) {
}

func TestNegotiateClient(t *testing.T) {
	for _, test := range []struct {
	for i, test := range [...]struct {
		responses []string
		err       bool
		rw        bool
		state     SessionState
	}{
		{[]string{`<proceed xmlns="badns"/>`}, true, false, Secure},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, true, Secure},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, false, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		{[]string{`</somethingbadhappened>`}, true, false, 0},
		{[]string{`<notproceedorfailure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, true, false, 0},
		{[]string{`chardata not start element`}, true, false, 0},
		0: {[]string{`<proceed xmlns="badns"/>`}, true, false, Secure},
		1: {[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, true, Secure},
		2: {[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		3: {[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, false, 0},
		4: {[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		5: {[]string{`</somethingbadhappened>`}, true, false, 0},
		6: {[]string{`<notproceedorfailure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, true, false, 0},
		7: {[]string{`chardata not start element`}, true, false, 0},
	} {
		stls := StartTLS(true, &tls.Config{})
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Session{conn: newConn(nopRWC{r, &b})}
		c.in.d = xml.NewDecoder(c.conn)
		mask, rw, err := stls.Negotiate(context.Background(), c, nil)
		switch {
		case test.err && err == nil:
			t.Error("Expected an error from starttls client negotiation")
			continue
		case !test.err && err != nil:
			t.Error(err)
			continue
		case test.err && err != nil:
			continue
		case b.String() != `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`:
			t.Errorf("Expected client to send starttls element but got `%s`", b.String())
		case test.state != mask:
			t.Errorf("Expected session state mask %v but got %v", test.state, mask)
		case test.rw && rw == nil:
			t.Error("Expected a new ReadWriter when negotiating STARTTLS as a client")
		case !test.rw && rw != nil:
			t.Error("Did not expect a new ReadWriter when negotiating STARTTLS as a client")
		}
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			stls := StartTLS(true, &tls.Config{})
			r := strings.NewReader(strings.Join(test.responses, "\n"))
			var b bytes.Buffer
			c := &Session{conn: newConn(nopRWC{r, &b})}
			c.in.d = xml.NewDecoder(c.conn)
			mask, rw, err := stls.Negotiate(context.Background(), c, nil)
			switch {
			case test.err && err == nil:
				t.Error("Expected an error from starttls client negotiation")
				return
			case !test.err && err != nil:
				t.Error(err)
				return
			case test.err && err != nil:
				return
			case b.String() != `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`:
				t.Errorf("Expected client to send starttls element but got `%s`", b.String())
			case test.state != mask:
				t.Errorf("Expected session state mask %v but got %v", test.state, mask)
			case test.rw && rw == nil:
				t.Error("Expected a new ReadWriter when negotiating STARTTLS as a client")
			case !test.rw && rw != nil:
				t.Error("Did not expect a new ReadWriter when negotiating STARTTLS as a client")
			}
		})
	}
}