~samwhited/xmpp

ref: eff9a1787b32b033cb8944ed99a0e0f71d5a02f6 xmpp/internal/integration/integration.go -rw-r--r-- 13.2 KiB
eff9a178Sam Whited internal/integration: add support for components 11 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
// Copyright 2020 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package integration contains helpers for integration testing.
//
// Normally users writing integration tests should not use this package
// directly, instead they should use the packges in subdirectories of this
// package.
package integration // import "mellium.im/xmpp/internal/integration"

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"math/big"
	"net"
	"os"
	"os/exec"
	"path/filepath"
	"sync"
	"testing"
	"time"

	"mellium.im/xmpp"
	"mellium.im/xmpp/jid"
)

// Cmd is an external command being prepared or run.
//
// A Cmd cannot be reused after calling its Run, Output or CombinedOutput
// methods.
type Cmd struct {
	*exec.Cmd

	name         string
	cfgDir       string
	kill         context.CancelFunc
	cfgF         []func() error
	deferF       []func(*Cmd) error
	in, out      *testWriter
	c2sListener  net.Listener
	s2sListener  net.Listener
	compListener net.Listener
	c2sNetwork   string
	s2sNetwork   string
	compNetwork  string
	shutdown     func(*Cmd) error
	user         jid.JID
	pass         string
	clientCrt    []byte
	clientCrtKey interface{}

	// Config is meant to be used by internal packages like prosody and ejabberd
	// to store their internal representation of the config before writing it out.
	Config interface{}
}

// New creates a new, unstarted, command.
//
// The provided context is used to kill the process (by calling os.Process.Kill)
// if the context becomes done before the command completes on its own.
func New(ctx context.Context, name string, opts ...Option) (*Cmd, error) {
	ctx, cancel := context.WithCancel(ctx)
	cmd := &Cmd{
		/* #nosec */
		Cmd:  exec.CommandContext(ctx, name),
		name: name,
		kill: cancel,
	}
	var err error
	cmd.cfgDir, err = ioutil.TempDir("", cmd.name)
	if err != nil {
		return nil, err
	}
	for _, opt := range opts {
		err = opt(cmd)
		if err != nil {
			return nil, fmt.Errorf("error applying option: %v", err)
		}
	}
	for _, f := range cmd.cfgF {
		err = f()
		if err != nil {
			return nil, fmt.Errorf("error running config func: %w", err)
		}
	}

	return cmd, nil
}

// ClientCert returns the last configured client certificate.
// The certificate request info is currently ignored and is only there to make
// promoting this method to a function and using it as
// tls.Config.GetClientCertificate possible.
func (cmd *Cmd) ClientCert(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
	return &tls.Certificate{
		Certificate: [][]byte{cmd.clientCrt},
		PrivateKey:  cmd.clientCrtKey,
	}, nil
}

// C2SListen returns a listener with a random port.
// The listener is created on the first call to C2SListener.
// Subsequent calls ignore the arguments and return the existing listener.
func (cmd *Cmd) C2SListen(network, addr string) (net.Listener, error) {
	if cmd.c2sListener != nil {
		return cmd.c2sListener, nil
	}

	var err error
	cmd.c2sListener, err = net.Listen(network, addr)
	cmd.c2sNetwork = network
	return cmd.c2sListener, err
}

// S2SListen returns a listener with a random port.
// The listener is created on the first call to S2SListener.
// Subsequent calls ignore the arguments and return the existing listener.
func (cmd *Cmd) S2SListen(network, addr string) (net.Listener, error) {
	if cmd.s2sListener != nil {
		return cmd.s2sListener, nil
	}

	var err error
	cmd.s2sListener, err = net.Listen(network, addr)
	cmd.s2sNetwork = network
	return cmd.s2sListener, err
}

// ComponentListen returns a listener with a random port.
// The listener is created on the first call to ComponentListener.
// Subsequent calls ignore the arguments and return the existing listener.
func (cmd *Cmd) ComponentListen(network, addr string) (net.Listener, error) {
	if cmd.compListener != nil {
		return cmd.compListener, nil
	}

	var err error
	cmd.compListener, err = net.Listen(network, addr)
	cmd.compNetwork = network
	return cmd.compListener, err
}

// ConfigDir returns the temporary directory used to store config files.
func (cmd *Cmd) ConfigDir() string {
	return cmd.cfgDir
}

// Close kills the command if it is still running and cleans up any temporary
// resources that were created.
func (cmd *Cmd) Close() error {
	defer cmd.kill()

	var e error
	if cmd.shutdown != nil {
		e = cmd.shutdown(cmd)
	}
	err := os.RemoveAll(cmd.cfgDir)
	if err != nil {
		return err
	}
	return e
}

// User returns the address and password of a user created on the server (if
// any).
func (cmd *Cmd) User() (jid.JID, string) {
	return cmd.user, cmd.pass
}

// DialClient attempts to connect to the server with a client-to-server (c2s)
// connection by dialing the address reserved by C2SListen and then negotiating
// a stream with the location set to the domainpart of j and the origin set to
// j.
func (cmd *Cmd) DialClient(ctx context.Context, j jid.JID, t *testing.T, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	return cmd.dial(ctx, false, j.Domain(), j, t, features...)
}

// DialServer attempts to connect to the server with a server-to-server (s2s)
// connection by dialing the address reserved by S2SListen and then negotiating
// a stream.
func (cmd *Cmd) DialServer(ctx context.Context, location, origin jid.JID, t *testing.T, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	return cmd.dial(ctx, true, location, origin, t, features...)
}

// C2SAddr returns the client-to-server address and network.
func (cmd *Cmd) C2SAddr() (net.Addr, string) {
	return cmd.c2sListener.Addr(), cmd.c2sNetwork
}

// S2SAddr returns the server-to-server address and network.
func (cmd *Cmd) S2SAddr() (net.Addr, string) {
	return cmd.s2sListener.Addr(), cmd.s2sNetwork
}

// ComponentAddr returns the component address and network.
func (cmd *Cmd) ComponentAddr() (net.Addr, string) {
	return cmd.compListener.Addr(), cmd.compNetwork
}

// ComponentConn dials a connection to the component socket and returns it
// without negotiating a session.
func (cmd *Cmd) ComponentConn(ctx context.Context) (net.Conn, error) {
	if cmd.compListener == nil {
		return nil, errors.New("component not configured, please configure a component listener")
	}
	addr := cmd.compListener.Addr().String()
	network := cmd.compNetwork

	conn, err := net.Dial(network, addr)
	if err != nil {
		return nil, fmt.Errorf("error dialing %s: %w", addr, err)
	}
	return conn, nil
}

// Conn dials a connection and returns it without negotiating a session.
func (cmd *Cmd) Conn(ctx context.Context, s2s bool) (net.Conn, error) {
	switch {
	case s2s && cmd.s2sListener == nil:
		return nil, errors.New("s2s not configured, please configure an s2s listener")
	case !s2s && cmd.c2sListener == nil:
		return nil, errors.New("c2s not configured, please configure a c2s listener")
	}

	var addr, network string
	if s2s {
		addr = cmd.s2sListener.Addr().String()
		network = cmd.s2sNetwork
	} else {
		addr = cmd.c2sListener.Addr().String()
		network = cmd.c2sNetwork
	}

	conn, err := net.Dial(network, addr)
	if err != nil {
		return nil, fmt.Errorf("error dialing %s: %w", addr, err)
	}
	return conn, nil
}

func (cmd *Cmd) dial(ctx context.Context, s2s bool, location, origin jid.JID, t *testing.T, features ...xmpp.StreamFeature) (*xmpp.Session, error) {
	conn, err := cmd.Conn(ctx, s2s)
	if err != nil {
		return nil, err
	}
	negotiator := xmpp.NewNegotiator(xmpp.StreamConfig{
		Features: features,
		TeeIn:    cmd.in,
		TeeOut:   cmd.out,
		S2S:      s2s,
	})
	session, err := xmpp.NegotiateSession(
		ctx,
		location,
		origin,
		conn,
		false,
		negotiator,
	)
	if err != nil {
		return nil, fmt.Errorf("error establishing session: %w", err)
	}
	return session, nil
}

// Option is used to configure a Cmd.
type Option func(cmd *Cmd) error

// User sets the values that will be returned by a call to cmd.User later. It
// does not actually create a user.
func User(user jid.JID, pass string) Option {
	return func(cmd *Cmd) error {
		cmd.user = user
		cmd.pass = pass
		return nil
	}
}

// Shutdown is run before the configuration is removed and is meant to
// gracefully shutdown the application in case it does not handle the kill
// signal correctly.
func Shutdown(f func(*Cmd) error) Option {
	return func(cmd *Cmd) error {
		cmd.shutdown = f
		return nil
	}
}

// Args sets additional command line args to be passed to the command.
func Args(f ...string) Option {
	return func(cmd *Cmd) error {
		cmd.Cmd.Args = append(cmd.Args, f...)
		return nil
	}
}

// Cert creates a private key and certificate with the given name.
func Cert(name string) Option {
	return cert(name, &x509.Certificate{
		SerialNumber: big.NewInt(1),
		NotBefore:    time.Now(),
		NotAfter:     time.Now().Add(365 * 24 * time.Hour),
		DNSNames:     []string{filepath.Base(name)},
	})
}

// ClientCert creates a private key and certificate with the given name that
// can be used for TLS authentication.
func ClientCert(name string) Option {
	return cert(name, &x509.Certificate{
		SerialNumber: big.NewInt(1),
		NotBefore:    time.Now(),
		NotAfter:     time.Now().Add(365 * 24 * time.Hour),
		DNSNames:     []string{filepath.Base(name)},
		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
	})
}

func cert(name string, crt *x509.Certificate) Option {
	return func(cmd *Cmd) error {
		key, err := rsa.GenerateKey(rand.Reader, 2048)
		if err != nil {
			return err
		}
		err = TempFile(name+".key", func(_ *Cmd, w io.Writer) error {
			return pem.Encode(w, &pem.Block{
				Type:  "RSA PRIVATE KEY",
				Bytes: x509.MarshalPKCS1PrivateKey(key),
			})
		})(cmd)
		if err != nil {
			return err
		}
		return TempFile(name+".crt", func(_ *Cmd, w io.Writer) error {
			cert, err := x509.CreateCertificate(rand.Reader, crt, crt, key.Public(), key)
			if err != nil {
				return err
			}
			if len(crt.ExtKeyUsage) > 0 && crt.ExtKeyUsage[0] == x509.ExtKeyUsageClientAuth {
				cmd.clientCrt = cert
				cmd.clientCrtKey = key
			}
			return pem.Encode(w, &pem.Block{
				Type:  "CERTIFICATE",
				Bytes: cert,
			})
		})(cmd)
	}
}

// TempFile creates a file in the commands temporary working directory.
// After all configuration is complete it then calls f to populate the config
// files.
func TempFile(cfgFileName string, f func(*Cmd, io.Writer) error) Option {
	return func(cmd *Cmd) (err error) {
		dir := filepath.Dir(cfgFileName)
		if dir != "" && dir != "." && dir != "/" && dir != ".." {
			err = os.MkdirAll(filepath.Join(cmd.cfgDir, dir), 0700)
			if err != nil {
				return err
			}
		}

		cmd.cfgF = append(cmd.cfgF, func() error {
			cfgFilePath := filepath.Join(cmd.cfgDir, cfgFileName)
			cfgFile, err := os.Create(cfgFilePath)
			if err != nil {
				return err
			}

			err = f(cmd, cfgFile)
			if err != nil {
				/* #nosec */
				cfgFile.Close()
				return err
			}
			return cfgFile.Close()
		})
		return nil
	}
}

type testWriter struct {
	sync.Mutex
	t   *testing.T
	tag string
}

func (w *testWriter) Write(p []byte) (int, error) {
	if w == nil {
		return len(p), nil
	}
	w.Lock()
	defer w.Unlock()

	if w.t != nil {
		w.t.Logf("%s%s", w.tag, p)
	}
	return len(p), nil
}

func (w *testWriter) Update(t *testing.T) {
	if w == nil {
		return
	}
	w.Lock()
	w.t = t
	w.Unlock()
}

// Log configures the command to log output to the current testing.T.
func Log() Option {
	return func(cmd *Cmd) error {
		w := &testWriter{}
		cmd.Cmd.Stdout = w
		cmd.Cmd.Stderr = w
		return nil
	}
}

// LogXML configures the command to log sent and received XML to the current
// testing.T.
func LogXML() Option {
	return func(cmd *Cmd) error {
		cmd.in = &testWriter{tag: "RECV"}
		cmd.out = &testWriter{tag: "SENT"}
		return nil
	}
}

// Defer is an option that calls f after the command is started.
func Defer(f func(*Cmd) error) Option {
	return func(cmd *Cmd) error {
		cmd.deferF = append(cmd.deferF, f)
		return nil
	}
}

// Test starts a command and returns a function that runs tests as a subtest
// using t.Run.
// Multiple calls to the returned function will result in uniquely named
// subtests.
// When all subtests have completed, the daemon is stopped.
func Test(ctx context.Context, name string, t *testing.T, opts ...Option) SubtestRunner {
	ctx, cancel := context.WithCancel(ctx)
	t.Cleanup(cancel)

	cmd, err := New(ctx, name, opts...)
	if err != nil {
		t.Fatal(err)
	}

	t.Cleanup(func() {
		err := cmd.Close()
		if err != nil {
			t.Logf("error cleaning up test: %v", err)
		}
	})

	if tw, ok := cmd.Cmd.Stdout.(*testWriter); ok {
		tw.Update(t)
	}
	cmd.in.Update(t)
	cmd.out.Update(t)
	err = cmd.Start()
	if err != nil {
		t.Fatal(err)
	}
	if cmd.c2sListener != nil {
		err = waitSocket(cmd.c2sNetwork, cmd.c2sListener.Addr().String())
		if err != nil {
			t.Fatal(err)
		}
	}
	if cmd.s2sListener != nil {
		err = waitSocket(cmd.s2sNetwork, cmd.s2sListener.Addr().String())
		if err != nil {
			t.Fatal(err)
		}
	}
	for _, f := range cmd.deferF {
		err := f(cmd)
		if err != nil {
			t.Fatal(err)
		}
	}

	i := -1
	return func(f func(context.Context, *testing.T, *Cmd)) bool {
		i++
		return t.Run(fmt.Sprintf("%s/%d", name, i), func(t *testing.T) {
			if tw, ok := cmd.Cmd.Stdout.(*testWriter); ok {
				tw.Update(t)
			}
			cmd.in.Update(t)
			cmd.out.Update(t)
			f(ctx, t, cmd)
		})
	}
}

// SubtestRunner is the signature of a function that can be used to start
// subtests.
type SubtestRunner func(func(context.Context, *testing.T, *Cmd)) bool