~samwhited/xmpp

f53fe84240a18afcd1d91c5192c1e8b77787464f — Sam Whited 5 years ago bc544a0
Add Session API and remove Conn
17 files changed, 202 insertions(+), 228 deletions(-)

M bind.go
M compress/compression.go
M config.go
M config_test.go
D conn_test.go
M dial.go
M dial_test.go
M example_test.go
M features.go
M listen.go
M lookup.go
M message_test.go
M sasl.go
R conn.go => session.go
M starttls.go
M starttls_test.go
M stream.go
M bind.go => bind.go +1 -1
@@ 45,7 45,7 @@ func BindResource() StreamFeature {
			}{}
			return true, nil, d.DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, conn *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (conn.state & Received) == Received {
				panic("xmpp: bind not yet implemented")
			}

M compress/compression.go => compress/compression.go +8 -6
@@ 83,14 83,16 @@ func New(methods ...Method) xmpp.StreamFeature {

			return true, listed.Methods, nil
		},
		Negotiate: func(ctx context.Context, conn *xmpp.Conn, data interface{}) (mask xmpp.SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rwc io.ReadWriteCloser, err error) {
			conn := session.Raw()

			// If we're a server.
			if (conn.State() & xmpp.Received) == xmpp.Received {
			if (session.State() & xmpp.Received) == xmpp.Received {
				clientSelected := struct {
					XMLName xml.Name `xml:"http://jabber.org/protocol/compress compress"`
					Method  string   `xml:"method"`
				}{}
				if err = conn.Decoder().Decode(&clientSelected); err != nil {
				if err = session.Decoder().Decode(&clientSelected); err != nil {
					return
				}



@@ 118,7 120,7 @@ func New(methods ...Method) xmpp.StreamFeature {
					return
				}

				rwc, err = selected.Wrapper(conn.Raw())
				rwc, err = selected.Wrapper(conn)
				return mask, rwc, err
			}



@@ 142,7 144,7 @@ func New(methods ...Method) xmpp.StreamFeature {
				return
			}

			d := conn.Decoder()
			d := session.Decoder()
			tok, err := d.Token()
			if err != nil {
				return mask, nil, err


@@ 152,7 154,7 @@ func New(methods ...Method) xmpp.StreamFeature {
				if err = d.Skip(); err != nil {
					return mask, nil, err
				}
				rwc, err = selected.Wrapper(conn.Raw())
				rwc, err = selected.Wrapper(conn)
				return mask, rwc, err
			}


M config.go => config.go +0 -7
@@ 71,10 71,3 @@ func NewServerConfig(location, origin *jid.JID, features ...StreamFeature) (c *C
	}
	return c
}

func connType(s2s bool) string {
	if s2s {
		return "xmpp-server"
	}
	return "xmpp-client"
}

M config_test.go => config_test.go +2 -2
@@ 9,7 9,7 @@ import (
)

// The default value of config.conntype should return "xmpp-client"
func TestDefaultConnType(t *testing.T) {
func TestDefaultSessionType(t *testing.T) {
	c := &Config{}
	if ct := connType(c.S2S); ct != "xmpp-client" {
		t.Errorf("Wrong default value for conntype; expected xmpp-client but got %s", ct)


@@ 17,7 17,7 @@ func TestDefaultConnType(t *testing.T) {
}

// If S2S is true, config.conntype should return "xmpp-server"
func TestS2SConnType(t *testing.T) {
func TestS2SSessionType(t *testing.T) {
	c := &Config{S2S: true}
	if ct := connType(c.S2S); ct != "xmpp-server" {
		t.Errorf("Wrong s2s value for conntype; expected xmpp-server but got %s", ct)

D conn_test.go => conn_test.go +0 -11
@@ 1,11 0,0 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp

import (
	"net"
)

var _ net.Conn = (*Conn)(nil)

M dial.go => dial.go +75 -102
@@ 13,11 13,41 @@ import (
	"mellium.im/xmpp/jid"
)

// DialClient discovers and connects to the address on the named network with a
// client-to-server (c2s) connection.
//
// If the context expires before the connection is complete, an error is
// returned. Once successfully connected, any expiration of the context will not
// affect the connection.
//
// addr is a JID with a domainpart of the server we wish to connect too.
// DialClient will attempt to look up SRV records for the given JIDs domainpart
// or connect to the domainpart directly.
//
// Network may be any of the network types supported by net.Dial, but you almost
// certainly want to use one of the tcp connection types ("tcp", "tcp4", or
// "tcp6").
func DialClient(ctx context.Context, network string, addr *jid.JID) (net.Conn, error) {
	var d Dialer
	return d.Dial(ctx, network, addr)
}

// DialServer discovers and connects to the address on the named network with a
// server-to-server connection (s2s).
//
// For more info see the DialClient function.
func DialServer(ctx context.Context, network string, addr *jid.JID) (net.Conn, error) {
	d := Dialer{
		S2S: true,
	}
	return d.Dial(ctx, network, addr)
}

// A Dialer contains options for connecting to an XMPP address.
//
// The zero value for each field is equivalent to dialing without that option.
// Dialing with the zero value of Dialer is therefore equivalent to just calling
// the Dial function.
// the DialClient function.
type Dialer struct {
	net.Dialer



@@ 25,32 55,51 @@ type Dialer struct {
	// domain. It also prevents fetching of the host metadata file.
	// Instead, it will try to connect to the domain directly.
	NoLookup bool

	// Attempt to dial a server-to-server connection.
	S2S bool
}

// Dial discovers and connects to the address on the named network that services
// the given local address with a client-to-server (c2s) connection.
// Dial discovers and connects to the address on the named network.
//
// laddr is the clients origin address. The remote address is taken from the
// origins domain part or from the domains SRV records. For a description of the
// ctx and network arguments, see the Dial function.
func Dial(ctx context.Context, network string, laddr *jid.JID) (*Conn, error) {
	var d Dialer
	return d.Dial(ctx, network, laddr)
// For a description of the arguments see the DialClient function.
func (d *Dialer) Dial(ctx context.Context, network string, addr *jid.JID) (net.Conn, error) {
	return d.dial(ctx, network, addr)
}

// DialConfig connects to the address on the named network using the provided
// config.
//
// The context must be non-nil. If the context expires before the connection is
// complete, an error is returned. Once successfully connected, any expiration
// of the context will not affect the connection.
//
// Network may be any of the network types supported by net.Dial, but you almost
// certainly want to use one of the tcp connection types ("tcp", "tcp4", or
// "tcp6").
func DialConfig(ctx context.Context, network string, config *Config) (*Conn, error) {
	var d Dialer
	return d.DialConfig(ctx, network, config)
func (d *Dialer) dial(ctx context.Context, network string, addr *jid.JID) (net.Conn, error) {
	if d.NoLookup {
		p, err := lookupPort(network, connType(d.S2S))
		if err != nil {
			return nil, err
		}
		return d.Dialer.DialContext(ctx, network, net.JoinHostPort(
			addr.Domainpart(),
			strconv.FormatUint(uint64(p), 10),
		))
	}

	addrs, err := lookupService(connType(d.S2S), network, addr)
	if err != nil {
		return nil, err
	}

	// Try dialing all of the SRV records we know about, breaking as soon as the
	// connection is established.
	for _, addr := range addrs {
		conn, e := d.Dialer.DialContext(
			ctx, network, net.JoinHostPort(
				addr.Target, strconv.FormatUint(uint64(addr.Port), 10),
			),
		)
		if e != nil {
			err = e
			continue
		}

		return conn, nil
	}
	return nil, err
}

// Copied from the net package in the standard library. Copyright The Go


@@ 83,85 132,9 @@ func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Tim
	return minNonzeroTime(earliest, d.Deadline)
}

// Dial discovers and connects to the address on the named network that services
// the given local address with a client-to-server (c2s) connection.
//
// For a description of the arguments see the Dial function.
func (d *Dialer) Dial(ctx context.Context, network string, laddr *jid.JID) (*Conn, error) {
	c := NewClientConfig(laddr)
	return d.DialConfig(ctx, network, c)
}

// DialConfig connects to the address on the named network using the provided
// config.
//
// For a description of the arguments see the Dial function.
func (d *Dialer) DialConfig(ctx context.Context, network string, config *Config) (*Conn, error) {
	c, err := d.dial(ctx, network, config)
	if err != nil {
		return c, err
func connType(s2s bool) string {
	if s2s {
		return "xmpp-server"
	}

	return c, err
}

func (d *Dialer) dial(ctx context.Context, network string, config *Config) (*Conn, error) {
	if ctx == nil {
		panic("xmpp.Dial: nil context")
	}

	// If we haven't specified any stream features, set some default ones.
	// if config.Features == nil || len(config.Features) == 0 {
	// 	stls := StartTLS(config.TLSConfig != nil)
	// 	bind := BindResource()
	// 	username, password := config.Origin.Domain().String(), config.Secret
	// 	sasl := SASL(
	// 		sasl.Plain("",           username, "password"),
	// 		sasl.ScramSha256("",     username, "password"),
	// 		sasl.ScramSha256Plus("", username, "password"),
	// 		sasl.ScramSha1("",       username, "password"),
	// 		sasl.ScramSha1Plus("",   username, "password"),
	// 	)
	// 	config.Features = map[xml.Name]StreamFeature{
	// 		stls.Name: stls,
	// 		sasl.Name: sasl,
	// 		bind.Name: bind,
	// 	}
	// }

	if d.NoLookup {
		p, err := lookupPort(network, connType(config.S2S))
		if err != nil {
			return nil, err
		}
		conn, err := d.Dialer.DialContext(ctx, network, net.JoinHostPort(
			config.Location.Domainpart(),
			strconv.FormatUint(uint64(p), 10),
		))
		if err != nil {
			return nil, err
		}
		return NewConn(ctx, config, conn)
	}

	addrs, err := lookupService(connType(config.S2S), network, config.Location)
	if err != nil {
		return nil, err
	}

	// Try dialing all of the SRV records we know about, breaking as soon as the
	// connection is established.
	for _, addr := range addrs {
		if conn, e := d.Dialer.DialContext(
			ctx, network, net.JoinHostPort(
				addr.Target, strconv.FormatUint(uint64(addr.Port), 10),
			),
		); e != nil {
			err = e
			continue
		} else {
			return NewConn(ctx, config, conn)
		}
	}
	return nil, err
	return "xmpp-client"
}

M dial_test.go => dial_test.go +17 -1
@@ 16,5 16,21 @@ func TestDialClientPanicsIfNilContext(t *testing.T) {
			t.Error("Expected Dial to panic when passed a nil context.")
		}
	}()
	Dial(nil, "tcp", jid.MustParse("feste@shakespeare.lit"))
	DialClient(nil, "tcp", jid.MustParse("feste@shakespeare.lit"))
}

// The default value of config.conntype should return "xmpp-client"
func TestDefaultConnType(t *testing.T) {
	c := &Config{}
	if ct := connType(c.S2S); ct != "xmpp-client" {
		t.Errorf("Wrong default value for conntype; expected xmpp-client but got %s", ct)
	}
}

// If S2S is true, config.conntype should return "xmpp-server"
func TestS2SConnType(t *testing.T) {
	c := &Config{S2S: true}
	if ct := connType(c.S2S); ct != "xmpp-server" {
		t.Errorf("Wrong s2s value for conntype; expected xmpp-server but got %s", ct)
	}
}

M example_test.go => example_test.go +8 -3
@@ 34,14 34,19 @@ func Example_rawSendMessage() {

	log.Printf("Dialing upstream XMPP server as %s…\n", laddr)

	c, err := xmpp.DialConfig(context.Background(), "tcp", config)
	c, err := xmpp.DialClient(context.Background(), "tcp", laddr)
	if err != nil {
		log.Fatal(err)
	}

	s, err := xmpp.NewSession(context.Background(), config, c)
	if err != nil {
		log.Fatal(err)
	}

	log.Printf("Connected with JID `%s`\n", c.LocalAddr())

	err = c.Encoder().Encode(struct {
	err = s.Encoder().Encode(struct {
		xmpp.Message
		Body string `xml:"body"`
	}{


@@ 56,7 61,7 @@ func Example_rawSendMessage() {
		log.Fatal(err)
	}

	err = c.Encoder().Flush()
	err = s.Encoder().Flush()
	if err != nil {
		log.Fatal(err)
	}

M features.go => features.go +17 -17
@@ 61,10 61,10 @@ type StreamFeature struct {
	// is called. For instance, in the case of compression this data parameter
	// might be the list of supported algorithms as a slice of strings (or in
	// whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error)
	Negotiate func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error)
}

func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
func (c *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
	server := (c.state & Received) == Received

	// If we're the server, write the initial stream features.


@@ 189,8 189,8 @@ type streamFeaturesList struct {
	cache map[string]sfData
}

func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesList, err error) {
	e := conn.Encoder()
func writeStreamFeatures(ctx context.Context, s *Session) (list *streamFeaturesList, err error) {
	e := s.Encoder()

	start := xml.StartElement{Name: xml.Name{Space: "", Local: "stream:features"}}
	if err = e.EncodeToken(start); err != nil {


@@ 202,12 202,12 @@ func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesL
		cache: make(map[string]sfData),
	}

	for _, feature := range conn.config.Features {
	for _, feature := range s.config.Features {
		// Check if all the necessary bits are set and none of the prohibited bits
		// are set.
		if (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
		if (s.state&feature.Necessary) == feature.Necessary && (s.state&feature.Prohibited) == 0 {
			var r bool
			r, err = feature.List(ctx, conn.out.e, xml.StartElement{
			r, err = feature.List(ctx, s.out.e, xml.StartElement{
				Name: feature.Name,
			})
			if err != nil {


@@ 233,7 233,7 @@ func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesL
	return
}

func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement) (*streamFeaturesList, error) {
func readStreamFeatures(ctx context.Context, s *Session, start xml.StartElement) (*streamFeaturesList, error) {
	switch {
	case start.Name.Local != "features":
		return nil, streamerror.InvalidXML


@@ 242,8 242,8 @@ func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement)
	}

	// Lock the connection features list.
	conn.flock.Lock()
	defer conn.flock.Unlock()
	s.flock.Lock()
	defer s.flock.Unlock()

	sf := &streamFeaturesList{
		cache: make(map[string]sfData),


@@ 251,7 251,7 @@ func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement)

parsefeatures:
	for {
		t, err := conn.in.d.Token()
		t, err := s.in.d.Token()
		if err != nil {
			return nil, err
		}


@@ 263,15 263,15 @@ parsefeatures:

			// Always add the feature to the list of features, even if we don't
			// support it.
			conn.features[tok.Name.Space] = nil
			s.features[tok.Name.Space] = nil

			if feature, ok := conn.config.Features[tok.Name]; ok && (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
				req, data, err := feature.Parse(ctx, conn.in.d, &tok)
			if feature, ok := s.config.Features[tok.Name]; ok && (s.state&feature.Necessary) == feature.Necessary && (s.state&feature.Prohibited) == 0 {
				req, data, err := feature.Parse(ctx, s.in.d, &tok)
				if err != nil {
					return nil, err
				}

				// TODO: Since we're storing the features data on conn.features we can
				// TODO: Since we're storing the features data on s.features we can
				// probably remove it from this temporary cache.
				sf.cache[tok.Name.Space] = sfData{
					req:     req,


@@ 281,14 281,14 @@ parsefeatures:

				// Since we do support the feature, add it to the connections list along
				// with any data returned from Parse.
				conn.features[tok.Name.Space] = data
				s.features[tok.Name.Space] = data
				if req {
					sf.req = true
				}
				continue parsefeatures
			}
			// If the feature is not one we support, skip it.
			if err := conn.in.d.Skip(); err != nil {
			if err := s.in.d.Skip(); err != nil {
				return nil, err
			}
		case xml.EndElement:

M listen.go => listen.go +3 -3
@@ 34,7 34,7 @@ func (l *Listener) Accept() (net.Conn, error) {
		return nil, err
	}

	c := &Conn{conn: conn, state: Received, config: l.config}
	c := &Session{conn: conn, state: Received, config: l.config}

	// If the connection is a tls.Conn already, make sure we don't advertise
	// StartTLS.


@@ 46,9 46,9 @@ func (l *Listener) Accept() (net.Conn, error) {
}

// AcceptXMPP accepts the next incoming call and returns the new connection.
func (l *Listener) AcceptXMPP() (*Conn, error) {
func (l *Listener) AcceptXMPP() (*Session, error) {
	c, err := l.Accept()
	return c.(*Conn), err
	return c.(*Session), err
}

func (l *Listener) Addr() net.Addr {

M lookup.go => lookup.go +4 -4
@@ 106,14 106,14 @@ func LookupBOSH(ctx context.Context, client *http.Client, addr *jid.JID) (urls [
	return lookupEndpoint(ctx, client, addr, "bosh")
}

func validateConnTypeOrPanic(conntype string) {
func validateSessionTypeOrPanic(conntype string) {
	if conntype != "ws" && conntype != "bosh" {
		panic("xmpp.lookupEndpoint: Invalid conntype specified")
	}
}

func lookupEndpoint(ctx context.Context, client *http.Client, addr *jid.JID, conntype string) (urls []string, err error) {
	validateConnTypeOrPanic(conntype)
	validateSessionTypeOrPanic(conntype)

	var (
		u  []string


@@ 166,7 166,7 @@ func lookupEndpoint(ctx context.Context, client *http.Client, addr *jid.JID, con
// TODO(ssw): Rely on the OS DNS cache, or cache lookups ourselves?

func lookupDNS(ctx context.Context, name, conntype string) (urls []string, err error) {
	validateConnTypeOrPanic(conntype)
	validateSessionTypeOrPanic(conntype)
	select {
	case <-ctx.Done():
		return urls, ctx.Err()


@@ 203,7 203,7 @@ func lookupDNS(ctx context.Context, name, conntype string) (urls []string, err e
// TODO(ssw): Memoize the following functions?

func lookupHostMeta(ctx context.Context, client *http.Client, name, conntype string) (urls []string, err error) {
	validateConnTypeOrPanic(conntype)
	validateSessionTypeOrPanic(conntype)
	select {
	case <-ctx.Done():
		return urls, ctx.Err()

M message_test.go => message_test.go +4 -8
@@ 52,8 52,7 @@ func TestUnmarshalMessage(t *testing.T) {
	}
	err := xml.Unmarshal(mb, m)
	if err != nil {
		t.Log(err)
		t.Fail()
		t.Error(err)
	}

	if m.Type != ChatMessage {


@@ 61,16 60,13 @@ func TestUnmarshalMessage(t *testing.T) {
		t.Fail()
	}
	if m.To.String() != "romeo@example.net" {
		t.Logf("Expected %s but got %s", "romeo@example.net", m.To.String())
		t.Fail()
		t.Errorf("Expected %s but got %s", "romeo@example.net", m.To.String())
	}
	if m.To.String() != "romeo@example.net" {
		t.Logf("Expected %s but got %s", "romeo@example.net", m.To.String())
		t.Fail()
		t.Errorf("Expected %s but got %s", "romeo@example.net", m.To.String())
	}
	if m.ID != "ktx72v49" {
		t.Logf("Expected %s but got %s", "ktx72v49", m.To.String())
		t.Fail()
		t.Errorf("Expected %s but got %s", "ktx72v49", m.To.String())
	}
}


M sasl.go => sasl.go +1 -1
@@ 67,7 67,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, conn *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			if (conn.state & Received) == Received {
				panic("SASL server not yet implemented")
			}

R conn.go => session.go +52 -20
@@ 16,9 16,41 @@ import (
	"mellium.im/xmpp/jid"
)

// A Conn represents an XMPP connection that can perform SRV lookups for a given
// SessionState is a bitmask that represents the current state of an XMPP
// session. For a description of each bit, see the various SessionState typed
// constants.
type SessionState uint8

const (
	// Secure indicates that the underlying connection has been secured. For
	// instance, after STARTTLS has been performed or if a pre-secured connection
	// is being used such as websockets over HTTPS.
	Secure SessionState = 1 << iota

	// Authn indicates that the session has been authenticated (probably with
	// SASL).
	Authn

	// Ready indicates that the session is fully negotiated and that XMPP stanzas
	// may be sent and received.
	Ready

	// Received indicates that the session was initiated by a foreign entity.
	Received

	// OutputStreamClosed indicates that the output stream has been closed with a
	// stream end tag.  When set all write operations will return an error even if
	// the underlying TCP connection is still open.
	OutputStreamClosed

	// InputStreamClosed indicates that the input stream has been closed with a
	// stream end tag. When set all read operations will return an error.
	InputStreamClosed
)

// A Session represents an XMPP connection that can perform SRV lookups for a given
// server and connect to the correct ports.
type Conn struct {
type Session struct {
	config *Config
	rwc    io.ReadWriteCloser



@@ 55,7 87,7 @@ type Conn struct {
// Feature checks if a feature with the given namespace was advertised
// by the server for the current stream. If it was data will be the canonical
// representation of the feature as returned by the feature's Parse function.
func (c *Conn) Feature(namespace string) (data interface{}, ok bool) {
func (c *Session) Feature(namespace string) (data interface{}, ok bool) {
	c.flock.Lock()
	defer c.flock.Unlock()



@@ 64,12 96,12 @@ func (c *Conn) Feature(namespace string) (data interface{}, ok bool) {
	return
}

// NewConn attempts to use an existing connection (or any io.ReadWriteCloser) to
// NewSession attempts to use an existing connection (or any io.ReadWriteCloser) to
// negotiate an XMPP session based on the given config. If the provided context
// is canceled before stream negotiation is complete an error is returned. After
// stream negotiation if the context is canceled it has no effect.
func NewConn(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Conn, error) {
	c := &Conn{
func NewSession(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Session, error) {
	c := &Session{
		config: config,
	}



@@ 80,28 112,28 @@ func NewConn(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Conn
	return c, c.negotiateStreams(ctx, rwc)
}

// Raw returns the Conn's backing net.Conn or other ReadWriteCloser.
func (c *Conn) Raw() io.ReadWriteCloser {
// Raw returns the Session's backing net.Conn or other ReadWriteCloser.
func (c *Session) Raw() io.ReadWriteCloser {
	return c.rwc
}

// Decoder returns the XML decoder that was used to negotiate the latest stream.
func (c *Conn) Decoder() *xml.Decoder {
func (c *Session) Decoder() *xml.Decoder {
	return c.in.d
}

// Encoder returns the XML encoder that was used to negotiate the latest stream.
func (c *Conn) Encoder() *xml.Encoder {
func (c *Session) Encoder() *xml.Encoder {
	return c.out.e
}

// Config returns the connections config.
func (c *Conn) Config() *Config {
func (c *Session) Config() *Config {
	return c.config
}

// Read reads data from the connection.
func (c *Conn) Read(b []byte) (n int, err error) {
func (c *Session) Read(b []byte) (n int, err error) {
	c.in.Lock()
	defer c.in.Unlock()



@@ 113,7 145,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}

// Write writes data to the connection.
func (c *Conn) Write(b []byte) (n int, err error) {
func (c *Session) Write(b []byte) (n int, err error) {
	c.out.Lock()
	defer c.out.Unlock()



@@ 126,19 158,19 @@ func (c *Conn) Write(b []byte) (n int, err error) {

// Close closes the underlying connection.
// Any blocked Read or Write operations will be unblocked and return errors.
func (c *Conn) Close() error {
func (c *Session) Close() error {
	return c.rwc.Close()
}

// State returns the current state of the session. For more information, see the
// SessionState type.
func (c *Conn) State() SessionState {
func (c *Session) State() SessionState {
	return c.state
}

// LocalAddr returns the Origin address for initiated connections, or the
// Location for received connections.
func (c *Conn) LocalAddr() net.Addr {
func (c *Session) LocalAddr() net.Addr {
	if (c.state & Received) == Received {
		return c.config.Location
	}


@@ 150,7 182,7 @@ func (c *Conn) LocalAddr() net.Addr {

// RemoteAddr returns the Location address for initiated connections, or the
// Origin address for received connections.
func (c *Conn) RemoteAddr() net.Addr {
func (c *Session) RemoteAddr() net.Addr {
	if (c.state & Received) == Received {
		return c.config.Origin
	}


@@ 170,7 202,7 @@ var errSetDeadline = errors.New("xmpp: cannot set deadline: not using a net.Conn
// successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
func (c *Conn) SetDeadline(t time.Time) error {
func (c *Session) SetDeadline(t time.Time) error {
	if c.conn != nil {
		return c.conn.SetDeadline(t)
	}


@@ 179,7 211,7 @@ func (c *Conn) SetDeadline(t time.Time) error {

// SetReadDeadline sets the deadline for future Read calls. A zero value for t
// means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
func (c *Session) SetReadDeadline(t time.Time) error {
	if c.conn != nil {
		return c.conn.SetReadDeadline(t)
	}


@@ 189,7 221,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
// SetWriteDeadline sets the deadline for future Write calls. Even if write
// times out, it may return n > 0, indicating that some of the data was
// successfully written. A zero value for t means Write will not time out.
func (c *Conn) SetWriteDeadline(t time.Time) error {
func (c *Session) SetWriteDeadline(t time.Time) error {
	if c.conn != nil {
		return c.conn.SetWriteDeadline(t)
	}

M starttls.go => starttls.go +1 -1
@@ 57,7 57,7 @@ func StartTLS(required bool) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, conn *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
			netconn, ok := conn.Raw().(net.Conn)
			if !ok {
				return mask, nil, ErrTLSUpgradeFailed

M starttls_test.go => starttls_test.go +4 -4
@@ 144,10 144,10 @@ func (dummyConn) SetWriteDeadline(t time.Time) error {

// We can't create a tls.Client or tls.Server for a generic RWC, so ensure that
// we fail (with a specific error) if this is the case.
func TestNegotiationFailsForNonNetConn(t *testing.T) {
func TestNegotiationFailsForNonNetSession(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	_, _, err := stls.Negotiate(context.Background(), &Conn{rwc: nopRWC{&b, &b}}, nil)
	_, _, err := stls.Negotiate(context.Background(), &Session{rwc: nopRWC{&b, &b}}, nil)
	if err != ErrTLSUpgradeFailed {
		t.Errorf("Expected error `%v` but got `%v`", ErrTLSUpgradeFailed, err)
	}


@@ 156,7 156,7 @@ func TestNegotiationFailsForNonNetConn(t *testing.T) {
func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	c := &Conn{state: Received, rwc: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	c := &Session{state: Received, rwc: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	_, rwc, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:


@@ 194,7 194,7 @@ func TestNegotiateClient(t *testing.T) {
		stls := StartTLS(true)
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Conn{rwc: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c := &Session{rwc: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c.in.d = xml.NewDecoder(c.rwc)
		mask, rwc, err := stls.Negotiate(context.Background(), c, nil)
		switch {

M stream.go => stream.go +5 -37
@@ 23,38 23,6 @@ const (

const streamIDLength = 16

// SessionState is a bitmask that represents the current state of an XMPP
// session. For a description of each bit, see the various SessionState typed
// constants.
type SessionState uint8

const (
	// Secure indicates that the underlying connection has been secured. For
	// instance, after STARTTLS has been performed or if a pre-secured connection
	// is being used such as websockets over HTTPS.
	Secure SessionState = 1 << iota

	// Authn indicates that the session has been authenticated (probably with
	// SASL).
	Authn

	// Ready indicates that the session is fully negotiated and that XMPP stanzas
	// may be sent and received.
	Ready

	// Received indicates that the session was initiated by a foreign entity.
	Received

	// OutputStreamClosed indicates that the output stream has been closed with a
	// stream end tag.  When set all write operations will return an error even if
	// the underlying TCP connection is still open.
	OutputStreamClosed

	// InputStreamClosed indicates that the input stream has been closed with a
	// stream end tag. When set all read operations will return an error.
	InputStreamClosed
)

type stream struct {
	to      *jid.JID
	from    *jid.JID


@@ 139,7 107,7 @@ func sendNewStream(w io.Writer, cfg *Config, id string) error {
		return err
	}

	if conn, ok := w.(*Conn); ok {
	if conn, ok := w.(*Session); ok {
		conn.out.stream = stream
	}
	return nil


@@ 148,9 116,9 @@ func sendNewStream(w io.Writer, cfg *Config, id string) error {
func expectNewStream(ctx context.Context, r io.Reader) error {
	var foundHeader bool

	// If the reader is a Conn, use its decoder, otherwise make a new one.
	// If the reader is a Session, use its decoder, otherwise make a new one.
	var d *xml.Decoder
	if conn, ok := r.(*Conn); ok {
	if conn, ok := r.(*Session); ok {
		d = conn.in.d
	} else {
		d = xml.NewDecoder(r)


@@ 188,7 156,7 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
				return streamerror.UnsupportedVersion
			}

			if conn, ok := r.(*Conn); ok {
			if conn, ok := r.(*Session); ok {
				if (conn.state&Received) != Received && stream.id == "" {
					// if we are the initiating entity and there is no stream ID…
					return streamerror.BadFormat


@@ 211,7 179,7 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
	}
}

func (c *Conn) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (err error) {
func (c *Session) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	for done := false; !done || rwc != nil; {