~samwhited/xmpp

ref: a2ffaf710170d692a9289c2b428a4efa90c8e703 xmpp/sasl2/sasl.go -rw-r--r-- 7.0 KiB
a2ffaf71Sam Whited all: update deps 3 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
// Copyright 2017 Sam Whited.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package sasl2 is an experimental implementation of XEP-0388: Extensible SASL
// Profile.
//
// BE ADVISED: This API is incomplete and is subject to change.
// Core functionality of this package is missing, and the entire package may be
// removed at any time.
package sasl2 // import "mellium.im/xmpp/sasl2"

import (
	"context"
	"encoding/xml"
	"errors"
	"fmt"
	"io"

	"mellium.im/sasl"
	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/saslerr"
	"mellium.im/xmpp/stream"
)

// BUG(ssw): feature may provide a security layer, but is not byte precise.

// TODO(ssw): Support caching mechanisms on the feature and pipelining the
// selection.

// Namespaces used by SASL2.
const (
	NS = "urn:xmpp:sasl:0"
)

// SASL returns a stream feature for performing authentication using the Simple
// Authentication and Security Layer (SASL) as defined in RFC 4422.
// It panics if no mechanisms are specified.
// The order in which mechanisms are specified will be the preferred order, so
// stronger mechanisms should be listed first.
//
// Identity is used when a user wants to act on behalf of another user.
// For instance, an admin might want to log in as another user to help them
// troubleshoot an issue.
// Normally it is left blank and the localpart of the Origin JID is used.
func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
	if len(mechanisms) == 0 {
		panic("sasl2: Must specify at least 1 mechanism")
	}

	return xmpp.StreamFeature{
		Name:       xml.Name{Space: NS, Local: "mechanisms"},
		Necessary:  xmpp.Secure,
		Prohibited: xmpp.Authn,
		List: func(ctx context.Context, e xmlstream.TokenWriter, start xml.StartElement) (req bool, err error) {
			req = true
			if err = e.EncodeToken(start); err != nil {
				return
			}

			startMechanism := xml.StartElement{Name: xml.Name{Space: "", Local: "mechanism"}}
			for _, m := range mechanisms {
				select {
				case <-ctx.Done():
					return true, ctx.Err()
				default:
				}

				if err = e.EncodeToken(startMechanism); err != nil {
					return
				}
				if err = e.EncodeToken(xml.CharData(m.Name)); err != nil {
					return
				}
				if err = e.EncodeToken(startMechanism.End()); err != nil {
					return
				}
			}
			return req, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, r xml.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
				XMLName xml.Name `xml:"urn:xmpp:sasl:0 mechanisms"`
				List    []string `xml:"urn:xmpp:sasl:0 mechanism"`
			}{}
			err := xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
			if (session.State() & xmpp.Received) == xmpp.Received {
				panic("SASL server not yet implemented")
			}

			conn := session.Conn()

			// Select a mechanism, preferring the client order.
			var selected sasl.Mechanism
		selectmechanism:
			for _, m := range mechanisms {
				for _, name := range data.([]string) {
					if name == m.Name {
						selected = m
						break selectmechanism
					}
				}
			}
			// No matching mechanism found…
			if selected.Name == "" {
				return mask, nil, errors.New(`No matching SASL mechanisms found`)
			}

			// Create a new SASL client and give it access to credentials, other
			// mechanisms advertised by the server, and the TLS session state if
			// possible (for SCRAM-PLUS mechanisms).
			opts := []sasl.Option{
				sasl.Credentials(func() ([]byte, []byte, []byte) {
					return []byte(session.LocalAddr().Localpart()), []byte(password), []byte(identity)
				}),
				sasl.RemoteMechanisms(data.([]string)...),
			}
			if connState, ok := conn.ConnectionState(); ok {
				opts = append(opts, sasl.TLSState(connState))
			}
			client := sasl.NewClient(selected, opts...)

			// Calculate the initial response
			more, resp, err := client.Step(nil)
			if err != nil {
				return mask, nil, err
			}

			// XEP-0388 §2.2:
			//     In order to explicitly transmit a zero-length SASL challenge or
			//     response, the sending party sends a single equals sign character
			//     ("=").
			if len(resp) == 0 {
				resp = []byte{'='}
			}

			// TODO: Printf'ing is probably a bad idea. Encode the tokens properly.
			// Send <auth/> and the initial payload to start SASL auth.
			if _, err = fmt.Fprintf(conn,
				`<authenticate xmlns='%s' mechanism='%s'><initial-response>%s</initial-response></authenticate>`,
				NS, selected.Name, resp,
			); err != nil {
				return mask, nil, err
			}

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.
			if !more {
				tok, err := session.Token()
				if err != nil {
					return mask, nil, err
				}
				if t, ok := tok.(xml.StartElement); ok {
					// TODO: Handle the additional data that could be returned if
					// success?
					_, _, err := decodeSASLChallenge(session, t, false)
					if err != nil {
						return mask, nil, err
					}
				} else {
					return mask, nil, stream.BadFormat
				}
			}

			success := false
			for more {
				select {
				case <-ctx.Done():
					return mask, nil, ctx.Err()
				default:
				}
				tok, err := session.Token()
				if err != nil {
					return mask, nil, err
				}
				var challenge []byte
				if t, ok := tok.(xml.StartElement); ok {
					challenge, success, err = decodeSASLChallenge(session, t, true)
					if err != nil {
						return mask, nil, err
					}
				} else {
					return mask, nil, stream.BadFormat
				}
				if more, resp, err = client.Step(challenge); err != nil {
					return mask, nil, err
				}
				if !more && success {
					// We're done with SASL and we're successful
					break
				}
				// TODO: What happens if there's more and success (broken server)?
				if _, err = fmt.Fprintf(conn,
					`<response xmlns='urn:xmpp:sasl:0'>%s</response>`, resp); err != nil {
					return mask, nil, err
				}
			}
			return xmpp.Authn, conn, nil
		},
	}
}

func decodeSASLChallenge(r xml.TokenReader, start xml.StartElement, allowChallenge bool) (challenge []byte, success bool, err error) {
	d := xml.NewTokenDecoder(r)
	switch start.Name {
	case xml.Name{Space: NS, Local: "challenge"}:
		if !allowChallenge {
			return nil, false, stream.UnsupportedStanzaType
		}
		challenge := struct {
			Data []byte `xml:",chardata"`
		}{}
		if err = d.DecodeElement(&challenge, &start); err != nil {
			return nil, false, err
		}
		return challenge.Data, false, nil
	case xml.Name{Space: NS, Local: "success"}:
		success := struct {
			XMLName xml.Name `xml:"urn:xmpp:sasl:0 success"`
			Data    []byte   `xml:"success-data"`
		}{}
		if err = d.DecodeElement(&challenge, &start); err != nil {
			return nil, true, err
		}
		return success.Data, true, nil
	case xml.Name{Space: NS, Local: "failure"}:
		fail := saslerr.Failure{}
		if err = d.DecodeElement(&fail, &start); err != nil {
			return nil, false, err
		}
		return nil, false, fail
	default:
		return nil, false, stream.UnsupportedStanzaType
	}
}