~samwhited/xmpp

6b2ccd65ab85c3ba839eb4863e60555dde13a95f — Sam Whited 9 months ago c1f6e40
stanza: support unique and stable stanza IDs

Fixes #111

Signed-off-by: Sam Whited <sam@samwhited.com>
5 files changed, 248 insertions(+), 3 deletions(-)

M CHANGELOG.md
M go.mod
M go.sum
A stanza/id.go
A stanza/id_test.go
M CHANGELOG.md => CHANGELOG.md +8 -0
@@ 2,6 2,14 @@

All notable changes to this project will be documented in this file.

## Unreleased

### Added

- stanza: new functions `AddID` and `AddOriginID` to support unique and stable
  stanza IDs


## v0.18.0 — 2021-02-14

### Breaking

M go.mod => go.mod +1 -1
@@ 9,5 9,5 @@ require (
	golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061
	golang.org/x/text v0.3.2
	mellium.im/sasl v0.2.1
	mellium.im/xmlstream v0.15.2-0.20201219131358-a51cc5cf8151
	mellium.im/xmlstream v0.15.2
)

M go.sum => go.sum +2 -2
@@ 19,5 19,5 @@ mellium.im/reader v0.1.0 h1:UUEMev16gdvaxxZC7fC08j7IzuDKh310nB6BlwnxTww=
mellium.im/reader v0.1.0/go.mod h1:F+X5HXpkIfJ9EE1zHQG9lM/hO946iYAmU7xjg5dsQHI=
mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w=
mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ=
mellium.im/xmlstream v0.15.2-0.20201219131358-a51cc5cf8151 h1:wAvC0xKy3hCct2M7MizuqiNBH2VXtvVN9bmCrMsklPk=
mellium.im/xmlstream v0.15.2-0.20201219131358-a51cc5cf8151/go.mod h1:7SUlP7f2qnMczK+Cu/OFgqaIhldMolVjo8np7xG41D0=
mellium.im/xmlstream v0.15.2 h1:RleOK10lEsVtzpEZsJeRl4Iu0iC5SQnTQIGJZ7ZHGEc=
mellium.im/xmlstream v0.15.2/go.mod h1:7SUlP7f2qnMczK+Cu/OFgqaIhldMolVjo8np7xG41D0=

A stanza/id.go => stanza/id.go +102 -0
@@ 0,0 1,102 @@
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package stanza

import (
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/attr"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
)

// Namespaces used by this package, provided as a convenience.
const (
	// The namespace for unique and stable stanza and origin IDs.
	NSSid = "urn:xmpp:sid:0"
)

const idLen = 32

type ID struct {
	XMLName xml.Name `xml:"urn:xmpp:sid:0 stanza-id"`
	ID      string   `xml:"id,attr"`
	By      jid.JID  `xml:"by,attr"`
}

// TokenReader implements xmlstream.Marshaler.
func (id ID) TokenReader() xml.TokenReader {
	return xmlstream.Wrap(nil, xml.StartElement{
		Name: xml.Name{Space: NSSid, Local: "stanza-id"},
		Attr: []xml.Attr{
			{Name: xml.Name{Local: "id"}, Value: id.ID},
			{Name: xml.Name{Local: "by"}, Value: id.By.String()},
		},
	})
}

// WriteXML implements xmlstream.WriterTo.
func (id ID) WriteXML(w xmlstream.TokenWriter) (int, error) {
	return xmlstream.Copy(w, id.TokenReader())
}

type OriginID struct {
	XMLName xml.Name `xml:"urn:xmpp:sid:0 origin-id"`
	ID      string   `xml:"id,attr"`
}

// TokenReader implements xmlstream.Marshaler.
func (id OriginID) TokenReader() xml.TokenReader {
	return xmlstream.Wrap(nil, xml.StartElement{
		Name: xml.Name{Space: NSSid, Local: "origin-id"},
		Attr: []xml.Attr{
			{Name: xml.Name{Local: "id"}, Value: id.ID},
		},
	})
}

// WriteXML implements xmlstream.WriterTo.
func (id OriginID) WriteXML(w xmlstream.TokenWriter) (int, error) {
	return xmlstream.Copy(w, id.TokenReader())
}

func isStanza(name xml.Name) bool {
	return (name.Local == "iq" || name.Local == "message" || name.Local == "presence") &&
		(name.Space == ns.Client || name.Space == ns.Server)
}

// AddID returns an transformer that adds a random stanza ID to any stanzas that
// does not already have one.
func AddID(by jid.JID) xmlstream.Transformer {
	return xmlstream.InsertFunc(func(start xml.StartElement, level uint64, w xmlstream.TokenWriter) error {
		if isStanza(start.Name) && level == 1 {
			_, err := ID{
				ID: attr.RandomLen(idLen),
				By: by,
			}.WriteXML(w)
			return err
		}
		return nil
	})
}

var (
	addOriginID = xmlstream.InsertFunc(func(start xml.StartElement, level uint64, w xmlstream.TokenWriter) error {
		if isStanza(start.Name) && level == 1 {
			_, err := OriginID{
				ID: attr.RandomLen(idLen),
			}.WriteXML(w)
			return err
		}
		return nil
	})
)

// AddOriginID is an xmlstream.Transformer that adds an origin ID to any stanzas
// found in the input stream.
func AddOriginID(r xml.TokenReader) xml.TokenReader {
	return addOriginID(r)
}

A stanza/id_test.go => stanza/id_test.go +135 -0
@@ 0,0 1,135 @@
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package stanza_test

import (
	"encoding/xml"
	"regexp"
	"strconv"
	"strings"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"
)

const (
	testOrigin = `<origin-id xmlns="urn:xmpp:sid:0" id="abc"></origin-id>`
	testStanza = `<stanza-id xmlns="urn:xmpp:sid:0" id="abc" by="test@example.net"></stanza-id>`
)

var idTestCases = [...]struct {
	in     string
	origin string
	id     string
}{
	0: {
		in:     `<message xmlns="jabber:client"></message>`,
		origin: `<message xmlns="jabber:client">` + testOrigin + `</message>`,
		id:     `<message xmlns="jabber:client">` + testStanza + `</message>`,
	},
	1: {
		in:     `<iq xmlns="jabber:client"></iq>`,
		origin: `<iq xmlns="jabber:client">` + testOrigin + `</iq>`,
		id:     `<iq xmlns="jabber:client">` + testStanza + `</iq>`,
	},
	2: {
		in:     `<presence xmlns="jabber:client"></presence>`,
		origin: `<presence xmlns="jabber:client">` + testOrigin + `</presence>`,
		id:     `<presence xmlns="jabber:client">` + testStanza + `</presence>`,
	},
	3: {
		in:     `<message xmlns="jabber:server"></message>`,
		origin: `<message xmlns="jabber:server">` + testOrigin + `</message>`,
		id:     `<message xmlns="jabber:server">` + testStanza + `</message>`,
	},
	4: {
		in:     `<iq xmlns="jabber:server"></iq>`,
		origin: `<iq xmlns="jabber:server">` + testOrigin + `</iq>`,
		id:     `<iq xmlns="jabber:server">` + testStanza + `</iq>`,
	},
	5: {
		in:     `<presence xmlns="jabber:server"></presence>`,
		origin: `<presence xmlns="jabber:server">` + testOrigin + `</presence>`,
		id:     `<presence xmlns="jabber:server">` + testStanza + `</presence>`,
	},
	6: {
		in:     `<not-stanza><message xmlns="jabber:client"></message></not-stanza>`,
		origin: `<not-stanza><message xmlns="jabber:client"></message></not-stanza>`,
		id:     `<not-stanza><message xmlns="jabber:client"></message></not-stanza>`,
	},
	7: {
		in:     `<not-stanza><iq xmlns="jabber:client"></iq></not-stanza>`,
		origin: `<not-stanza><iq xmlns="jabber:client"></iq></not-stanza>`,
		id:     `<not-stanza><iq xmlns="jabber:client"></iq></not-stanza>`,
	},
	8: {
		in:     `<not-stanza><presence xmlns="jabber:client"></presence></not-stanza>`,
		origin: `<not-stanza><presence xmlns="jabber:client"></presence></not-stanza>`,
		id:     `<not-stanza><presence xmlns="jabber:client"></presence></not-stanza>`,
	},
	9: {
		in:     `<presence xmlns="jabber:badns"></presence>`,
		origin: `<presence xmlns="jabber:badns"></presence>`,
		id:     `<presence xmlns="jabber:badns"></presence>`,
	},
}

func TestAddID(t *testing.T) {
	idReplacer := regexp.MustCompile(`id="(.*?)"`)

	by := jid.MustParse("test@example.net")
	addID := stanza.AddID(by)

	for i, tc := range idTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			t.Run("origin", func(t *testing.T) {
				r := stanza.AddOriginID(xml.NewDecoder(strings.NewReader(tc.in)))
				// Prevent duplicate xmlns attributes. See https://mellium.im/issue/75
				r = xmlstream.RemoveAttr(func(start xml.StartElement, attr xml.Attr) bool {
					return attr.Name.Local == "xmlns"
				})(r)
				var buf strings.Builder
				e := xml.NewEncoder(&buf)
				_, err := xmlstream.Copy(e, r)
				if err != nil {
					t.Fatalf("error copying xml stream: %v", err)
				}
				if err = e.Flush(); err != nil {
					t.Fatalf("error flushing stream: %v", err)
				}
				out := buf.String()
				// We need this to be testable, not random.
				out = idReplacer.ReplaceAllString(out, `id="abc"`)
				if out != tc.origin {
					t.Errorf("wrong output:\nwant=%v,\n got=%v", tc.origin, out)
				}
			})
			t.Run("stanza", func(t *testing.T) {
				r := addID(xml.NewDecoder(strings.NewReader(tc.in)))
				// Prevent duplicate xmlns attributes. See https://mellium.im/issue/75
				r = xmlstream.RemoveAttr(func(start xml.StartElement, attr xml.Attr) bool {
					return attr.Name.Local == "xmlns"
				})(r)
				var buf strings.Builder
				e := xml.NewEncoder(&buf)
				_, err := xmlstream.Copy(e, r)
				if err != nil {
					t.Fatalf("error copying xml stream: %v", err)
				}
				if err = e.Flush(); err != nil {
					t.Fatalf("error flushing stream: %v", err)
				}
				out := buf.String()
				// We need this to be testable, not random.
				out = idReplacer.ReplaceAllString(out, `id="abc"`)
				if out != tc.id {
					t.Errorf("wrong output:\nwant=%v,\n got=%v", tc.id, out)
				}
			})
		})
	}
}