~samwhited/xmpp

xmpp/blocklist/handler.go -rw-r--r-- 2.2 KiB
e9b0a2deSam Whited docs: do a quick editing pass over the docs a day ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package blocklist

import (
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/mux"
	"mellium.im/xmpp/stanza"
)

// Handle returns an option that registers the given handler on the mux for the
// various blocking command payloads.
func Handle(h Handler) mux.Option {
	return func(m *mux.ServeMux) {
		mux.IQ(stanza.GetIQ, xml.Name{Space: NS, Local: "blocklist"}, h)(m)
		mux.IQ(stanza.SetIQ, xml.Name{Space: NS, Local: "block"}, h)(m)
		mux.IQ(stanza.SetIQ, xml.Name{Space: NS, Local: "unblock"}, h)(m)
	}
}

// Handler can be used to respond to incoming blocking command requests.
type Handler struct {
	Block      func(jid.JID)
	Unblock    func(jid.JID)
	UnblockAll func()
	List       func(chan<- jid.JID)
}

// HandleIQ implements mux.IQHandler.
func (h Handler) HandleIQ(iq stanza.IQ, r xmlstream.TokenReadEncoder, start *xml.StartElement) error {
	if start.Name.Local == "blocklist" {
		res := iq.Result(xmlstream.Wrap(nil, *start))
		// Copy the start IQ and start payload first.
		_, err := xmlstream.Copy(r, xmlstream.LimitReader(res, 2))
		if err != nil {
			return err
		}
		if h.List != nil {
			c := make(chan jid.JID)
			go func() {
				h.List(c)
				close(c)
			}()
			for j := range c {
				_, err = xmlstream.Copy(r, xmlstream.Wrap(nil, xml.StartElement{
					Name: xml.Name{Space: NS, Local: "item"},
					Attr: []xml.Attr{{
						Name:  xml.Name{Local: "jid"},
						Value: j.String(),
					}},
				}))
				if err != nil {
					return err
				}
			}
		}
		// Copy the end payload and end IQ.
		_, err = xmlstream.Copy(r, xmlstream.LimitReader(res, 2))
		return err
	}

	iter := xmlstream.NewIter(r)
	var found bool
	for iter.Next() {
		found = true
		itemStart, _ := iter.Current()
		jstr := itemStart.Attr[0].Value
		j := jid.MustParse(jstr)
		switch start.Name.Local {
		case "block":
			if h.Block != nil {
				h.Block(j)
			}
		case "unblock":
			if h.Unblock != nil {
				h.Unblock(j)
			}
		}
	}
	err := iter.Err()
	if err != nil {
		return err
	}
	if !found && start.Name.Local == "unblock" && h.UnblockAll != nil {
		h.UnblockAll()
	}
	return nil
}