~samwhited/xmpp

ref: 3a5d79291c618f5ac03dc34e1e78af5a6e0ad514 xmpp/features.go -rw-r--r-- 6.5 KiB
3a5d7929Sam Whited Workaround for Go issue #16497 5 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp

import (
	"context"
	"encoding/xml"
	"fmt"
	"io"

	"mellium.im/xmpp/ns"
	"mellium.im/xmpp/streamerror"
)

// A StreamFeature represents a feature that may be selected during stream
// negotiation. Features should be stateless and usable from multiple goroutines
// unless otherwise specified.
type StreamFeature struct {
	// The XML name of the feature in the <stream:feature/> list. If a start
	// element with this name is seen while the connection is reading the features
	// list, it will trigger this StreamFeature's List function as a callback.
	Name xml.Name

	// Bits that are required before this feature is advertised. For instance, if
	// this feature should only be advertised after the user is authenticated we
	// might set this to "Authn" or if it should be advertised only after the
	// feature is authenticated and encrypted we might set this to "Authn|Secure".
	Necessary SessionState

	// Bits that must be off for this feature to be advertised. For instance, if
	// this feature should only be advertised before the connection is
	// authenticated (eg. if the feature performs authentication itself), we might
	// set this to "Authn".
	Prohibited SessionState

	// Used to send the feature in a features list for server connections.
	List func(ctx context.Context, conn io.Writer) (req bool, err error)

	// Used to parse the feature that begins with the given xml start element
	// (which should have a Name that matches this stream feature's Name).
	// Returns whether or not the feature is required, and any data that will be
	// needed if the feature is selected for negotiation (eg. the list of
	// mechanisms if the feature was SASL).
	Parse func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (req bool, data interface{}, err error)

	// A function that will take over the session temporarily while negotiating
	// the feature. The "mask" SessionState represents the state bits that should
	// be flipped after negotiation of the feature is complete. For instance, if
	// this feature creates a security layer (such as TLS) and performs
	// authentication, mask would be set to Authn|Secure|StreamRestartRequired,
	// but if it does not authenticate the connection it would return
	// Secure|StreamRestartRequired. If the mask includes the StreamRestart bit,
	// the stream will be restarted automatically after Negotiate returns (unless
	// it returns an error). If this is an initiated connection and the features
	// List call returned a value, that value is passed to the data parameter when
	// Negotiate is called. For instance, in the case of compression this data
	// parameter might be the list of supported algorithms as a slice of strings
	// (or in whatever format the feature implementation has decided upon).
	Negotiate func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, err error)
}

// Returns the number of stream features written (zero means we've reached the
// end of negotiation), and the number of required features written (zero means
// we've potentially reached the end of negotiation, but the client may
// negotiate more optional features).
func writeStreamFeatures(ctx context.Context, conn *Conn) (n int, req int, err error) {
	if _, err = fmt.Fprint(conn, `<stream:features>`); err != nil {
		return
	}
	for _, feature := range conn.config.Features {
		// Check if all the necessary bits are set and none of the prohibited bits
		// are set.
		if (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
			var r bool
			r, err = feature.List(ctx, conn)
			if err != nil {
				return
			}
			if r {
				req += 1
			}
			n += 1
		}
	}
	_, err = fmt.Fprint(conn, `</stream:features>`)
	return
}

func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, err error) {
	if (c.state & Received) == Received {
		_, _, err = writeStreamFeatures(ctx, c)
		if err != nil {
			return
		}
		panic("Sending stream:features not yet implemented")
	} else {
		t, err := c.in.d.Token()
		if err != nil {
			return done, err
		}
		start, ok := t.(xml.StartElement)
		if !ok {
			return done, streamerror.BadFormat
		}
		list, err := readStreamFeatures(ctx, c, start)

		switch {
		case err != nil:
			return done, err
		case list.total == 0 || len(list.cache) == 0:
			// If we received an empty list (or one with no supported features, we're
			// done.
			return true, nil
		}

		// If the list has any required items, negotiate the first required feature.
		// Otherwise just negotiate the first feature in the list.
		var data sfData
		for _, v := range list.cache {
			if !list.req || v.req {
				data = v
				break
			}
		}
		mask, err := data.feature.Negotiate(ctx, c, data.data)
		if err == nil {
			c.state |= mask
		}
		return !list.req, err
	}
}

type sfData struct {
	req     bool
	data    interface{}
	feature StreamFeature
}

type streamFeaturesList struct {
	total int
	req   bool
	cache map[xml.Name]sfData
}

func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement) (*streamFeaturesList, error) {
	switch {
	case start.Name.Local != "features":
		return nil, streamerror.InvalidXML
	case start.Name.Space != ns.Stream:
		return nil, streamerror.BadNamespacePrefix
	}

	sf := &streamFeaturesList{
		cache: make(map[xml.Name]sfData),
	}

parsefeatures:
	for {
		t, err := conn.in.d.Token()
		if err != nil {
			return nil, err
		}
		switch tok := t.(type) {
		case xml.StartElement:
			// If the token is a new feature, see if it's one we handle. If so, parse
			// it. Increment the total features count regardless.
			sf.total += 1
			if feature, ok := conn.config.Features[tok.Name]; ok && (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
				req, data, err := feature.Parse(ctx, conn.in.d, &tok)
				if err != nil {
					return nil, err
				}
				sf.cache[tok.Name] = sfData{
					req:     req,
					data:    data,
					feature: feature,
				}
				if req {
					sf.req = true
				}
				continue parsefeatures
			}
			// If the feature is not one we support, skip it.
			if err := conn.in.d.Skip(); err != nil {
				return nil, err
			}
		case xml.EndElement:
			if tok.Name.Local == "features" && tok.Name.Space == ns.Stream {
				// We've reached the end of the features list!
				return sf, nil
			}
			// Oops, how did that happen? We shouldn't have been able to hit an end
			// element that wasn't the </stream:features> token.
			return nil, streamerror.InvalidXML
		default:
			return nil, streamerror.RestrictedXML
		}
	}
}