~samwhited/xmpp

88ae82b5c9078be63fc1774330da6ae8cf0bb995 — Sam Whited 11 months ago ef35225 stanza_mux
mux: let ServeMux register stanza handlers

Previously the multiplexer was too generalized and only handled top
level elements. To handle children different types of alternative muxers
had to be registered. Instead, move all the functionality for matching
against different stanza types into the main muxer and separate the
stanza handlers from the generic top level handlers.

Fixes #25
5 files changed, 599 insertions(+), 547 deletions(-)

D mux/iq.go
D mux/iq_test.go
M mux/mux.go
M mux/mux_test.go
A mux/stanza.go
D mux/iq.go => mux/iq.go +0 -280
@@ 1,280 0,0 @@
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package mux

import (
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"
)

// IQHandler responds to an IQ stanza.
type IQHandler interface {
	HandleIQ(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error
}

// The IQHandlerFunc type is an adapter to allow the use of ordinary functions
// as IQ handlers. If f is a function with the appropriate signature,
// IQHandlerFunc(f) is an IQHandler that calls f.
type IQHandlerFunc func(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error

// HandleIQ calls f(iq, t, start).
func (f IQHandlerFunc) HandleIQ(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	return f(iq, t, start)
}

type patternKey struct {
	xml.Name
	Type stanza.IQType
}

// IQMux is an XMPP multiplexer meant for handling IQ payloads.
//
// IQs are matched by the type and the XML name of their first child element (if
// any).
// If either the namespace or the localname is left off, any namespace or
// localname will be matched.
// Full XML names take precedence, followed by wildcard localnames, followed by
// wildcard namespaces.
//
// Unlike get and set type IQs, result IQs may have no child element, and error
// IQs may have more than one child element.
// Because of this it is normally adviseable to register handlers for type Error
// without any filter on the child element since we cannot guarantee what child
// token will come first and be matched against.
// Similarly, for IQs of type result, it is important to note that the start
// element passed to the handler may be nil, meaning that there is no child
// element.
type IQMux struct {
	patterns map[patternKey]IQHandler
}

// NewIQMux allocates and returns a new IQMux.
func NewIQMux(opt ...IQOption) *IQMux {
	m := &IQMux{}
	for _, o := range opt {
		o(m)
	}
	return m
}

// Handler returns the handler to use for an IQ payload with the given name and
// type.
// If no handler exists, a default handler is returned (h is always non-nil).
func (m *IQMux) Handler(iqType stanza.IQType, name xml.Name) (h IQHandler, ok bool) {
	pattern := patternKey{Name: name, Type: iqType}
	h = m.patterns[pattern]
	if h != nil {
		return h, true
	}

	n := name
	n.Space = ""
	pattern.Name = n
	h = m.patterns[pattern]
	if h != nil {
		return h, true
	}

	n = name
	n.Local = ""
	pattern.Name = n
	h = m.patterns[pattern]
	if h != nil {
		return h, true
	}

	pattern.Name = xml.Name{}
	h = m.patterns[pattern]
	if h != nil {
		return h, true
	}

	return IQHandlerFunc(iqFallback), false
}

func getPayload(t xmlstream.TokenReadEncoder, start *xml.StartElement) (stanza.IQ, *xml.StartElement, error) {
	iq, err := newIQFromStart(start)
	if err != nil {
		return iq, nil, err
	}

	tok, err := t.Token()
	if err != nil {
		return iq, nil, err
	}
	// If this is a result type IQ (or a malformed IQ) there may be no payload, so
	// make sure start is nil.
	payloadStart, ok := tok.(xml.StartElement)
	start = &payloadStart
	if !ok {
		start = nil
	}
	return iq, start, nil
}

// HandleXMPP dispatches the IQ to the handler whose pattern most closely
// matches start.Name.
func (m *IQMux) HandleXMPP(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	iq, start, err := getPayload(t, start)
	if err != nil {
		return err
	}

	h, _ := m.Handler(iq.Type, start.Name)
	return h.HandleIQ(iq, t, start)
}

// IQOption configures an IQMux.
type IQOption func(m *IQMux)

// HandleIQ returns an option that matches the IQ payload by XML name and IQ
// type.
// For readability, users may want to use the GetIQ, SetIQ, ErrorIQ, and
// ResultIQ shortcuts instead.
//
// For more details, see the documentation on IQMux.
func HandleIQ(iqType stanza.IQType, n xml.Name, h IQHandler) IQOption {
	return func(m *IQMux) {
		if h == nil {
			panic("mux: nil handler")
		}
		pattern := patternKey{Name: n, Type: iqType}
		if _, ok := m.patterns[pattern]; ok {
			panic("mux: multiple registrations for {" + pattern.Space + "}" + pattern.Local)
		}
		if m.patterns == nil {
			m.patterns = make(map[patternKey]IQHandler)
		}
		m.patterns[pattern] = h
	}
}

// GetIQ is a shortcut for HandleIQ with the type set to "get".
func GetIQ(n xml.Name, h IQHandler) IQOption {
	return HandleIQ(stanza.GetIQ, n, h)
}

// GetIQFunc is a shortcut for HandleIQFunc with the type set to "get".
func GetIQFunc(n xml.Name, h IQHandlerFunc) IQOption {
	return GetIQ(n, h)
}

// SetIQ is a shortcut for HandleIQ with the type set to "set".
func SetIQ(n xml.Name, h IQHandler) IQOption {
	return HandleIQ(stanza.SetIQ, n, h)
}

// SetIQFunc is a shortcut for HandleIQ with the type set to "set".
func SetIQFunc(n xml.Name, h IQHandlerFunc) IQOption {
	return SetIQ(n, h)
}

// ErrorIQ is a shortcut for HandleIQ with the type set to "error" and a
// wildcard XML name.
//
// This differs from the other IQ types because error IQs may contain one or
// more child elements and we cannot guarantee the order of child elements and
// therefore won't know which element to match on.
// Instead it is normally wise to register a handler for all error type IQs and
// then skip or handle unnecessary payloads until we find the error itself.
func ErrorIQ(h IQHandler) IQOption {
	return HandleIQ(stanza.ErrorIQ, xml.Name{}, h)
}

// ErrorIQFunc is a shortcut for HandleIQFunc with the type set to "error" and a
// wildcard XML name.
//
// For more information, see ErrorIQ.
func ErrorIQFunc(h IQHandlerFunc) IQOption {
	return ErrorIQ(h)
}

// ResultIQ is a shortcut for HandleIQ with the type set to "result".
//
// Unlike IQs of type get, set, and error, result type IQs may or may not
// contain a payload.
// Because of this it is important to check whether the start element is nil in
// handlers meant to handle result type IQs.
func ResultIQ(n xml.Name, h IQHandler) IQOption {
	return HandleIQ(stanza.ResultIQ, n, h)
}

// ResultIQFunc is a shortcut for HandleIQFunc with the type set to "result".
//
// For more information, see ResultIQ.
func ResultIQFunc(n xml.Name, h IQHandlerFunc) IQOption {
	return ResultIQ(n, h)
}

// HandleIQFunc returns an option that matches the IQ payload by XML name and IQ
// type.
func HandleIQFunc(iqType stanza.IQType, n xml.Name, h IQHandlerFunc) IQOption {
	return HandleIQ(iqType, n, h)
}

// newIQFromStart takes a start element and returns an IQ.
func newIQFromStart(start *xml.StartElement) (stanza.IQ, error) {
	iq := stanza.IQ{}
	var err error
	for _, a := range start.Attr {
		switch a.Name.Local {
		case "id":
			if a.Name.Space != "" {
				continue
			}
			iq.ID = a.Value
		case "to":
			if a.Name.Space != "" {
				continue
			}
			iq.To, err = jid.Parse(a.Value)
			if err != nil {
				return iq, err
			}
		case "from":
			if a.Name.Space != "" {
				continue
			}
			iq.From, err = jid.Parse(a.Value)
			if err != nil {
				return iq, err
			}
		case "lang":
			if a.Name.Space != ns.XML {
				continue
			}
			iq.Lang = a.Value
		case "type":
			if a.Name.Space != "" {
				continue
			}
			iq.Type = stanza.IQType(a.Value)
		}
	}
	return iq, nil
}

func iqFallback(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	if iq.Type == stanza.ErrorIQ {
		return nil
	}

	iq.To, iq.From = iq.From, iq.To
	iq.Type = "error"

	e := stanza.Error{
		Type:      stanza.Cancel,
		Condition: stanza.ServiceUnavailable,
	}
	_, err := xmlstream.Copy(t, stanza.WrapIQ(
		iq,
		e.TokenReader(),
	))
	return err
}

D mux/iq_test.go => mux/iq_test.go +0 -179
@@ 1,179 0,0 @@
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package mux_test

import (
	"bytes"
	"encoding/xml"
	"errors"
	"io"
	"strconv"
	"strings"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/xmpptest"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

const exampleNS = "com.example"

var passIQHandler mux.IQHandlerFunc = func(stanza.IQ, xmlstream.TokenReadEncoder, *xml.StartElement) error {
	return passTest
}

var failIQHandler mux.IQHandlerFunc = func(stanza.IQ, xmlstream.TokenReadEncoder, *xml.StartElement) error {
	return errors.New("mux_test: FAILED")
}

var iqTestCases = [...]struct {
	m      *mux.IQMux
	p      xml.Name
	iqType stanza.IQType
}{
	0: {
		// Exact match handler should be selected if available.
		m: mux.NewIQMux(
			mux.HandleIQ(stanza.GetIQ, xml.Name{Local: "a", Space: exampleNS}, failIQHandler),
			mux.HandleIQ(stanza.GetIQ, xml.Name{Local: "test", Space: "b"}, failIQHandler),
			mux.HandleIQFunc(stanza.GetIQ, xml.Name{Local: "test", Space: exampleNS}, passIQHandler),
		),
		p: xml.Name{Local: "test", Space: exampleNS},
	},
	1: {
		// If no exact match is available, fallback to the namespace wildcard
		// handler.
		m: mux.NewIQMux(
			mux.GetIQFunc(xml.Name{Local: "test", Space: ""}, passIQHandler),
			mux.HandleIQ(stanza.GetIQ, xml.Name{Local: "", Space: exampleNS}, failIQHandler),
		),
		p: xml.Name{Local: "test", Space: exampleNS},
	},
	2: {
		// If no exact match or namespace handler is available, fallback local name
		// handler.
		m: mux.NewIQMux(
			mux.HandleIQ(stanza.GetIQ, xml.Name{Local: "", Space: exampleNS}, passIQHandler),
		),
		p: xml.Name{Local: "test", Space: exampleNS},
	},
	3: {
		// If no exact match or localname/namespace wildcard is available, fallback
		// to just matching on type alone.
		m: mux.NewIQMux(
			mux.ResultIQFunc(xml.Name{Local: "test", Space: exampleNS}, failIQHandler),
			mux.ErrorIQFunc(passIQHandler),
		),
		p:      xml.Name{Local: "test", Space: exampleNS},
		iqType: stanza.ErrorIQ,
	},
	4: {
		// IQs must be routed correctly by type.
		m: mux.NewIQMux(
			mux.GetIQ(xml.Name{Local: "test", Space: exampleNS}, failIQHandler),
			mux.SetIQFunc(xml.Name{Local: "test", Space: exampleNS}, failIQHandler),
			mux.ResultIQ(xml.Name{Local: "test", Space: exampleNS}, passIQHandler),
			mux.ErrorIQ(passIQHandler),
		),
		p:      xml.Name{Local: "test", Space: exampleNS},
		iqType: stanza.ResultIQ,
	},
}

func TestIQMux(t *testing.T) {
	for i, tc := range iqTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			iqType := stanza.GetIQ
			if tc.iqType != "" {
				iqType = tc.iqType
			}
			err := tc.m.HandleXMPP(
				testReadEncoder{xmlstream.Wrap(nil, xml.StartElement{Name: tc.p})},
				&xml.StartElement{Name: xml.Name{Local: "iq"}, Attr: []xml.Attr{{Name: xml.Name{Local: "type"}, Value: string(iqType)}}},
			)
			if err != passTest {
				t.Fatalf("unexpected error: `%v'", err)
			}
		})
	}
}

type testReadEncoder struct {
	xml.TokenReader
}

func (testReadEncoder) EncodeToken(t xml.Token) error { return nil }

func (testReadEncoder) EncodeElement(interface{}, xml.StartElement) error {
	panic("unexpected EncodeElement")
}

func (testReadEncoder) Encode(interface{}) error { panic("unexpected Encode") }

func TestIQFallback(t *testing.T) {
	buf := &bytes.Buffer{}
	rw := struct {
		io.Reader
		io.Writer
	}{
		Reader: strings.NewReader(`<iq to="romeo@example.com" from="juliet@example.com" id="123"><test/></iq>`),
		Writer: buf,
	}
	s := xmpptest.NewSession(0, rw)

	r := s.TokenReader()
	defer r.Close()
	tok, err := r.Token()
	if err != nil {
		t.Fatalf("Bad start token read: `%v'", err)
	}
	start := tok.(xml.StartElement)
	w := s.TokenWriter()
	defer w.Close()
	err = mux.NewIQMux().HandleXMPP(testEncoder{
		TokenReader: r,
		TokenWriter: w,
	}, &start)
	if err != nil {
		t.Errorf("Unexpected error: `%v'", err)
	}
	if err := w.Flush(); err != nil {
		t.Errorf("Unexpected error flushing token writer: %q", err)
	}

	const expected = `<iq type="error" to="juliet@example.com" from="romeo@example.com" id="123"><error type="cancel"><service-unavailable xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></service-unavailable></error></iq>`
	if buf.String() != expected {
		t.Errorf("Bad output:\nwant=`%v'\n got=`%v'", expected, buf.String())
	}
}

func TestNilIQHandlerPanics(t *testing.T) {
	defer func() {
		if r := recover(); r == nil {
			t.Error("Expected a panic when trying to register a nil IQ handler")
		}
	}()
	mux.NewIQMux(mux.GetIQ(xml.Name{}, nil))
}

func TestIdenticalIQHandlerPanics(t *testing.T) {
	defer func() {
		if r := recover(); r == nil {
			t.Error("Expected a panic when trying to register a duplicate IQ handler")
		}
	}()
	mux.NewIQMux(
		mux.GetIQ(xml.Name{Space: "space", Local: "local"}, failIQHandler),
		mux.GetIQ(xml.Name{Space: "space", Local: "local"}, failIQHandler),
	)
}

func TestLazyIQMuxMapInitialization(t *testing.T) {
	m := &mux.IQMux{}

	// This will panic if the map isn't initialized lazily.
	mux.GetIQ(xml.Name{}, failIQHandler)(m)
}

M mux/mux.go => mux/mux.go +323 -40
@@ 11,8 11,15 @@ import (
	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"
)

type iqPattern struct {
	Payload xml.Name
	Type    stanza.IQType
}

// ServeMux is an XMPP stream multiplexer.
// It matches the start element token of each top level stream element against a
// list of registered patterns and calls the handler for the pattern that most


@@ 24,20 31,10 @@ import (
// Full XML names take precedence, followed by wildcard localnames, followed by
// wildcard namespaces.
type ServeMux struct {
	patterns map[xml.Name]xmpp.Handler
}

func fallback(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	if start.Name.Local != "iq" {
		return nil
	}

	iq, start, err := getPayload(t, start)
	if err != nil {
		return err
	}

	return iqFallback(iq, t, start)
	patterns         map[xml.Name]xmpp.Handler
	iqPatterns       map[iqPattern]IQHandler
	msgPatterns      map[stanza.MessageType]MessageHandler
	presencePatterns map[stanza.PresenceType]PresenceHandler
}

// New allocates and returns a new ServeMux.


@@ 73,11 70,80 @@ func (m *ServeMux) Handler(name xml.Name) (h xmpp.Handler, ok bool) {
		return h, true
	}

	return xmpp.HandlerFunc(fallback), false
	// Route stanzas.
	if name.Space == ns.Client || name.Space == ns.Server {
		switch name.Local {
		case "iq":
			return xmpp.HandlerFunc(m.iqRouter), true
		case "message":
			return xmpp.HandlerFunc(m.msgRouter), true
		case "presence":
			return xmpp.HandlerFunc(m.presenceRouter), true
		}
	}

	return nopHandler{}, false
}

// IQHandler returns the handler to use for an IQ payload with the given type
// and payload name.
// If no handler exists, a default handler is returned (h is always non-nil).
func (m *ServeMux) IQHandler(typ stanza.IQType, payload xml.Name) (h IQHandler, ok bool) {
	pattern := iqPattern{Payload: payload, Type: typ}
	h = m.iqPatterns[pattern]
	if h != nil {
		return h, true
	}

	pattern.Payload.Space = ""
	pattern.Payload.Local = payload.Local
	h = m.iqPatterns[pattern]
	if h != nil {
		return h, true
	}

	pattern.Payload.Space = payload.Space
	pattern.Payload.Local = ""
	h = m.iqPatterns[pattern]
	if h != nil {
		return h, true
	}

	pattern.Payload.Space = ""
	pattern.Payload.Local = ""
	h = m.iqPatterns[pattern]
	if h != nil {
		return h, true
	}

	return IQHandlerFunc(iqFallback), false
}

// MessageHandler returns the handler to use for a message payload with the
// given type.
// If no handler exists, a default handler is returned (h is always non-nil).
func (m *ServeMux) MessageHandler(typ stanza.MessageType) (h MessageHandler, ok bool) {
	h = m.msgPatterns[typ]
	if h != nil {
		return h, true
	}

	return nopHandler{}, false
}

// PresenceHandler returns the handler to use for a presence payload with the
// given type.
// If no handler exists, a default handler is returned (h is always non-nil).
func (m *ServeMux) PresenceHandler(typ stanza.PresenceType) (h PresenceHandler, ok bool) {
	h = m.presencePatterns[typ]
	if h != nil {
		return h, true
	}

	return nopHandler{}, false
}

// HandleXMPP dispatches the request to the handler whose pattern most closely
// matches start.Name.
// HandleXMPP dispatches the request to the handler that most closely matches.
func (m *ServeMux) HandleXMPP(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	h, _ := m.Handler(start.Name)
	return h.HandleXMPP(t, start)


@@ 86,41 152,72 @@ func (m *ServeMux) HandleXMPP(t xmlstream.TokenReadEncoder, start *xml.StartElem
// Option configures a ServeMux.
type Option func(m *ServeMux)

func registerStanza(local string, h xmpp.Handler) Option {
// IQ returns an option that matches IQ stanzas based on their type and the name
// of the payload.
func IQ(typ stanza.IQType, payload xml.Name, h IQHandler) Option {
	return func(m *ServeMux) {
		Handle(xml.Name{Local: local, Space: ns.Client}, h)(m)
		Handle(xml.Name{Local: local, Space: ns.Server}, h)(m)
		if h == nil {
			panic("mux: nil IQ handler")
		}
		pattern := iqPattern{Payload: payload, Type: typ}
		if _, ok := m.iqPatterns[pattern]; ok {
			panic("mux: multiple registrations for " + string(typ) + " iq with payload {" + pattern.Payload.Space + "}" + pattern.Payload.Local)
		}
		if m.iqPatterns == nil {
			m.iqPatterns = make(map[iqPattern]IQHandler)
		}
		m.iqPatterns[pattern] = h
	}
}

// IQ returns an option that matches on all IQ stanzas.
func IQ(h xmpp.Handler) Option {
	return registerStanza("iq", h)
}

// IQFunc returns an option that matches on all IQ stanzas.
func IQFunc(h xmpp.HandlerFunc) Option {
	return IQ(h)
// IQFunc returns an option that matches IQ stanzas.
// For more information see IQ.
func IQFunc(typ stanza.IQType, payload xml.Name, h IQHandler) Option {
	return IQ(typ, payload, h)
}

// Message returns an option that matches on all message stanzas.
func Message(h xmpp.Handler) Option {
	return registerStanza("message", h)
// Message returns an option that matches message stanzas by type.
func Message(typ stanza.MessageType, h MessageHandler) Option {
	return func(m *ServeMux) {
		if h == nil {
			panic("mux: nil message handler")
		}
		if _, ok := m.msgPatterns[typ]; ok {
			panic("mux: multiple registrations for " + typ + " message")
		}
		if m.msgPatterns == nil {
			m.msgPatterns = make(map[stanza.MessageType]MessageHandler)
		}
		m.msgPatterns[typ] = h
	}
}

// MessageFunc returns an option that matches on all message stanzas.
func MessageFunc(h xmpp.HandlerFunc) Option {
	return Message(h)
// MessageFunc returns an option that matches message stanzas.
// For more information see Message.
func MessageFunc(typ stanza.MessageType, h MessageHandlerFunc) Option {
	return Message(typ, h)
}

// Presence returns an option that matches on all presence stanzas.
func Presence(h xmpp.Handler) Option {
	return registerStanza("presence", h)
// Presence returns an option that matches presence stanzas by type.
func Presence(typ stanza.PresenceType, h PresenceHandler) Option {
	return func(m *ServeMux) {
		if h == nil {
			panic("mux: nil presence handler")
		}
		if _, ok := m.presencePatterns[typ]; ok {
			panic("mux: multiple registrations for " + typ + " presence")
		}
		if m.presencePatterns == nil {
			m.presencePatterns = make(map[stanza.PresenceType]PresenceHandler)
		}
		m.presencePatterns[typ] = h
	}
}

// PresenceFunc returns an option that matches on all presence stanzas.
func PresenceFunc(h xmpp.HandlerFunc) Option {
	return Presence(h)
// PresenceFunc returns an option that matches on presence stanzas.
// For more information see Presence.
func PresenceFunc(typ stanza.PresenceType, h PresenceHandlerFunc) Option {
	return Presence(typ, h)
}

// Handle returns an option that matches on the provided XML name.


@@ 145,3 242,189 @@ func Handle(n xml.Name, h xmpp.Handler) Option {
func HandleFunc(n xml.Name, h xmpp.HandlerFunc) Option {
	return Handle(n, h)
}

type nopHandler struct{}

func (nopHandler) HandleXMPP(t xmlstream.TokenReadEncoder, start *xml.StartElement) error { return nil }
func (nopHandler) HandleMessage(msg stanza.Message, t xmlstream.TokenReadEncoder) error   { return nil }
func (nopHandler) HandlePresence(p stanza.Presence, t xmlstream.TokenReadEncoder) error   { return nil }

func (m *ServeMux) iqRouter(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	iq, err := newIQFromStart(start)
	if err != nil {
		return err
	}

	tok, err := t.Token()
	if err != nil {
		return err
	}
	payloadStart, _ := tok.(xml.StartElement)
	h, _ := m.IQHandler(iq.Type, payloadStart.Name)
	return h.HandleIQ(iq, t, &payloadStart)
}

func (m *ServeMux) msgRouter(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	msg, err := newMsgFromStart(start)
	if err != nil {
		return err
	}

	h, _ := m.MessageHandler(msg.Type)
	return h.HandleMessage(msg, t)
}

func (m *ServeMux) presenceRouter(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	presence, err := newPresenceFromStart(start)
	if err != nil {
		return err
	}

	h, _ := m.PresenceHandler(presence.Type)
	return h.HandlePresence(presence, t)
}

func iqFallback(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	if iq.Type == stanza.ErrorIQ {
		return nil
	}

	iq.To, iq.From = iq.From, iq.To
	iq.Type = "error"

	e := stanza.Error{
		Type:      stanza.Cancel,
		Condition: stanza.ServiceUnavailable,
	}
	_, err := xmlstream.Copy(t, stanza.WrapIQ(
		iq,
		e.TokenReader(),
	))
	return err
}

// newIQFromStart takes a start element and returns an IQ.
func newIQFromStart(start *xml.StartElement) (stanza.IQ, error) {
	iq := stanza.IQ{}
	var err error
	for _, a := range start.Attr {
		switch a.Name.Local {
		case "id":
			if a.Name.Space != "" {
				continue
			}
			iq.ID = a.Value
		case "to":
			if a.Name.Space != "" {
				continue
			}
			iq.To, err = jid.Parse(a.Value)
			if err != nil {
				return iq, err
			}
		case "from":
			if a.Name.Space != "" {
				continue
			}
			iq.From, err = jid.Parse(a.Value)
			if err != nil {
				return iq, err
			}
		case "lang":
			if a.Name.Space != ns.XML {
				continue
			}
			iq.Lang = a.Value
		case "type":
			if a.Name.Space != "" {
				continue
			}
			iq.Type = stanza.IQType(a.Value)
		}
	}
	return iq, nil
}

// newMsgFromStart takes a start element and returns a message.
func newMsgFromStart(start *xml.StartElement) (stanza.Message, error) {
	msg := stanza.Message{}
	var err error
	for _, a := range start.Attr {
		switch a.Name.Local {
		case "id":
			if a.Name.Space != "" {
				continue
			}
			msg.ID = a.Value
		case "to":
			if a.Name.Space != "" {
				continue
			}
			msg.To, err = jid.Parse(a.Value)
			if err != nil {
				return msg, err
			}
		case "from":
			if a.Name.Space != "" {
				continue
			}
			msg.From, err = jid.Parse(a.Value)
			if err != nil {
				return msg, err
			}
		case "lang":
			if a.Name.Space != ns.XML {
				continue
			}
			msg.Lang = a.Value
		case "type":
			if a.Name.Space != "" {
				continue
			}
			msg.Type = stanza.MessageType(a.Value)
		}
	}
	return msg, nil
}

// newPresenceFromStart takes a start element and returns a message.
func newPresenceFromStart(start *xml.StartElement) (stanza.Presence, error) {
	presence := stanza.Presence{}
	var err error
	for _, a := range start.Attr {
		switch a.Name.Local {
		case "id":
			if a.Name.Space != "" {
				continue
			}
			presence.ID = a.Value
		case "to":
			if a.Name.Space != "" {
				continue
			}
			presence.To, err = jid.Parse(a.Value)
			if err != nil {
				return presence, err
			}
		case "from":
			if a.Name.Space != "" {
				continue
			}
			presence.From, err = jid.Parse(a.Value)
			if err != nil {
				return presence, err
			}
		case "lang":
			if a.Name.Space != ns.XML {
				continue
			}
			presence.Lang = a.Value
		case "type":
			if a.Name.Space != "" {
				continue
			}
			presence.Type = stanza.PresenceType(a.Value)
		}
	}
	return presence, nil
}

M mux/mux_test.go => mux/mux_test.go +216 -48
@@ 14,70 14,256 @@ import (
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/marshal"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/internal/xmpptest"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

var passTest = errors.New("mux_test: PASSED")
var (
	passTest = errors.New("mux_test: PASSED")
	failTest = errors.New("mux_test: FAILED")
)

const exampleNS = "com.example"

var passHandler xmpp.HandlerFunc = func(xmlstream.TokenReadEncoder, *xml.StartElement) error {
type passHandler struct{}

func (passHandler) HandleXMPP(xmlstream.TokenReadEncoder, *xml.StartElement) error   { return passTest }
func (passHandler) HandleMessage(stanza.Message, xmlstream.TokenReadEncoder) error   { return passTest }
func (passHandler) HandlePresence(stanza.Presence, xmlstream.TokenReadEncoder) error { return passTest }
func (passHandler) HandleIQ(stanza.IQ, xmlstream.TokenReadEncoder, *xml.StartElement) error {
	return passTest
}

var failHandler xmpp.HandlerFunc = func(xmlstream.TokenReadEncoder, *xml.StartElement) error {
	return errors.New("mux_test: FAILED")
type failHandler struct{}

func (failHandler) HandleXMPP(t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	return failTest
}
func (failHandler) HandleMessage(msg stanza.Message, t xmlstream.TokenReadEncoder) error {
	return failTest
}
func (failHandler) HandlePresence(p stanza.Presence, t xmlstream.TokenReadEncoder) error {
	return failTest
}
func (failHandler) HandleIQ(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	return failTest
}

var testCases = [...]struct {
	m *mux.ServeMux
	p xml.Name
	m           []mux.Option
	x           string
	expectNil   bool
	expectPanic bool
}{
	0: {
		m: mux.New(mux.IQ(passHandler), mux.Presence(failHandler)),
		p: xml.Name{Local: "iq", Space: ns.Client},
		// Basic muxing based on localname and IQ type should work.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{}, passHandler{}),
			mux.IQ(stanza.SetIQ, xml.Name{}, failHandler{}),
			mux.Presence(stanza.AvailablePresence, failHandler{}),
		},
		x: `<iq xml:lang="en-us" type="get" xmlns="jabber:client"></iq>`,
	},
	1: {
		m: mux.New(mux.IQFunc(passHandler), mux.Presence(failHandler)),
		p: xml.Name{Local: "iq", Space: ns.Server},
		// Basic muxing isn't affected by the server namespace.
		m: []mux.Option{
			mux.IQFunc(stanza.SetIQ, xml.Name{}, passHandler{}),
			mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{}),
			mux.Presence(stanza.AvailablePresence, failHandler{}),
		},
		x: `<iq type="set" xmlns="jabber:server"></iq>`,
	},
	2: {
		m: mux.New(mux.IQ(failHandler), mux.Message(passHandler)),
		p: xml.Name{Local: "message", Space: ns.Client},
		// The message option works with a client namespace.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{}),
			mux.Message(stanza.ChatMessage, passHandler{}),
		},
		x: `<message id="123" type="chat" xmlns="jabber:client"></message>`,
	},
	3: {
		m: mux.New(mux.IQ(failHandler), mux.MessageFunc(passHandler)),
		p: xml.Name{Local: "message", Space: ns.Server},
		// The message option works with a server namespace.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{}),
			mux.MessageFunc(stanza.ChatMessage, passHandler{}.HandleMessage),
		},
		x: `<message to="feste@example.net" from="olivia@example.net" type="chat" xmlns="jabber:server"></message>`,
	},
	4: {
		m: mux.New(mux.Message(failHandler), mux.IQ(failHandler), mux.Presence(passHandler)),
		p: xml.Name{Local: "presence", Space: ns.Client},
		// The presence option works with a client namespace and no type attribute.
		m: []mux.Option{
			mux.Message(stanza.HeadlineMessage, failHandler{}),
			mux.IQ(stanza.SetIQ, xml.Name{}, failHandler{}),
			mux.Presence(stanza.AvailablePresence, passHandler{}),
		},
		x: `<presence id="484" xml:lang="es" xmlns="jabber:client"></presence>`,
	},
	5: {
		m: mux.New(mux.Message(failHandler), mux.IQ(failHandler), mux.PresenceFunc(passHandler)),
		p: xml.Name{Local: "presence", Space: ns.Server},
		m: []mux.Option{
			// The presence option works with a server namespace and an empty type
			// attribute.
			mux.Message(stanza.ChatMessage, failHandler{}),
			mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{}),
			mux.PresenceFunc(stanza.AvailablePresence, passHandler{}.HandlePresence),
		},
		x: `<presence type="" xmlns="jabber:server"></presence>`,
	},
	6: {
		m: mux.New(mux.IQ(passHandler)),
		p: xml.Name{Local: "iq", Space: ns.Server},
		// Other top level elements can be routed with a wildcard namespace.
		m: []mux.Option{mux.Handle(xml.Name{Local: "test"}, passHandler{})},
		x: `<test xmlns="summertime"/>`,
	},
	7: {
		m: mux.New(mux.Handle(xml.Name{Local: "test"}, passHandler)),
		p: xml.Name{Local: "test", Space: "summertime"},
		// Other top level elements can be routed with a wildcard localname.
		m: []mux.Option{mux.HandleFunc(xml.Name{Space: "summertime"}, passHandler{}.HandleXMPP)},
		x: `<test xmlns="summertime"/>`,
	},
	8: {
		m: mux.New(mux.HandleFunc(xml.Name{Space: "summertime"}, passHandler)),
		p: xml.Name{Local: "test", Space: "summertime"},
		// Other top level elements can be routed with an exact match.
		m: []mux.Option{
			mux.Handle(xml.Name{Local: "test"}, failHandler{}),
			mux.HandleFunc(xml.Name{Space: "summertime"}, failHandler{}.HandleXMPP),
			mux.HandleFunc(xml.Name{Local: "test", Space: "summertime"}, passHandler{}.HandleXMPP),
		},
		x: `<test xmlns="summertime"/>`,
	},
	9: {
		// IQ exact child match handler.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{Local: "a", Space: exampleNS}, failHandler{}),
			mux.IQ(stanza.GetIQ, xml.Name{Local: "test", Space: "b"}, failHandler{}),
			mux.IQ(stanza.SetIQ, xml.Name{Local: "a", Space: exampleNS}, failHandler{}),
			mux.IQ(stanza.SetIQ, xml.Name{Local: "test", Space: "b"}, failHandler{}),
			mux.IQ(stanza.SetIQ, xml.Name{Local: "test", Space: exampleNS}, failHandler{}),
			mux.IQ(stanza.GetIQ, xml.Name{Local: "test", Space: exampleNS}, passHandler{}),
		},
		x: `<iq type="get" xmlns="jabber:client"><test xmlns="com.example"/></iq>`,
	},
	10: {
		// If no exact match is available, fallback to the namespace wildcard
		// handler.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{Local: "test", Space: ""}, passHandler{}),
			mux.IQ(stanza.GetIQ, xml.Name{Local: "", Space: exampleNS}, failHandler{}),
		},
		x: `<iq type="get" xmlns="jabber:client"><test xmlns="com.example"/></iq>`,
	},
	11: {
		// If no exact match or namespace handler is available, fallback local name
		// handler.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{Local: "", Space: exampleNS}, passHandler{}),
			mux.IQ(stanza.ResultIQ, xml.Name{Local: "", Space: exampleNS}, failHandler{}),
		},
		x: `<iq type="get" xmlns="jabber:client"><test xmlns="com.example"/></iq>`,
	},
	12: {
		// If no exact match or localname/namespace wildcard is available, fallback
		// to just matching on type alone.
		m: []mux.Option{
			mux.IQ(stanza.ResultIQ, xml.Name{Local: "test", Space: exampleNS}, failHandler{}),
			mux.IQ(stanza.ErrorIQ, xml.Name{}, passHandler{}),
		},
		x: `<iq type="error" xmlns="jabber:client"><test xmlns="com.example"/></iq>`,
	},
	13: {
		// Test nop non-stanza handler.
		x:         `<nop/>`,
		expectNil: true,
	},
	14: {
		// Test nop message handler.
		m:         []mux.Option{mux.Message(stanza.HeadlineMessage, failHandler{})},
		x:         `<message xml:lang="de" type="chat" xmlns="jabber:server"/>`,
		expectNil: true,
	},
	15: {
		// Test nop presence handler.
		m:         []mux.Option{mux.Presence(stanza.SubscribedPresence, failHandler{})},
		x:         `<presence to="romeo@example.net" from="mercutio@example.net" xmlns="jabber:server"/>`,
		expectNil: true,
	},
	16: {
		// Expect nil IQ handler to panic.
		m:           []mux.Option{mux.IQ(stanza.GetIQ, xml.Name{}, nil)},
		expectPanic: true,
	},
	17: {
		// Expect nil message handler to panic.
		m:           []mux.Option{mux.Message(stanza.ChatMessage, nil)},
		expectPanic: true,
	},
	18: {
		// Expect nil presence handler to panic.
		m:           []mux.Option{mux.Presence(stanza.ProbePresence, nil)},
		expectPanic: true,
	},
	19: {
		// Expect nil top level element handler to panic.
		m:           []mux.Option{mux.Handle(xml.Name{Local: "test"}, nil)},
		expectPanic: true,
	},
	20: {
		// Expect duplicate IQ handler to panic.
		m: []mux.Option{
			mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{}),
			mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{}),
		},
		expectPanic: true,
	},
	21: {
		// Expect duplicate message handler to panic.
		m: []mux.Option{
			mux.Message(stanza.ChatMessage, failHandler{}),
			mux.Message(stanza.ChatMessage, failHandler{}),
		},
		expectPanic: true,
	},
	22: {
		// Expect duplicate presence handler to panic.
		m: []mux.Option{
			mux.Presence(stanza.ProbePresence, failHandler{}),
			mux.Presence(stanza.ProbePresence, failHandler{}),
		},
		expectPanic: true,
	},
	23: {
		// Expect duplicate top level element handler to panic.
		m: []mux.Option{
			mux.Handle(xml.Name{Local: "test"}, failHandler{}),
			mux.Handle(xml.Name{Local: "test"}, failHandler{}),
		},
		expectPanic: true,
	},
}

type nopEncoder struct {
	xml.TokenReader
}

func (nopEncoder) Encode(interface{}) error                          { return nil }
func (nopEncoder) EncodeElement(interface{}, xml.StartElement) error { return nil }
func (nopEncoder) EncodeToken(xml.Token) error                       { return nil }

func TestMux(t *testing.T) {
	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			err := tc.m.HandleXMPP(nil, &xml.StartElement{Name: tc.p})
			if err != passTest {
			if tc.expectPanic {
				defer func() {
					if r := recover(); r == nil {
						t.Errorf("Expected panic")
					}
				}()
			}
			m := mux.New(tc.m...)
			d := xml.NewDecoder(strings.NewReader(tc.x))
			tok, _ := d.Token()
			start := tok.(xml.StartElement)

			err := m.HandleXMPP(nopEncoder{TokenReader: d}, &start)
			if (!tc.expectNil && err != passTest) || (tc.expectNil && err != nil) {
				t.Fatalf("unexpected error: `%v'", err)
			}
		})


@@ 102,7 288,7 @@ func TestFallback(t *testing.T) {
		io.Reader
		io.Writer
	}{
		Reader: strings.NewReader(`<iq to="romeo@example.com" from="juliet@example.com" id="123"><test/></iq>`),
		Reader: strings.NewReader(`<iq xmlns="jabber:client" to="romeo@example.com" from="juliet@example.com" id="123"><test/></iq>`),
		Writer: buf,
	}
	s := xmpptest.NewSession(0, rw)


@@ 133,27 319,9 @@ func TestFallback(t *testing.T) {
	}
}

func TestNilHandlerPanics(t *testing.T) {
	defer func() {
		if r := recover(); r == nil {
			t.Error("Expected a panic when trying to register a nil handler")
		}
	}()
	mux.New(mux.IQ(nil))
}

func TestIdenticalHandlerPanics(t *testing.T) {
	defer func() {
		if r := recover(); r == nil {
			t.Error("Expected a panic when trying to register a duplicate handler")
		}
	}()
	mux.New(mux.IQ(failHandler), mux.IQ(failHandler))
}

func TestLazyServeMuxMapInitialization(t *testing.T) {
	m := &mux.ServeMux{}

	// This will panic if the map isn't initialized lazily.
	mux.IQ(failHandler)(m)
	mux.IQ(stanza.GetIQ, xml.Name{}, failHandler{})(m)
}

A mux/stanza.go => mux/stanza.go +60 -0
@@ 0,0 1,60 @@
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package mux

import (
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/stanza"
)

// IQHandler responds to IQ stanzas.
type IQHandler interface {
	HandleIQ(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error
}

// The IQHandlerFunc type is an adapter to allow the use of ordinary functions
// as IQ handlers.
// If f is a function with the appropriate signature, IQHandlerFunc(f) is an
// IQHandler that calls f.
type IQHandlerFunc func(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error

// HandleIQ calls f(iq, t, start).
func (f IQHandlerFunc) HandleIQ(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	return f(iq, t, start)
}

// MessageHandler responds to message stanzas.
type MessageHandler interface {
	HandleMessage(msg stanza.Message, t xmlstream.TokenReadEncoder) error
}

// The MessageHandlerFunc type is an adapter to allow the use of ordinary
// functions as message handlers.
// If f is a function with the appropriate signature, MessageHandlerFunc(f) is a
// MessageHandler that calls f.
type MessageHandlerFunc func(msg stanza.Message, t xmlstream.TokenReadEncoder) error

// HandleMessage calls f(msg, t).
func (f MessageHandlerFunc) HandleMessage(msg stanza.Message, t xmlstream.TokenReadEncoder) error {
	return f(msg, t)
}

// PresenceHandler responds to message stanzas.
type PresenceHandler interface {
	HandlePresence(p stanza.Presence, t xmlstream.TokenReadEncoder) error
}

// The PresenceHandlerFunc type is an adapter to allow the use of ordinary
// functions as presence handlers.
// If f is a function with the appropriate signature, PresenceHandlerFunc(f) is
// a PresenceHandler that calls f.
type PresenceHandlerFunc func(p stanza.Presence, t xmlstream.TokenReadEncoder) error

// HandlePresence calls f(p, t).
func (f PresenceHandlerFunc) HandlePresence(p stanza.Presence, t xmlstream.TokenReadEncoder) error {
	return f(p, t)
}