~samwhited/xmpp

76c24b3f547857bdef3334e3a2a4d2695b36c6dc — Sam Whited 3 months ago e7a6fe8
all: new stream config API

This is am more flexible API that lets us change stream properties such
as the language or advertised features list based on previous feature
negotiation.

Fixes #106

Signed-off-by: Sam Whited <sam@samwhited.com>
M doc.go => doc.go +5 -3
@@ 57,9 57,11 @@
//     …
//     session, err := xmpp.NewSession(
//         context.TODO(), addr.Domain(), addr, conn, xmpp.Secure,
//         xmpp.NewNegotiator(xmpp.StreamConfig{
//             Lang: "en",
//             …
//         xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
//             return xmpp.StreamConfig{
//                 Lang: "en",
//                 …
//             }
//         }),
//     )
//

M examples/echobot/echo.go => examples/echobot/echo.go +8 -11
@@ 38,23 38,20 @@ func echo(ctx context.Context, addr, pass string, xmlIn, xmlOut io.Writer, logge
		return fmt.Errorf("Error dialing sesion: %w", err)
	}

	s, err := xmpp.NewSession(ctx, j.Domain(), j, conn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{
		Lang: "en",
		Features: func(_ *xmpp.Session, f ...xmpp.StreamFeature) []xmpp.StreamFeature {
			if f != nil {
				return f
			}
			return []xmpp.StreamFeature{
	s, err := xmpp.NewSession(ctx, j.Domain(), j, conn, 0, xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Lang: "en",
			Features: []xmpp.StreamFeature{
				xmpp.BindResource(),
				xmpp.StartTLS(&tls.Config{
					ServerName: j.Domain().String(),
					MinVersion: tls.VersionTLS12,
				}),
				xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),
			}
		},
		TeeIn:  xmlIn,
		TeeOut: xmlOut,
			},
			TeeIn:  xmlIn,
			TeeOut: xmlOut,
		}
	}))
	if err != nil {
		return fmt.Errorf("Error establishing a session: %w", err)

M examples/im/main.go => examples/im/main.go +7 -7
@@ 186,19 186,19 @@ func main() {
	if err != nil {
		logger.Fatalf("error dialing connection: %v", err)
	}
	negotiator := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return []xmpp.StreamFeature{
	negotiator := xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: []xmpp.StreamFeature{
				xmpp.BindResource(),
				xmpp.StartTLS(&tls.Config{
					ServerName: parsedAddr.Domain().String(),
					MinVersion: tls.VersionTLS12,
				}),
				xmpp.SASL(parsedAuthAddr.String(), pass, sasl.ScramSha256Plus, sasl.ScramSha1Plus, sasl.ScramSha256, sasl.ScramSha1, sasl.Plain),
			}
		},
		TeeIn:  logWriter{logger: recvXML},
		TeeOut: logWriter{logger: sentXML},
			},
			TeeIn:  logWriter{logger: recvXML},
			TeeOut: logWriter{logger: sentXML},
		}
	})
	session, err := xmpp.NewSession(dialCtx, parsedAddr.Domain(), parsedAddr, conn, 0, negotiator)
	dialCtxCancel()

M internal/integration/integration.go => internal/integration/integration.go +6 -6
@@ 393,12 393,12 @@ func (cmd *Cmd) dial(ctx context.Context, s2s bool, location, origin jid.JID, t 
	if err != nil {
		return nil, err
	}
	negotiator := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
		TeeIn:  cmd.in,
		TeeOut: cmd.out,
	negotiator := xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: features,
			TeeIn:    cmd.in,
			TeeOut:   cmd.out,
		}
	})
	var mask xmpp.SessionState
	if s2s {

M internal/integration/mellium/mellium.go => internal/integration/mellium/mellium.go +5 -7
@@ 182,19 182,17 @@ func listen(s2s bool, l net.Listener, logger *log.Logger, cfg Config) {
			streamCfg := xmpp.StreamConfig{}
			if s2s {
				mask |= xmpp.S2S
				streamCfg.Features = func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
					return cfg.S2SFeatures
				}
				streamCfg.Features = cfg.S2SFeatures
			} else {
				streamCfg.Features = func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
					return cfg.C2SFeatures
				}
				streamCfg.Features = cfg.C2SFeatures
			}
			if cfg.LogXML {
				streamCfg.TeeIn = logWriter{logger: log.New(logger.Writer(), "RECV ", log.LstdFlags)}
				streamCfg.TeeOut = logWriter{logger: log.New(logger.Writer(), "SEND ", log.LstdFlags)}
			}
			session, err := xmpp.ReceiveSession(context.TODO(), conn, mask, xmpp.NewNegotiator(streamCfg))
			session, err := xmpp.ReceiveSession(context.TODO(), conn, mask, xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
				return streamCfg
			}))
			if err != nil {
				logger.Printf("error negotiating %s session: %v", connType, err)
				return

M negotiator.go => negotiator.go +26 -25
@@ 40,24 40,7 @@ type StreamConfig struct {
	Lang string

	// A list of stream features to attempt to negotiate.
	// Features will be called every time a new stream is started so that the user
	// may look up required stream features based on information about an incoming
	// stream such as the location and origin JID.
	// Individual features still control whether or not they are listed at any
	// given time, so all possible features should be returned on each step and
	// new features only added to the list when we learn that they are possible
	// eg. because the origin or location JID is set and we can look up that users
	// configuration in the database.
	// For example, you would not return StartTLS the first time this feature is
	// called then return Auth once you see that the secure bit is set on the
	// session state because the stream features themselves would handle this for
	// you.
	// Instead you would always return StartTLS and Auth, but you might only add
	// the "password reset" feature once you see that the origin JID is one that
	// has a backup email in the database.
	// The previous stream features list is passed in at each step so that it can
	// be re-used or appended to if desired (however, this is not required).
	Features func(*Session, ...StreamFeature) []StreamFeature
	Features []StreamFeature

	// WebSocket indicates that the negotiator should use the WebSocket
	// subprotocol defined in RFC 7395.


@@ 77,7 60,27 @@ type StreamConfig struct {
// session.
// If StartTLS is one of the supported stream features, the Negotiator attempts
// to negotiate it whether the server advertises support or not.
func NewNegotiator(cfg StreamConfig) Negotiator {
//
// The function will be called every time a new stream is started so that the
// user may look up required stream features (and other stream configuration)
// based on information about an incoming stream such as the location and origin
// JID.
// Individual features still control whether or not they are listed at any
// given time, so all possible features should be returned on each step and
// new features only added to the list when we learn that they are possible
// eg. because the origin or location JID is set and we can look up that users
// configuration in the database.
// For example, you would not return StartTLS the first time this function is
// called then return Auth once you see that the secure bit is set on the
// session state because the stream features themselves would handle this for
// you.
// Instead you would always return StartTLS and Auth, but you might only add
// the "password reset" feature once you see that the origin JID is one that
// has a backup email in the database.
// The previous stream config is passed in at each step so that it can be
// re-used or the stream features may be appended to if desired (however, this
// is not required).
func NewNegotiator(cfg func(*Session, StreamConfig) StreamConfig) Negotiator {
	return negotiator(cfg)
}



@@ 86,9 89,10 @@ type negotiatorState struct {
	cancelTee context.CancelFunc
}

func negotiator(cfg StreamConfig) Negotiator {
	var features []StreamFeature
func negotiator(cf func(*Session, StreamConfig) StreamConfig) Negotiator {
	var cfg StreamConfig
	return func(ctx context.Context, in, out *stream.Info, s *Session, data interface{}) (mask SessionState, rw io.ReadWriter, restartNext interface{}, err error) {
		cfg = cf(s, cfg)
		nState, ok := data.(negotiatorState)
		// If no state was passed in, this is the first negotiate call so make up a
		// default.


@@ 186,10 190,7 @@ func negotiator(cfg StreamConfig) Negotiator {
			}
		}

		if cfg.Features != nil {
			features = cfg.Features(s, features...)
		}
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, cfg.WebSocket, features)
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, cfg.WebSocket, cfg.Features)
		nState.doRestart = rw != nil
		return mask, rw, nState, err
	}

M session.go => session.go +24 -27
@@ 297,13 297,10 @@ func DialClientSession(ctx context.Context, origin jid.JID, features ...StreamFe
	if err != nil {
		return nil, err
	}
	return NewSession(ctx, origin.Domain(), origin, conn, 0, NewNegotiator(StreamConfig{
		Features: func(_ *Session, f ...StreamFeature) []StreamFeature {
			if f != nil {
				return f
			}
			return features
		},
	return NewSession(ctx, origin.Domain(), origin, conn, 0, NewNegotiator(func(*Session, StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 317,10 314,10 @@ func DialServerSession(ctx context.Context, location, origin jid.JID, features .
	if err != nil {
		return nil, err
	}
	return NewSession(ctx, location, origin, conn, S2S, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return NewSession(ctx, location, origin, conn, S2S, NewNegotiator(func(*Session, StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 331,10 328,10 @@ func DialServerSession(ctx context.Context, location, origin jid.JID, features .
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func NewClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return NewSession(ctx, origin.Domain(), origin, rw, 0, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return NewSession(ctx, origin.Domain(), origin, rw, 0, NewNegotiator(func(*Session, StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 345,10 342,10 @@ func NewClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter, fea
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func ReceiveClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return ReceiveSession(ctx, rw, 0, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return ReceiveSession(ctx, rw, 0, NewNegotiator(func(*Session, StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 359,10 356,10 @@ func ReceiveClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter,
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return NewSession(ctx, location, origin, rw, S2S, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return NewSession(ctx, location, origin, rw, S2S, NewNegotiator(func(*Session, StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 373,10 370,10 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func ReceiveServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return ReceiveSession(ctx, rw, S2S, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return ReceiveSession(ctx, rw, S2S, NewNegotiator(func(*Session, StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}


M session_test.go => session_test.go +22 -20
@@ 114,26 114,28 @@ var readyFeature = xmpp.StreamFeature{
var negotiateTests = [...]negotiateTestCase{
	0: {negotiator: errNegotiator, err: errTestNegotiate},
	1: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{xmpp.StartTLS(nil)}
			},
		negotiator: xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{xmpp.StartTLS(nil)},
			}
		}),
		in:  `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out: `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'><starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`,
		err: errors.New("XML syntax error on line 1: unexpected EOF"),
	},
	2: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{}),
		in:         `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out:        `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'>`,
		err:        errors.New("xmpp: features advertised out of order"),
		negotiator: xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{}
		}),
		in:  `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out: `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'>`,
		err: errors.New("xmpp: features advertised out of order"),
	},
	3: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{readyFeature}
			},
		negotiator: xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{readyFeature},
			}
		}),
		in:           `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:server'><stream:features><ready xmlns='urn:example'/></stream:features>`,
		out:          `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:server' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'>`,


@@ 408,19 410,19 @@ func TestNegotiateStreamError(t *testing.T) {
	semaphore := make(chan struct{})
	go func() {
		defer close(semaphore)
		_, err := xmpp.ReceiveSession(ctx, serverConn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{errorStartTLS(stream.Conflict)}
			},
		_, err := xmpp.ReceiveSession(ctx, serverConn, 0, xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{errorStartTLS(stream.Conflict)},
			}
		}))
		if err != nil {
			t.Logf("error receiving session: %v", err)
		}
	}()
	_, err := xmpp.NewSession(ctx, clientJID, clientJID.Bare(), clientConn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return []xmpp.StreamFeature{xmpp.StartTLS(nil)}
		},
	_, err := xmpp.NewSession(ctx, clientJID, clientJID.Bare(), clientConn, 0, xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: []xmpp.StreamFeature{xmpp.StartTLS(nil)},
		}
	}))
	if !errors.Is(err, stream.Conflict) {
		t.Errorf("unexpected client err: want=%v, got=%v", stream.Conflict, err)

M websocket/integration_test.go => websocket/integration_test.go +6 -6
@@ 74,14 74,14 @@ func integrationDialWebsocket(ctx context.Context, t *testing.T, cmd *integratio
	session, err := xmpp.NewSession(
		context.TODO(), j.Domain(), j, conn,
		xmpp.Secure,
		xmpp.NewNegotiator(xmpp.StreamConfig{
			WebSocket: true,
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{
		xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				WebSocket: true,
				Features: []xmpp.StreamFeature{
					xmpp.SASL("", pass, sasl.Plain),
					xmpp.BindResource(),
				}
			},
				},
			}
		}),
	)
	if err != nil {

M websocket/ws.go => websocket/ws.go +10 -10
@@ 25,11 25,11 @@ import (
// client on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
		WebSocket: true,
	n := xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features:  features,
			WebSocket: true,
		}
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {


@@ 42,11 42,11 @@ func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ..
// receiving server on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func ReceiveSession(ctx context.Context, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
		WebSocket: true,
	n := xmpp.NewNegotiator(func(*xmpp.Session, xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features:  features,
			WebSocket: true,
		}
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {