// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.
package disco
import (
"context"
"encoding/xml"
"errors"
"mellium.im/xmlstream"
"mellium.im/xmpp"
"mellium.im/xmpp/jid"
"mellium.im/xmpp/paging"
"mellium.im/xmpp/stanza"
)
const (
defPageSize = 32
)
// ItemsQuery is the payload of a query for a node's items.
type ItemsQuery struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/disco#items query"`
Node string `xml:"node,attr,omitempty"`
}
// TokenReader implements xmlstream.Marshaler.
func (q ItemsQuery) TokenReader() xml.TokenReader {
start := xml.StartElement{Name: xml.Name{Space: NSItems, Local: "query"}}
if q.Node != "" {
start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "node"}, Value: q.Node})
}
return xmlstream.Wrap(nil, start)
}
// WriteXML implements xmlstream.WriterTo.
func (q ItemsQuery) WriteXML(w xmlstream.TokenWriter) (int, error) {
return xmlstream.Copy(w, q.TokenReader())
}
// Item represents a discovered item.
type Item struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/disco#items item"`
JID jid.JID `xml:"jid,attr"`
Name string `xml:"name,attr,omitempty"`
Node string `xml:"node,attr,omitempty"`
}
// TokenReader implements xmlstream.Marshaler.
func (i Item) TokenReader() xml.TokenReader {
start := xml.StartElement{
Name: xml.Name{Space: NSItems, Local: "item"},
Attr: []xml.Attr{{
Name: xml.Name{Local: "jid"},
Value: i.JID.String(),
}},
}
if i.Node != "" {
start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "node"}, Value: i.Node})
}
if i.Name != "" {
start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "name"}, Value: i.Name})
}
return xmlstream.Wrap(nil, start)
}
// WriteXML implements xmlstream.WriterTo.
func (i Item) WriteXML(w xmlstream.TokenWriter) (int, error) {
return xmlstream.Copy(w, i.TokenReader())
}
// ItemIter is an iterator over discovered items.
// It supports paging
type ItemIter struct {
max uint64
iter *paging.Iter
current Item
err error
ctx context.Context
session *xmpp.Session
}
// Next returns true if there are more items to decode.
func (i *ItemIter) Next() bool {
if i.err != nil {
return false
}
next := i.iter.Next()
nextPage := i.iter.NextPage()
if !next && nextPage == nil {
return false
}
// If there is a next element and we don't need to turn the page, just decode
// is like normal.
if next {
start, r := i.iter.Current()
d := xml.NewTokenDecoder(r)
item := Item{}
i.err = d.DecodeElement(&item, start)
if i.err != nil {
return false
}
i.current = item
return true
}
// Turn the page.
i.err = i.iter.Close()
if i.err != nil {
return false
}
// TODO: set context based on a deadline?
i.iter = GetItems(i.ctx, i.current, i.session).iter
return i.Next()
}
// Err returns the last error encountered by the iterator (if any).
func (i *ItemIter) Err() error {
if i.err != nil {
return i.err
}
return i.iter.Err()
}
// Item returns the last roster item parsed by the iterator.
func (i *ItemIter) Item() Item {
return i.current
}
// Close indicates that we are finished with the given iterator and processing
// the stream may continue.
// Calling it multiple times has no effect.
func (i *ItemIter) Close() error {
if i.iter == nil {
return nil
}
return i.iter.Close()
}
// GetItems discovers a set of items associated with a JID and optional node of
// the provided item.
// The Name attribute of the query item is ignored.
// An empty Node means to query the root items for the JID.
// It blocks until a response is received.
//
// The iterator must be closed before anything else is done on the session.
// Any errors encountered while creating the iter are deferred until the iter is
// used.
func GetItems(ctx context.Context, item Item, s *xmpp.Session) *ItemIter {
return GetItemsIQ(ctx, item.Node, stanza.IQ{To: item.JID}, s)
}
// GetItemsIQ is like GetItems but it allows you to customize the IQ.
// Changing the type of the provided IQ has no effect.
func GetItemsIQ(ctx context.Context, node string, iq stanza.IQ, s *xmpp.Session) *ItemIter {
if iq.Type != stanza.GetIQ {
iq.Type = stanza.GetIQ
}
query := ItemsQuery{
Node: node,
}
iter, err := s.IterIQ(ctx, iq.Wrap(query.TokenReader()))
if err != nil {
return &ItemIter{err: err}
}
return &ItemIter{iter: paging.WrapIter(iter, defPageSize), ctx: ctx, session: s}
}
// ErrSkipItem is used as a return value from WalkItemFuncs to indicate that the
// node named in the call is to be skipped.
// It is not returned as an error by any function.
var ErrSkipItem = errors.New("skip this item")
// WalkItemFunc is the type of function called by WalkItem to visit each item in
// an item hierarchy.
// Item nodes are unique and absolute (in particular they should not be treated
// like paths, even if a particular implementation uses paths for node names).
//
// The error result returned by the function controls how WalkItem continues.
// If the function returns the special value ErrSkipItem, WalkItem skips the
// current item.
// Otherwise, if the function returns a non-nil error, WalkItem stops entirely
// and returns that error.
//
// The error reports an error related to the item, signaling that WalkItem will
// not walk into that item.
// The function may decide how to handle that error, including returning it to
// stop walking the entire tree.
//
// The function is called before querying for an item to allow SkipItem to
// bypass the query entirely.
// If an error occurs while making the query, the function will be called again
// with the same item to report the error.
type WalkItemFunc func(level int, item Item, err error) error
// WalkItem walks the tree rooted at the JID, calling fn for each item in the
// tree, including root.
// To query the root, leave item.Node empty.
// The Name attribute of the query item is ignored.
//
// All errors that arise visiting items are filtered by fn: see the WalkItemFunc
// documentation for details.
//
// The items are walked in wire order which may make the output
// non-deterministic.
func WalkItem(ctx context.Context, item Item, s *xmpp.Session, fn WalkItemFunc) error {
return walkItem(ctx, 0, 0, []Item{item}, s, fn)
}
func ignoredErr(err error) bool {
return errors.Is(err, stanza.Error{Condition: stanza.FeatureNotImplemented}) || errors.Is(err, stanza.Error{Condition: stanza.ServiceUnavailable})
}
func walkItem(ctx context.Context, level, itemIdx int, items []Item, s *xmpp.Session, fn WalkItemFunc) error {
last := len(items) - 1
item := items[itemIdx]
err := fn(level, item, nil)
if err != nil {
if err == ErrSkipItem {
err = nil
}
return err
}
// Look for loops and duplicates.
for n, oldItem := range items {
if n == itemIdx {
continue
}
if oldItem.Node == item.Node && oldItem.JID.Equal(item.JID) {
return nil
}
}
items, err = appendItems(ctx, s, itemIdx, items)
if ignoredErr(err) {
err = nil
}
if err != nil {
// Report the error with a second call to fn.
err = fn(level, item, err)
if err != nil {
return err
}
}
for n := range items[last+1:] {
err = walkItem(ctx, level+1, n+last+1, items, s, fn)
if err != nil {
if err == ErrSkipItem {
continue
}
return err
}
}
return nil
}
func appendItems(ctx context.Context, s *xmpp.Session, itemIdx int, items []Item) (i []Item, err error) {
iter := GetItems(ctx, items[itemIdx], s)
defer func() {
e := iter.Close()
if err == nil {
err = e
}
}()
for iter.Next() {
items = append(items, iter.Item())
}
return items, iter.Err()
}