~samwhited/xmpp

900bf34ceaa0671a74f98137e776234e55f30fb2 — Sam Whited 4 years ago f53fe84
More session cleanup
7 files changed, 117 insertions(+), 110 deletions(-)

M bind.go
M features.go
M sasl.go
M session.go
M starttls.go
M starttls_test.go
M stream.go
M bind.go => bind.go +8 -6
@@ 45,13 45,15 @@ func BindResource() StreamFeature {
			}{}
			return true, nil, d.DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, conn *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (conn.state & Received) == Received {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (session.state & Received) == Received {
				panic("xmpp: bind not yet implemented")
			}

			conn := session.Raw()

			reqID := internal.RandomID(internal.IDLen)
			if resource := conn.config.Origin.Resourcepart(); resource == "" {
			if resource := session.config.Origin.Resourcepart(); resource == "" {
				// Send a request for the server to set a resource part.
				_, err = fmt.Fprintf(conn, bindIQServerGeneratedRP, reqID)
			} else {


@@ 61,7 63,7 @@ func BindResource() StreamFeature {
			if err != nil {
				return mask, nil, err
			}
			tok, err := conn.in.d.Token()
			tok, err := session.in.d.Token()
			if err != nil {
				return mask, nil, err
			}


@@ 78,7 80,7 @@ func BindResource() StreamFeature {
			}{}
			switch start.Name {
			case xml.Name{Space: ns.Client, Local: "iq"}:
				if err = conn.in.d.DecodeElement(&resp, &start); err != nil {
				if err = session.in.d.DecodeElement(&resp, &start); err != nil {
					return mask, nil, err
				}
			default:


@@ 89,7 91,7 @@ func BindResource() StreamFeature {
			case resp.ID != reqID:
				return mask, nil, streamerror.UndefinedCondition
			case resp.Type == ResultIQ:
				conn.origin = resp.Bind.JID
				session.origin = resp.Bind.JID
			case resp.Type == ErrorIQ:
				return mask, nil, resp.Err
			default:

M features.go => features.go +12 -12
@@ 64,13 64,13 @@ type StreamFeature struct {
	Negotiate func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error)
}

func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
	server := (c.state & Received) == Received
func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
	server := (s.state & Received) == Received

	// If we're the server, write the initial stream features.
	var list *streamFeaturesList
	if server {
		list, err = writeStreamFeatures(ctx, c)
		list, err = writeStreamFeatures(ctx, s)
		if err != nil {
			return false, nil, err
		}


@@ 82,7 82,7 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read

	if !server {
		// Read a new startstream:features token.
		t, err = c.Decoder().Token()
		t, err = s.Decoder().Token()
		if err != nil {
			return done, nil, err
		}


@@ 92,7 92,7 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read
		}

		// If we're the client read the rest of the stream features list.
		list, err = readStreamFeatures(ctx, c, start)
		list, err = readStreamFeatures(ctx, s, start)

		switch {
		case err != nil:


@@ 114,7 114,7 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read

		if server {
			// Read a new feature to negotiate.
			t, err = c.Decoder().Token()
			t, err = s.Decoder().Token()
			if err != nil {
				return done, nil, err
			}


@@ 125,7 125,7 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read

			// If the feature was not sent or was already negotiated, error.

			_, negotiated := c.negotiated[start.Name.Space]
			_, negotiated := s.negotiated[start.Name.Space]
			data, sent = list.cache[start.Name.Space]
			if !sent || negotiated {
				// TODO: What should we return here?


@@ 135,7 135,7 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read
			// If we're the client, iterate through the cached features and select one
			// to negotiate.
			for _, v := range list.cache {
				if _, ok := c.negotiated[v.feature.Name.Space]; ok {
				if _, ok := s.negotiated[v.feature.Name.Space]; ok {
					// If this feature has already been negotiated, skip it.
					continue
				}


@@ 159,11 159,11 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read
			}
		}

		mask, rwc, err = data.feature.Negotiate(ctx, c, data.data)
		mask, rwc, err = data.feature.Negotiate(ctx, s, data.data)
		if err == nil {
			c.state |= mask
			s.state |= mask
		}
		c.negotiated[data.feature.Name.Space] = struct{}{}
		s.negotiated[data.feature.Name.Space] = struct{}{}

		// If we negotiated a required feature or a stream restart is required
		// we're done with this feature set.


@@ 172,7 172,7 @@ func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read
		}
	}

	return !list.req || (c.state&Ready == Ready), rwc, err
	return !list.req || (s.state&Ready == Ready), rwc, err
}

type sfData struct {

M sasl.go => sasl.go +12 -10
@@ 67,11 67,13 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, conn *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (conn.state & Received) == Received {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (session.state & Received) == Received {
				panic("SASL server not yet implemented")
			}

			conn := session.Raw()

			var selected sasl.Mechanism
			// Select a mechanism, prefering the client order.
		selectmechanism:


@@ 88,13 90,13 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				return mask, nil, errors.New(`No matching SASL mechanisms found`)
			}

			c := conn.Config()
			c := session.Config()
			opts := []sasl.Option{
				sasl.Authz(c.Identity),
				sasl.Credentials(conn.LocalAddr().(*jid.JID).Localpart(), c.Password),
				sasl.Credentials(session.LocalAddr().(*jid.JID).Localpart(), c.Password),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if tlsconn, ok := conn.rwc.(*tls.Conn); ok {
			if tlsconn, ok := conn.(*tls.Conn); ok {
				opts = append(opts, sasl.ConnState(tlsconn.ConnectionState()))
			}
			client := sasl.NewClient(selected, opts...)


@@ 124,14 126,14 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.
			if !more {
				tok, err := conn.in.d.Token()
				tok, err := session.in.d.Token()
				if err != nil {
					return mask, nil, err
				}
				if t, ok := tok.(xml.StartElement); ok {
					// TODO: Handle the additional data that could be returned if
					// success?
					_, _, err := decodeSASLChallenge(conn.in.d, t, false)
					_, _, err := decodeSASLChallenge(session.in.d, t, false)
					if err != nil {
						return mask, nil, err
					}


@@ 147,13 149,13 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
					return mask, nil, ctx.Err()
				default:
				}
				tok, err := conn.in.d.Token()
				tok, err := session.in.d.Token()
				if err != nil {
					return mask, nil, err
				}
				var challenge []byte
				if t, ok := tok.(xml.StartElement); ok {
					challenge, success, err = decodeSASLChallenge(conn.in.d, t, true)
					challenge, success, err = decodeSASLChallenge(session.in.d, t, true)
					if err != nil {
						return mask, nil, err
					}


@@ 173,7 175,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
					return mask, nil, err
				}
			}
			return Authn, conn.Raw(), nil
			return Authn, conn, nil
		},
	}
}

M session.go => session.go +48 -48
@@ 87,12 87,12 @@ type Session struct {
// Feature checks if a feature with the given namespace was advertised
// by the server for the current stream. If it was data will be the canonical
// representation of the feature as returned by the feature's Parse function.
func (c *Session) Feature(namespace string) (data interface{}, ok bool) {
	c.flock.Lock()
	defer c.flock.Unlock()
func (s *Session) Feature(namespace string) (data interface{}, ok bool) {
	s.flock.Lock()
	defer s.flock.Unlock()

	// TODO: Make the features struct actually store the parsed representation.
	data, ok = c.features[namespace]
	data, ok = s.features[namespace]
	return
}



@@ 101,92 101,92 @@ func (c *Session) Feature(namespace string) (data interface{}, ok bool) {
// is canceled before stream negotiation is complete an error is returned. After
// stream negotiation if the context is canceled it has no effect.
func NewSession(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Session, error) {
	c := &Session{
	s := &Session{
		config: config,
	}

	if conn, ok := rwc.(net.Conn); ok {
		c.conn = conn
		s.conn = conn
	}

	return c, c.negotiateStreams(ctx, rwc)
	return s, s.negotiateStreams(ctx, rwc)
}

// Raw returns the Session's backing net.Conn or other ReadWriteCloser.
func (c *Session) Raw() io.ReadWriteCloser {
	return c.rwc
func (s *Session) Raw() io.ReadWriteCloser {
	return s.rwc
}

// Decoder returns the XML decoder that was used to negotiate the latest stream.
func (c *Session) Decoder() *xml.Decoder {
	return c.in.d
func (s *Session) Decoder() *xml.Decoder {
	return s.in.d
}

// Encoder returns the XML encoder that was used to negotiate the latest stream.
func (c *Session) Encoder() *xml.Encoder {
	return c.out.e
func (s *Session) Encoder() *xml.Encoder {
	return s.out.e
}

// Config returns the connections config.
func (c *Session) Config() *Config {
	return c.config
func (s *Session) Config() *Config {
	return s.config
}

// Read reads data from the connection.
func (c *Session) Read(b []byte) (n int, err error) {
	c.in.Lock()
	defer c.in.Unlock()
func (s *Session) Read(b []byte) (n int, err error) {
	s.in.Lock()
	defer s.in.Unlock()

	if c.state&InputStreamClosed == InputStreamClosed {
	if s.state&InputStreamClosed == InputStreamClosed {
		return 0, errors.New("XML input stream is closed")
	}

	return c.rwc.Read(b)
	return s.rwc.Read(b)
}

// Write writes data to the connection.
func (c *Session) Write(b []byte) (n int, err error) {
	c.out.Lock()
	defer c.out.Unlock()
func (s *Session) Write(b []byte) (n int, err error) {
	s.out.Lock()
	defer s.out.Unlock()

	if c.state&OutputStreamClosed == OutputStreamClosed {
	if s.state&OutputStreamClosed == OutputStreamClosed {
		return 0, errors.New("XML output stream is closed")
	}

	return c.rwc.Write(b)
	return s.rwc.Write(b)
}

// Close closes the underlying connection.
// Any blocked Read or Write operations will be unblocked and return errors.
func (c *Session) Close() error {
	return c.rwc.Close()
func (s *Session) Close() error {
	return s.rwc.Close()
}

// State returns the current state of the session. For more information, see the
// SessionState type.
func (c *Session) State() SessionState {
	return c.state
func (s *Session) State() SessionState {
	return s.state
}

// LocalAddr returns the Origin address for initiated connections, or the
// Location for received connections.
func (c *Session) LocalAddr() net.Addr {
	if (c.state & Received) == Received {
		return c.config.Location
func (s *Session) LocalAddr() net.Addr {
	if (s.state & Received) == Received {
		return s.config.Location
	}
	if c.origin != nil {
		return c.origin
	if s.origin != nil {
		return s.origin
	}
	return c.config.Origin
	return s.config.Origin
}

// RemoteAddr returns the Location address for initiated connections, or the
// Origin address for received connections.
func (c *Session) RemoteAddr() net.Addr {
	if (c.state & Received) == Received {
		return c.config.Origin
func (s *Session) RemoteAddr() net.Addr {
	if (s.state & Received) == Received {
		return s.config.Origin
	}
	return c.config.Location
	return s.config.Location
}

var errSetDeadline = errors.New("xmpp: cannot set deadline: not using a net.Conn")


@@ 202,18 202,18 @@ var errSetDeadline = errors.New("xmpp: cannot set deadline: not using a net.Conn
// successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
func (c *Session) SetDeadline(t time.Time) error {
	if c.conn != nil {
		return c.conn.SetDeadline(t)
func (s *Session) SetDeadline(t time.Time) error {
	if s.conn != nil {
		return s.conn.SetDeadline(t)
	}
	return errSetDeadline
}

// SetReadDeadline sets the deadline for future Read calls. A zero value for t
// means Read will not time out.
func (c *Session) SetReadDeadline(t time.Time) error {
	if c.conn != nil {
		return c.conn.SetReadDeadline(t)
func (s *Session) SetReadDeadline(t time.Time) error {
	if s.conn != nil {
		return s.conn.SetReadDeadline(t)
	}
	return errSetDeadline
}


@@ 221,9 221,9 @@ func (c *Session) SetReadDeadline(t time.Time) error {
// SetWriteDeadline sets the deadline for future Write calls. Even if write
// times out, it may return n > 0, indicating that some of the data was
// successfully written. A zero value for t means Write will not time out.
func (c *Session) SetWriteDeadline(t time.Time) error {
	if c.conn != nil {
		return c.conn.SetWriteDeadline(t)
func (s *Session) SetWriteDeadline(t time.Time) error {
	if s.conn != nil {
		return s.conn.SetWriteDeadline(t)
	}
	return errSetDeadline
}

M starttls.go => starttls.go +12 -13
@@ 11,7 11,6 @@ import (
	"errors"
	"fmt"
	"io"
	"net"

	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/ns"


@@ 57,31 56,31 @@ func StartTLS(required bool) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, conn *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			netconn, ok := conn.Raw().(net.Conn)
			if !ok {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			conn := session.conn
			if conn == nil {
				return mask, nil, ErrTLSUpgradeFailed
			}

			// Fetch or create a TLSConfig to use.
			var tlsconf *tls.Config
			if conn.config.TLSConfig == nil {
			if session.config.TLSConfig == nil {
				tlsconf = &tls.Config{
					ServerName: conn.LocalAddr().(*jid.JID).Domain().String(),
					ServerName: session.LocalAddr().(*jid.JID).Domain().String(),
				}
			} else {
				tlsconf = conn.config.TLSConfig
				tlsconf = session.config.TLSConfig
			}

			if (conn.state & Received) == Received {
			if (session.state & Received) == Received {
				fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
				rwc = tls.Server(netconn, tlsconf)
				rwc = tls.Server(conn, tlsconf)
			} else {
				// Select starttls for negotiation.
				fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)

				// Receive a <proceed/> or <failure/> response from the server.
				t, err := conn.in.d.Token()
				t, err := session.in.d.Token()
				if err != nil {
					return mask, nil, err
				}


@@ 92,13 91,13 @@ func StartTLS(required bool) StreamFeature {
						return mask, nil, streamerror.UnsupportedStanzaType
					case tok.Name.Local == "proceed":
						// Skip the </proceed> token.
						if err = conn.in.d.Skip(); err != nil {
						if err = session.in.d.Skip(); err != nil {
							return mask, nil, streamerror.InvalidXML
						}
						rwc = tls.Client(netconn, tlsconf)
						rwc = tls.Client(conn, tlsconf)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = conn.in.d.Skip(); err != nil {
						if err = session.in.d.Skip(); err != nil {
							err = streamerror.InvalidXML
						}
						// Failure is not an "error", it's expected behavior. Immediately

M starttls_test.go => starttls_test.go +6 -2
@@ 118,6 118,8 @@ func (nopRWC) Close() error {
	return nil
}

var _ net.Conn = dummyConn{}

type dummyConn struct {
	io.ReadWriteCloser
}


@@ 156,7 158,8 @@ func TestNegotiationFailsForNonNetSession(t *testing.T) {
func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	c := &Session{state: Received, rwc: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	c := &Session{state: Received, conn: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	c.rwc = c.conn
	_, rwc, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:


@@ 194,7 197,8 @@ func TestNegotiateClient(t *testing.T) {
		stls := StartTLS(true)
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Session{rwc: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c := &Session{conn: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c.rwc = c.conn
		c.in.d = xml.NewDecoder(c.rwc)
		mask, rwc, err := stls.Negotiate(context.Background(), c, nil)
		switch {

M stream.go => stream.go +19 -19
@@ 107,8 107,8 @@ func sendNewStream(w io.Writer, cfg *Config, id string) error {
		return err
	}

	if conn, ok := w.(*Session); ok {
		conn.out.stream = stream
	if session, ok := w.(*Session); ok {
		session.out.stream = stream
	}
	return nil
}


@@ 118,8 118,8 @@ func expectNewStream(ctx context.Context, r io.Reader) error {

	// If the reader is a Session, use its decoder, otherwise make a new one.
	var d *xml.Decoder
	if conn, ok := r.(*Session); ok {
		d = conn.in.d
	if session, ok := r.(*Session); ok {
		d = session.in.d
	} else {
		d = xml.NewDecoder(r)
	}


@@ 156,12 156,12 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
				return streamerror.UnsupportedVersion
			}

			if conn, ok := r.(*Session); ok {
				if (conn.state&Received) != Received && stream.id == "" {
			if session, ok := r.(*Session); ok {
				if (session.state&Received) != Received && stream.id == "" {
					// if we are the initiating entity and there is no stream ID…
					return streamerror.BadFormat
				}
				conn.in.stream = stream
				session.in.stream = stream
			}
			return nil
		case xml.ProcInst:


@@ 179,40 179,40 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
	}
}

func (c *Session) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (err error) {
func (s *Session) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	for done := false; !done || rwc != nil; {
		if rwc != nil {
			c.features = make(map[string]interface{})
			c.negotiated = make(map[string]struct{})
			c.rwc = rwc
			c.in.d = xml.NewDecoder(c.rwc)
			c.out.e = xml.NewEncoder(c.rwc)
			s.features = make(map[string]interface{})
			s.negotiated = make(map[string]struct{})
			s.rwc = rwc
			s.in.d = xml.NewDecoder(s.rwc)
			s.out.e = xml.NewEncoder(s.rwc)
			rwc = nil

			if (c.state & Received) == Received {
			if (s.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in
				// response.
				if err = expectNewStream(ctx, c); err != nil {
				if err = expectNewStream(ctx, s); err != nil {
					return err
				}
				if err = sendNewStream(c, c.config, internal.RandomID(streamIDLength)); err != nil {
				if err = sendNewStream(s, s.config, internal.RandomID(streamIDLength)); err != nil {
					return err
				}
			} else {
				// If we're the initiating entity, send a new stream and then wait for one
				// in response.
				if err := sendNewStream(c, c.config, ""); err != nil {
				if err := sendNewStream(s, s.config, ""); err != nil {
					return err
				}
				if err := expectNewStream(ctx, c); err != nil {
				if err := expectNewStream(ctx, s); err != nil {
					return err
				}
			}
		}

		if done, rwc, err = c.negotiateFeatures(ctx); err != nil {
		if done, rwc, err = s.negotiateFeatures(ctx); err != nil {
			return err
		}
	}