~samwhited/xmpp

b175c3b4857ee29910b20d2acb57e092fd86450c — Sam Whited 11 months ago 79e4455
receipts: add struct and stream APIs

Previously we had to block waiting on response if we wanted to request
read receipts. However, this is more likely to be used asynchronously
where we associate the receipt with the original message using the ID.
To faciliate this, one type and one function was added: one for the
struct based API and one for the stream based API.

Fixes #89

Signed-off-by: Sam Whited <sam@samwhited.com>
5 files changed, 180 insertions(+), 6 deletions(-)

M CHANGELOG.md
M go.mod
M go.sum
M receipts/receipts.go
M receipts/receipts_test.go
M CHANGELOG.md => CHANGELOG.md +2 -0
@@ 19,6 19,8 @@ All notable changes to this project will be documented in this file.
  connections
- oob: implementations of `xmlstream.Marshaler` and `xmlstream.WriterTo` for the
  types `IQ`, `Query`, and Data
- receipts: add `Request` and `Requested` to add receipt requests to messages
  without waiting on a response
- roster: add `Set` and `Delete` functions for roster management
- stream: new `InnerXML` and `ApplicationError` methods on `Error` provide a way
  to easily construct customized stream errors

M go.mod => go.mod +1 -1
@@ 8,5 8,5 @@ require (
	golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7
	golang.org/x/text v0.3.2
	mellium.im/sasl v0.2.1
	mellium.im/xmlstream v0.15.2-0.20201217125941-a994c65b8415
	mellium.im/xmlstream v0.15.2-0.20201219131358-a51cc5cf8151
)

M go.sum => go.sum +2 -2
@@ 17,5 17,5 @@ mellium.im/reader v0.1.0 h1:UUEMev16gdvaxxZC7fC08j7IzuDKh310nB6BlwnxTww=
mellium.im/reader v0.1.0/go.mod h1:F+X5HXpkIfJ9EE1zHQG9lM/hO946iYAmU7xjg5dsQHI=
mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w=
mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ=
mellium.im/xmlstream v0.15.2-0.20201217125941-a994c65b8415 h1:MUDmwN48+7kNNTH8X02YES5p8WyQxD3DeEeUHXFuaS0=
mellium.im/xmlstream v0.15.2-0.20201217125941-a994c65b8415/go.mod h1:7SUlP7f2qnMczK+Cu/OFgqaIhldMolVjo8np7xG41D0=
mellium.im/xmlstream v0.15.2-0.20201219131358-a51cc5cf8151 h1:wAvC0xKy3hCct2M7MizuqiNBH2VXtvVN9bmCrMsklPk=
mellium.im/xmlstream v0.15.2-0.20201219131358-a51cc5cf8151/go.mod h1:7SUlP7f2qnMczK+Cu/OFgqaIhldMolVjo8np7xG41D0=

M receipts/receipts.go => receipts/receipts.go +63 -3
@@ 25,6 25,68 @@ const (
	NS = "urn:xmpp:receipts"
)

// Requested is a type that can be added to messages to request a read receipt.
// When unmarshaled or marshaled its value indicates whether it was or will be
// present in the message.
//
// This type is used to manually include a request in a message struct.
// To send a message and wait for the receipt see the methods on Handler.
type Requested struct {
	XMLName xml.Name `xml:"urn:xmpp:receipts request"`
	Value   bool
}

// TokenReader implements xmlstream.Marshaler.
func (r Requested) TokenReader() xml.TokenReader {
	return xmlstream.Wrap(
		nil,
		xml.StartElement{Name: xml.Name{Space: NS, Local: "request"}},
	)
}

// WriteXML implements xmlstream.WriterTo.
func (r Requested) WriteXML(w xmlstream.TokenWriter) (int, error) {
	return xmlstream.Copy(w, r.TokenReader())
}

// MarshalXML implements xml.Marshaler.
func (r Requested) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
	_, err := r.WriteXML(e)
	if err != nil {
		return err
	}
	return e.Flush()
}

// UnmarshalXML implements xml.Unmarshaler.
func (r *Requested) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
	r.Value = start.Name.Space == NS && start.Name.Local == "request"
	return d.Skip()
}

var receiptInserter = xmlstream.InsertFunc(func(start xml.StartElement, w xmlstream.TokenWriter) error {
	if start.Name.Local != "message" || (start.Name.Space != ns.Client && start.Name.Space != ns.Server) {
		return nil
	}
	for _, attr := range start.Attr {
		if attr.Name.Local == "type" && attr.Value == "error" {
			return nil
		}
	}

	_, err := xmlstream.Copy(w, Requested{Value: true}.TokenReader())
	return err
})

// Request is an xmlstream.Transformer that inserts a request for a read receipt
// into any message read through r.
// It is provided to allow easily requesting read receipts asynchronously.
// To send a message and block waiting on a read receipt, see the methods on
// Handler.
func Request(r xml.TokenReader) xml.TokenReader {
	return receiptInserter(r)
}

// Handle returns an option that registers a Handler for message receipts.
func Handle(h *Handler) mux.Option {
	return func(m *mux.ServeMux) {


@@ 150,9 212,7 @@ func (h *Handler) SendMessageElement(ctx context.Context, s *xmpp.Session, paylo
	h.sent[msg.ID] = c
	h.m.Unlock()

	r := xmlstream.Wrap(nil, xml.StartElement{
		Name: xml.Name{Space: NS, Local: "request"},
	})
	r := Requested{Value: true}.TokenReader()
	if payload != nil {
		r = xmlstream.MultiReader(payload, r)
	}

M receipts/receipts_test.go => receipts/receipts_test.go +112 -0
@@ 8,6 8,7 @@ import (
	"bytes"
	"context"
	"encoding/xml"
	"strconv"
	"strings"
	"testing"



@@ 19,6 20,117 @@ import (
	"mellium.im/xmpp/stanza"
)

var (
	_ xml.Marshaler         = receipts.Requested{}
	_ xmlstream.Marshaler   = receipts.Requested{}
	_ xmlstream.WriterTo    = receipts.Requested{}
	_ xml.Unmarshaler       = (*receipts.Requested)(nil)
	_ xmlstream.Transformer = receipts.Request
)

var requestTestCases = [...]struct {
	in  string
	out string
}{
	0: {},
	1: {
		in:  `<message xmlns="jabber:client"/>`,
		out: `<message xmlns="jabber:client"><request xmlns="urn:xmpp:receipts"></request></message>`,
	},
	2: {
		in:  `<message xmlns="jabber:server"/><message xmlns="jabber:client"><body>test</body></message>`,
		out: `<message xmlns="jabber:server"><request xmlns="urn:xmpp:receipts"></request></message><message xmlns="jabber:client"><request xmlns="urn:xmpp:receipts"></request><body xmlns="jabber:client">test</body></message>`,
	},
	3: {
		in:  `<message xmlns="jabber:badns"/>`,
		out: `<message xmlns="jabber:badns"></message>`,
	},
	4: {
		in:  `<message xmlns="jabber:client" type="error"/>`,
		out: `<message xmlns="jabber:client" type="error"></message>`,
	},
	5: {
		in:  `<message xmlns="jabber:server" type="error"/>`,
		out: `<message xmlns="jabber:server" type="error"></message>`,
	},
}

func TestRequest(t *testing.T) {
	for i, tc := range requestTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			r := receipts.Request(xml.NewDecoder(strings.NewReader(tc.in)))
			// Prevent duplicate xmlns attributes. See https://mellium.im/issue/75
			r = xmlstream.RemoveAttr(func(start xml.StartElement, attr xml.Attr) bool {
				return (start.Name.Local == "message" || start.Name.Local == "iq") && attr.Name.Local == "xmlns"
			})(r)
			var buf strings.Builder
			e := xml.NewEncoder(&buf)
			_, err := xmlstream.Copy(e, r)
			if err != nil {
				t.Fatalf("error encoding: %v", err)
			}
			if err = e.Flush(); err != nil {
				t.Fatalf("error flushing: %v", err)
			}

			if out := buf.String(); tc.out != out {
				t.Errorf("wrong output:\nwant=%s,\n got=%s", tc.out, out)
			}
		})
	}
}

func TestMarshal(t *testing.T) {
	var buf strings.Builder
	e := xml.NewEncoder(&buf)
	err := e.Encode(struct {
		stanza.Message

		Requested receipts.Requested
	}{})
	if err != nil {
		t.Fatalf("error encoding: %v", err)
	}
	if err = e.Flush(); err != nil {
		t.Fatalf("error flushing: %v", err)
	}

	const expected = `<message to="" from=""><request xmlns="urn:xmpp:receipts"></request></message>`
	if out := buf.String(); expected != out {
		t.Errorf("wrong output:\nwant=%s,\n got=%s", expected, out)
	}
}

var unmarshalTestCases = [...]struct {
	in  string
	out bool
}{
	0: {
		in:  `<message><request xmlns="urn:xmpp:receipts"/></message>`,
		out: true,
	},
	1: {in: `<message><wrong xmlns="urn:xmpp:receipts"/></message>`},
	2: {in: `<message><request xmlns="urn:xmpp:wrongns"/></message>`},
}

func TestUnmarshal(t *testing.T) {
	for i, tc := range unmarshalTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			m := struct {
				stanza.Message
				Requested receipts.Requested
			}{}
			err := xml.NewDecoder(strings.NewReader(tc.in)).Decode(&m)
			if err != nil {
				t.Errorf("error decoding: %v", err)
			}
			if m.Requested.Value != tc.out {
				t.Errorf("bad decode: want=%t, got=%t", tc.out, m.Requested.Value)
			}
		})
	}
}

func TestClosedDoesNotPanic(t *testing.T) {
	h := &receipts.Handler{}