~samwhited/xmpp

c67f48dea0c20ebd113402da1257b3f93045d520 — Sam Whited 5 years ago 6bc1589
Add more tests

Add test of adding features to a config
Add test of stanza error interface
Add a StartTLS negotiate tests
4 files changed, 136 insertions(+), 3 deletions(-)

M config_test.go
M errors_test.go
M starttls.go
M starttls_test.go
M config_test.go => config_test.go +8 -0
@@ 23,3 23,11 @@ func TestS2SConnType(t *testing.T) {
		t.Errorf("Wrong s2s value for conntype; expected xmpp-server but got %s", ct)
	}
}

// New configs should populate the features map with no duplicates.
func TestNewConfigShouldPopulateFeatures(t *testing.T) {
	c := NewServerConfig(nil, nil, BindResource(), BindResource(), StartTLS(true))
	if len(c.Features) != 2 {
		t.Errorf("Expected two features (Bind and StartTLS) but got: %v", c.Features)
	}
}

M errors_test.go => errors_test.go +8 -0
@@ 6,6 6,7 @@ package xmpp

import (
	"fmt"
	"testing"
)

var (


@@ 14,3 15,10 @@ var (
	_ fmt.Stringer = (*errorType)(nil)
	_ fmt.Stringer = Auth
)

func TestErrorReturnsCondition(t *testing.T) {
	s := StanzaError{Condition: "leprosy"}
	if s.Condition != s.Error() {
		t.Errorf("Expected stanza error to return condition `leprosy` but got %s", s.Error())
	}
}

M starttls.go => starttls.go +4 -3
@@ 46,13 46,14 @@ func StartTLS(required bool) StreamFeature {
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == NSStartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, err error) {
			if _, ok := conn.rwc.(net.Conn); !ok {
			netconn, ok := conn.rwc.(net.Conn)
			if !ok {
				return mask, ErrTLSUpgradeFailed
			}

			if (conn.state & Received) == Received {
				fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
				conn.rwc = tls.Server(conn.rwc.(net.Conn), conn.config.TLSConfig)
				conn.rwc = tls.Server(netconn, conn.config.TLSConfig)
			} else {
				// Select starttls for negotiation.
				fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)


@@ 72,7 73,7 @@ func StartTLS(required bool) StreamFeature {
						if err = conn.in.d.Skip(); err != nil {
							return EndStream, InvalidXML
						}
						conn.rwc = tls.Client(conn.rwc.(net.Conn), conn.config.TLSConfig)
						conn.rwc = tls.Client(netconn, conn.config.TLSConfig)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = conn.in.d.Skip(); err != nil {

M starttls_test.go => starttls_test.go +116 -0
@@ 7,8 7,14 @@ package xmpp
import (
	"bytes"
	"context"
	"crypto/tls"
	"encoding/xml"
	"io"
	"net"
	"reflect"
	"strings"
	"testing"
	"time"
)

// There is no room for variation on the starttls feature negotiation, so step


@@ 96,3 102,113 @@ func TestStartTLSParse(t *testing.T) {
		}
	}
}

type nopRWC struct {
	io.Reader
	io.Writer
}

func (nopRWC) Close() error {
	return nil
}

type dummyConn struct {
	io.ReadWriteCloser
}

func (dummyConn) LocalAddr() net.Addr {
	return nil
}

func (dummyConn) RemoteAddr() net.Addr {
	return nil
}

func (dummyConn) SetDeadline(t time.Time) error {
	return nil
}

func (dummyConn) SetReadDeadline(t time.Time) error {
	return nil
}

func (dummyConn) SetWriteDeadline(t time.Time) error {
	return nil
}

// We can't create a tls.Client or tls.Server for a generic RWC, so ensure that
// we fail (with a specific error) if this is the case.
func TestNegotiationFailsForNonNetConn(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	_, err := stls.Negotiate(context.Background(), &Conn{rwc: nopRWC{&b, &b}}, nil)
	if err != ErrTLSUpgradeFailed {
		t.Errorf("Expected error `%v` but got `%v`", ErrTLSUpgradeFailed, err)
	}
}

func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	c := &Conn{state: Received, rwc: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	_, err := stls.Negotiate(context.Background(), c, nil)
	if err != nil {
		t.Fatal(err)
	}

	// The server should send a proceed element.
	proceed := struct {
		XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
	}{}
	d := xml.NewDecoder(&b)
	if err = d.Decode(&proceed); err != nil {
		t.Error(err)
	}

	// The server should upgrade the connection to a tls.Conn
	if _, ok := c.rwc.(*tls.Conn); !ok {
		t.Errorf("Expected server conn to have been upgraded to a *tls.Conn but got %s", reflect.TypeOf(c.rwc))
	}
}

func TestNegotiateClient(t *testing.T) {
	for _, test := range []struct {
		responses []string
		err       bool
		state     SessionState
	}{
		{[]string{`<proceed xmlns="badns"/>`}, true, Secure | StreamRestartRequired},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, Secure | StreamRestartRequired},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, EndStream},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, 0},
		{[]string{`</somethingbadhappened>`}, true, 0},
		{[]string{`<notproceedorfailure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, true, 0},
		{[]string{`chardata not start element`}, true, 0},
	} {
		stls := StartTLS(true)
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Conn{rwc: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c.in.d = xml.NewDecoder(c.rwc)
		mask, err := stls.Negotiate(context.Background(), c, nil)
		switch {
		case test.err && err == nil:
			t.Error("Expected an error from starttls client negotiation")
			continue
		case !test.err && err != nil:
			t.Error(err)
			continue
		case test.err && err != nil:
			continue
		case b.String() != `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`:
			t.Errorf("Expected client to send starttls element but got `%s`", b.String())
		case test.state != mask:
			t.Errorf("Expected session state mask %v but got %v", test.state, mask)
		}
		// The client should upgrade the connection to a tls.Conn
		if _, ok := c.rwc.(*tls.Conn); test.state&Secure == Secure && !ok {
			t.Errorf("Expected client conn to have been upgraded to a *tls.Conn but got %s", reflect.TypeOf(c.rwc))
		}
	}
}