~samwhited/xmpp

3383439c54460ff5bbf71709d16763a45d4ba1a4 — Sam Whited 2 months ago c330385
xmpp, websocket: move websocket negotiator config

Previously we configured whether to use websockets in the stream config.
This meant that in a future version of this library where the stream
config can be changed on each feature negotiation step (eg. to set the
language once we know the JID we're trying to auth as) we could swap
back and forth between websockets and the normal way, which is not
ideal. It also just felt wrong there: the websocket negotiator should
obviously be in the websocket package.
However, splitting it out was difficult because unless we copy/paste the
entire implementation (a maintainability nightmare) we end up with
import loops.
To fix this a somewhat jank internal API was added so that we can use
the same implementation but copy/pate a single string key instead of the
whole thing. For now this works and is hidden from the user.

Signed-off-by: Sam Whited <sam@samwhited.com>
M CHANGELOG.md => CHANGELOG.md +4 -0
@@ 8,6 8,8 @@ All notable changes to this project will be documented in this file.

- roster: rename `version` attribute to `ver`
- styling: decoding tokens now uses an iterator pattern
- xmpp: the `WebSocket` option on `StreamConfig` has been removed in favor of
  `websocket.Negotiator`


### Security


@@ 23,6 25,8 @@ All notable changes to this project will be documented in this file.
- stanza: implement [XEP-0203: Delayed Delivery]
- stanza: more general `UnmarshalError` function that doesn't focus on IQs
- stanza: add `Error` method to `Presence` and `Message`
- websocket: add `Negotiator` to replace the `WebSocket` option on the stream
  config


### Fixed

A internal/wskey/key.go => internal/wskey/key.go +28 -0
@@ 0,0 1,28 @@
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package wskey is a context key used by negotiators.
//
// We are doing exactly what the context package tells us not to do and using it
// to pass optional arguments to the function (in this case, whether to use the
// WebSocket subprotocol or not).
// A better way to do this would be to move the negotiator to an
// internal/negotiator package and have xmpp and websocket both import and use
// that.
// Unfortunately, that would cause import loops (because the negotiator function
// takes an xmpp.Session, so the internal/negotiator package would also need to
// import the xmpp package).
// We could also copy/pate the entire implementation into websocket, but this is
// a maintainability nightmare.
//
// Having a secret internal API may not be ideal, but it does let us get away
// with a nice surface API without any real drawbacks other than an extra tiny
// internal package to house this key.
package wskey // import "mellium.im/xmpp/internal/wskey"

// Key is an internal type used as a context key by the xmpp and websocket
// packages.
// If it is provided on a context to xmpp.NewNegotiator, the WebSocket
// subprotocol is used instead of the normal XMPP protocol.
type Key struct{}

M negotiator.go => negotiator.go +13 -9
@@ 11,6 11,7 @@ import (

	"mellium.im/xmpp/internal/attr"
	intstream "mellium.im/xmpp/internal/stream"
	"mellium.im/xmpp/internal/wskey"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stream"
)


@@ 59,10 60,6 @@ type StreamConfig struct {
	// be re-used or appended to if desired (however, this is not required).
	Features func(*Session, ...StreamFeature) []StreamFeature

	// WebSocket indicates that the negotiator should use the WebSocket
	// subprotocol defined in RFC 7395.
	WebSocket bool

	// If set a copy of any reads from the session will be written to TeeIn and
	// any writes to the session will be written to TeeOut (similar to the tee(1)
	// command).


@@ 99,6 96,13 @@ func negotiator(cfg StreamConfig) Negotiator {
			}
		}

		// This is a secret internal API that lets us use this same negotiator
		// implementation in the websocket package without copy/pasting the entire
		// implementation or creating import loops.
		// For more information see the internal/wskey package.
		wsCtx := ctx.Value(wskey.Key{})
		websocket := wsCtx != nil

		c := s.Conn()
		// If the session is not already using a tee conn, but we're configured to
		// use one, return the new teeConn and don't set any state bits.


@@ 126,7 130,7 @@ func negotiator(cfg StreamConfig) Negotiator {

				location := s.LocalAddr()
				origin := s.RemoteAddr()
				err = intstream.Expect(ctx, in, s.in.d, s.State()&Received == Received, cfg.WebSocket)
				err = intstream.Expect(ctx, in, s.in.d, s.State()&Received == Received, websocket)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err


@@ 152,7 156,7 @@ func negotiator(cfg StreamConfig) Negotiator {
				location = in.To
				origin = in.From

				err = intstream.Send(s.Conn(), out, s.State()&S2S == S2S, cfg.WebSocket, stream.DefaultVersion, cfg.Lang, origin.String(), location.String(), attr.RandomID())
				err = intstream.Send(s.Conn(), out, s.State()&S2S == S2S, websocket, stream.DefaultVersion, cfg.Lang, origin.String(), location.String(), attr.RandomID())
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err


@@ 162,12 166,12 @@ func negotiator(cfg StreamConfig) Negotiator {
				// one in response.
				origin := s.LocalAddr()
				location := s.RemoteAddr()
				err = intstream.Send(s.Conn(), out, s.State()&S2S == S2S, cfg.WebSocket, stream.DefaultVersion, cfg.Lang, location.String(), origin.String(), "")
				err = intstream.Send(s.Conn(), out, s.State()&S2S == S2S, websocket, stream.DefaultVersion, cfg.Lang, location.String(), origin.String(), "")
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				err = intstream.Expect(ctx, in, s.in.d, s.State()&Received == Received, cfg.WebSocket)
				err = intstream.Expect(ctx, in, s.in.d, s.State()&Received == Received, websocket)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err


@@ 189,7 193,7 @@ func negotiator(cfg StreamConfig) Negotiator {
		if cfg.Features != nil {
			features = cfg.Features(s, features...)
		}
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, cfg.WebSocket, features)
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, websocket, features)
		nState.doRestart = rw != nil
		return mask, rw, nState, err
	}

M websocket/integration_test.go => websocket/integration_test.go +1 -2
@@ 74,8 74,7 @@ func integrationDialWebsocket(ctx context.Context, t *testing.T, cmd *integratio
	session, err := xmpp.NewSession(
		context.TODO(), j.Domain(), j, conn,
		xmpp.Secure,
		xmpp.NewNegotiator(xmpp.StreamConfig{
			WebSocket: true,
		websocket.Negotiator(xmpp.StreamConfig{
			Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
				return []xmpp.StreamFeature{
					xmpp.SASL("", pass, sasl.Plain),

A websocket/negotiator.go => websocket/negotiator.go +24 -0
@@ 0,0 1,24 @@
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package websocket

import (
	"context"
	"io"

	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/wskey"
	"mellium.im/xmpp/stream"
)

// Negotiator is like xmpp.NewNegotiator except that it uses the websocket
// subprotocol.
func Negotiator(cfg xmpp.StreamConfig) xmpp.Negotiator {
	xmppNegotiator := xmpp.NewNegotiator(cfg)
	return func(ctx context.Context, in, out *stream.Info, session *xmpp.Session, data interface{}) (xmpp.SessionState, io.ReadWriter, interface{}, error) {
		ctx = context.WithValue(ctx, wskey.Key{}, struct{}{})
		return xmppNegotiator(ctx, in, out, session, data)
	}
}

M websocket/ws.go => websocket/ws.go +2 -4
@@ 25,11 25,10 @@ import (
// client on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := xmpp.NewNegotiator(xmpp.StreamConfig{
	n := Negotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
		WebSocket: true,
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {


@@ 42,11 41,10 @@ func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ..
// receiving server on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func ReceiveSession(ctx context.Context, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := xmpp.NewNegotiator(xmpp.StreamConfig{
	n := Negotiator(xmpp.StreamConfig{
		Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature {
			return features
		},
		WebSocket: true,
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {