~samwhited/xmpp

2f55994bc8e59bc8e211713c1b2b403c9f767fa2 — Sam Whited 5 months ago b885e24
blocklist: new package implementing XEP-0191

Fixes #139

Signed-off-by: Sam Whited <sam@samwhited.com>
4 files changed, 449 insertions(+), 0 deletions(-)

M CHANGELOG.md
A blocklist/blocking.go
A blocklist/blocking_test.go
A blocklist/handler.go
M CHANGELOG.md => CHANGELOG.md +2 -0
@@ 15,6 15,7 @@ All notable changes to this project will be documented in this file.

### Added

- blocklist: new package implementing [XEP-0191: Blocking Command]
- commands: new package implementing [XEP-0050: Ad-Hoc Commands]

### Fixed


@@ 29,6 30,7 @@ All notable changes to this project will be documented in this file.
- xmpp: empty IQ iters no longer return EOF when there is no payload


[XEP-0191: Blocking Command]: https://xmpp.org/extensions/xep-0191.html
[XEP-0050: Ad-Hoc Commands]: https://xmpp.org/extensions/xep-0050.html



A blocklist/blocking.go => blocklist/blocking.go +156 -0
@@ 0,0 1,156 @@
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package blocking implements blocking and unblocking of contacts.
package blocklist // import "mellium.im/xmpp/blocklist"

import (
	"context"
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"
)

// NS is the namespace used by this package, provided as a convenience.
const NS = `urn:xmpp:blocklist`

// Match checks j1 aginst a JID in the blocklist (j2) and returns true if they
// are a match.
//
// The JID matches the blocklist JID if any of the following compare to the
// blocklist JID (falling back in this order):
//
//  - Full JID (user@domain/resource)
//  - Bare JID (user@domain)
//  - Full domain (domain/resource)
//  - Bare domain
func Match(j1, j2 jid.JID) bool {
	return j1.Equal(j2) ||
		j1.Bare().Equal(j2) ||
		jid.NewUnsafe("", j1.Domainpart(), j1.Resourcepart()).JID.Equal(j2) ||
		j1.Domain().Equal(j2)
}

// Iter is an iterator over blocklist JIDs.
type Iter struct {
	iter    *xmlstream.Iter
	current jid.JID
	err     error
}

// Next returns true if there are more items to decode.
func (i *Iter) Next() bool {
	if i.err != nil || !i.iter.Next() {
		return false
	}
	start, _ := i.iter.Current()
	// If we encounter a lone token that doesn't begin with a start element (eg.
	// a comment) skip it. This should never happen with XMPP, but we don't want
	// to panic in case this somehow happens so just skip it.
	if start == nil {
		return i.Next()
	}
	for _, attr := range start.Attr {
		if attr.Name.Local == "jid" {
			i.current, i.err = jid.Parse(attr.Value)
			break
		}
	}
	return true
}

// Err returns the last error encountered by the iterator (if any).
func (i *Iter) Err() error {
	if i.err != nil {
		return i.err
	}

	return i.iter.Err()
}

// JID returns the last blocked JID parsed by the iterator.
func (i *Iter) JID() jid.JID {
	return i.current
}

// Close indicates that we are finished with the given iterator and processing
// the stream may continue.
// Calling it multiple times has no effect.
func (i *Iter) Close() error {
	if i.iter == nil {
		return nil
	}
	return i.iter.Close()
}

// Fetch sends a request to the JID asking for the blocklist.
func Fetch(ctx context.Context, s *xmpp.Session) *Iter {
	return FetchIQ(ctx, stanza.IQ{}, s)
}

// FetchIQ is like Fetch except that it lets you customize the IQ.
// Changing the type of the provided IQ has no effect.
func FetchIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session) *Iter {
	if iq.Type != stanza.GetIQ {
		iq.Type = stanza.GetIQ
	}
	iter, err := s.IterIQ(ctx, iq.Wrap(xmlstream.Wrap(nil, xml.StartElement{
		Name: xml.Name{Space: NS, Local: "blocklist"},
	})))
	if err != nil {
		return &Iter{err: err}
	}
	return &Iter{
		iter: iter,
	}
}

// Add adds JIDs to the blocklist.
func Add(ctx context.Context, s *xmpp.Session, j ...jid.JID) error {
	return AddIQ(ctx, stanza.IQ{}, s, j...)
}

// AddIQ is like Add except that it lets you customize the IQ.
// Changing the type of the provided IQ has no effect.
func AddIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session, j ...jid.JID) error {
	return doIQ(ctx, "block", iq, s, j...)
}

// Remove removes JIDs from the blocklist.
// If no JIDs are provided the entire blocklist is cleared.
func Remove(ctx context.Context, s *xmpp.Session, j ...jid.JID) error {
	return RemoveIQ(ctx, stanza.IQ{}, s, j...)
}

// RemoveIQ is like Remove except that it lets you customize the IQ.
// Changing the type of the provided IQ has no effect.
func RemoveIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session, j ...jid.JID) error {
	return doIQ(ctx, "unblock", iq, s, j...)
}

func doIQ(ctx context.Context, local string, iq stanza.IQ, s *xmpp.Session, j ...jid.JID) error {
	if iq.Type != stanza.SetIQ {
		iq.Type = stanza.SetIQ
	}
	var jids []xml.TokenReader
	for _, jj := range j {
		jids = append(jids, xmlstream.Wrap(nil, xml.StartElement{
			Name: xml.Name{Local: "item"},
			Attr: []xml.Attr{{Name: xml.Name{Local: "jid"}, Value: jj.String()}},
		}))
	}
	r, err := s.SendIQ(ctx, iq.Wrap(xmlstream.Wrap(
		xmlstream.MultiReader(jids...),
		xml.StartElement{
			Name: xml.Name{Space: NS, Local: local},
		},
	)))
	if err != nil {
		return err
	}
	return r.Close()
}

A blocklist/blocking_test.go => blocklist/blocking_test.go +201 -0
@@ 0,0 1,201 @@
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package blocklist_test

import (
	"context"
	"encoding/xml"
	"errors"
	"reflect"
	"strconv"
	"strings"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/blocklist"
	"mellium.im/xmpp/internal/xmpptest"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

var matchTestCases = [...]struct {
	j1, j2 jid.JID
	result bool
}{
	// Full JID match
	0: {j1: jid.MustParse("user@domain/resource"), j2: jid.MustParse("user@domain/resource"), result: true},
	1: {j1: jid.MustParse("otheruser@domain/resource"), j2: jid.MustParse("user@domain/resource"), result: false},
	2: {j1: jid.MustParse("user@otherdomain/resource"), j2: jid.MustParse("user@domain/resource"), result: false},
	3: {j1: jid.MustParse("user@domain/otherresource"), j2: jid.MustParse("user@domain/resource"), result: false},
	4: {j1: jid.MustParse("otherdomain/resource"), j2: jid.MustParse("user@domain/resource"), result: false},
	5: {j1: jid.MustParse("user@domain"), j2: jid.MustParse("user@domain/resource"), result: false},
	6: {j1: jid.MustParse("domain"), j2: jid.MustParse("user@domain/resource"), result: false},

	// Bare JID match
	7:  {j1: jid.MustParse("user@domain"), j2: jid.MustParse("user@domain"), result: true},
	8:  {j1: jid.MustParse("user@domain/res"), j2: jid.MustParse("user@domain"), result: true},
	9:  {j1: jid.MustParse("domain"), j2: jid.MustParse("user@domain"), result: false},
	10: {j1: jid.MustParse("domain/res"), j2: jid.MustParse("user@domain"), result: false},
	11: {j1: jid.MustParse("otheruser@domain"), j2: jid.MustParse("user@domain"), result: false},

	// Full domain match
	12: {j1: jid.MustParse("domain/resource"), j2: jid.MustParse("domain/resource"), result: true},
	13: {j1: jid.MustParse("user@domain/resource"), j2: jid.MustParse("domain/resource"), result: true},
	14: {j1: jid.MustParse("domain"), j2: jid.MustParse("domain/resource"), result: false},
	15: {j1: jid.MustParse("user@domain"), j2: jid.MustParse("domain/resource"), result: false},
	16: {j1: jid.MustParse("otherdomain/resource"), j2: jid.MustParse("domain/resource"), result: false},
	17: {j1: jid.MustParse("domain/otherresource"), j2: jid.MustParse("domain/resource"), result: false},

	// Bare domain match
	18: {j1: jid.MustParse("domain"), j2: jid.MustParse("domain"), result: true},
	19: {j1: jid.MustParse("domain/resource"), j2: jid.MustParse("domain"), result: true},
	20: {j1: jid.MustParse("user@domain"), j2: jid.MustParse("domain"), result: true},
	21: {j1: jid.MustParse("user@domain/resource"), j2: jid.MustParse("domain"), result: true},
	22: {j1: jid.MustParse("otherdomain"), j2: jid.MustParse("domain"), result: false},
	23: {j1: jid.MustParse("user@otherdomain"), j2: jid.MustParse("domain"), result: false},
	24: {j1: jid.MustParse("user@otherdomain/res"), j2: jid.MustParse("domain"), result: false},
	25: {j1: jid.MustParse("otherdomain/res"), j2: jid.MustParse("domain"), result: false},
}

func TestMatch(t *testing.T) {
	for i, tc := range matchTestCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			result := blocklist.Match(tc.j1, tc.j2)
			if tc.result != result {
				t.Errorf("unexpected result: got=%t, want=%t", result, tc.result)
			}
		})
	}
}

var testCases = [...]struct {
	items []jid.JID
	err   error
}{
	0: {},
	1: {
		items: []jid.JID{
			jid.MustParse("juliet@example.com"),
			jid.MustParse("benvolio@example.org"),
		},
	},
}

func TestFetch(t *testing.T) {
	var IQ = stanza.IQ{ID: "123"}
	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			var list []jid.JID
			h := blocklist.Handler{
				Block: func(j jid.JID) {
					list = append(list, j)
				},
				Unblock: func(j jid.JID) {
					b := list[:0]
					for _, x := range list {
						if !j.Equal(x) {
							b = append(b, x)
						}
					}
					list = b
				},
				UnblockAll: func() {
					list = list[:0]
				},
				List: func(j chan<- jid.JID) {
					for _, jj := range list {
						j <- jj
					}
				},
			}
			m := mux.New(blocklist.Handle(h))
			cs := xmpptest.NewClientServer(xmpptest.ServerHandler(m))

			err := blocklist.AddIQ(context.Background(), IQ, cs.Client, tc.items...)
			if err != nil {
				t.Fatalf("error setting the blocklist: %v", err)
			}

			iter := blocklist.FetchIQ(context.Background(), IQ, cs.Client)
			items := make([]jid.JID, 0, len(tc.items))
			for iter.Next() {
				items = append(items, iter.JID())
			}
			if err := iter.Err(); err != tc.err {
				t.Errorf("wrong error after iter: want=%v, got=%v", tc.err, err)
			}
			iter.Close()

			// Don't try to compare nil and empty slice with DeepEqual
			if len(tc.items) == 0 {
				tc.items = make([]jid.JID, 0)
			}

			if !reflect.DeepEqual(items, tc.items) {
				t.Errorf("wrong items:\nwant=\n%+v,\ngot=\n%+v", tc.items, items)
			}

			// Test removing one item.
			if len(tc.items) > 0 {
				err = blocklist.RemoveIQ(context.Background(), IQ, cs.Client, tc.items[0])
				if err != nil {
					t.Errorf("error removing first blocklist item: %v", err)
				}
				if !reflect.DeepEqual(list, tc.items[1:]) {
					t.Errorf("wrong items after removing %s:\nwant=\n%+v,\ngot=\n%+v", tc.items[0], tc.items[1:], list)
				}
			}

			// Test removing all items.
			err = blocklist.RemoveIQ(context.Background(), IQ, cs.Client)
			if err != nil {
				t.Errorf("error removing remaining blocklist items: %v", err)
			}
			if len(list) > 0 {
				t.Errorf("failed to remove remaining items")
			}
		})
	}
}

func TestFetchNoStart(t *testing.T) {
	cs := xmpptest.NewClientServer(
		xmpptest.ServerHandlerFunc(func(e xmlstream.TokenReadEncoder, start *xml.StartElement) error {
			const resp = `<iq id="123" type="result"><blocklist xmlns='urn:xmpp:blocklist'><!-- comment --></blocklist></iq>`
			_, err := xmlstream.Copy(e, xml.NewDecoder(strings.NewReader(resp)))
			return err
		}),
	)
	iter := blocklist.FetchIQ(context.Background(), stanza.IQ{ID: "123"}, cs.Client)
	for iter.Next() {
		// Just iterate
	}
	if err := iter.Err(); err != nil {
		t.Errorf("Wrong error after iter: want=nil, got=%q", err)
	}
	iter.Close()
}

type errReadWriter struct{}

func (errReadWriter) Write([]byte) (int, error) {
	return 0, errors.New("called Write on errReadWriter")
}

func (errReadWriter) Read([]byte) (int, error) {
	return 0, errors.New("called Read on errReadWriter")
}

func TestErroredDoesNotPanic(t *testing.T) {
	s := xmpptest.NewSession(0, errReadWriter{})
	iter := blocklist.Fetch(context.Background(), s)
	if iter.Next() {
		t.Errorf("expected false from call to next")
	}
	if err := iter.Close(); err != nil {
		t.Errorf("got unexpected error closing iter: %v", err)
	}
}

A blocklist/handler.go => blocklist/handler.go +90 -0
@@ 0,0 1,90 @@
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package blocklist

import (
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

// Handle returns an option that registers the given handler on the mux for the
// various blocking command payloads.
func Handle(h Handler) mux.Option {
	return func(m *mux.ServeMux) {
		mux.IQ(stanza.GetIQ, xml.Name{Space: NS, Local: "blocklist"}, h)(m)
		mux.IQ(stanza.SetIQ, xml.Name{Space: NS, Local: "block"}, h)(m)
		mux.IQ(stanza.SetIQ, xml.Name{Space: NS, Local: "unblock"}, h)(m)
	}
}

// Handler can be used to respond to incoming blocking command requests.
type Handler struct {
	Block      func(jid.JID)
	Unblock    func(jid.JID)
	UnblockAll func()
	List       func(chan<- jid.JID)
}

// HandleIQ implements mux.IQHandler.
func (h Handler) HandleIQ(iq stanza.IQ, r xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	if start.Name.Local == "blocklist" {
		res := iq.Result(xmlstream.Wrap(nil, *start))
		// Copy the start IQ and start payload first.
		_, err := xmlstream.Copy(r, xmlstream.LimitReader(res, 2))
		if err != nil {
			return err
		}
		if h.List != nil {
			c := make(chan jid.JID)
			go func() {
				h.List(c)
				close(c)
			}()
			for j := range c {
				xmlstream.Copy(r, xmlstream.Wrap(nil, xml.StartElement{
					Name: xml.Name{Space: NS, Local: "item"},
					Attr: []xml.Attr{{
						Name:  xml.Name{Local: "jid"},
						Value: j.String(),
					}},
				}))
			}
		}
		// Copy the end payload and end IQ.
		_, err = xmlstream.Copy(r, xmlstream.LimitReader(res, 2))
		return err
	}

	iter := xmlstream.NewIter(r)
	var found bool
	for iter.Next() {
		found = true
		itemStart, _ := iter.Current()
		jstr := itemStart.Attr[0].Value
		j := jid.MustParse(jstr)
		switch start.Name.Local {
		case "block":
			if h.Block != nil {
				h.Block(j)
			}
		case "unblock":
			if h.Unblock != nil {
				h.Unblock(j)
			}
		}
	}
	err := iter.Err()
	if err != nil {
		return err
	}
	if !found && start.Name.Local == "unblock" && h.UnblockAll != nil {
		h.UnblockAll()
	}
	return nil
}