~samwhited/xmpp

xmpp/websocket/ws.go -rw-r--r-- 7.1 KiB
e9b0a2deSam Whited docs: do a quick editing pass over the docs a day ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package websocket

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"net"
	"net/http"
	"sort"
	"strings"

	"golang.org/x/net/websocket"

	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/discover"
	"mellium.im/xmpp/jid"
)

// NewSession establishes an XMPP session from the perspective of the initiating
// client on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func NewSession(ctx context.Context, addr jid.JID, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := Negotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: features,
		}
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {
		mask |= xmpp.Secure
	}
	return xmpp.NewSession(ctx, addr.Domain(), addr, rw, mask, n)
}

// ReceiveSession establishes an XMPP session from the perspective of the
// receiving server on rw using the WebSocket subprotocol.
// It does not perform the WebSocket handshake.
func ReceiveSession(ctx context.Context, rw io.ReadWriter, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	n := Negotiator(func(*xmpp.Session, *xmpp.StreamConfig) xmpp.StreamConfig {
		return xmpp.StreamConfig{
			Features: features,
		}
	})
	var mask xmpp.SessionState
	if wsConn, ok := rw.(*websocket.Conn); ok && wsConn.LocalAddr().(*websocket.Addr).Scheme == "wss" {
		mask |= xmpp.Secure
	}
	return xmpp.ReceiveSession(ctx, rw, mask, n)
}

// NewClient performs the WebSocket handshake on rwc and then attempts to
// establish an XMPP session on top of it.
// Location is the WebSocket location and addr is the actual JID expected at the
// XMPP layer.
func NewClient(ctx context.Context, origin, location string, addr jid.JID, rwc io.ReadWriteCloser, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	d := Dialer{
		Origin: origin,
	}
	cfg, err := d.config(location)
	if err != nil {
		return nil, err
	}
	conn, err := websocket.NewClient(cfg, rwc)
	if err != nil {
		return nil, err
	}
	return NewSession(ctx, addr, conn, features...)
}

// DialSession uses a default dialer to create a WebSocket connection and
// attempts to negotiate an XMPP session over it.
//
// If the provided context is canceled after stream negotiation is complete it
// has no effect on the session.
func DialSession(ctx context.Context, origin string, addr jid.JID, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	conn, err := Dial(ctx, origin, addr)
	if err != nil {
		return nil, err
	}
	return NewSession(ctx, addr, conn, features...)
}

// Dial discovers WebSocket endpoints associated with the given address and
// attempts to make a connection to one of them (with appropriate fallback
// behavior).
//
// Calling Dial is the equivalent of creating a Dialer type with only the Origin
// option set and calling its Dial method.
func Dial(ctx context.Context, origin string, addr jid.JID) (net.Conn, error) {
	d := Dialer{
		Origin: origin,
	}
	return d.Dial(ctx, addr)
}

// DialDirect dials the provided WebSocket endpoint without performing any TXT
// or Web Host Metadata file lookup.
//
// Calling DialDirect is the equivalent of creating a Dialer type with only the
// Origin option set and calling its DialDirect method.
func DialDirect(ctx context.Context, origin, addr string) (net.Conn, error) {
	d := Dialer{
		Origin: origin,
	}
	return d.DialDirect(ctx, addr)
}

// Dialer discovers and connects to the WebSocket address on the named network.
// The zero value for each field is equivalent to dialing without that option
// with the exception of Origin (which is required).
// Dialing with the zero value of Dialer (except Origin) is equivalent to
// calling the Dial function.
type Dialer struct {
	// A WebSocket client origin.
	Origin string

	// TLS config for secure WebSocket (wss).
	// If TLSConfig is nil a default config is used.
	TLSConfig *tls.Config

	// Allow falling back to insecure WebSocket connections without TLS.
	// If endpoint discovery is used and a secure WebSocket endpoint is available
	// it will still be prioritized.
	//
	// The WebSocket transport does not support StartTLS so this value will fall
	// back to using bare WebSockets (a scheme of ws:) and is therefore insecure
	// and should never be used.
	InsecureNoTLS bool

	// Additional header fields to be sent in WebSocket opening handshake.
	Header http.Header

	// Dialer used when opening websocket connections.
	Dialer *net.Dialer

	// Resolver to use when looking up TXT records.
	Resolver *net.Resolver

	// HTTP Client to use when looking up Web Host Metadata files.
	Client *http.Client
}

// Dial opens a new client connection to a WebSocket.
//
// If addr is a hostname or has a scheme of "http" or "https" it will attempt to
// look up TXT records and Web Host Metadata files to find WebSocket endpoints
// to connect to.
// If however addr is a complete URI with a scheme of "ws" or "wss" it will
// attempt to connect to the provided endpoint directly with no other lookup.
func (d *Dialer) Dial(ctx context.Context, addr jid.JID) (net.Conn, error) {
	// Setup defaults for the underlying client and resolver.
	httpClient := d.Client
	if httpClient == nil {
		httpClient = &http.Client{}
	}
	netResolver := d.Resolver
	if netResolver == nil {
		netResolver = &net.Resolver{}
	}

	urls, err := discover.LookupWebSocket(ctx, netResolver, httpClient, addr)
	if err != nil {
		return nil, err
	}
	if len(urls) == 0 {
		return nil, fmt.Errorf("websocket: no XMPP websocket endpoint found on %s", addr.Domainpart())
	}
	// Prioritize wss over anything else, then ws, then anything else that will
	// likely just result in an error.
	sort.Slice(urls, func(i, j int) bool {
		switch {
		case strings.HasPrefix(urls[i], "wss:"):
			return true
		case strings.HasPrefix(urls[i], "wss:"):
			return false
		case strings.HasPrefix(urls[i], "ws:"):
			return true
		}
		return false
	})

	var conn net.Conn
	var cfg *websocket.Config
	for _, u := range urls {
		if !d.InsecureNoTLS && strings.HasPrefix(u, "ws:") {
			continue
		}
		cfg, err = d.config(u)
		if err != nil {
			continue
		}
		conn, err = websocket.DialConfig(cfg)
		if err == nil {
			return conn, err
		}
	}
	return conn, err
}

// DialDirect dials the websocket endpoint without performing any TXT or Web
// Host Metadata file lookup.
//
// Context is currently not used due to restrictions in the underlying WebSocket
// implementation.
// This may change in the future.
func (d *Dialer) DialDirect(_ context.Context, addr string) (net.Conn, error) {
	cfg, err := d.config(addr)
	if err != nil {
		return nil, err
	}
	return websocket.DialConfig(cfg)
}

func (d *Dialer) config(addr string) (cfg *websocket.Config, err error) {
	cfg, err = websocket.NewConfig(addr, d.Origin)
	if err != nil {
		return nil, err
	}
	cfg.Protocol = []string{WSProtocol}
	cfg.TlsConfig = d.TLSConfig
	if cfg.TlsConfig == nil {
		cfg.TlsConfig = &tls.Config{
			ServerName: cfg.Location.Host,
			MinVersion: tls.VersionTLS12,
		}
	}
	cfg.Dialer = d.Dialer
	return cfg, nil
}