~samwhited/xmpp

18e80dff43d31c3a153b5ed778a0bd6e53d08489 — Sam Whited 2 years ago 8390ed7
xmpp: make *Conn ConnectionState match *tls.Conn
4 files changed, 19 insertions(+), 9 deletions(-)

M conn.go
M conn_test.go
M sasl.go
M session.go
M conn.go => conn.go +15 -5
@@ 44,12 44,22 @@ func newConn(rw io.ReadWriter) *Conn {
}

// ConnectionState returns basic TLS details about the connection if TLS has
// been negotiated. If TLS has not been negotiated, ok is false.
func (c *Conn) ConnectionState() (connState tls.ConnectionState, ok bool) {
	if c.tlsConn != nil {
		return c.tlsConn.ConnectionState(), true
// been negotiated.
// If TLS has not been negotiated it returns a zero value tls.ConnectionState.
//
// To check if TLS has been negotiated, see the Secure method.
func (c *Conn) ConnectionState() tls.ConnectionState {
	if c.tlsConn == nil {
		return tls.ConnectionState{}
	}
	return connState, false
	return c.tlsConn.ConnectionState()
}

// Secure returns whether the Conn is backed by an underlying tls.Conn.
// If Secure returns true, ConnectionState will proxy to the underlying tls.Conn
// instead of returning an empty connectiono state.
func (c *Conn) Secure() bool {
	return c.tlsConn != nil
}

// Close closes the connection.

M conn_test.go => conn_test.go +1 -1
@@ 37,7 37,7 @@ func TestConn(t *testing.T) {
			conn := newConn(tc.rw)

			_, isTLSConn := tc.rw.(*tls.Conn)
			if _, ok := conn.ConnectionState(); ok != isTLSConn {
			if ok := conn.Secure(); ok != isTLSConn {
				t.Errorf("TLS conn not wrapped properly: want=%t, got=%t", isTLSConn, ok)
			}


M sasl.go => sasl.go +2 -2
@@ 101,8 101,8 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
				}),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if connState, ok := conn.ConnectionState(); ok {
				opts = append(opts, sasl.TLSState(connState))
			if conn.Secure() {
				opts = append(opts, sasl.TLSState(conn.ConnectionState()))
			}
			client := sasl.NewClient(selected, opts...)


M session.go => session.go +1 -1
@@ 123,7 123,7 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
	// If rw was already a *tls.Conn or some other Conn that is secure, go ahead
	// and mark the connection as secure so that we don't try to negotiate
	// StartTLS.
	if _, ok := s.conn.ConnectionState(); ok {
	if s.conn.Secure() {
		s.state |= Secure
	}