~samwhited/xmpp

e0d114695950879ac895eafe7e469c8c2a5c7087 — Sam Whited 3 months ago 4b2af3b
mux: support responding to disco#items requests

Signed-off-by: Sam Whited <sam@samwhited.com>
5 files changed, 234 insertions(+), 10 deletions(-)

M CHANGELOG.md
M disco/handler.go
M disco/handler_test.go
M mux/mux.go
M mux/mux_test.go
M CHANGELOG.md => CHANGELOG.md +1 -1
@@ 30,7 30,7 @@ All notable changes to this project will be documented in this file.
- commands: new package implementing [XEP-0050: Ad-Hoc Commands]
- history: implement [XEP-0313: Message Archive Management]
- muc: new package implementing [XEP-0045: Multi-User Chat] and [XEP-0249: Direct MUC Invitations]
- mux: `mux.ServeMux` now implements `info.FeatureIter`
- mux: `mux.ServeMux` now implements `info.FeatureIter` and `items.Iter`
- roster: the roster `Iter` now returns the roster version being iterated over
  from the `Version` method
- roster: if a `stanza.Error` is returned from the `Push` handler it is now sent

M disco/handler.go => disco/handler.go +24 -8
@@ 9,17 9,19 @@ import (

	"mellium.im/xmlstream"
	"mellium.im/xmpp/disco/info"
	"mellium.im/xmpp/disco/items"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

// Handle returns an option that configures a multiplexer to handle service
// discovery requests by iterating over its own handlers and checking if they
// implement info.FeatureIter.
// implement info.FeatureIter or items.Iter.
func Handle() mux.Option {
	return func(m *mux.ServeMux) {
		h := &discoHandler{ServeMux: m}
		mux.IQ(stanza.GetIQ, xml.Name{Space: NSInfo, Local: "query"}, h)(m)
		mux.IQ(stanza.GetIQ, xml.Name{Space: NSItems, Local: "query"}, h)(m)
	}
}



@@ 35,15 37,15 @@ func (h *discoHandler) HandleIQ(iq stanza.IQ, r xmlstream.TokenReadEncoder, star
	seen := make(map[string]struct{})
	pr, pw := xmlstream.Pipe()
	go func() {
		var node string
		for _, attr := range start.Attr {
			if attr.Name.Local == "node" {
				node = attr.Value
				break
			}
		}
		switch start.Name.Space {
		case NSInfo:
			var node string
			for _, attr := range start.Attr {
				if attr.Name.Local == "node" {
					node = attr.Value
					break
				}
			}
			pw.CloseWithError(h.ServeMux.ForFeatures(node, func(f info.Feature) error {
				_, ok := seen[f.Var]
				if ok {


@@ 53,6 55,16 @@ func (h *discoHandler) HandleIQ(iq stanza.IQ, r xmlstream.TokenReadEncoder, star
				_, err := xmlstream.Copy(pw, f.TokenReader())
				return err
			}))
		case NSItems:
			pw.CloseWithError(h.ServeMux.ForItems(node, func(i items.Item) error {
				_, ok := seen[i.Node]
				if ok {
					return nil
				}
				seen[i.Node] = struct{}{}
				_, err := xmlstream.Copy(pw, i.TokenReader())
				return err
			}))
		}
	}()



@@ 62,3 74,7 @@ func (h *discoHandler) HandleIQ(iq stanza.IQ, r xmlstream.TokenReadEncoder, star
	)))
	return err
}

func (*discoHandler) ForItems(string, func(items.Item) error) error {
	return nil
}

M disco/handler_test.go => disco/handler_test.go +58 -1
@@ 6,15 6,18 @@ package disco_test

import (
	"context"
	"encoding/xml"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/disco"
	"mellium.im/xmpp/disco/items"
	"mellium.im/xmpp/internal/xmpptest"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

func TestRoundTrip(t *testing.T) {
func TestFeaturesRoundTrip(t *testing.T) {
	m := mux.New(disco.Handle())
	cs := xmpptest.NewClientServer(
		xmpptest.ServerHandler(m),


@@ 36,3 39,57 @@ func TestRoundTrip(t *testing.T) {
		t.Errorf("got unexpected features %v", info.Features)
	}
}

type itemHandler struct{}

func (itemHandler) HandleXMPP(xmlstream.TokenReadEncoder, *xml.StartElement) error {
	panic("should not be called")
}

func (itemHandler) ForItems(node string, f func(items.Item) error) error {
	if node != "" {
		return nil
	}

	return f(items.Item{
		Name: disco.NSItems,
	})
}

func TestItemsRoundTrip(t *testing.T) {
	m := mux.New(
		disco.Handle(),
		mux.Handle(xml.Name{}, itemHandler{}),
	)
	cs := xmpptest.NewClientServer(
		xmpptest.ServerHandler(m),
	)

	iter := disco.FetchItemsIQ(context.Background(), "", stanza.IQ{ID: "123"}, cs.Client)
	allItems := []items.Item{}
	for iter.Next() {
		allItems = append(allItems, iter.Item())
	}
	if err := iter.Err(); err != nil {
		t.Fatalf("error iterating over items: %v", err)
	}
	if len(allItems) != 1 || allItems[0].Name != disco.NSItems {
		t.Errorf("wrong items: want=%s, got=%v", disco.NSItems, allItems)
	}
	err := iter.Close()
	if err != nil {
		t.Fatalf("error closing iterator: %v", err)
	}

	iter = disco.FetchItemsIQ(context.Background(), "node", stanza.IQ{ID: "123"}, cs.Client)
	for iter.Next() {
		t.Fatalf("error, got item %v but did not expect any", iter.Item())
	}
	if err := iter.Err(); err != nil {
		t.Fatalf("error iterating over empty items: %v", err)
	}
	err = iter.Close()
	if err != nil {
		t.Fatalf("error closing empty iter: %v", err)
	}
}

M mux/mux.go => mux/mux.go +38 -0
@@ 18,6 18,7 @@ import (
	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/disco/info"
	"mellium.im/xmpp/disco/items"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/stanza"
)


@@ 210,6 211,43 @@ func (m *ServeMux) HandleXMPP(t xmlstream.TokenReadEncoder, start *xml.StartElem
	return h.HandleXMPP(t, start)
}

// ForItems implements items.Iter for the mux by iterating over all child items.
func (m *ServeMux) ForItems(node string, f func(items.Item) error) error {
	for _, h := range m.patterns {
		if itemIter, ok := h.(items.Iter); ok {
			err := itemIter.ForItems(node, f)
			if err != nil {
				return err
			}
		}
	}
	for _, h := range m.iqPatterns {
		if itemIter, ok := h.(items.Iter); ok {
			err := itemIter.ForItems(node, f)
			if err != nil {
				return err
			}
		}
	}
	for _, h := range m.msgPatterns {
		if itemIter, ok := h.(items.Iter); ok {
			err := itemIter.ForItems(node, f)
			if err != nil {
				return err
			}
		}
	}
	for _, h := range m.presencePatterns {
		if itemIter, ok := h.(items.Iter); ok {
			err := itemIter.ForItems(node, f)
			if err != nil {
				return err
			}
		}
	}
	return nil
}

// ForFeatures implements info.FeatureIter for the mux by iterating over all
// child features.
func (m *ServeMux) ForFeatures(node string, f func(info.Feature) error) error {

M mux/mux_test.go => mux/mux_test.go +113 -0
@@ 16,6 16,7 @@ import (

	"mellium.im/xmlstream"
	"mellium.im/xmpp/disco/info"
	"mellium.im/xmpp/disco/items"
	"mellium.im/xmpp/internal/marshal"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/internal/xmpptest"


@@ 600,6 601,12 @@ func (handleFeature) ForFeatures(node string, f func(info.Feature) error) error 
	})
}

func (handleFeature) ForItems(node string, f func(items.Item) error) error {
	return f(items.Item{
		Name: testFeature,
	})
}

type iqFeature struct{}

func (iqFeature) HandleIQ(stanza.IQ, xmlstream.TokenReadEncoder, *xml.StartElement) error {


@@ 612,6 619,12 @@ func (iqFeature) ForFeatures(node string, f func(info.Feature) error) error {
	})
}

func (iqFeature) ForItems(node string, f func(items.Item) error) error {
	return f(items.Item{
		Name: iqTestFeature,
	})
}

type messageFeature struct{}

func (messageFeature) HandleMessage(stanza.Message, xmlstream.TokenReadEncoder) error {


@@ 624,6 637,12 @@ func (messageFeature) ForFeatures(node string, f func(info.Feature) error) error
	})
}

func (messageFeature) ForItems(node string, f func(items.Item) error) error {
	return f(items.Item{
		Name: msgTestFeature,
	})
}

type presenceFeature struct{}

func (presenceFeature) HandlePresence(stanza.Presence, xmlstream.TokenReadEncoder) error {


@@ 636,6 655,12 @@ func (presenceFeature) ForFeatures(node string, f func(info.Feature) error) erro
	})
}

func (presenceFeature) ForItems(node string, f func(items.Item) error) error {
	return f(items.Item{
		Name: presenceTestFeature,
	})
}

func TestFeatures(t *testing.T) {
	m := mux.New(
		mux.Handle(xml.Name{}, handleFeature{}),


@@ 679,6 704,49 @@ func TestFeatures(t *testing.T) {
	}
}

func TestItems(t *testing.T) {
	m := mux.New(
		mux.Handle(xml.Name{}, handleFeature{}),
		mux.IQ("", xml.Name{}, iqFeature{}),
		mux.Message("", xml.Name{}, messageFeature{}),
		mux.Presence("", xml.Name{}, presenceFeature{}),
	)
	var (
		foundHandler  bool
		foundIQ       bool
		foundPresence bool
		foundMsg      bool
	)
	err := m.ForItems("", func(i items.Item) error {
		switch i.Name {
		case testFeature:
			foundHandler = true
		case iqTestFeature:
			foundIQ = true
		case msgTestFeature:
			foundMsg = true
		case presenceTestFeature:
			foundPresence = true
		}
		return nil
	})
	if err != nil {
		t.Fatalf("unexpected error while iterating over features: %v", err)
	}
	if !foundHandler {
		t.Errorf("items iter did not find plain handler item")
	}
	if !foundIQ {
		t.Errorf("items iter did not find IQ item")
	}
	if !foundMsg {
		t.Errorf("items iter did not find message item")
	}
	if !foundPresence {
		t.Errorf("items iter did not find presence item")
	}
}

func TestFeaturesHandlerErr(t *testing.T) {
	m := mux.New(
		mux.Handle(xml.Name{}, handleFeature{}),


@@ 724,3 792,48 @@ func TestFeaturesHandlePresenceErr(t *testing.T) {
		t.Fatalf("wrong error: want=%v, got=%v", io.EOF, err)
	}
}

func TestItemsHandlerErr(t *testing.T) {
	m := mux.New(
		mux.Handle(xml.Name{}, handleFeature{}),
	)
	err := m.ForItems("", func(i items.Item) error {
		return io.EOF
	})
	if err != io.EOF {
		t.Fatalf("wrong error: want=%v, got=%v", io.EOF, err)
	}
}
func TestItemsHandleIQErr(t *testing.T) {
	m := mux.New(
		mux.IQ("", xml.Name{}, iqFeature{}),
	)
	err := m.ForItems("", func(i items.Item) error {
		return io.EOF
	})
	if err != io.EOF {
		t.Fatalf("wrong error: want=%v, got=%v", io.EOF, err)
	}
}
func TestItemsHandleMsgErr(t *testing.T) {
	m := mux.New(
		mux.Message("", xml.Name{}, messageFeature{}),
	)
	err := m.ForItems("", func(i items.Item) error {
		return io.EOF
	})
	if err != io.EOF {
		t.Fatalf("wrong error: want=%v, got=%v", io.EOF, err)
	}
}
func TestItemsHandlePresenceErr(t *testing.T) {
	m := mux.New(
		mux.Presence("", xml.Name{}, presenceFeature{}),
	)
	err := m.ForItems("", func(i items.Item) error {
		return io.EOF
	})
	if err != io.EOF {
		t.Fatalf("wrong error: want=%v, got=%v", io.EOF, err)
	}
}