~samwhited/xmpp

d8bb9e72dd6794bd2c83a847a327eb685b5a114b — Sam Whited 1 year, 2 months ago 75e32aa
all: move stanza wrapping to methods

Previously to wrap a payload in a stanza you would use the functions
WrapIQ, WrapMessage, and WrapPresence. Each of these took their
respective stanza types and a payload.
These have been moved to Wrap methods on the various stanza types that
take a payload to make them easier to use in handlers where you already
have the stanza.
The down side is that these methods now exist on types that embed a
stanza, which may be confusing since the payload will be ignored and
only the stanza will be used.

Signed-off-by: Sam Whited <sam@samwhited.com>
M CHANGELOG.md => CHANGELOG.md +6 -0
@@ 4,6 4,12 @@ All notable changes to this project will be documented in this file.

## Unreleased

### Breaking

- mux: move `Wrap{IQ,Presence,Message}` functions to methods on the stanza types
- mux: new handler types and API


### Added

- mux: ability to select handlers by stanza payload

M bind.go => bind.go +2 -2
@@ 52,12 52,12 @@ type bindIQ struct {

func (biq *bindIQ) TokenReader() xml.TokenReader {
	if biq.Err != nil {
		return stanza.WrapIQ(biq.IQ, xmlstream.Wrap(biq.Err.TokenReader(),
		return biq.Wrap(xmlstream.Wrap(biq.Err.TokenReader(),
			xml.StartElement{Name: xml.Name{Local: "bind", Space: ns.Bind}},
		))
	}

	return stanza.WrapIQ(biq.IQ, xmlstream.Wrap(biq.Bind.TokenReader(),
	return biq.Wrap(xmlstream.Wrap(biq.Bind.TokenReader(),
		xml.StartElement{Name: xml.Name{Local: "bind", Space: ns.Bind}},
	))
}

M echobot_example_test.go => echobot_example_test.go +1 -1
@@ 56,7 56,7 @@ func Example_echobot() {
	}()

	// Send initial presence to let the server know we want to receive messages.
	err = s.Send(context.TODO(), stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))
	err = s.Send(context.TODO(), stanza.Presence{Type: stanza.AvailablePresence}.Wrap(nil))
	if err != nil {
		log.Printf("Error sending initial presence: %q", err)
		return

M examples/echobot/echo.go => examples/echobot/echo.go +1 -1
@@ 71,7 71,7 @@ func echo(ctx context.Context, addr, pass string, xmlIn, xmlOut io.Writer, logge
	}()

	// Send initial presence to let the server know we want to receive messages.
	err = s.Send(ctx, stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))
	err = s.Send(ctx, stanza.Presence{Type: stanza.AvailablePresence}.Wrap(nil))
	if err != nil {
		return fmt.Errorf("Error sending initial presence: %w", err)
	}

M examples/msgrepl/main.go => examples/msgrepl/main.go +1 -1
@@ 87,7 87,7 @@ func main() {
	}()

	// Send initial presence to let the server know we want to receive messages.
	err = session.Send(ctx, stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))
	err = session.Send(ctx, stanza.Presence{Type: stanza.AvailablePresence}.Wrap(nil))
	if err != nil {
		logger.Fatalf("Error sending initial presence: %w", err)
	}

M mux/mux.go => mux/mux.go +1 -1
@@ 497,6 497,6 @@ func iqFallback(iq stanza.IQ, t xmlstream.TokenReadEncoder, start *xml.StartElem
		Type:      stanza.Cancel,
		Condition: stanza.ServiceUnavailable,
	}
	_, err := xmlstream.Copy(t, stanza.WrapIQ(iq, e.TokenReader()))
	_, err := xmlstream.Copy(t, iq.Wrap(e.TokenReader()))
	return err
}

M mux/mux_test.go => mux/mux_test.go +1 -1
@@ 518,7 518,7 @@ func TestFallback(t *testing.T) {
		t.Errorf("Unexpected error flushing token writer: %q", err)
	}

	const expected = `<iq type="error" to="juliet@example.com" from="romeo@example.com" id="123"><error type="cancel"><service-unavailable xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></service-unavailable></error></iq>`
	const expected = `<iq xmlns="jabber:client" type="error" to="juliet@example.com" from="romeo@example.com" id="123"><error type="cancel"><service-unavailable xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></service-unavailable></error></iq>`
	if buf.String() != expected {
		t.Errorf("Bad output:\nwant=`%v'\n got=`%v'", expected, buf.String())
	}

M ping/ping.go => ping/ping.go +1 -4
@@ 84,8 84,5 @@ func (iq IQ) WriteXML(w xmlstream.TokenWriter) (n int, err error) {
// TokenReader satisfies the xmlstream.Marshaler interface.
func (iq IQ) TokenReader() xml.TokenReader {
	start := xml.StartElement{Name: xml.Name{Local: "ping", Space: NS}}
	return stanza.WrapIQ(
		iq.IQ,
		xmlstream.Wrap(nil, start),
	)
	return iq.Wrap(xmlstream.Wrap(nil, start))
}

M ping/ping_test.go => ping/ping_test.go +1 -1
@@ 104,7 104,7 @@ func TestRoundTrip(t *testing.T) {
	out := b.String()
	// TODO: figure out a better way to ignore randomly generated IDs.
	out = regexp.MustCompile(`id=".*?"`).ReplaceAllString(out, `id="123"`)
	const expected = `<iq type="result" from="to@example.net" id="123"></iq>`
	const expected = `<iq xmlns="jabber:client" type="result" from="to@example.net" id="123"></iq>`
	if out != expected {
		t.Errorf("got=%s, want=%s", out, expected)
	}

M roster/roster.go => roster/roster.go +1 -1
@@ 158,7 158,7 @@ func (iq IQ) TokenReader() xml.TokenReader {
		iq.IQ.Type = stanza.GetIQ
	}

	return stanza.WrapIQ(iq.IQ, iq.payload())
	return iq.IQ.Wrap(iq.payload())
}

// Payload returns a stream of XML tokekns that match the roster query payload

M send_test.go => send_test.go +7 -7
@@ 59,12 59,12 @@ var sendIQTests = [...]struct {
	0: {
		iq:         stanza.IQ{ID: testIQID, Type: stanza.GetIQ},
		writesBody: true,
		resp:       stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ResultIQ}, nil),
		resp:       stanza.IQ{ID: testIQID, Type: stanza.ResultIQ}.Wrap(nil),
	},
	1: {
		iq:         stanza.IQ{ID: testIQID, Type: stanza.SetIQ},
		writesBody: true,
		resp:       stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}, nil),
		resp:       stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}.Wrap(nil),
	},
	2: {
		iq:         stanza.IQ{Type: stanza.ResultIQ, ID: testIQID},


@@ 168,7 168,7 @@ func TestSendIQ(t *testing.T) {
				ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
				defer cancel()

				resp, err := s.SendIQ(ctx, stanza.WrapIQ(tc.iq, tc.payload))
				resp, err := s.SendIQ(ctx, tc.iq.Wrap(tc.payload))
				if err != tc.err {
					t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
				}


@@ 215,19 215,19 @@ var sendTests = [...]struct {
		err: xmpp.ErrNotStart,
	},
	2: {
		r:          stanza.WrapMessage(to, stanza.NormalMessage, nil),
		r:          stanza.Message{To: to, Type: stanza.NormalMessage}.Wrap(nil),
		writesBody: true,
	},
	3: {
		r:          stanza.WrapPresence(to, stanza.AvailablePresence, nil),
		r:          stanza.Presence{To: to, Type: stanza.AvailablePresence}.Wrap(nil),
		writesBody: true,
	},
	4: {
		r:          stanza.WrapIQ(stanza.IQ{Type: stanza.ResultIQ}, nil),
		r:          stanza.IQ{Type: stanza.ResultIQ}.Wrap(nil),
		writesBody: true,
	},
	5: {
		r:          stanza.WrapIQ(stanza.IQ{Type: stanza.ErrorIQ}, nil),
		r:          stanza.IQ{Type: stanza.ErrorIQ}.Wrap(nil),
		writesBody: true,
	},
}

M session.go => session.go +4 -4
@@ 410,10 410,10 @@ noreply:

	// If the user did not write a response to an IQ, send a default one.
	if needsResp && !rw.wroteResp {
		_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
		_, err := xmlstream.Copy(w, stanza.IQ{
			ID:   id,
			Type: stanza.ErrorIQ,
		}, stanza.Error{
		}.Wrap(stanza.Error{
			Type:      stanza.Cancel,
			Condition: stanza.ServiceUnavailable,
		}.TokenReader()))


@@ 794,12 794,12 @@ func (s *Session) SendIQElement(ctx context.Context, payload xml.TokenReader, iq

	// If this an IQ of type "set" or "get" we expect a response.
	if needsResp {
		return s.sendResp(ctx, iq.ID, stanza.WrapIQ(iq, payload))
		return s.sendResp(ctx, iq.ID, iq.Wrap(payload))
	}

	// If this is an IQ of type result or error, we don't expect a response so
	// just send it normally.
	return nil, s.Send(ctx, stanza.WrapIQ(iq, payload))
	return nil, s.Send(ctx, iq.Wrap(payload))
}

func (s *Session) sendResp(ctx context.Context, id string, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) {

M session_test.go => session_test.go +8 -8
@@ 162,10 162,10 @@ var serveTests = [...]struct {
	},
	3: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
			_, err := xmlstream.Copy(rw, stanza.IQ{
				ID:   "1234",
				Type: stanza.ResultIQ,
			}, nil))
			}.Wrap(nil))
			return err
		}),
		in:  `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,


@@ 173,10 173,10 @@ var serveTests = [...]struct {
	},
	4: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
			_, err := xmlstream.Copy(rw, stanza.IQ{
				ID:   "wrongid",
				Type: stanza.ResultIQ,
			}, nil))
			}.Wrap(nil))
			return err
		}),
		in:  `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,


@@ 184,10 184,10 @@ var serveTests = [...]struct {
	},
	5: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
			_, err := xmlstream.Copy(rw, stanza.IQ{
				ID:   "1234",
				Type: stanza.ErrorIQ,
			}, nil))
			}.Wrap(nil))
			return err
		}),
		in:  `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,


@@ 195,10 195,10 @@ var serveTests = [...]struct {
	},
	6: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
			_, err := xmlstream.Copy(rw, stanza.IQ{
				ID:   "1234",
				Type: stanza.GetIQ,
			}, nil))
			}.Wrap(nil))
			return err
		}),
		in:  `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,

M stanza/example_pingstream_test.go => stanza/example_pingstream_test.go +1 -1
@@ 18,7 18,7 @@ import (
// a ping payload.
func WrapPingIQ(to jid.JID) xml.TokenReader {
	start := xml.StartElement{Name: xml.Name{Local: "ping", Space: "urn:xmpp:ping"}}
	return stanza.WrapIQ(stanza.IQ{To: to, Type: stanza.GetIQ}, xmlstream.Wrap(nil, start))
	return stanza.IQ{To: to, Type: stanza.GetIQ}.Wrap(xmlstream.Wrap(nil, start))
}

func Example_stream() {

M stanza/iq.go => stanza/iq.go +15 -46
@@ 12,40 12,6 @@ import (
	"mellium.im/xmpp/jid"
)

// WrapIQ wraps a payload in an IQ stanza.
// The resulting IQ may not contain an id or from attribute and thus may not be
// valid without further processing.
func WrapIQ(iq IQ, payload xml.TokenReader) xml.TokenReader {
	attr := []xml.Attr{
		{Name: xml.Name{Local: "type"}, Value: string(iq.Type)},
	}

	if !iq.To.Equal(jid.JID{}) {
		to, err := iq.To.MarshalXMLAttr(xml.Name{Space: "", Local: "to"})
		if err == nil && to.Value != "" {
			attr = append(attr, to)
		}
	}
	if !iq.From.Equal(jid.JID{}) {
		from, err := iq.From.MarshalXMLAttr(xml.Name{Space: "", Local: "from"})
		if err == nil && from.Value != "" {
			attr = append(attr, from)
		}
	}

	if iq.Lang != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "lang", Space: ns.XML}, Value: iq.Lang})
	}
	if iq.ID != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: iq.ID})
	}

	return xmlstream.Wrap(payload, xml.StartElement{
		Name: xml.Name{Local: "iq"},
		Attr: attr,
	})
}

// IQ ("Information Query") is used as a general request response mechanism.
// IQ's are one-to-one, provide get and set semantics, and always require a
// response in the form of a result or an error.


@@ 74,21 40,19 @@ func (iq IQ) StartElement() xml.StartElement {
	name.Local = "iq"

	attr := make([]xml.Attr, 0, 5)
	if iq.ID != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: iq.ID})
	}
	attr = append(attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: string(iq.Type)})
	if !iq.To.Equal(jid.JID{}) {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "to"}, Value: iq.To.String()})
	}
	if !iq.From.Equal(jid.JID{}) {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "from"}, Value: iq.From.String()})
	}
	if iq.ID != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: iq.ID})
	}
	if iq.Lang != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Space: ns.XML, Local: "lang"}, Value: iq.Lang})
	}
	if iq.Type != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: string(iq.Type)})
	}

	return xml.StartElement{
		Name: name,


@@ 96,16 60,21 @@ func (iq IQ) StartElement() xml.StartElement {
	}
}

// Wrap wraps the payload in a stanza.
//
// The resulting IQ may not contain an id or from attribute and thus may not be
// valid without further processing.
func (iq IQ) Wrap(payload xml.TokenReader) xml.TokenReader {
	return xmlstream.Wrap(payload, iq.StartElement())
}

// Result returns a token reader that wraps the first element from payload in an
// IQ stanza with the to and from attributes switched and the type set to
// ResultIQ.
func (iq IQ) Result(payload xml.TokenReader) xml.TokenReader {
	return WrapIQ(IQ{
		ID:   iq.ID,
		To:   iq.From,
		From: iq.To,
		Type: ResultIQ,
	}, payload)
	iq.Type = ResultIQ
	iq.From, iq.To = iq.To, iq.From
	return iq.Wrap(payload)
}

// IQType is the type of an IQ stanza.

M stanza/iq_test.go => stanza/iq_test.go +1 -1
@@ 48,7 48,7 @@ func TestIQ(t *testing.T) {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			b := new(bytes.Buffer)
			e := xml.NewEncoder(b)
			iq := stanza.WrapIQ(stanza.IQ{To: jid.MustParse(tc.to), Type: tc.typ}, tc.payload)
			iq := stanza.IQ{To: jid.MustParse(tc.to), Type: tc.typ}.Wrap(tc.payload)
			if _, err := xmlstream.Copy(e, iq); err != tc.err {
				t.Errorf("Unexpected error: want=`%v', got=`%v'", tc.err, err)
			}

M stanza/message.go => stanza/message.go +6 -14
@@ 12,17 12,6 @@ import (
	"mellium.im/xmpp/jid"
)

// WrapMessage wraps a payload in a message stanza.
func WrapMessage(to jid.JID, typ MessageType, payload xml.TokenReader) xml.TokenReader {
	return xmlstream.Wrap(payload, xml.StartElement{
		Name: xml.Name{Local: "message"},
		Attr: []xml.Attr{
			{Name: xml.Name{Local: "to"}, Value: to.String()},
			{Name: xml.Name{Local: "type"}, Value: string(typ)},
		},
	})
}

// Message is an XMPP stanza that contains a payload for direct one-to-one
// communication with another network entity. It is often used for sending chat
// messages to an individual or group chat server, or for notifications and


@@ 52,6 41,7 @@ func (msg Message) StartElement() xml.StartElement {
	name.Local = "message"

	attr := make([]xml.Attr, 0, 5)
	attr = append(attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: string(msg.Type)})
	if msg.ID != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: msg.ID})
	}


@@ 64,9 54,6 @@ func (msg Message) StartElement() xml.StartElement {
	if msg.Lang != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Space: ns.XML, Local: "lang"}, Value: msg.Lang})
	}
	if msg.Type != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: string(msg.Type)})
	}

	return xml.StartElement{
		Name: name,


@@ 74,6 61,11 @@ func (msg Message) StartElement() xml.StartElement {
	}
}

// Wrap wraps the payload in a stanza.
func (msg Message) Wrap(payload xml.TokenReader) xml.TokenReader {
	return xmlstream.Wrap(payload, msg.StartElement())
}

// MessageType is the type of a message stanza.
// It should normally be one of the constants defined in this package.
type MessageType string

M stanza/presence.go => stanza/presence.go +8 -18
@@ 12,24 12,6 @@ import (
	"mellium.im/xmpp/jid"
)

// WrapPresence wraps a payload in a presence stanza.
//
// If to is the zero value for jid.JID, no to attribute is set on the resulting
// presence.
func WrapPresence(to jid.JID, typ PresenceType, payload xml.TokenReader) xml.TokenReader {
	attrs := make([]xml.Attr, 0, 2)
	if !to.Equal(jid.JID{}) {
		attrs = append(attrs, xml.Attr{Name: xml.Name{Local: "to"}, Value: to.String()})
	}
	if typ != AvailablePresence {
		attrs = append(attrs, xml.Attr{Name: xml.Name{Local: "type"}, Value: string(typ)})
	}
	return xmlstream.Wrap(payload, xml.StartElement{
		Name: xml.Name{Local: "presence"},
		Attr: attrs,
	})
}

// Presence is an XMPP stanza that is used as an indication that an entity is
// available for communication. It is used to set a status message, broadcast
// availability, and advertise entity capabilities. It can be directed


@@ 81,6 63,14 @@ func (p Presence) StartElement() xml.StartElement {
	}
}

// Wrap wraps the payload in a stanza.
//
// If to is the zero value for jid.JID, no to attribute is set on the resulting
// presence.
func (p Presence) Wrap(payload xml.TokenReader) xml.TokenReader {
	return xmlstream.Wrap(payload, p.StartElement())
}

// PresenceType is the type of a presence stanza.
// It should normally be one of the constants defined in this package.
type PresenceType string

M stanza/presence_test.go => stanza/presence_test.go +1 -1
@@ 56,7 56,7 @@ func TestWrapPresence(t *testing.T) {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			buf := &bytes.Buffer{}
			e := xml.NewEncoder(buf)
			presence := stanza.WrapPresence(tc.to, tc.typ, tc.payload)
			presence := stanza.Presence{To: tc.to, Type: tc.typ}.Wrap(tc.payload)
			_, err := xmlstream.Copy(e, presence)
			if err != nil {
				t.Fatalf("Error encoding stream: %q", err)

M stanza/stanza_test.go => stanza/stanza_test.go +1 -1
@@ 58,7 58,7 @@ func TestMessage(t *testing.T) {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			b := new(bytes.Buffer)
			e := xml.NewEncoder(b)
			message := stanza.WrapMessage(jid.MustParse(tc.to), tc.typ, tc.payload)
			message := stanza.Message{To: jid.MustParse(tc.to), Type: tc.typ}.Wrap(tc.payload)
			if _, err := xmlstream.Copy(e, message); err != tc.err {
				t.Errorf("Unexpected error: want=`%v', got=`%v'", tc.err, err)
			}

M xtime/time.go => xtime/time.go +2 -2
@@ 91,10 91,10 @@ func (t *Time) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {

// Get sends a request to the provided JID asking for its time.
func Get(ctx context.Context, s *xmpp.Session, to jid.JID) (time.Time, error) {
	result, err := s.SendIQ(ctx, stanza.WrapIQ(stanza.IQ{
	result, err := s.SendIQ(ctx, stanza.IQ{
		Type: stanza.GetIQ,
		To:   to,
	}, xmlstream.Wrap(nil, xml.StartElement{Name: xml.Name{Local: "time", Space: NS}})))
	}.Wrap(xmlstream.Wrap(nil, xml.StartElement{Name: xml.Name{Local: "time", Space: NS}})))
	var t time.Time
	if err != nil {
		return t, err

M xtime/time_test.go => xtime/time_test.go +1 -1
@@ 75,7 75,7 @@ func TestRoundTrip(t *testing.T) {
	out := b.String()
	// TODO: figure out a better way to ignore randomly generated IDs.
	out = regexp.MustCompile(`id=".*?"`).ReplaceAllString(out, `id="123"`)
	const expected = `<iq type="result" from="to@example.net" id="123"><time xmlns="urn:xmpp:time"><tzo>Z</tzo><utc>0001-01-01T00:00:00Z</utc></time></iq>`
	const expected = `<iq xmlns="jabber:server" type="result" from="to@example.net" id="123"><time xmlns="urn:xmpp:time"><tzo>Z</tzo><utc>0001-01-01T00:00:00Z</utc></time></iq>`
	if out != expected {
		t.Errorf("got=%s, want=%s", out, expected)
	}