~samwhited/xmpp

92a682eade563e5ac455d489ba67a9610d4b99f4 — Sam Whited 2 months ago 404b735
xmpp: make stream config more flexible

Previously the stream config was a struct and the only thing that could
be changed between stream restarts was the features we advertised.
However, we may want to change other parts of the stream config between
restarts. For example, if we figure out the users JID after the first
step we may want to look them up in the database and set the default
stream language based on their preferences.

To accomplish this we now take a stream config function instead of
taking the struct directly (and the Features field has gone back to
being a slice and is no longer a function itself).
Each time we iterate we update the config by calling the function, which
can look up properties of the session before deciding what config needs
to change.

Fixes #106

Signed-off-by: Sam Whited <sam@samwhited.com>
M CHANGELOG.md => CHANGELOG.md +2 -0
@@ 21,6 21,8 @@ All notable changes to this project will be documented in this file.
  `websocket.Negotiator`
- xmpp: the `IterIQ` and `IterIQElement` methods on `Session` now return the
  start element token associated with the IQ payload
- xmpp: `Negotiator` now takes a stream config function instead of a
  `StreamConfig` struct


### Added

M doc.go => doc.go +5 -3
@@ 57,9 57,11 @@
//     …
//     session, err := xmpp.NewSession(
//         context.TODO(), addr.Domain(), addr, conn, xmpp.Secure,
//         xmpp.NewNegotiator(xmpp.StreamConfig{
//             Lang: "en",
//             …
//         xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) {
//             return xmpp.StreamConfig{
//                 Lang: "en",
//                 …
//             },
//         }),
//     )
//

M docs/overview.md => docs/overview.md +1 -1
@@ 174,7 174,7 @@ type StreamConfig struct {
// session.
// If StartTLS is one of the supported stream features, the Negotiator attempts
// to negotiate it whether the server advertises support or not.
func NewNegotiator(cfg StreamConfig) Negotiator
func NewNegotiator(func(*Session, *StreamConfig) StreamConfig) Negotiator
```

It uses stream features as discussed in the previous

M examples/echobot/echo.go => examples/echobot/echo.go +8 -11
@@ 38,23 38,20 @@ func echo(ctx context.Context, addr, pass string, xmlIn, xmlOut io.Writer, logge
		return fmt.Errorf("Error dialing sesion: %w", err)
	}

	s, err := xmpp.NewSession(ctx, j.Domain(), j, conn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{
		Lang: "en",
		Features: func(_ *xmpp.Session, f ...xmpp.StreamFeature) []xmpp.StreamFeature {
			if f != nil {
				return f
			}
			return []xmpp.StreamFeature{
	s, err := xmpp.NewSession(ctx, j.Domain(), j, conn, 0, xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Lang: "en",
			Features: []xmpp.StreamFeature{
				xmpp.BindResource(),
				xmpp.StartTLS(&tls.Config{
					ServerName: j.Domain().String(),
					MinVersion: tls.VersionTLS12,
				}),
				xmpp.SASL("", pass, sasl.ScramSha1Plus, sasl.ScramSha1, sasl.Plain),
			}
		},
		TeeIn:  xmlIn,
		TeeOut: xmlOut,
			},
			TeeIn:  xmlIn,
			TeeOut: xmlOut,
		}
	}))
	if err != nil {
		return fmt.Errorf("Error establishing a session: %w", err)

M examples/im/main.go => examples/im/main.go +7 -7
@@ 186,19 186,19 @@ func main() {
	if err != nil {
		logger.Fatalf("error dialing connection: %v", err)
	}
	negotiator := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return []xmpp.StreamFeature{
	negotiator := xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: []xmpp.StreamFeature{
				xmpp.BindResource(),
				xmpp.StartTLS(&tls.Config{
					ServerName: parsedAddr.Domain().String(),
					MinVersion: tls.VersionTLS12,
				}),
				xmpp.SASL(parsedAuthAddr.String(), pass, sasl.ScramSha256Plus, sasl.ScramSha1Plus, sasl.ScramSha256, sasl.ScramSha1, sasl.Plain),
			}
		},
		TeeIn:  logWriter{logger: recvXML},
		TeeOut: logWriter{logger: sentXML},
			},
			TeeIn:  logWriter{logger: recvXML},
			TeeOut: logWriter{logger: sentXML},
		}
	})
	session, err := xmpp.NewSession(dialCtx, parsedAddr.Domain(), parsedAddr, conn, 0, negotiator)
	dialCtxCancel()

M internal/integration/integration.go => internal/integration/integration.go +6 -6
@@ 450,12 450,12 @@ func (cmd *Cmd) dial(ctx context.Context, s2s bool, location, origin jid.JID, t 
	if err != nil {
		return nil, err
	}
	negotiator := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
		TeeIn:  cmd.in,
		TeeOut: cmd.out,
	negotiator := xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: features,
			TeeIn:    cmd.in,
			TeeOut:   cmd.out,
		}
	})
	var mask xmpp.SessionState
	if s2s {

M internal/integration/mellium/mellium.go => internal/integration/mellium/mellium.go +5 -7
@@ 182,19 182,17 @@ func listen(s2s bool, l net.Listener, logger *log.Logger, cfg Config) {
			streamCfg := xmpp.StreamConfig{}
			if s2s {
				mask |= xmpp.S2S
				streamCfg.Features = func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
					return cfg.S2SFeatures
				}
				streamCfg.Features = cfg.S2SFeatures
			} else {
				streamCfg.Features = func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
					return cfg.C2SFeatures
				}
				streamCfg.Features = cfg.C2SFeatures
			}
			if cfg.LogXML {
				streamCfg.TeeIn = logWriter{logger: log.New(logger.Writer(), "RECV ", log.LstdFlags)}
				streamCfg.TeeOut = logWriter{logger: log.New(logger.Writer(), "SEND ", log.LstdFlags)}
			}
			session, err := xmpp.ReceiveSession(context.TODO(), conn, mask, xmpp.NewNegotiator(streamCfg))
			session, err := xmpp.ReceiveSession(context.TODO(), conn, mask, xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
				return streamCfg
			}))
			if err != nil {
				logger.Printf("error negotiating %s session: %v", connType, err)
				return

M negotiator.go => negotiator.go +25 -25
@@ 41,24 41,7 @@ type StreamConfig struct {
	Lang string

	// A list of stream features to attempt to negotiate.
	// Features will be called every time a new stream is started so that the user
	// may look up required stream features based on information about an incoming
	// stream such as the location and origin JID.
	// Individual features still control whether or not they are listed at any
	// given time, so all possible features should be returned on each step and
	// new features only added to the list when we learn that they are possible
	// eg. because the origin or location JID is set and we can look up that users
	// configuration in the database.
	// For example, you would not return StartTLS the first time this feature is
	// called then return Auth once you see that the secure bit is set on the
	// session state because the stream features themselves would handle this for
	// you.
	// Instead you would always return StartTLS and Auth, but you might only add
	// the "password reset" feature once you see that the origin JID is one that
	// has a backup email in the database.
	// The previous stream features list is passed in at each step so that it can
	// be re-used or appended to if desired (however, this is not required).
	Features func(*Session, ...StreamFeature) []StreamFeature
	Features []StreamFeature

	// If set a copy of any reads from the session will be written to TeeIn and
	// any writes to the session will be written to TeeOut (similar to the tee(1)


@@ 74,7 57,26 @@ type StreamConfig struct {
// session.
// If StartTLS is one of the supported stream features, the Negotiator attempts
// to negotiate it whether the server advertises support or not.
func NewNegotiator(cfg StreamConfig) Negotiator {
//
// The cfg function will be called every time a new stream is started so that
// the user may look up required stream features, the default language, and
// other properties based on information about an incoming stream such as the
// location and origin JID.
// Individual features still control whether or not they are listed at any
// given time, so all possible features should be returned on each step and
// new features only added to the list when we learn that they are possible
// eg. because the origin or location JID is set and we can look up that users
// configuration in the database.
// For example, you would not return StartTLS the first time this feature is
// called then return Auth once you see that the secure bit is set on the
// session state because the stream features themselves would handle this for
// you.
// Instead you would always return StartTLS and Auth, but you might only add
// the "password reset" feature once you see that the origin JID is one that
// has a backup email in the database.
// The previous config is passed in at each step so that it can be re-used or
// modified (however, this is not required).
func NewNegotiator(cfg func(*Session, *StreamConfig) StreamConfig) Negotiator {
	return negotiator(cfg)
}



@@ 83,8 85,8 @@ type negotiatorState struct {
	cancelTee context.CancelFunc
}

func negotiator(cfg StreamConfig) Negotiator {
	var features []StreamFeature
func negotiator(f func(*Session, *StreamConfig) StreamConfig) Negotiator {
	cfg := f(nil, nil)
	return func(ctx context.Context, in, out *stream.Info, s *Session, data interface{}) (mask SessionState, rw io.ReadWriter, restartNext interface{}, err error) {
		nState, ok := data.(negotiatorState)
		// If no state was passed in, this is the first negotiate call so make up a


@@ 190,10 192,8 @@ func negotiator(cfg StreamConfig) Negotiator {
			}
		}

		if cfg.Features != nil {
			features = cfg.Features(s, features...)
		}
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, websocket, features)
		cfg = f(s, &cfg)
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, websocket, cfg.Features)
		nState.doRestart = rw != nil
		return mask, rw, nState, err
	}

M session.go => session.go +24 -27
@@ 297,13 297,10 @@ func DialClientSession(ctx context.Context, origin jid.JID, features ...StreamFe
	if err != nil {
		return nil, err
	}
	return NewSession(ctx, origin.Domain(), origin, conn, 0, NewNegotiator(StreamConfig{
		Features: func(_ *Session, f ...StreamFeature) []StreamFeature {
			if f != nil {
				return f
			}
			return features
		},
	return NewSession(ctx, origin.Domain(), origin, conn, 0, NewNegotiator(func(*Session, *StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 317,10 314,10 @@ func DialServerSession(ctx context.Context, location, origin jid.JID, features .
	if err != nil {
		return nil, err
	}
	return NewSession(ctx, location, origin, conn, S2S, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return NewSession(ctx, location, origin, conn, S2S, NewNegotiator(func(*Session, *StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 331,10 328,10 @@ func DialServerSession(ctx context.Context, location, origin jid.JID, features .
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func NewClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return NewSession(ctx, origin.Domain(), origin, rw, 0, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return NewSession(ctx, origin.Domain(), origin, rw, 0, NewNegotiator(func(*Session, *StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 345,10 342,10 @@ func NewClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter, fea
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func ReceiveClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return ReceiveSession(ctx, rw, 0, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return ReceiveSession(ctx, rw, 0, NewNegotiator(func(*Session, *StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 359,10 356,10 @@ func ReceiveClientSession(ctx context.Context, origin jid.JID, rw io.ReadWriter,
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return NewSession(ctx, location, origin, rw, S2S, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return NewSession(ctx, location, origin, rw, S2S, NewNegotiator(func(*Session, *StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}



@@ 373,10 370,10 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func ReceiveServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return ReceiveSession(ctx, rw, S2S, NewNegotiator(StreamConfig{
		Features: func(*Session, ...StreamFeature) []StreamFeature {
			return features
		},
	return ReceiveSession(ctx, rw, S2S, NewNegotiator(func(*Session, *StreamConfig) StreamConfig {
		return StreamConfig{
			Features: features,
		}
	}))
}


M session_test.go => session_test.go +23 -20
@@ 114,26 114,28 @@ var readyFeature = xmpp.StreamFeature{
var negotiateTests = [...]negotiateTestCase{
	0: {negotiator: errNegotiator, err: errTestNegotiate},
	1: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{xmpp.StartTLS(nil)}
			},
		negotiator: xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{xmpp.StartTLS(nil)},
			}
		}),
		in:  `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out: `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'><starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`,
		err: errors.New("XML syntax error on line 1: unexpected EOF"),
	},
	2: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{}),
		in:         `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out:        `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'>`,
		err:        errors.New("xmpp: features advertised out of order"),
		negotiator: xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{}
		}),
		in:  `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'><stream:features><other/></stream:features>`,
		out: `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'>`,
		err: errors.New("xmpp: features advertised out of order"),
	},
	3: {
		negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{readyFeature}
			},
		negotiator: xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{readyFeature},
			}
		}),
		in:           `<stream:stream id='316732270768047465' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:server'><stream:features><ready xmlns='urn:example'/></stream:features>`,
		out:          `<?xml version="1.0" encoding="UTF-8"?><stream:stream xmlns='jabber:server' xmlns:stream='http://etherx.jabber.org/streams' version='1.0'>`,


@@ 408,19 410,20 @@ func TestNegotiateStreamError(t *testing.T) {
	semaphore := make(chan struct{})
	go func() {
		defer close(semaphore)
		_, err := xmpp.ReceiveSession(ctx, serverConn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{errorStartTLS(stream.Conflict)}
			},

		_, err := xmpp.ReceiveSession(ctx, serverConn, 0, xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{errorStartTLS(stream.Conflict)},
			}
		}))
		if err != nil {
			t.Logf("error receiving session: %v", err)
		}
	}()
	_, err := xmpp.NewSession(ctx, clientJID, clientJID.Bare(), clientConn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return []xmpp.StreamFeature{xmpp.StartTLS(nil)}
		},
	_, err := xmpp.NewSession(ctx, clientJID, clientJID.Bare(), clientConn, 0, xmpp.NewNegotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: []xmpp.StreamFeature{xmpp.StartTLS(nil)},
		}
	}))
	if !errors.Is(err, stream.Conflict) {
		t.Errorf("unexpected client err: want=%v, got=%v", stream.Conflict, err)

M websocket/integration_test.go => websocket/integration_test.go +5 -5
@@ 74,13 74,13 @@ func integrationDialWebsocket(ctx context.Context, t *testing.T, cmd *integratio
	session, err := xmpp.NewSession(
		context.TODO(), j.Domain(), j, conn,
		xmpp.Secure,
		websocket.Negotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{
		websocket.Negotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
			return xmpp.StreamConfig{
				Features: []xmpp.StreamFeature{
					xmpp.SASL("", pass, sasl.Plain),
					xmpp.BindResource(),
				}
			},
				},
			}
		}),
	)
	if err != nil {

M websocket/negotiator.go => websocket/negotiator.go +1 -1
@@ 15,7 15,7 @@ import (

// Negotiator is like xmpp.NewNegotiator except that it uses the websocket
// subprotocol.
func Negotiator(cfg xmpp.StreamConfig) xmpp.Negotiator {
func Negotiator(cfg func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig) xmpp.Negotiator {
	xmppNegotiator := xmpp.NewNegotiator(cfg)
	return func(ctx context.Context, in, out *stream.Info, session *xmpp.Session, data interface{}) (xmpp.SessionState, io.ReadWriter, interface{}, error) {
		ctx = context.WithValue(ctx, wskey.Key{}, struct{}{})

M websocket/ws.go => websocket/ws.go +8 -8
@@ 25,10 25,10 @@ import (
// client on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := Negotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
	n := Negotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: features,
		}
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {


@@ 41,10 41,10 @@ func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ..
// receiving server on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func ReceiveSession(ctx context.Context, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := Negotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
	n := Negotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: features,
		}
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {