~samwhited/xmpp

8e7dd38c4e28fde6e683ffeedb8e40c102143f87 — Sam Whited 2 years ago a62964b
all: unexport and simplify Conn type
7 files changed, 52 insertions(+), 127 deletions(-)

M conn.go
D conn_test.go
M dial.go
M sasl.go
M sasl2/sasl.go
M session.go
M starttls_test.go
M conn.go => conn.go +25 -55
@@ 5,84 5,54 @@
package xmpp

import (
	"crypto/tls"
	"io"
	"net"
	"time"
)

// Conn is a net.Conn created for the purpose of establishing an XMPP session.
type Conn struct {
	tlsConn *tls.Conn
	c       net.Conn
	rw      io.ReadWriter
	closer  func() error
// conn is a net.Conn created for the purpose of establishing an XMPP session.
type conn struct {
	c  net.Conn
	rw io.ReadWriter
}

// newConn wraps an io.ReadWriter in a Conn.
// If rw is already a net.Conn or io.Closer the methods are proxied
// appropriately. If rw is a *tls.Conn then ConnectionState returns the
// appropriate value.
// If rw is already a *Conn, it is returned immediately.
func newConn(rw io.ReadWriter) *Conn {
	nc := &Conn{rw: rw}
	if closer, ok := rw.(io.Closer); ok {
		nc.closer = closer.Close
	}

	switch typrw := rw.(type) {
	case *Conn:
		return typrw
	case *tls.Conn:
		nc.tlsConn = typrw
		nc.c = typrw
	case net.Conn:
		nc.c = typrw
// If rw is already a net.Conn, it is returned without modification.
// If rw is not a conn but prev is, the various Conn methods that are not part
// of io.ReadWriter proxy through to prev.
func newConn(rw io.ReadWriter, prev net.Conn) net.Conn {
	if c, ok := rw.(net.Conn); ok {
		return c
	}

	nc := &conn{rw: rw, c: prev}
	return nc
}

// ConnectionState returns basic TLS details about the connection if TLS has
// been negotiated.
// If TLS has not been negotiated it returns a zero value tls.ConnectionState.
//
// To check if TLS has been negotiated, see the Secure method.
func (c *Conn) ConnectionState() tls.ConnectionState {
	if c.tlsConn == nil {
		return tls.ConnectionState{}
	}
	return c.tlsConn.ConnectionState()
}

// Secure returns whether the Conn is backed by an underlying tls.Conn.
// If Secure returns true, ConnectionState will proxy to the underlying tls.Conn
// instead of returning an empty connectiono state.
func (c *Conn) Secure() bool {
	return c.tlsConn != nil
}

// Close closes the connection.
func (c *Conn) Close() error {
	if c.closer == nil {
		return nil
func (c *conn) Close() error {
	if c.c != nil {
		return c.c.Close()
	}
	if closer, ok := c.rw.(io.Closer); ok {
		return closer.Close()
	}
	return c.closer()
	return nil
}

// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
func (c *conn) LocalAddr() net.Addr {
	return c.c.LocalAddr()
}

// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *Conn) Read(b []byte) (n int, err error) {
func (c *conn) Read(b []byte) (n int, err error) {
	return c.rw.Read(b)
}

// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
func (c *conn) RemoteAddr() net.Addr {
	return c.c.RemoteAddr()
}



@@ 90,13 60,13 @@ func (c *Conn) RemoteAddr() net.Addr {
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes
// will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
func (c *conn) SetDeadline(t time.Time) error {
	return c.c.SetDeadline(t)
}

// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
func (c *conn) SetReadDeadline(t time.Time) error {
	return c.c.SetReadDeadline(t)
}



@@ 104,11 74,11 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes
// will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
func (c *conn) SetWriteDeadline(t time.Time) error {
	return c.c.SetWriteDeadline(t)
}

// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
func (c *conn) Write(b []byte) (int, error) {
	return c.rw.Write(b)
}

D conn_test.go => conn_test.go +0 -52
@@ 1,52 0,0 @@
// Copyright 2018 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp

import (
	"crypto/tls"
	"errors"
	"io"
	"strconv"
	"testing"
)

var errClose = errors.New("test close error")

type errCloser struct {
	io.ReadWriter
}

func (errCloser) Close() error {
	return errClose
}

var connTestCases = [...]struct {
	rw  io.ReadWriter
	err error
}{
	0: {rw: struct{ io.ReadWriter }{}},
	1: {rw: &tls.Conn{}},
	2: {rw: errCloser{}, err: errClose},
}

func TestConn(t *testing.T) {
	for i, tc := range connTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			conn := newConn(tc.rw)

			_, isTLSConn := tc.rw.(*tls.Conn)
			if ok := conn.Secure(); ok != isTLSConn {
				t.Errorf("TLS conn not wrapped properly: want=%t, got=%t", isTLSConn, ok)
			}

			// Don't run closer tests against dummy TLS connections that will panic.
			if !isTLSConn {
				if err := conn.Close(); err != tc.err {
					t.Errorf("Unexpected error closing conn: want=%q, got=%q", tc.err, err)
				}
			}
		})
	}
}

M dial.go => dial.go +5 -5
@@ 19,7 19,7 @@ import (
// client-to-server (c2s) connection.
//
// For more information see the Dialer type.
func DialClient(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
func DialClient(ctx context.Context, network string, addr jid.JID) (net.Conn, error) {
	var d Dialer
	return d.Dial(ctx, network, addr)
}


@@ 28,7 28,7 @@ func DialClient(ctx context.Context, network string, addr jid.JID) (*Conn, error
// server-to-server connection (s2s).
//
// For more info see the Dialer type.
func DialServer(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
func DialServer(ctx context.Context, network string, addr jid.JID) (net.Conn, error) {
	d := Dialer{
		S2S: true,
	}


@@ 80,11 80,11 @@ type Dialer struct {
// "tcp6").
//
// For more information see the Dialer type.
func (d *Dialer) Dial(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
func (d *Dialer) Dial(ctx context.Context, network string, addr jid.JID) (net.Conn, error) {
	return d.dial(ctx, network, addr)
}

func (d *Dialer) dial(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
func (d *Dialer) dial(ctx context.Context, network string, addr jid.JID) (net.Conn, error) {
	domain := addr.Domainpart()
	service := connType(!d.NoTLS, d.S2S)
	var addrs []*net.SRV


@@ 159,7 159,7 @@ func (d *Dialer) dial(ctx context.Context, network string, addr jid.JID) (*Conn,
			continue
		}

		return newConn(c), nil
		return c, nil
	}
	return nil, err
}

M sasl.go => sasl.go +8 -6
@@ 6,6 6,7 @@ package xmpp

import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"


@@ 77,7 78,7 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
				panic("SASL server not yet implemented")
			}

			conn := session.Conn()
			c := session.Conn()

			var selected sasl.Mechanism
			// Select a mechanism, preferring the client order.


@@ 101,8 102,9 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
				}),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if conn.Secure() {
				opts = append(opts, sasl.TLSState(conn.ConnectionState()))

			if tlsConn, ok := c.(*tls.Conn); ok {
				opts = append(opts, sasl.TLSState(tlsConn.ConnectionState()))
			}
			client := sasl.NewClient(selected, opts...)



@@ 121,7 123,7 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
			}

			// Send <auth/> and the initial payload to start SASL auth.
			if _, err = fmt.Fprintf(conn,
			if _, err = fmt.Fprintf(c,
				`<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='%s'>%s</auth>`,
				selected.Name, resp,
			); err != nil {


@@ 177,12 179,12 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
					break
				}
				// TODO: What happens if there's more and success (broken server)?
				if _, err = fmt.Fprintf(conn,
				if _, err = fmt.Fprintf(c,
					`<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>%s</response>`, resp); err != nil {
					return mask, nil, err
				}
			}
			return Authn, conn, nil
			return Authn, c, nil
		},
	}
}

M sasl2/sasl.go => sasl2/sasl.go +5 -2
@@ 12,6 12,7 @@ package sasl2 // import "mellium.im/xmpp/sasl2"

import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"


@@ 119,9 120,11 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFe
				}),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if conn.Secure() {
				opts = append(opts, sasl.TLSState(conn.ConnectionState()))

			if tlsConn, ok := conn.(*tls.Conn); ok {
				opts = append(opts, sasl.TLSState(tlsConn.ConnectionState()))
			}

			client := sasl.NewClient(selected, opts...)

			// Calculate the initial response

M session.go => session.go +7 -5
@@ 6,9 6,11 @@ package xmpp

import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"io"
	"net"
	"sync"

	"mellium.im/xmlstream"


@@ 59,7 61,7 @@ const (
// A Session represents an XMPP session comprising an input and an output XML
// stream.
type Session struct {
	conn *Conn
	conn net.Conn

	state SessionState
	slock sync.RWMutex


@@ 110,7 112,7 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
		panic("xmpp: attempted to negotiate session with nil negotiator")
	}
	s := &Session{
		conn:       newConn(rw),
		conn:       newConn(rw, nil),
		origin:     origin,
		location:   location,
		features:   make(map[string]interface{}),


@@ 123,7 125,7 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
	// If rw was already a *tls.Conn or some other Conn that is secure, go ahead
	// and mark the connection as secure so that we don't try to negotiate
	// StartTLS.
	if s.conn.Secure() {
	if _, ok := s.conn.(*tls.Conn); ok {
		s.state |= Secure
	}



@@ 140,7 142,7 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
		if rw != nil {
			s.features = make(map[string]interface{})
			s.negotiated = make(map[string]struct{})
			s.conn = newConn(rw)
			s.conn = newConn(rw, s.conn)
			s.in.d = xml.NewDecoder(s.conn)
			s.out.e = xml.NewEncoder(s.conn)
		}


@@ 369,7 371,7 @@ func (s *Session) Feature(namespace string) (data interface{}, ok bool) {
// This should almost never be read from or written to, but is useful during
// stream negotiation for wrapping the existing connection in a new layer (eg.
// compression or TLS).
func (s *Session) Conn() *Conn {
func (s *Session) Conn() net.Conn {
	return s.conn
}


M starttls_test.go => starttls_test.go +2 -2
@@ 128,7 128,7 @@ func (nopRWC) Close() error {
func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true, &tls.Config{})
	var b bytes.Buffer
	c := &Session{state: Received, conn: newConn(nopRWC{&b, &b})}
	c := &Session{state: Received, conn: newConn(nopRWC{&b, &b}, nil)}
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:


@@ 167,7 167,7 @@ func TestNegotiateClient(t *testing.T) {
			stls := StartTLS(true, &tls.Config{})
			r := strings.NewReader(strings.Join(test.responses, "\n"))
			var b bytes.Buffer
			c := &Session{conn: newConn(nopRWC{r, &b})}
			c := &Session{conn: newConn(nopRWC{r, &b}, nil)}
			c.in.d = xml.NewDecoder(c.conn)
			mask, rw, err := stls.Negotiate(context.Background(), c, nil)
			switch {