~samwhited/xmpp

xmpp/negotiator.go -rw-r--r-- 4.8 KiB
81420f0eSam Whited xmpp: always set "from" on s2s stanzas 3 days ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp

import (
	"context"
	"io"

	"mellium.im/xmpp/internal/attr"
	"mellium.im/xmpp/internal/stream"
)

// Negotiator is a function that can be passed to NegotiateSession to perform
// custom session negotiation. This can be used for creating custom stream
// initialization logic that does not use XMPP feature negotiation such as the
// connection mechanism described in XEP-0114: Jabber Component Protocol.
// Normally NewClientSession or NewServerSession should be used instead.
//
// If a Negotiator is passed into NegotiateSession it will be called repeatedly
// until a mask is returned with the Ready bit set. Each time Negotiator is
// called any bits set in the state mask that it returns will be set on the
// session state and any cache value that is returned will be passed back in
// during the next iteration. If a new io.ReadWriter is returned, it is set as
// the session's underlying io.ReadWriter and the internal session state
// (encoders, decoders, etc.) will be reset.
type Negotiator func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, cache interface{}, err error)

// StreamConfig contains options for configuring the default Negotiator.
type StreamConfig struct {
	// The native language of the stream.
	Lang string

	// S2S causes the negotiator to negotiate a server-to-server (s2s) connection.
	S2S bool

	// A list of stream features to attempt to negotiate.
	Features []StreamFeature

	// If set a copy of any reads from the session will be written to TeeIn and
	// any writes to the session will be written to TeeOut (similar to the tee(1)
	// command).
	// This can be used to build an "XML console", but users should be careful
	// since this bypasses TLS and could expose passwords and other sensitve data.
	TeeIn, TeeOut io.Writer
}

// NewNegotiator creates a Negotiator that uses a collection of StreamFeatures
// to negotiate an XMPP client-to-server (c2s) or server-to-server (s2s)
// session.
// If StartTLS is one of the supported stream features, the Negotiator attempts
// to negotiate it whether the server advertises support or not.
func NewNegotiator(cfg StreamConfig) Negotiator {
	return negotiator(cfg)
}

type negotiatorState struct {
	doRestart bool
	cancelTee context.CancelFunc
}

func negotiator(cfg StreamConfig) Negotiator {
	return func(ctx context.Context, s *Session, data interface{}) (mask SessionState, rw io.ReadWriter, restartNext interface{}, err error) {
		nState, ok := data.(negotiatorState)
		// If no state was passed in, this is the first negotiate call so make up a
		// default.
		if !ok {
			nState = negotiatorState{
				doRestart: true,
				cancelTee: nil,
			}
		}

		c := s.Conn()
		// If the session is not already using a tee conn, but we're configured to
		// use one, return the new teeConn and don't set any state bits.
		if _, ok := c.(teeConn); !ok && (cfg.TeeIn != nil || cfg.TeeOut != nil) {
			// Cancel any previous teeConn's so that we don't double write to in and
			// out.
			if nState.cancelTee != nil {
				nState.cancelTee()
			}

			// This context is just for canceling the tee effect so it is not part of
			// the normal context chain and its parent is Background.
			ctx, cancel := context.WithCancel(context.Background())
			c = newTeeConn(ctx, c, cfg.TeeIn, cfg.TeeOut)
			nState.cancelTee = cancel
			return mask, c, nState, err
		}

		// Loop for as long as we're not done negotiating features or a stream
		// restart is still required.
		if nState.doRestart {
			if (s.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in
				// response.

				s.in.Info, err = stream.Expect(ctx, s.in.d, s.State()&Received == Received)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				s.out.Info, err = stream.Send(s.Conn(), cfg.S2S, stream.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), attr.RandomID())
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
			} else {
				// If we're the initiating entity, send a new stream and then wait for
				// one in response.

				s.out.Info, err = stream.Send(s.Conn(), cfg.S2S, stream.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), "")
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				s.in.Info, err = stream.Expect(ctx, s.in.d, s.State()&Received == Received)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
			}
		}

		mask, rw, err = negotiateFeatures(ctx, s, data == nil, cfg.Features)
		nState.doRestart = rw != nil
		if cfg.S2S {
			mask |= S2S
		}
		return mask, rw, nState, err
	}
}