~samwhited/xmpp

74d83694004d9866034111f2f1587317cec38bb2 — Sam Whited 4 years ago 58a2177
Allow features to upgrade the underlying RWC

Fixes #9
7 files changed, 93 insertions(+), 91 deletions(-)

M bind.go
M conn.go
M features.go
M sasl.go
M starttls.go
M starttls_test.go
M stream.go
M bind.go => bind.go +10 -10
@@ 37,7 37,7 @@ func BindResource() StreamFeature {
			}{}
			return true, nil, d.DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, err error) {
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (conn.state & Received) == Received {
				panic("xmpp: bind not yet implemented")
			} else {


@@ 50,15 50,15 @@ func BindResource() StreamFeature {
					_, err = fmt.Fprintf(conn, bindIQClientRequestedRP, reqID, resource)
				}
				if err != nil {
					return mask, err
					return mask, nil, err
				}
				tok, err := conn.in.d.Token()
				if err != nil {
					return mask, err
					return mask, nil, err
				}
				start, ok := tok.(xml.StartElement)
				if !ok {
					return mask, streamerror.BadFormat
					return mask, nil, streamerror.BadFormat
				}
				resp := struct {
					IQ


@@ 70,23 70,23 @@ func BindResource() StreamFeature {
				switch start.Name {
				case xml.Name{Space: ns.Client, Local: "iq"}:
					if err = conn.in.d.DecodeElement(&resp, &start); err != nil {
						return mask, err
						return mask, nil, err
					}
				default:
					return mask, streamerror.BadFormat
					return mask, nil, streamerror.BadFormat
				}

				switch {
				case resp.ID != reqID:
					return mask, streamerror.UndefinedCondition
					return mask, nil, streamerror.UndefinedCondition
				case resp.Type == ResultIQ:
					conn.origin = resp.Bind.JID
				case resp.Type == ErrorIQ:
					return mask, resp.Err
					return mask, nil, resp.Err
				default:
					return mask, StanzaError{Condition: BadRequest}
					return mask, nil, StanzaError{Condition: BadRequest}
				}
				return Ready, nil
				return Ready, nil, nil
			}
		},
	}

M conn.go => conn.go +1 -3
@@ 65,15 65,13 @@ func (c *Conn) Features() map[xml.Name]struct{} {
func NewConn(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Conn, error) {
	c := &Conn{
		config: config,
		rwc:    rwc,
		state:  StreamRestartRequired,
	}

	if conn, ok := rwc.(net.Conn); ok {
		c.conn = conn
	}

	return c, c.negotiateStreams(ctx)
	return c, c.negotiateStreams(ctx, rwc)
}

// Raw returns the Conn's backing net.Conn or other ReadWriteCloser.

M features.go => features.go +18 -18
@@ 49,16 49,16 @@ type StreamFeature struct {
	// the feature. The "mask" SessionState represents the state bits that should
	// be flipped after negotiation of the feature is complete. For instance, if
	// this feature creates a security layer (such as TLS) and performs
	// authentication, mask would be set to Authn|Secure|StreamRestartRequired,
	// but if it does not authenticate the connection it would return
	// Secure|StreamRestartRequired. If the mask includes the StreamRestart bit,
	// the stream will be restarted automatically after Negotiate returns (unless
	// it returns an error). If this is an initiated connection and the features
	// List call returned a value, that value is passed to the data parameter when
	// Negotiate is called. For instance, in the case of compression this data
	// parameter might be the list of supported algorithms as a slice of strings
	// (or in whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, err error)
	// authentication, mask would be set to Authn|Secure, but if it does not
	// authenticate the connection it would just return Secure. If negotiate
	// returns a new io.ReadWriteCloser (probably wrapping the old conn.Raw()) the
	// stream will be restarted automatically after Negotiate returns using the
	// new RWC. If this is an initiated connection and the features List call
	// returned a value, that value is passed to the data parameter when Negotiate
	// is called. For instance, in the case of compression this data parameter
	// might be the list of supported algorithms as a slice of strings (or in
	// whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error)
}

// Returns the number of stream features written (zero means we've reached the


@@ 88,31 88,31 @@ func writeStreamFeatures(ctx context.Context, conn *Conn) (n int, req int, err e
	return
}

func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, err error) {
func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
	if (c.state & Received) == Received {
		_, _, err = writeStreamFeatures(ctx, c)
		if err != nil {
			return
			return false, nil, err
		}
		panic("Sending stream:features not yet implemented")
	} else {
		t, err := c.in.d.Token()
		if err != nil {
			return done, err
			return done, nil, err
		}
		start, ok := t.(xml.StartElement)
		if !ok {
			return done, streamerror.BadFormat
			return done, nil, streamerror.BadFormat
		}
		list, err := readStreamFeatures(ctx, c, start)

		switch {
		case err != nil:
			return done, err
			return done, nil, err
		case list.total == 0 || len(list.cache) == 0:
			// If we received an empty list (or one with no supported features, we're
			// done.
			return true, nil
			return true, nil, nil
		}

		// If the list has any required items, negotiate the first required feature.


@@ 124,11 124,11 @@ func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, err error) {
				break
			}
		}
		mask, err := data.feature.Negotiate(ctx, c, data.data)
		mask, rwc, err := data.feature.Negotiate(ctx, c, data.data)
		if err == nil {
			c.state |= mask
		}
		return !list.req || (c.state&Ready == Ready), err
		return !list.req || (c.state&Ready == Ready), rwc, err
	}
}


M sasl.go => sasl.go +14 -14
@@ 67,7 67,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, err error) {
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (conn.state & Received) == Received {
				panic("SASL server not yet implemented")
			} else {


@@ 84,7 84,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				}
				// No matching mechanism found…
				if selected.Name == "" {
					return mask, errors.New(`No matching SASL mechanisms found`)
					return mask, nil, errors.New(`No matching SASL mechanisms found`)
				}

				c := conn.Config()


@@ 100,7 100,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {

				more, resp, err := client.Step(nil)
				if err != nil {
					return mask, err
					return mask, nil, err
				}

				// RFC6120 §6.4.2:


@@ 117,7 117,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
					`<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='%s'>%s</auth>`,
					selected.Name, resp,
				); err != nil {
					return mask, err
					return mask, nil, err
				}

				// If we're already done after the first step, decode the <success/> or


@@ 125,17 125,17 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				if !more {
					tok, err := conn.in.d.Token()
					if err != nil {
						return mask, err
						return mask, nil, err
					}
					if t, ok := tok.(xml.StartElement); ok {
						// TODO: Handle the additional data that could be returned if
						// success?
						_, _, err := decodeSASLChallenge(conn.in.d, t, false)
						if err != nil {
							return mask, err
							return mask, nil, err
						}
					} else {
						return mask, streamerror.BadFormat
						return mask, nil, streamerror.BadFormat
					}
				}



@@ 143,24 143,24 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				for more {
					select {
					case <-ctx.Done():
						return mask, ctx.Err()
						return mask, nil, ctx.Err()
					default:
					}
					tok, err := conn.in.d.Token()
					if err != nil {
						return mask, err
						return mask, nil, err
					}
					var challenge []byte
					if t, ok := tok.(xml.StartElement); ok {
						challenge, success, err = decodeSASLChallenge(conn.in.d, t, true)
						if err != nil {
							return mask, err
							return mask, nil, err
						}
					} else {
						return mask, streamerror.BadFormat
						return mask, nil, streamerror.BadFormat
					}
					if more, resp, err = client.Step(challenge); err != nil {
						return mask, err
						return mask, nil, err
					}
					if !more && success {
						// We're done with SASL and we're successful


@@ 169,10 169,10 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
					// TODO: What happens if there's more and success (broken server)?
					if _, err = fmt.Fprintf(conn,
						`<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>%s</response>`, resp); err != nil {
						return mask, err
						return mask, nil, err
					}
				}
				return Authn | StreamRestartRequired, nil
				return Authn, conn.Raw(), nil
			}
		},
	}

M starttls.go => starttls.go +24 -13
@@ 13,6 13,7 @@ import (
	"io"
	"net"

	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/ns"
	"mellium.im/xmpp/streamerror"
)


@@ 25,7 26,7 @@ var (

// StartTLS returns a new stream feature that can be used for negotiating TLS.
// For StartTLS to work, the underlying connection must support TLS (it must
// implement net.Conn) and the connection config must have a TLSConfig.
// implement net.Conn).
func StartTLS(required bool) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Local: "starttls", Space: ns.StartTLS},


@@ 48,15 49,25 @@ func StartTLS(required bool) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, err error) {
			netconn, ok := conn.rwc.(net.Conn)
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			netconn, ok := conn.Raw().(net.Conn)
			if !ok {
				return mask, ErrTLSUpgradeFailed
				return mask, nil, ErrTLSUpgradeFailed
			}

			// Fetch or create a TLSConfig to use.
			var tlsconf *tls.Config
			if conn.config.TLSConfig == nil {
				tlsconf = &tls.Config{
					ServerName: conn.LocalAddr().(*jid.JID).Domain().String(),
				}
			} else {
				tlsconf = conn.config.TLSConfig
			}

			if (conn.state & Received) == Received {
				fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
				conn.rwc = tls.Server(netconn, conn.config.TLSConfig)
				rwc = tls.Server(netconn, tlsconf)
			} else {
				// Select starttls for negotiation.
				fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)


@@ 64,19 75,19 @@ func StartTLS(required bool) StreamFeature {
				// Receive a <proceed/> or <failure/> response from the server.
				t, err := conn.in.d.Token()
				if err != nil {
					return mask, err
					return mask, nil, err
				}
				switch tok := t.(type) {
				case xml.StartElement:
					switch {
					case tok.Name.Space != ns.StartTLS:
						return mask, streamerror.UnsupportedStanzaType
						return mask, nil, streamerror.UnsupportedStanzaType
					case tok.Name.Local == "proceed":
						// Skip the </proceed> token.
						if err = conn.in.d.Skip(); err != nil {
							return mask, streamerror.InvalidXML
							return mask, nil, streamerror.InvalidXML
						}
						conn.rwc = tls.Client(netconn, conn.config.TLSConfig)
						rwc = tls.Client(netconn, tlsconf)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = conn.in.d.Skip(); err != nil {


@@ 86,15 97,15 @@ func StartTLS(required bool) StreamFeature {
						// afterwards the server will end the stream. However, if we
						// encounter bad XML while skipping the </failure> token, return
						// that error.
						return mask, err
						return mask, nil, err
					default:
						return mask, streamerror.UnsupportedStanzaType
						return mask, nil, streamerror.UnsupportedStanzaType
					}
				default:
					return mask, streamerror.RestrictedXML
					return mask, nil, streamerror.RestrictedXML
				}
			}
			mask = Secure | StreamRestartRequired
			mask = Secure
			return
		},
	}

M starttls_test.go => starttls_test.go +20 -22
@@ 11,7 11,6 @@ import (
	"encoding/xml"
	"io"
	"net"
	"reflect"
	"strings"
	"testing"
	"time"


@@ 143,7 142,7 @@ func (dummyConn) SetWriteDeadline(t time.Time) error {
func TestNegotiationFailsForNonNetConn(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	_, err := stls.Negotiate(context.Background(), &Conn{rwc: nopRWC{&b, &b}}, nil)
	_, _, err := stls.Negotiate(context.Background(), &Conn{rwc: nopRWC{&b, &b}}, nil)
	if err != ErrTLSUpgradeFailed {
		t.Errorf("Expected error `%v` but got `%v`", ErrTLSUpgradeFailed, err)
	}


@@ 153,9 152,12 @@ func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	c := &Conn{state: Received, rwc: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	_, err := stls.Negotiate(context.Background(), c, nil)
	if err != nil {
	_, rwc, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:
		t.Fatal(err)
	case rwc == nil:
		t.Fatal("Expected a new RWC when negotiating STARTTLS as a server")
	}

	// The server should send a proceed element.


@@ 166,34 168,30 @@ func TestNegotiateServer(t *testing.T) {
	if err = d.Decode(&proceed); err != nil {
		t.Error(err)
	}

	// The server should upgrade the connection to a tls.Conn
	if _, ok := c.rwc.(*tls.Conn); !ok {
		t.Errorf("Expected server conn to have been upgraded to a *tls.Conn but got %s", reflect.TypeOf(c.rwc))
	}
}

func TestNegotiateClient(t *testing.T) {
	for _, test := range []struct {
		responses []string
		err       bool
		rwc       bool
		state     SessionState
	}{
		{[]string{`<proceed xmlns="badns"/>`}, true, Secure | StreamRestartRequired},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, Secure | StreamRestartRequired},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, 0},
		{[]string{`</somethingbadhappened>`}, true, 0},
		{[]string{`<notproceedorfailure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, true, 0},
		{[]string{`chardata not start element`}, true, 0},
		{[]string{`<proceed xmlns="badns"/>`}, true, false, Secure},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, true, Secure},
		{[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, false, 0},
		{[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		{[]string{`</somethingbadhappened>`}, true, false, 0},
		{[]string{`<notproceedorfailure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, true, false, 0},
		{[]string{`chardata not start element`}, true, false, 0},
	} {
		stls := StartTLS(true)
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Conn{rwc: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c.in.d = xml.NewDecoder(c.rwc)
		mask, err := stls.Negotiate(context.Background(), c, nil)
		mask, rwc, err := stls.Negotiate(context.Background(), c, nil)
		switch {
		case test.err && err == nil:
			t.Error("Expected an error from starttls client negotiation")


@@ 207,10 205,10 @@ func TestNegotiateClient(t *testing.T) {
			t.Errorf("Expected client to send starttls element but got `%s`", b.String())
		case test.state != mask:
			t.Errorf("Expected session state mask %v but got %v", test.state, mask)
		}
		// The client should upgrade the connection to a tls.Conn
		if _, ok := c.rwc.(*tls.Conn); test.state&Secure == Secure && !ok {
			t.Errorf("Expected client conn to have been upgraded to a *tls.Conn but got %s", reflect.TypeOf(c.rwc))
		case test.rwc && rwc == nil:
			t.Error("Expected a new RWC when negotiating STARTTLS as a client")
		case !test.rwc && rwc != nil:
			t.Error("Did not expect a new RWC when negotiating STARTTLS as a client")
		}
	}
}

M stream.go => stream.go +6 -11
@@ 47,11 47,6 @@ const (
	// Indicates that the input stream has been closed with a stream end tag. When
	// set all read operations will return an error.
	InputStreamClosed

	// Indicates that the session's streams must be restarted. This bit will
	// trigger an automatic restart and will be flipped back to off as soon as the
	// stream is restarted.
	StreamRestartRequired
)

type stream struct {


@@ 214,16 209,16 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
	}
}

func (c *Conn) negotiateStreams(ctx context.Context) (err error) {

func (c *Conn) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	for done := false; !done || c.state&StreamRestartRequired == StreamRestartRequired; {
		if c.state&StreamRestartRequired == StreamRestartRequired {
	for done := false; !done || rwc != nil; {
		if rwc != nil {
			c.features = make(map[xml.Name]struct{})
			c.rwc = rwc
			c.in.d = xml.NewDecoder(c.rwc)
			c.out.e = xml.NewEncoder(c.rwc)
			c.state &= ^StreamRestartRequired
			rwc = nil

			if (c.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in


@@ 246,7 241,7 @@ func (c *Conn) negotiateStreams(ctx context.Context) (err error) {
			}
		}

		if done, err = c.negotiateFeatures(ctx); err != nil {
		if done, rwc, err = c.negotiateFeatures(ctx); err != nil {
			return err
		}
	}