~samwhited/xmpp

ref: e8c09b3ff1c1489c21d96a0f7f8f0e8728fc095a xmpp/delay/delay.go -rw-r--r-- 3.4 KiB
e8c09b3fSam Whited design: fix typo in design doc template 1 year, 4 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package delay implements delayed delivery of stanzas.
package delay // import "mellium.im/xmpp/delay"

import (
	"encoding/xml"
	"fmt"
	"time"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/xtime"
)

// NS is the namespace used by this package.
const NS = "urn:xmpp:delay"

// Delay is a type that can be added to stanzas to indicate that they have been
// delivered with a delay.
type Delay struct {
	XMLName xml.Name  `xml:"urn:xmpp:delay delay"`
	From    jid.JID   `xml:"from,attr,omitempty"`
	Time    time.Time `xml:"stamp,attr"`
	Reason  string    `xml:",chardata"`
}

// TokenReader implements xmlstream.Marshaler.
func (d Delay) TokenReader() xml.TokenReader {
	timeAttr, err := xtime.Time{Time: d.Time}.MarshalXMLAttr(xml.Name{Local: "stamp"})
	if err != nil {
		panic(fmt.Errorf("delay: unreachable error reached while marshaling time: %w", err))
	}
	start := xml.StartElement{
		Name: xml.Name{Space: NS, Local: "delay"},
		Attr: []xml.Attr{timeAttr},
	}

	if !d.From.Equal(jid.JID{}) {
		start.Attr = append(start.Attr, xml.Attr{
			Name:  xml.Name{Local: "from"},
			Value: d.From.String(),
		})
	}

	if d.Reason != "" {
		return xmlstream.Wrap(xmlstream.Token(xml.CharData(d.Reason)), start)
	}
	return xmlstream.Wrap(nil, start)
}

// WriteXML implements xmlstream.WriterTo.
func (d Delay) WriteXML(w xmlstream.TokenWriter) (int, error) {
	return xmlstream.Copy(w, d.TokenReader())
}

// MarshalXML implements xml.Marshaler.
func (d Delay) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
	_, err := d.WriteXML(e)
	return err
}

// UnmarshalXML implements xml.Unmarshaler.
func (d *Delay) UnmarshalXML(decoder *xml.Decoder, start xml.StartElement) error {
	var err error
	var foundStamp, foundFrom bool
	for _, attr := range start.Attr {
		if attr.Name.Space != "" && attr.Name.Space != NS {
			continue
		}
		switch attr.Name.Local {
		case "stamp":
			foundStamp = true
			var xt xtime.Time
			err = (&xt).UnmarshalXMLAttr(attr)
			d.Time = xt.Time
		case "from":
			foundFrom = true
			err = (&d.From).UnmarshalXMLAttr(attr)
		}
		if err != nil {
			return err
		}
		if foundStamp && foundFrom {
			break
		}
	}
	tok, err := decoder.Token()
	if err != nil {
		return err
	}
	switch data := tok.(type) {
	case xml.CharData:
		d.Reason = string(data)
	case xml.EndElement:
		return nil
	}
	return decoder.Skip()
}

// TODO: replace when #113 is ready.
func isStanza(name xml.Name) bool {
	return (name.Local == "iq" || name.Local == "message" || name.Local == "presence") &&
		(name.Space == ns.Client || name.Space == ns.Server)
}

// Stanza inserts a delay into any stanza read through the stream.
func Stanza(d Delay) xmlstream.Transformer {
	return xmlstream.InsertFunc(func(start xml.StartElement, level uint64, w xmlstream.TokenWriter) error {
		if !isStanza(start.Name) || level != 1 {
			return nil
		}

		_, err := xmlstream.Copy(w, d.TokenReader())
		return err
	})
}

// Insert adds a delay into any element read through the transformer at the
// current nesting level.
func Insert(d Delay) xmlstream.Transformer {
	return xmlstream.InsertFunc(func(start xml.StartElement, level uint64, w xmlstream.TokenWriter) error {
		if level != 1 {
			return nil
		}

		_, err := xmlstream.Copy(w, d.TokenReader())
		return err
	})
}