~samwhited/xmpp

16f780a014326e6ada272c968e280873dafb6c4b — Sam Whited 4 years ago eee8e73
Session should not close the underlying Conn
8 files changed, 62 insertions(+), 69 deletions(-)

M bind.go
M features.go
M sasl.go
M session.go
M starttls.go
M starttls_test.go
M stream.go
M stream_test.go
M bind.go => bind.go +1 -1
@@ 45,7 45,7 @@ func BindResource() StreamFeature {
			}{}
			return true, nil, d.DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			if (session.state & Received) == Received {
				panic("xmpp: bind not yet implemented")
			}

M features.go => features.go +12 -12
@@ 54,17 54,17 @@ type StreamFeature struct {
	// this feature creates a security layer (such as TLS) and performs
	// authentication, mask would be set to Authn|Secure, but if it does not
	// authenticate the connection it would just return Secure. If negotiate
	// returns a new io.ReadWriteCloser (probably wrapping the old conn.Conn()) the
	// returns a new io.ReadWriter (probably wrapping the old session.Conn()) the
	// stream will be restarted automatically after Negotiate returns using the
	// new RWC. If this is an initiated connection and the features List call
	// returned a value, that value is passed to the data parameter when Negotiate
	// is called. For instance, in the case of compression this data parameter
	// might be the list of supported algorithms as a slice of strings (or in
	// whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error)
	// new ReadWriter. If this is an initiated connection and the features List
	// call returned a value, that value is passed to the data parameter when
	// Negotiate is called. For instance, in the case of compression this data
	// parameter might be the list of supported algorithms as a slice of strings
	// (or in whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error)
}

func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rw io.ReadWriter, err error) {
	server := (s.state & Received) == Received

	// If we're the server, write the initial stream features.


@@ 129,7 129,7 @@ func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read
			data, sent = list.cache[start.Name.Space]
			if !sent || negotiated {
				// TODO: What should we return here?
				return done, rwc, streamerror.PolicyViolation
				return done, rw, streamerror.PolicyViolation
			}
		} else {
			// If we're the client, iterate through the cached features and select one


@@ 159,7 159,7 @@ func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read
			}
		}

		mask, rwc, err = data.feature.Negotiate(ctx, s, data.data)
		mask, rw, err = data.feature.Negotiate(ctx, s, data.data)
		if err == nil {
			s.state |= mask
		}


@@ 167,12 167,12 @@ func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rwc io.Read

		// If we negotiated a required feature or a stream restart is required
		// we're done with this feature set.
		if rwc != nil || data.req {
		if rw != nil || data.req {
			break
		}
	}

	return !list.req || (s.state&Ready == Ready), rwc, err
	return !list.req || (s.state&Ready == Ready), rw, err
}

type sfData struct {

M sasl.go => sasl.go +1 -1
@@ 66,7 66,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			if (session.state & Received) == Received {
				panic("SASL server not yet implemented")
			}

M session.go => session.go +18 -18
@@ 47,15 47,15 @@ const (
	InputStreamClosed
)

// A Session represents an XMPP connection that can perform SRV lookups for a given
// server and connect to the correct ports.
// A Session represents an XMPP connection that can perform SRV lookups for a
// given server and connect to the correct ports.
type Session struct {
	config *Config
	rwc    io.ReadWriteCloser

	// If the initial rwc is a conn, save a reference to that as well so that we
	// can set deadlines on it later even if the rwc is upgraded.
	// If the initial ReadWriter is a conn, save a reference to that as well so
	// that we can use it directly without type casting constantly.
	conn net.Conn
	rw   io.ReadWriter

	state SessionState



@@ 99,21 99,22 @@ func (s *Session) Feature(namespace string) (data interface{}, ok bool) {
// negotiate an XMPP session based on the given config. If the provided context
// is canceled before stream negotiation is complete an error is returned. After
// stream negotiation if the context is canceled it has no effect.
func NewSession(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Session, error) {
func NewSession(ctx context.Context, config *Config, rw io.ReadWriter) (*Session, error) {
	s := &Session{
		config: config,
		rw:     rw,
	}

	if conn, ok := rwc.(net.Conn); ok {
	if conn, ok := rw.(net.Conn); ok {
		s.conn = conn
	}

	return s, s.negotiateStreams(ctx, rwc)
	return s, s.negotiateStreams(ctx, rw)
}

// Conn returns the Session's backing net.Conn or other ReadWriteCloser.
func (s *Session) Conn() io.ReadWriteCloser {
	return s.rwc
// Conn returns the Session's backing net.Conn or other ReadWriter.
func (s *Session) Conn() io.ReadWriter {
	return s.rw
}

// Decoder returns the XML decoder that was used to negotiate the latest stream.


@@ 139,7 140,8 @@ func (s *Session) read(b []byte) (n int, err error) {
		return 0, errors.New("XML input stream is closed")
	}

	return s.rwc.Read(b)
	n, err = s.rw.Read(b)
	return
}

func (s *Session) write(b []byte) (n int, err error) {


@@ 150,18 152,16 @@ func (s *Session) write(b []byte) (n int, err error) {
		return 0, errors.New("XML output stream is closed")
	}

	return s.rwc.Write(b)
	n, err = s.rw.Write(b)
	return
}

// Close ends the output stream and blocks until the remote client closes the
// input stream.
func (s *Session) Close() (err error) {
	_, err = s.write([]byte(`</stream:stream>`))
	if err != nil {
		return
	}
	// TODO: Block until input stream is closed?
	return s.rwc.Close()
	_, err = s.write([]byte(`</stream:stream>`))
	return
}

// State returns the current state of the session. For more information, see the

M starttls.go => starttls.go +3 -3
@@ 55,7 55,7 @@ func StartTLS(required bool) StreamFeature {
			err := d.DecodeElement(&parsed, start)
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			conn := session.conn
			if conn == nil {
				return mask, nil, ErrTLSUpgradeFailed


@@ 73,7 73,7 @@ func StartTLS(required bool) StreamFeature {

			if (session.state & Received) == Received {
				fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
				rwc = tls.Server(conn, tlsconf)
				rw = tls.Server(conn, tlsconf)
			} else {
				// Select starttls for negotiation.
				fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)


@@ 93,7 93,7 @@ func StartTLS(required bool) StreamFeature {
						if err = session.in.d.Skip(); err != nil {
							return mask, nil, streamerror.InvalidXML
						}
						rwc = tls.Client(conn, tlsconf)
						rw = tls.Client(conn, tlsconf)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = session.in.d.Skip(); err != nil {

M starttls_test.go => starttls_test.go +15 -15
@@ 144,12 144,12 @@ func (dummyConn) SetWriteDeadline(t time.Time) error {
	return nil
}

// We can't create a tls.Client or tls.Server for a generic RWC, so ensure that
// we fail (with a specific error) if this is the case.
// We can't create a tls.Client or tls.Server for a generic ReadWriter, so
// ensure that we fail (with a specific error) if this is the case.
func TestNegotiationFailsForNonNetSession(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	_, _, err := stls.Negotiate(context.Background(), &Session{rwc: nopRWC{&b, &b}}, nil)
	_, _, err := stls.Negotiate(context.Background(), &Session{rw: nopRWC{&b, &b}}, nil)
	if err != ErrTLSUpgradeFailed {
		t.Errorf("Expected error `%v` but got `%v`", ErrTLSUpgradeFailed, err)
	}


@@ 159,13 159,13 @@ func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true)
	var b bytes.Buffer
	c := &Session{state: Received, conn: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	c.rwc = c.conn
	_, rwc, err := stls.Negotiate(context.Background(), c, nil)
	c.rw = c.conn
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:
		t.Fatal(err)
	case rwc == nil:
		t.Fatal("Expected a new RWC when negotiating STARTTLS as a server")
	case rw == nil:
		t.Fatal("Expected a new ReadWriter when negotiating STARTTLS as a server")
	}

	// The server should send a proceed element.


@@ 182,7 182,7 @@ func TestNegotiateClient(t *testing.T) {
	for _, test := range []struct {
		responses []string
		err       bool
		rwc       bool
		rw        bool
		state     SessionState
	}{
		{[]string{`<proceed xmlns="badns"/>`}, true, false, Secure},


@@ 198,9 198,9 @@ func TestNegotiateClient(t *testing.T) {
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Session{conn: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c.rwc = c.conn
		c.in.d = xml.NewDecoder(c.rwc)
		mask, rwc, err := stls.Negotiate(context.Background(), c, nil)
		c.rw = c.conn
		c.in.d = xml.NewDecoder(c.rw)
		mask, rw, err := stls.Negotiate(context.Background(), c, nil)
		switch {
		case test.err && err == nil:
			t.Error("Expected an error from starttls client negotiation")


@@ 214,10 214,10 @@ func TestNegotiateClient(t *testing.T) {
			t.Errorf("Expected client to send starttls element but got `%s`", b.String())
		case test.state != mask:
			t.Errorf("Expected session state mask %v but got %v", test.state, mask)
		case test.rwc && rwc == nil:
			t.Error("Expected a new RWC when negotiating STARTTLS as a client")
		case !test.rwc && rwc != nil:
			t.Error("Did not expect a new RWC when negotiating STARTTLS as a client")
		case test.rw && rw == nil:
			t.Error("Expected a new ReadWriter when negotiating STARTTLS as a client")
		case !test.rw && rw != nil:
			t.Error("Did not expect a new ReadWriter when negotiating STARTTLS as a client")
		}
	}
}

M stream.go => stream.go +8 -8
@@ 169,17 169,17 @@ func expectNewStream(ctx context.Context, s *Session) error {
	}
}

func (s *Session) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (err error) {
func (s *Session) negotiateStreams(ctx context.Context, rw io.ReadWriter) (err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	for done := false; !done || rwc != nil; {
		if rwc != nil {
	for done := false; !done || rw != nil; {
		if rw != nil {
			s.features = make(map[string]interface{})
			s.negotiated = make(map[string]struct{})
			s.rwc = rwc
			s.in.d = xml.NewDecoder(s.rwc)
			s.out.e = xml.NewEncoder(s.rwc)
			rwc = nil
			s.rw = rw
			s.in.d = xml.NewDecoder(s.rw)
			s.out.e = xml.NewEncoder(s.rw)
			rw = nil

			if (s.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in


@@ 202,7 202,7 @@ func (s *Session) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) 
			}
		}

		if done, rwc, err = s.negotiateFeatures(ctx); err != nil {
		if done, rw, err = s.negotiateFeatures(ctx); err != nil {
			return err
		}
	}

M stream_test.go => stream_test.go +4 -11
@@ 8,7 8,6 @@ import (
	"bytes"
	"fmt"
	"io"
	"io/ioutil"
	"strings"
	"testing"



@@ 36,7 35,7 @@ func TestSendNewS2S(t *testing.T) {
				ids = "abc"
			}
			config.S2S = tc.s2s
			err := sendNewStream(&Session{rwc: rwNopCloser{&b}}, config, ids)
			err := sendNewStream(&Session{rw: &b}, config, ids)

			str := b.String()
			if !strings.HasPrefix(str, xmlHeader) {


@@ 80,19 79,13 @@ func (nopReader) Read(p []byte) (n int, err error) {

func TestSendNewS2SReturnsWriteErr(t *testing.T) {
	config := NewClientConfig(jid.MustParse("test@example.net"))
	if err := sendNewStream(&Session{rwc: struct {
		io.ReadCloser
	if err := sendNewStream(&Session{rw: struct {
		io.Reader
		io.Writer
	}{
		ioutil.NopCloser(nopReader{}),
		nopReader{},
		errWriter{},
	}}, config, "abc"); err != io.ErrUnexpectedEOF {
		t.Errorf("Expected errWriterErr (%s) but got `%s`", io.ErrUnexpectedEOF, err)
	}
}

type rwNopCloser struct {
	io.ReadWriter
}

func (rwNopCloser) Close() error { return nil }