~samwhited/xmpp

xmpp/history/query.go -rw-r--r-- 4.3 KiB
60a076f3Sam Whited .builds: disable testing against gotip 2 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package history

import (
	"encoding/xml"
	"time"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/form"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/paging"
)

// Query is a request to the archive for data.
// An empty query indicates all messages should be fetched without a filter and
// with a random ID.
type Query struct {
	// Query parameters
	ID string

	// Filters
	With     jid.JID
	Start    time.Time
	End      time.Time
	BeforeID string
	AfterID  string
	IDs      []string

	// Limit limits the total number of messages returned.
	Limit uint64

	// Last starts fetching from the last page (or before PageID if set).
	Last bool

	// PageID is the ID of a message within the existing query that we should
	// start paging after (or before, if Last is set).
	// This lets us skip over the redundant message when querying with Start/End,
	// or skip to a later page within the query if we abandoned it and need to
	// start over (but don't want to fetch all the pages we've already processed).
	PageID string

	// Reverse flips messages returned within a page.
	Reverse bool
}

const (
	fieldWith   = "with"
	fieldStart  = "start"
	fieldEnd    = "end"
	fieldAfter  = "after-id"
	fieldBefore = "before-id"
	fieldIDs    = "ids"
)

// TokenReader implements xmlstream.Marshaler.
func (f *Query) TokenReader() xml.TokenReader {
	dataForm := form.New(
		form.Hidden("FORM_TYPE", form.Value(NS)),
		form.JID(fieldWith),
		form.Text(fieldStart),
		form.Text(fieldEnd),
		form.Text(fieldAfter),
		form.Text(fieldBefore),
		form.ListMulti(fieldIDs),
	)
	if !f.With.Equal(jid.JID{}) {
		/* #nosec */
		dataForm.Set(fieldWith, f.With)
	}
	if !f.Start.IsZero() {
		/* #nosec */
		dataForm.Set(fieldStart, f.Start.UTC().Format(time.RFC3339))
	}
	if !f.End.IsZero() {
		/* #nosec */
		dataForm.Set(fieldEnd, f.End.UTC().Format(time.RFC3339))
	}
	if f.AfterID != "" {
		/* #nosec */
		dataForm.Set(fieldAfter, f.AfterID)
	}
	if f.BeforeID != "" {
		/* #nosec */
		dataForm.Set(fieldBefore, f.BeforeID)
	}
	if len(f.IDs) > 0 {
		/* #nosec */
		dataForm.Set(fieldIDs, f.IDs)
	}
	filter, _ := dataForm.Submit()

	inner := []xml.TokenReader{
		filter,
	}
	if f.Last {
		inner = append(inner, (&paging.RequestPrev{
			Max:    f.Limit,
			Before: f.PageID,
		}).TokenReader())
	} else {
		inner = append(inner, (&paging.RequestNext{
			Max:   f.Limit,
			After: f.PageID,
		}).TokenReader())
	}
	if f.Reverse {
		inner = append(inner, xmlstream.Wrap(
			nil,
			xml.StartElement{Name: xml.Name{Local: "flip-page"}},
		))
	}
	return xmlstream.Wrap(
		xmlstream.MultiReader(inner...),
		xml.StartElement{
			Name: xml.Name{Space: NS, Local: "query"},
			Attr: []xml.Attr{{Name: xml.Name{Local: "queryid"}, Value: f.ID}},
		},
	)
}

// WriteXML implements xmlstream.WriterTo.
func (f *Query) WriteXML(w xmlstream.TokenWriter) (int, error) {
	return xmlstream.Copy(w, f.TokenReader())
}

// MarshalXML implements xml.Marshaler.
func (f *Query) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
	_, err := f.WriteXML(e)
	return err
}

// UnmarshalXML implements xml.Unmarshaler.
func (f *Query) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
	s := struct {
		XMLName xml.Name   `xml:"urn:xmpp:mam:2 query"`
		ID      string     `xml:"queryid,attr"`
		Form    *form.Data `xml:"jabber:x:data x"`
		Flip    struct {
			XMLName xml.Name `xml:"flip-page"`
		}
		Set struct {
			XMLName xml.Name `xml:"http://jabber.org/protocol/rsm set"`
			Max     uint64   `xml:"max"`
			After   string   `xml:"after"`
			Before  struct {
				XMLName xml.Name `xml:"before"`
			}
		}
	}{}
	err := d.DecodeElement(&s, &start)
	if err != nil {
		return err
	}

	f.ID = s.ID
	f.With, _ = s.Form.GetJID(fieldWith)
	startTime, ok := s.Form.GetString(fieldStart)
	if ok {
		//panic(startTime)
		f.Start, err = time.Parse(time.RFC3339, startTime)
		if err != nil {
			return err
		}
	}
	endTime, ok := s.Form.GetString(fieldEnd)
	if ok {
		f.End, err = time.Parse(time.RFC3339, endTime)
		if err != nil {
			return err
		}
	}
	f.BeforeID, _ = s.Form.GetString(fieldBefore)
	f.AfterID, _ = s.Form.GetString(fieldAfter)
	f.IDs, _ = s.Form.GetStrings(fieldIDs)
	f.Limit = s.Set.Max

	f.Last = s.Set.Before.XMLName.Local == "before"
	f.Reverse = s.Flip.XMLName.Local == "flip-page"
	return nil
}