~samwhited/xmpp

ref: a8e3ffb3b023272e5b3010a96da9db9e0103a2b3 xmpp/stream.go -rw-r--r-- 5.3 KiB
a8e3ffb3Sam Whited stream: rename streamerror package to stream 3 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp

import (
	"context"
	"encoding/xml"
	"fmt"
	"io"

	"golang.org/x/text/language"
	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stream"
)

const (
	xmlHeader = `<?xml version="1.0" encoding="UTF-8"?>`
)

type streamInfo struct {
	to      *jid.JID
	from    *jid.JID
	id      string
	version internal.Version
	xmlns   string
	lang    language.Tag
}

// This MUST only return stream errors.
// TODO: Is the above true? Just make it return a StreamError?
func streamFromStartElement(s xml.StartElement) (streamInfo, error) {
	streamData := streamInfo{}
	for _, attr := range s.Attr {
		switch attr.Name {
		case xml.Name{Space: "", Local: "to"}:
			streamData.to = &jid.JID{}
			if err := streamData.to.UnmarshalXMLAttr(attr); err != nil {
				return streamData, stream.ImproperAddressing
			}
		case xml.Name{Space: "", Local: "from"}:
			streamData.from = &jid.JID{}
			if err := streamData.from.UnmarshalXMLAttr(attr); err != nil {
				return streamData, stream.ImproperAddressing
			}
		case xml.Name{Space: "", Local: "id"}:
			streamData.id = attr.Value
		case xml.Name{Space: "", Local: "version"}:
			(&streamData.version).UnmarshalXMLAttr(attr)
		case xml.Name{Space: "", Local: "xmlns"}:
			if attr.Value != "jabber:client" && attr.Value != "jabber:server" {
				return streamData, stream.InvalidNamespace
			}
			streamData.xmlns = attr.Value
		case xml.Name{Space: "xmlns", Local: "stream"}:
			if attr.Value != ns.Stream {
				return streamData, stream.InvalidNamespace
			}
		case xml.Name{Space: "xml", Local: "lang"}:
			streamData.lang = language.Make(attr.Value)
		}
	}
	return streamData, nil
}

// Sends a new XML header followed by a stream start element on the given
// io.Writer. We don't use an xml.Encoder both because Go's standard library xml
// package really doesn't like the namespaced stream:stream attribute and
// because we can guarantee well-formedness of the XML with a print in this case
// and printing is much faster than encoding. Afterwards, clear the
// StreamRestartRequired bit and set the output stream information.
func sendNewStream(s *Session, cfg *Config, id string) error {
	streamData := streamInfo{
		to:      cfg.Location,
		from:    cfg.Origin,
		lang:    cfg.Lang,
		version: cfg.Version,
	}
	switch cfg.S2S {
	case true:
		streamData.xmlns = ns.Server
	case false:
		streamData.xmlns = ns.Client
	}

	streamData.id = id
	if id == "" {
		id = " "
	} else {
		id = ` id='` + id + `' `
	}

	_, err := fmt.Fprintf(s.Conn(),
		xmlHeader+`<stream:stream%sto='%s' from='%s' version='%s' xml:lang='%s' xmlns='%s' xmlns:stream='http://etherx.jabber.org/streams'>`,
		id,
		cfg.Location.String(),
		cfg.Origin.String(),
		cfg.Version,
		cfg.Lang,
		streamData.xmlns,
	)
	if err != nil {
		return err
	}

	s.out.streamInfo = streamData
	return nil
}

func expectNewStream(ctx context.Context, s *Session) error {
	var foundHeader bool

	d := s.in.d
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		default:
		}
		t, err := d.Token()
		if err != nil {
			return err
		}
		switch tok := t.(type) {
		case xml.StartElement:
			switch {
			case tok.Name.Local == "error" && tok.Name.Space == ns.Stream:
				se := stream.StreamError{}
				if err := d.DecodeElement(&se, &tok); err != nil {
					return err
				}
				return se
			case tok.Name.Local != "stream":
				return stream.BadFormat
			case tok.Name.Space != ns.Stream:
				return stream.InvalidNamespace
			}

			streamData, err := streamFromStartElement(tok)
			switch {
			case err != nil:
				return err
			case streamData.version != internal.DefaultVersion:
				return stream.UnsupportedVersion
			}

			if (s.state&Received) != Received && streamData.id == "" {
				// if we are the initiating entity and there is no stream ID…
				return stream.BadFormat
			}
			s.in.streamInfo = streamData
			return nil
		case xml.ProcInst:
			// TODO: If version or encoding are declared, validate XML 1.0 and UTF-8
			if !foundHeader && tok.Target == "xml" {
				foundHeader = true
				continue
			}
			return stream.RestrictedXML
		case xml.EndElement:
			return stream.NotWellFormed
		default:
			return stream.RestrictedXML
		}
	}
}

func (s *Session) negotiateStreams(ctx context.Context, rw io.ReadWriter) (err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	for done := false; !done || rw != nil; {
		if rw != nil {
			s.features = make(map[string]interface{})
			s.negotiated = make(map[string]struct{})
			s.rw = rw
			s.in.d = xml.NewDecoder(s.rw)
			s.out.e = xml.NewEncoder(s.rw)
			rw = nil

			if (s.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in
				// response.
				if err = expectNewStream(ctx, s); err != nil {
					return err
				}
				if err = sendNewStream(s, s.config, internal.RandomID()); err != nil {
					return err
				}
			} else {
				// If we're the initiating entity, send a new stream and then wait for
				// one in response.
				if err = sendNewStream(s, s.config, ""); err != nil {
					return err
				}
				if err = expectNewStream(ctx, s); err != nil {
					return err
				}
			}
		}

		if done, rw, err = s.negotiateFeatures(ctx); err != nil {
			return err
		}
	}
	return nil
}