~samwhited/xmpp

ref: b8d4b070f83a6621be2cfadf3a31042a1ae6ecc1 xmpp/sasl_test.go -rw-r--r-- 3.3 KiB
b8d4b070Sam Whited Add partial Send() function 5 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp

import (
	"bytes"
	"context"
	"encoding/xml"
	"strings"
	"testing"

	"mellium.im/sasl"
	"mellium.im/xmpp/ns"
)

func TestSASLPanicsNoMechanisms(t *testing.T) {
	defer func() {
		if r := recover(); r == nil {
			t.Error("Expected call to SASL() with no mechanisms to panic")
		}
	}()
	_ = SASL()
}

func TestSASLList(t *testing.T) {
	var b bytes.Buffer
	e := xml.NewEncoder(&b)
	start := xml.StartElement{Name: xml.Name{Space: ns.SASL, Local: "mechanisms"}}
	s := SASL(sasl.Plain, sasl.ScramSha256)
	req, err := s.List(context.Background(), e, start)
	switch {
	case err != nil:
		t.Fatal(err)
	case req != true:
		t.Error("Expected SASL to be a required feature")
	}
	if err = e.Flush(); err != nil {
		t.Fatal(err)
	}

	// Mechanisms should be printed exactly thus:
	if !bytes.Contains((&b).Bytes(), []byte(`<mechanism>PLAIN</mechanism>`)) {
		t.Error("Expected mechanisms list to include PLAIN")
	}
	if !bytes.Contains((&b).Bytes(), []byte(`<mechanism>SCRAM-SHA-256</mechanism>`)) {
		t.Error("Expected mechanisms list to include SCRAM-SHA-256")
	}

	// The wrapper can be a bit more flexible as long as the mechanisms are there.
	d := xml.NewDecoder(&b)
	tok, err := d.Token()
	if err != nil {
		t.Fatal(err)
	}
	se := tok.(xml.StartElement)
	if se.Name.Local != "mechanisms" || se.Name.Space != ns.SASL {
		t.Errorf("Unexpected name for mechanisms start element: %+v", se.Name)
	}
	// Skip two mechanisms
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	d.Skip()
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	d.Skip()

	// Check the end token.
	tok, err = d.Token()
	if err != nil {
		t.Fatal(err)
	}
	_ = tok.(xml.EndElement)
}

func TestSASLParse(t *testing.T) {
	s := SASL(sasl.Plain)
	for _, test := range []struct {
		xml   string
		items []string
		err   bool
	}{
		{`<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>
		<mechanism>EXTERNAL</mechanism>
		<mechanism>SCRAM-SHA-1-PLUS</mechanism>
		<mechanism>SCRAM-SHA-1</mechanism>
		<mechanism>PLAIN</mechanism>
		</mechanisms>`, []string{"EXTERNAL", "PLAIN", "SCRAM-SHA-1-PLUS", "SCRAM-SHA-1"}, false},
		{`<oops xmlns='urn:ietf:params:xml:ns:xmpp-sasl'><mechanism>PLAIN</mechanism></oop>`, nil, true},
		{`<mechanisms xmlns='badns'><mechanism>PLAIN</mechanism></mechanisms>`, nil, true},
		{`<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'><mechanism xmlns="nope">PLAIN</mechanism></mechanisms>`, []string{}, false},
	} {
		r := strings.NewReader(test.xml)
		d := xml.NewDecoder(r)
		tok, _ := d.Token()
		start := tok.(xml.StartElement)
		req, list, err := s.Parse(context.Background(), d, &start)
		switch {
		case test.err && err == nil:
			t.Error("Expected sasl mechanism parsing to error")
		case !test.err && err != nil:
			t.Error(err)
		case req != true:
			t.Error("Expected parsed SASL feature to be required")
		case len(list.([]string)) != len(test.items):
			t.Errorf("Expected data to contain 4 items, got %d", len(list.([]string)))
		}
		for _, m := range test.items {
			matched := false
			for _, m2 := range list.([]string) {
				if m == m2 {
					matched = true
					break
				}
			}
			if !matched {
				t.Fatalf("Expected data to contain %v", m)
			}

		}
	}
}