~samwhited/xmpp

2f75337c31e8f9fabb66145f1402ab427dde4e09 — Sam Whited 2 years ago 45f5e50
xmpp: fix proxying of Close on wrapped io.Closer's
2 files changed, 64 insertions(+), 2 deletions(-)

M conn.go
A conn_test.go
M conn.go => conn.go +12 -2
@@ 16,12 16,19 @@ type Conn struct {
	tlsConn *tls.Conn
	c       net.Conn
	rw      io.ReadWriter
	close   func() error
	closer  func() error
}

// newConn wraps an io.ReadWriter in a Conn.
// If rw is already a net.Conn or io.Closer the methods are proxied
// appropriately. If rw is a *tls.Conn then ConnectionState returns the
// appropriate value.
// If rw is already a *Conn, it is returned immediately.
func newConn(rw io.ReadWriter) *Conn {
	nc := &Conn{rw: rw}
	if closer, ok := rw.(io.Closer); ok {
		nc.closer = closer.Close
	}

	switch typrw := rw.(type) {
	case *Conn:


@@ 47,7 54,10 @@ func (c *Conn) ConnectionState() (connState tls.ConnectionState, ok bool) {

// Close closes the connection.
func (c *Conn) Close() error {
	return c.c.Close()
	if c.closer == nil {
		return nil
	}
	return c.closer()
}

// LocalAddr returns the local network address.

A conn_test.go => conn_test.go +52 -0
@@ 0,0 1,52 @@
// Copyright 2018 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp

import (
	"crypto/tls"
	"errors"
	"io"
	"strconv"
	"testing"
)

var closeErr = errors.New("test close error")

type errCloser struct {
	io.ReadWriter
}

func (errCloser) Close() error {
	return closeErr
}

var connTestCases = [...]struct {
	rw  io.ReadWriter
	err error
}{
	0: {rw: struct{ io.ReadWriter }{}},
	1: {rw: &tls.Conn{}},
	2: {rw: errCloser{}, err: closeErr},
}

func TestConn(t *testing.T) {
	for i, tc := range connTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			conn := newConn(tc.rw)

			_, isTLSConn := tc.rw.(*tls.Conn)
			if _, ok := conn.ConnectionState(); ok != isTLSConn {
				t.Errorf("TLS conn not wrapped properly: want=%t, got=%t", isTLSConn, ok)
			}

			// Don't run closer tests against dummy TLS connections that will panic.
			if !isTLSConn {
				if err := conn.Close(); err != tc.err {
					t.Errorf("Unexpected error closing conn: want=%q, got=%q", tc.err, err)
				}
			}
		})
	}
}