~samwhited/xmpp

1ca7f04257cfcc11a193a9c97c95c62072dd7be9 — Sam Whited 3 years ago 3ee9e06
all: always wrap ReadWriter's in a Conn
6 files changed, 116 insertions(+), 85 deletions(-)

M dial.go
M sasl.go
M sasl2/sasl.go
M session.go
M starttls.go
M starttls_test.go
M dial.go => dial.go +96 -7
@@ 6,6 6,8 @@ package xmpp

import (
	"context"
	"crypto/tls"
	"io"
	"net"
	"strconv"
	"time"


@@ 14,8 16,87 @@ import (
	"mellium.im/xmpp/jid"
)

// newConn wraps an io.ReadWriter in a Conn.
func newConn(rw io.ReadWriter) *Conn {
	nc := &Conn{}

	switch typrw := rw.(type) {
	case *Conn:
		return typrw
	case *tls.Conn:
		nc.tlsConn = typrw
		nc.c = typrw
	case net.Conn:
		nc.c = typrw
	}
	nc.rw = rw

	return nc
}

// Conn is a net.Conn created for the purpose of establishing an XMPP session.
type Conn net.Conn
type Conn struct {
	tlsConn *tls.Conn
	c       net.Conn
	rw      io.ReadWriter
}

// ConnectionState returns basic TLS details about the connection if TLS has
// been negotiated. If TLS has not been negotiated, ok is false.
func (c *Conn) ConnectionState() (connState tls.ConnectionState, ok bool) {
	if c.tlsConn != nil {
		return c.tlsConn.ConnectionState(), true
	}
	return connState, false
}

// Close closes the connection.
func (c *Conn) Close() error {
	return c.c.Close()
}

// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
	return c.c.LocalAddr()
}

// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *Conn) Read(b []byte) (n int, err error) {
	return c.rw.Read(b)
}

// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
	return c.c.RemoteAddr()
}

// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes
// will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
	return c.c.SetDeadline(t)
}

// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
	return c.c.SetReadDeadline(t)
}

// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes
// will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
	return c.c.SetWriteDeadline(t)
}

// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
	return c.rw.Write(b)
}

// DialClient discovers and connects to the address on the named network with a
// client-to-server (c2s) connection.


@@ 31,7 112,7 @@ type Conn net.Conn
// Network may be any of the network types supported by net.Dial, but you almost
// certainly want to use one of the tcp connection types ("tcp", "tcp4", or
// "tcp6").
func DialClient(ctx context.Context, network string, addr *jid.JID) (Conn, error) {
func DialClient(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
	var d Dialer
	return d.Dial(ctx, network, addr)
}


@@ 40,7 121,7 @@ func DialClient(ctx context.Context, network string, addr *jid.JID) (Conn, error
// server-to-server connection (s2s).
//
// For more info see the DialClient function.
func DialServer(ctx context.Context, network string, addr *jid.JID) (Conn, error) {
func DialServer(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
	d := Dialer{
		S2S: true,
	}


@@ 67,20 148,26 @@ type Dialer struct {
// Dial discovers and connects to the address on the named network.
//
// For a description of the arguments see the DialClient function.
func (d *Dialer) Dial(ctx context.Context, network string, addr *jid.JID) (Conn, error) {
func (d *Dialer) Dial(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
	return d.dial(ctx, network, addr)
}

func (d *Dialer) dial(ctx context.Context, network string, addr *jid.JID) (Conn, error) {
func (d *Dialer) dial(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
	if d.NoLookup {
		p, err := internal.LookupPort(network, connType(d.S2S))
		if err != nil {
			return nil, err
		}
		return d.Dialer.DialContext(ctx, network, net.JoinHostPort(
		c, err := d.Dialer.DialContext(ctx, network, net.JoinHostPort(
			addr.Domainpart(),
			strconv.FormatUint(uint64(p), 10),
		))
		if err != nil {
			return nil, err
		}
		return &Conn{
			c: c,
		}, nil
	}

	addrs, err := internal.LookupService(connType(d.S2S), network, addr)


@@ 101,7 188,9 @@ func (d *Dialer) dial(ctx context.Context, network string, addr *jid.JID) (Conn,
			continue
		}

		return conn, nil
		return &Conn{
			c: conn,
		}, nil
	}
	return nil, err
}

M sasl.go => sasl.go +2 -3
@@ 6,7 6,6 @@ package xmpp

import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"


@@ 96,8 95,8 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				sasl.Credentials(session.LocalAddr().Localpart(), c.Password),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if tlsconn, ok := conn.(*tls.Conn); ok {
				opts = append(opts, sasl.ConnState(tlsconn.ConnectionState()))
			if connState, ok := conn.ConnectionState(); ok {
				opts = append(opts, sasl.ConnState(connState))
			}
			client := sasl.NewClient(selected, opts...)


M sasl2/sasl.go => sasl2/sasl.go +2 -3
@@ 12,7 12,6 @@ package sasl2 // import "mellium.im/xmpp/sasl2"

import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"


@@ 115,8 114,8 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
				sasl.Credentials(session.LocalAddr().Localpart(), c.Password),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if tlsconn, ok := conn.(*tls.Conn); ok {
				opts = append(opts, sasl.ConnState(tlsconn.ConnectionState()))
			if connState, ok := conn.ConnectionState(); ok {
				opts = append(opts, sasl.ConnState(connState))
			}
			client := sasl.NewClient(selected, opts...)


M session.go => session.go +10 -20
@@ 8,7 8,6 @@ import (
	"context"
	"encoding/xml"
	"io"
	"net"
	"sync"

	"mellium.im/xmlstream"


@@ 56,10 55,7 @@ const (
type Session struct {
	config *Config

	// If the initial ReadWriter is a conn, save a reference to that as well so
	// that we can use it directly without type casting constantly.
	conn net.Conn
	rw   io.ReadWriter
	conn *Conn

	state SessionState
	slock sync.RWMutex


@@ 113,15 109,12 @@ func NegotiateSession(ctx context.Context, config *Config, rw io.ReadWriter, neg
	}
	s := &Session{
		config:     config,
		rw:         rw,
		conn:       newConn(rw),
		features:   make(map[string]interface{}),
		negotiated: make(map[string]struct{}),
	}
	if conn, ok := rw.(net.Conn); ok {
		s.conn = conn
	}
	s.in.d = xml.NewDecoder(s.rw)
	s.out.e = xml.NewEncoder(s.rw)
	s.in.d = xml.NewDecoder(s.conn)
	s.out.e = xml.NewEncoder(s.conn)
	s.in.ctx, s.in.cancel = context.WithCancel(context.Background())

	// Call negotiate until the ready bit is set.


@@ 137,12 130,9 @@ func NegotiateSession(ctx context.Context, config *Config, rw io.ReadWriter, neg
		if rw != nil {
			s.features = make(map[string]interface{})
			s.negotiated = make(map[string]struct{})
			s.rw = rw
			s.in.d = xml.NewDecoder(s.rw)
			s.out.e = xml.NewEncoder(s.rw)
			if conn, ok := rw.(net.Conn); ok {
				s.conn = conn
			}
			s.conn = newConn(rw)
			s.in.d = xml.NewDecoder(s.conn)
			s.out.e = xml.NewEncoder(s.conn)
		}
		s.state |= mask
	}


@@ 198,13 188,13 @@ func (s *Session) Feature(namespace string) (data interface{}, ok bool) {
	return
}

// Conn returns the Session's backing net.Conn or other io.ReadWriter.
// Conn returns the Session's backing connection.
//
// This should almost never be read from or written to, but is useful during
// stream negotiation for wrapping the existing connection in a new layer (eg.
// compression or TLS).
func (s *Session) Conn() io.ReadWriter {
	return s.rw
func (s *Session) Conn() *Conn {
	return s.conn
}

// Token satisfies the xml.TokenReader interface for Session.

M starttls.go => starttls.go +1 -6
@@ 11,7 11,6 @@ import (
	"errors"
	"fmt"
	"io"
	"net"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"


@@ 58,11 57,7 @@ func StartTLS(required bool, cfg *tls.Config) StreamFeature {
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			conn, ok := session.Conn().(net.Conn)
			if !ok || conn == nil {
				return mask, nil, ErrTLSUpgradeFailed
			}

			conn := session.Conn()
			state := session.State()
			d := xml.NewTokenDecoder(session)


M starttls_test.go => starttls_test.go +5 -46
@@ 1,6 1,6 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp



@@ 10,10 10,8 @@ import (
	"crypto/tls"
	"encoding/xml"
	"io"
	"net"
	"strings"
	"testing"
	"time"

	"mellium.im/xmpp/internal/ns"
)


@@ 118,48 116,10 @@ func (nopRWC) Close() error {
	return nil
}

var _ net.Conn = dummyConn{}

type dummyConn struct {
	io.ReadWriteCloser
}

func (dummyConn) LocalAddr() net.Addr {
	return nil
}

func (dummyConn) RemoteAddr() net.Addr {
	return nil
}

func (dummyConn) SetDeadline(t time.Time) error {
	return nil
}

func (dummyConn) SetReadDeadline(t time.Time) error {
	return nil
}

func (dummyConn) SetWriteDeadline(t time.Time) error {
	return nil
}

// We can't create a tls.Client or tls.Server for a generic ReadWriter, so
// ensure that we fail (with a specific error) if this is the case.
func TestNegotiationFailsForNonNetSession(t *testing.T) {
	stls := StartTLS(true, nil)
	var b bytes.Buffer
	_, _, err := stls.Negotiate(context.Background(), &Session{rw: nopRWC{&b, &b}}, nil)
	if err != ErrTLSUpgradeFailed {
		t.Errorf("Expected error `%v` but got `%v`", ErrTLSUpgradeFailed, err)
	}
}

func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true, &tls.Config{})
	var b bytes.Buffer
	c := &Session{state: Received, conn: dummyConn{nopRWC{&b, &b}}}
	c.rw = c.conn
	c := &Session{state: Received, conn: newConn(nopRWC{&b, &b})}
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:


@@ 197,9 157,8 @@ func TestNegotiateClient(t *testing.T) {
		stls := StartTLS(true, &tls.Config{})
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Session{conn: dummyConn{nopRWC{r, &b}}}
		c.rw = c.conn
		c.in.d = xml.NewDecoder(c.rw)
		c := &Session{conn: newConn(nopRWC{r, &b})}
		c.in.d = xml.NewDecoder(c.conn)
		mask, rw, err := stls.Negotiate(context.Background(), c, nil)
		switch {
		case test.err && err == nil: