~samwhited/xmpp

ref: 52be89df867be9fc4e8ba83b0c3880887a6485ba xmpp/starttls.go -rw-r--r-- 3.6 KiB
52be89dfSam Whited form: Rename Option and FieldOption types 4 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp

import (
	"context"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"
	"io"

	"mellium.im/xmpp/ns"
	"mellium.im/xmpp/streamerror"
)

// BUG(ssw): STARTTLS feature does not have security layer byte precision.

// Errors used by STARTTLS.
var (
	ErrTLSUpgradeFailed = errors.New("The underlying connection cannot be upgraded to TLS")
)

// StartTLS returns a new stream feature that can be used for negotiating TLS.
// For StartTLS to work, the underlying connection must support TLS (it must
// implement net.Conn).
func StartTLS(required bool) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Local: "starttls", Space: ns.StartTLS},
		Prohibited: Secure,
		List: func(ctx context.Context, e *xml.Encoder, start xml.StartElement) (req bool, err error) {
			if err = e.EncodeToken(start); err != nil {
				return required, err
			}
			if required {
				startRequired := xml.StartElement{Name: xml.Name{Space: "", Local: "required"}}
				if err = e.EncodeToken(startRequired); err != nil {
					return required, err
				}
				if err = e.EncodeToken(startRequired.End()); err != nil {
					return required, err
				}
			}
			return required, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
				XMLName  xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
				Required struct {
					XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls required"`
				}
			}{}
			err := d.DecodeElement(&parsed, start)
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			conn := session.conn
			if conn == nil {
				return mask, nil, ErrTLSUpgradeFailed
			}

			// Fetch or create a TLSConfig to use.
			var tlsconf *tls.Config
			if session.config.TLSConfig == nil {
				tlsconf = &tls.Config{
					ServerName: session.LocalAddr().Domain().String(),
				}
			} else {
				tlsconf = session.config.TLSConfig
			}

			if (session.state & Received) == Received {
				fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
				rw = tls.Server(conn, tlsconf)
			} else {
				// Select starttls for negotiation.
				fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)

				// Receive a <proceed/> or <failure/> response from the server.
				t, err := session.in.d.Token()
				if err != nil {
					return mask, nil, err
				}
				switch tok := t.(type) {
				case xml.StartElement:
					switch {
					case tok.Name.Space != ns.StartTLS:
						return mask, nil, streamerror.UnsupportedStanzaType
					case tok.Name.Local == "proceed":
						// Skip the </proceed> token.
						if err = session.in.d.Skip(); err != nil {
							return mask, nil, streamerror.InvalidXML
						}
						rw = tls.Client(conn, tlsconf)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = session.in.d.Skip(); err != nil {
							err = streamerror.InvalidXML
						}
						// Failure is not an "error", it's expected behavior. Immediately
						// afterwards the server will end the stream. However, if we
						// encounter bad XML while skipping the </failure> token, return
						// that error.
						return mask, nil, err
					default:
						return mask, nil, streamerror.UnsupportedStanzaType
					}
				default:
					return mask, nil, streamerror.RestrictedXML
				}
			}
			mask = Secure
			return
		},
	}
}