~samwhited/xmpp

ref: 07c0fcc4e5ab13995ab52449e7ddc695dbe48b7d xmpp/negotiator.go -rw-r--r-- 6.2 KiB
07c0fcc4Sam Whited .builds: split testing and validation builds 8 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp

import (
	"context"
	"crypto/tls"
	"io"
	"net"

	"mellium.im/xmpp/internal/attr"
	"mellium.im/xmpp/internal/stream"
)

// Negotiator is a function that can be passed to NegotiateSession to perform
// custom session negotiation. This can be used for creating custom stream
// initialization logic that does not use XMPP feature negotiation such as the
// connection mechanism described in XEP-0114: Jabber Component Protocol.
// Normally NewClientSession or NewServerSession should be used instead.
//
// If a Negotiator is passed into NegotiateSession it will be called repeatedly
// until a mask is returned with the Ready bit set. Each time Negotiator is
// called any bits set in the state mask that it returns will be set on the
// session state and any cache value that is returned will be passed back in
// during the next iteration. If a new io.ReadWriter is returned, it is set as
// the session's underlying io.ReadWriter and the internal session state
// (encoders, decoders, etc.) will be reset.
type Negotiator func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, cache interface{}, err error)

var _ tlsConn = (*teeConn)(nil)

// teeConn is a net.Conn that also copies reads and writes to the provided
// writers.
type teeConn struct {
	net.Conn
	tlsConn     *tls.Conn
	ctx         context.Context
	multiWriter io.Writer
	teeReader   io.Reader
}

// newTeeConn creates a teeConn. If the provided context is canceled, writes
// start passing through to the underlying net.Conn and are no longer copied to
// in and out.
func newTeeConn(ctx context.Context, c net.Conn, in, out io.Writer) teeConn {
	if tc, ok := c.(teeConn); ok {
		return tc
	}

	tc := teeConn{Conn: c, ctx: ctx}
	tc.tlsConn, _ = c.(*tls.Conn)
	if in != nil {
		tc.teeReader = io.TeeReader(c, in)
	}
	if out != nil {
		tc.multiWriter = io.MultiWriter(c, out)
	}
	return tc
}

func (tc teeConn) ConnectionState() tls.ConnectionState {
	if tc.tlsConn == nil {
		return tls.ConnectionState{}
	}
	return tc.tlsConn.ConnectionState()
}

func (tc teeConn) Write(p []byte) (int, error) {
	if tc.multiWriter == nil {
		return tc.Conn.Write(p)
	}
	select {
	case <-tc.ctx.Done():
		tc.multiWriter = nil
		return tc.Conn.Write(p)
	default:
	}
	return tc.multiWriter.Write(p)
}

func (tc teeConn) Read(p []byte) (int, error) {
	if tc.teeReader == nil {
		return tc.Conn.Read(p)
	}
	select {
	case <-tc.ctx.Done():
		tc.teeReader = nil
		return tc.Conn.Read(p)
	default:
	}
	return tc.teeReader.Read(p)
}

// StreamConfig contains options for configuring the default Negotiator.
type StreamConfig struct {
	// The native language of the stream.
	Lang string

	// S2S causes the negotiator to negotiate a server-to-server (s2s) connection.
	S2S bool

	// A list of stream features to attempt to negotiate.
	Features []StreamFeature

	// If set a copy of any reads from the session will be written to TeeIn and
	// any writes to the session will be written to TeeOut (similar to the tee(1)
	// command).
	// This can be used to build an "XML console", but users should be careful
	// since this bypasses TLS and could expose passwords and other sensitve data.
	TeeIn, TeeOut io.Writer
}

// NewNegotiator creates a Negotiator that uses a collection of StreamFeatures
// to negotiate an XMPP client-to-server (c2s) or server-to-server (s2s)
// session.
// If StartTLS is one of the supported stream features, the Negotiator attempts
// to negotiate it whether the server advertises support or not.
func NewNegotiator(cfg StreamConfig) Negotiator {
	return negotiator(cfg)
}

type negotiatorState struct {
	doRestart bool
	cancelTee context.CancelFunc
}

func negotiator(cfg StreamConfig) Negotiator {
	return func(ctx context.Context, s *Session, data interface{}) (mask SessionState, rw io.ReadWriter, restartNext interface{}, err error) {
		nState, ok := data.(negotiatorState)
		// If no state was passed in, this is the first negotiate call so make up a
		// default.
		if !ok {
			nState = negotiatorState{
				doRestart: true,
				cancelTee: nil,
			}
		}

		c := s.Conn()
		// If the session is not already using a tee conn, but we're configured to
		// use one, return the new teeConn and don't set any state bits.
		if _, ok := c.(teeConn); !ok && (cfg.TeeIn != nil || cfg.TeeOut != nil) {
			// Cancel any previous teeConn's so that we don't double write to in and
			// out.
			if nState.cancelTee != nil {
				nState.cancelTee()
			}

			// This context is just for canceling the tee effect so it is not part of
			// the normal context chain and its parent is Background.
			ctx, cancel := context.WithCancel(context.Background())
			c = newTeeConn(ctx, c, cfg.TeeIn, cfg.TeeOut)
			nState.cancelTee = cancel
			return mask, c, nState, err
		}

		// Loop for as long as we're not done negotiating features or a stream
		// restart is still required.
		if nState.doRestart {
			if (s.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in
				// response.

				s.in.Info, err = stream.Expect(ctx, s.in.d, s.State()&Received == Received)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				s.out.Info, err = stream.Send(s.Conn(), cfg.S2S, stream.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), attr.RandomID())
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
			} else {
				// If we're the initiating entity, send a new stream and then wait for
				// one in response.
				s.out.Info, err = stream.Send(s.Conn(), cfg.S2S, stream.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), "")
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				s.in.Info, err = stream.Expect(ctx, s.in.d, s.State()&Received == Received)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
			}
		}

		// TODO: Check if the first token is a stream error (if so, unmarshal and
		// return, otherwise pass the token into negotiateFeatures).
		mask, rw, err = negotiateFeatures(ctx, s, data == nil, cfg.Features)
		nState.doRestart = rw != nil
		return mask, rw, nState, err
	}
}