~samwhited/xmpp

ref: 7e76defde884af7cd63b3d6bb7065694db076b8f xmpp/features.go -rw-r--r-- 11.1 KiB
7e76defdSam Whited all: pass Session directly to handler 1 year, 10 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp

import (
	"context"
	"encoding/xml"
	"errors"
	"io"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/stream"
)

const (
	featuresLocal = "features"
)

// A StreamFeature represents a feature that may be selected during stream
// negotiation, eg. STARTTLS, compression, and SASL authentication are all
// stream features.
// Features should be stateless as they may be reused between
// connection attempts, however, a method for passing state between features
// exists on the Parse and Negotiate functions.
type StreamFeature struct {
	// The XML name of the feature in the <stream:feature/> list. If a start
	// element with this name is seen while the connection is reading the features
	// list, it will trigger this StreamFeature's Parse function as a callback.
	Name xml.Name

	// Bits that are required before this feature is advertised. For instance, if
	// this feature should only be advertised after the user is authenticated we
	// might set this to "Authn" or if it should be advertised only after the
	// feature is authenticated and encrypted we might set this to "Authn|Secure".
	Necessary SessionState

	// Bits that must be off for this feature to be advertised. For instance, if
	// this feature should only be advertised before the connection is
	// authenticated (eg. if the feature performs authentication itself), we might
	// set this to "Authn".
	Prohibited SessionState

	// Used to send the feature in a features list for server connections. The
	// start element will have a name that matches the features name and should be
	// used as the outermost tag in the stream (but also may be ignored). List
	// implementations that call e.EncodeToken directly need to call e.Flush when
	// finished to ensure that the XML is written to the underlying writer.
	List func(ctx context.Context, e xmlstream.TokenWriter, start xml.StartElement) (req bool, err error)

	// Used to parse the feature that begins with the given xml start element
	// (which should have a Name that matches this stream feature's Name).
	// Returns whether or not the feature is required, and any data that will be
	// needed if the feature is selected for negotiation (eg. the list of
	// mechanisms if the feature was SASL authentication).
	Parse func(ctx context.Context, r xml.TokenReader, start *xml.StartElement) (req bool, data interface{}, err error)

	// A function that will take over the session temporarily while negotiating
	// the feature. The "mask" SessionState represents the state bits that should
	// be flipped after negotiation of the feature is complete. For instance, if
	// this feature creates a security layer (such as TLS) and performs
	// authentication, mask would be set to Authn|Secure, but if it does not
	// authenticate the connection it would just return Secure. If negotiate
	// returns a new io.ReadWriter (probably wrapping the old session.Conn()) the
	// stream will be restarted automatically after Negotiate returns using the
	// new ReadWriter. If this is an initiated connection and the features List
	// call returned a value, that value is passed to the data parameter when
	// Negotiate is called. For instance, in the case of compression this data
	// parameter might be the list of supported algorithms as a slice of strings
	// (or in whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error)
}

func containsStartTLS(features []StreamFeature) (startTLS StreamFeature, ok bool) {
	for _, feature := range features {
		if feature.Name.Space == ns.StartTLS {
			startTLS, ok = feature, true
			break
		}
	}
	return startTLS, ok
}

func negotiateFeatures(ctx context.Context, s *Session, first bool, features []StreamFeature) (mask SessionState, rw io.ReadWriter, err error) {
	server := (s.state & Received) == Received

	// If we're the server, write the initial stream features.
	var list *streamFeaturesList
	if server {
		list, err = writeStreamFeatures(ctx, s, features)
		if err != nil {
			return mask, nil, err
		}
	}

	var t xml.Token
	var start xml.StartElement
	var ok bool

	var startTLS StreamFeature
	var doStartTLS bool
	if !server {
		// Read a new start stream:features token.
		t, err = s.Token()
		if err != nil {
			return mask, nil, err
		}
		start, ok = t.(xml.StartElement)
		if !ok {
			return mask, nil, stream.BadFormat
		}

		// If we're the client read the rest of the stream features list.
		list, err = readStreamFeatures(ctx, s, start, features)
		if err != nil {
			return mask, nil, err
		}

		startTLS, doStartTLS = containsStartTLS(features)
		_, advertisedStartTLS := list.cache[ns.StartTLS]
		// If this is the first features list and StartTLS isn't advertised (but
		// is in the features list to be negotiated) and we're not already on a
		// secure connection, try it anyways to prevent downgrade attacks per RFC
		// 7590.
		doStartTLS = first && !advertisedStartTLS && s.State()&Secure != Secure && doStartTLS

		switch {
		case doStartTLS:
			// Skip length checks if we need to negotiate StartTLS for downgrade
			// attack prevention.
		case list.total == 0:
			// If we received an empty list (or one with no supported features), we're
			// done.
			return Ready, nil, nil
		case len(list.cache) == 0:
			// If we received a list with features we support but where none of them
			// could be negotiated (eg. they were advertised in the wrong order), this
			// is an error:
			// TODO: This error isn't very good.
			return mask, nil, errors.New("xmpp: features advertised out of order")
		}
	}

	var sent bool

	// If the list has any optional items that we support, negotiate them first
	// before moving on to the required items.
	for {
		var data sfData

		if server {
			// Read a new feature to negotiate.
			t, err = s.Token()
			if err != nil {
				return mask, nil, err
			}
			start, ok = t.(xml.StartElement)
			if !ok {
				return mask, nil, stream.BadFormat
			}

			// If the feature was not sent or was already negotiated, error.

			_, negotiated := s.negotiated[start.Name.Space]
			data, sent = list.cache[start.Name.Space]
			if !sent || negotiated {
				// TODO: What should we return here?
				return mask, rw, stream.PolicyViolation
			}
		} else {
			// If we need to try and negotiate StartTLS even though it wasn't
			// advertised, select it.
			if doStartTLS && startTLS.Name.Space == ns.StartTLS {
				data = sfData{
					req:     true,
					feature: startTLS,
				}
			} else {
				// If we're the client, iterate through the cached features and select one
				// to negotiate.
				for _, v := range list.cache {
					if _, ok := s.negotiated[v.feature.Name.Space]; ok {
						// If this feature has already been negotiated, skip it.
						continue
					}

					// If the feature is optional, select it.
					if !v.req {
						data = v
						break
					}

					// If the feature is required, tentatively select it (but finish looking
					// for optional features).
					if v.req {
						data = v
					}
				}
			}

			// No features that haven't already been negotiated were sent… we're done.
			if data.feature.Name.Local == "" {
				return Ready, nil, nil
			}
		}

		mask, rw, err = data.feature.Negotiate(ctx, s, data.data)
		if err == nil {
			s.state |= mask
		}
		s.negotiated[data.feature.Name.Space] = struct{}{}

		// If we negotiated a required feature or a stream restart is required
		// we're done with this feature set.
		if rw != nil || data.req {
			break
		}
	}

	// If the list contains no required features, negotiation is complete.
	if !list.req {
		mask |= Ready
	}

	return mask, rw, err
}

type sfData struct {
	req     bool
	data    interface{}
	feature StreamFeature
}

type streamFeaturesList struct {
	total int
	req   bool

	// Namespace to sfData
	cache map[string]sfData
}

func getFeature(name xml.Name, features []StreamFeature) (feature StreamFeature, ok bool) {
	for _, f := range features {
		if f.Name == name {
			return f, true
		}
	}
	return feature, false
}

func writeStreamFeatures(ctx context.Context, s *Session, features []StreamFeature) (list *streamFeaturesList, err error) {
	start := xml.StartElement{Name: xml.Name{Space: "", Local: "stream:features"}}
	w := s.TokenWriter()
	defer w.Close()
	if err = w.EncodeToken(start); err != nil {
		return
	}

	// Lock the connection features list.
	list = &streamFeaturesList{
		cache: make(map[string]sfData),
	}

	for _, feature := range features {
		// Check if all the necessary bits are set and none of the prohibited bits
		// are set.
		if (s.state&feature.Necessary) == feature.Necessary && (s.state&feature.Prohibited) == 0 {
			var r bool
			r, err = feature.List(ctx, s.out.e, xml.StartElement{
				Name: feature.Name,
			})
			if err != nil {
				return list, err
			}
			list.cache[feature.Name.Space] = sfData{
				req:     r,
				data:    nil,
				feature: feature,
			}
			if r {
				list.req = true
			}
			list.total++
		}
	}
	if err = w.EncodeToken(start.End()); err != nil {
		return list, err
	}
	if err = s.Flush(); err != nil {
		return list, err
	}
	return list, err
}

func readStreamFeatures(ctx context.Context, s *Session, start xml.StartElement, features []StreamFeature) (*streamFeaturesList, error) {
	switch {
	case start.Name.Local != featuresLocal:
		return nil, stream.InvalidXML
	case start.Name.Space != ns.Stream:
		return nil, stream.BadNamespacePrefix
	}

	sf := &streamFeaturesList{
		cache: make(map[string]sfData),
	}

parsefeatures:
	for {
		t, err := s.in.d.Token()
		if err != nil {
			return nil, err
		}
		switch tok := t.(type) {
		case xml.StartElement:
			// If the token is a new feature, see if it's one we handle. If so, parse
			// it. Increment the total features count regardless.
			sf.total++

			// Always add the feature to the list of features, even if we don't
			// support it.
			s.features[tok.Name.Space] = nil

			feature, ok := getFeature(tok.Name, features)

			if ok && s.state&feature.Necessary == feature.Necessary && s.state&feature.Prohibited == 0 {
				req, data, err := feature.Parse(ctx, s.in.d, &tok)
				if err != nil {
					return nil, err
				}

				// TODO: Since we're storing the features data on s.features we can
				// probably remove it from this temporary cache.
				sf.cache[tok.Name.Space] = sfData{
					req:     req,
					data:    data,
					feature: feature,
				}

				// Since we do support the feature, add it to the connections list along
				// with any data returned from Parse.
				s.features[tok.Name.Space] = data
				if req {
					sf.req = true
				}
				continue parsefeatures
			}
			// If the feature is not one we support, skip it.
			if err := xmlstream.Skip(s.in.d); err != nil {
				return nil, err
			}
		case xml.EndElement:
			if tok.Name.Local == featuresLocal && tok.Name.Space == ns.Stream {
				// We've reached the end of the features list!
				return sf, nil
			}
			// Oops, how did that happen? We shouldn't have been able to hit an end
			// element that wasn't the </stream:features> token.
			return nil, stream.InvalidXML
		default:
			return nil, stream.RestrictedXML
		}
	}
}