~samwhited/xmpp

9ca315a5a9bf2636d7de1efedd19fb7fd3459801 — Sam Whited 5 years ago 97239da
Move stream errors into their own package
8 files changed, 125 insertions(+), 94 deletions(-)

M bind.go
M features.go
M sasl.go
M starttls.go
M stream.go
R streamerror.go => streamerror/streamerror.go
A streamerror/streamerror_test.go
D streamerror_test.go
M bind.go => bind.go +4 -5
@@ 12,6 12,7 @@ import (

	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/streamerror"
)

const (


@@ 56,7 57,7 @@ func BindResource() StreamFeature {
				}
				start, ok := tok.(xml.StartElement)
				if !ok {
					return mask, BadFormat
					return mask, streamerror.BadFormat
				}
				resp := struct {
					IQ


@@ 70,14 71,12 @@ func BindResource() StreamFeature {
						return mask, err
					}
				default:
					return mask, BadFormat
					return mask, streamerror.BadFormat
				}

				switch {
				case resp.ID != reqID:
					// TODO: Do we actually care about this? Should this be a stanza error
					// instead?
					return mask, UndefinedCondition
					return mask, streamerror.UndefinedCondition
				case resp.Type == ResultIQ:
					conn.origin = resp.Bind.JID
				case resp.Type == ErrorIQ:

M features.go => features.go +7 -5
@@ 9,6 9,8 @@ import (
	"encoding/xml"
	"fmt"
	"io"

	"mellium.im/xmpp/streamerror"
)

// A StreamFeature represents a feature that may be selected during stream


@@ 99,7 101,7 @@ func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, err error) {
		}
		start, ok := t.(xml.StartElement)
		if !ok {
			return done, BadFormat
			return done, streamerror.BadFormat
		}
		list, err := readStreamFeatures(ctx, c, start)



@@ 144,9 146,9 @@ type streamFeaturesList struct {
func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement) (*streamFeaturesList, error) {
	switch {
	case start.Name.Local != "features":
		return nil, InvalidXML
		return nil, streamerror.InvalidXML
	case start.Name.Space != NSStream:
		return nil, BadNamespacePrefix
		return nil, streamerror.BadNamespacePrefix
	}

	sf := &streamFeaturesList{


@@ 190,9 192,9 @@ parsefeatures:
			}
			// Oops, how did that happen? We shouldn't have been able to hit an end
			// element that wasn't the </stream:features> token.
			return nil, InvalidXML
			return nil, streamerror.InvalidXML
		default:
			return nil, RestrictedXML
			return nil, streamerror.RestrictedXML
		}
	}
}

M sasl.go => sasl.go +3 -2
@@ 13,6 13,7 @@ import (

	"mellium.im/sasl"
	"mellium.im/xmpp/internal/saslerr"
	"mellium.im/xmpp/streamerror"
)

// BUG(ssw): We can't support server side SASL yet until the SASL library


@@ 118,7 119,7 @@ func SASL(mechanisms ...*sasl.Mechanism) StreamFeature {
							return mask, err
						}
					} else {
						return mask, BadFormat
						return mask, streamerror.BadFormat
					}
					if more, resp, err = selected.Step(challenge); err != nil {
						return mask, err


@@ 154,6 155,6 @@ func decodeSASLChallenge(d *xml.Decoder, start xml.StartElement) (challenge []by
		}
		return nil, false, fail
	default:
		return nil, false, UnsupportedStanzaType
		return nil, false, streamerror.UnsupportedStanzaType
	}
}

M starttls.go => starttls.go +7 -5
@@ 12,6 12,8 @@ import (
	"fmt"
	"io"
	"net"

	"mellium.im/xmpp/streamerror"
)

// BUG(ssw): STARTTLS feature does not have security layer byte precision.


@@ 67,27 69,27 @@ func StartTLS(required bool) StreamFeature {
				case xml.StartElement:
					switch {
					case tok.Name.Space != NSStartTLS:
						return mask, UnsupportedStanzaType
						return mask, streamerror.UnsupportedStanzaType
					case tok.Name.Local == "proceed":
						// Skip the </proceed> token.
						if err = conn.in.d.Skip(); err != nil {
							return EndStream, InvalidXML
							return EndStream, streamerror.InvalidXML
						}
						conn.rwc = tls.Client(netconn, conn.config.TLSConfig)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = conn.in.d.Skip(); err != nil {
							err = InvalidXML
							err = streamerror.InvalidXML
						}
						// Failure is not an "error", it's expected behavior. The server is
						// telling us to end the stream. However, if we encounter bad XML
						// while skipping the </failure> token, return that error.
						return EndStream, err
					default:
						return mask, UnsupportedStanzaType
						return mask, streamerror.UnsupportedStanzaType
					}
				default:
					return mask, RestrictedXML
					return mask, streamerror.RestrictedXML
				}
			}
			mask = Secure | StreamRestartRequired

M stream.go => stream.go +12 -11
@@ 13,6 13,7 @@ import (
	"golang.org/x/text/language"
	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/streamerror"
)

const streamIDLength = 16


@@ 68,12 69,12 @@ func streamFromStartElement(s xml.StartElement) (stream, error) {
		case xml.Name{Space: "", Local: "to"}:
			stream.to = &jid.JID{}
			if err := stream.to.UnmarshalXMLAttr(attr); err != nil {
				return stream, ImproperAddressing
				return stream, streamerror.ImproperAddressing
			}
		case xml.Name{Space: "", Local: "from"}:
			stream.from = &jid.JID{}
			if err := stream.from.UnmarshalXMLAttr(attr); err != nil {
				return stream, ImproperAddressing
				return stream, streamerror.ImproperAddressing
			}
		case xml.Name{Space: "", Local: "id"}:
			stream.id = attr.Value


@@ 81,12 82,12 @@ func streamFromStartElement(s xml.StartElement) (stream, error) {
			(&stream.version).UnmarshalXMLAttr(attr)
		case xml.Name{Space: "", Local: "xmlns"}:
			if attr.Value != "jabber:client" && attr.Value != "jabber:server" {
				return stream, InvalidNamespace
				return stream, streamerror.InvalidNamespace
			}
			stream.xmlns = attr.Value
		case xml.Name{Space: "xmlns", Local: "stream"}:
			if attr.Value != NSStream {
				return stream, InvalidNamespace
				return stream, streamerror.InvalidNamespace
			}
		case xml.Name{Space: "xml", Local: "lang"}:
			stream.lang = language.Make(attr.Value)


@@ 172,9 173,9 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
		case xml.StartElement:
			switch {
			case tok.Name.Local != "stream":
				return BadFormat
				return streamerror.BadFormat
			case tok.Name.Space != NSStream:
				return InvalidNamespace
				return streamerror.InvalidNamespace
			}

			stream, err := streamFromStartElement(tok)


@@ 182,13 183,13 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
			case err != nil:
				return err
			case stream.version != internal.DefaultVersion:
				return UnsupportedVersion
				return streamerror.UnsupportedVersion
			}

			if conn, ok := r.(*Conn); ok {
				if (conn.state&Received) != Received && stream.id == "" {
					// if we are the initiating entity and there is no stream ID…
					return BadFormat
					return streamerror.BadFormat
				}
				if (conn.state & StreamRestartRequired) == StreamRestartRequired {
					conn.state &= ^StreamRestartRequired


@@ 203,11 204,11 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
				foundHeader = true
				continue
			}
			return RestrictedXML
			return streamerror.RestrictedXML
		case xml.EndElement:
			return NotWellFormed
			return streamerror.NotWellFormed
		default:
			return RestrictedXML
			return streamerror.RestrictedXML
		}
	}
}

R streamerror.go => streamerror/streamerror.go +9 -3
@@ 2,7 2,14 @@
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp
// The streamerror package contains XMPP stream errors as defined by RFC 6120
// §4.9.
//
// These error conditions are not in the main xmpp package to prevent naming
// conflicts with similarly named stanza error conditions and because they are
// less frequently used. Most people will want to use the facilities of the
// mellium.im/xmpp package and not create stream errors directly.
package streamerror // import "mellium.im/xmpp/streamerror"

import (
	"encoding/xml"


@@ 199,7 206,7 @@ func (s *StreamError) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error
// MarshalXML satisfies the xml package's Marshaler interface and allows
// StreamError's to be correctly marshaled back into XML.
func (s StreamError) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
	e.EncodeElement(
	return e.EncodeElement(
		struct {
			Err struct {
				XMLName  xml.Name


@@ 219,5 226,4 @@ func (s StreamError) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
			Attr: []xml.Attr{},
		},
	)
	return nil
}

A streamerror/streamerror_test.go => streamerror/streamerror_test.go +83 -0
@@ 0,0 1,83 @@
// Copyright 2015 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package streamerror

import (
	"encoding/xml"
	"net"
	"testing"
)

var _ error = (*StreamError)(nil)
var _ error = StreamError{}
var _ xml.Marshaler = (*StreamError)(nil)
var _ xml.Marshaler = StreamError{}
var _ xml.Unmarshaler = (*StreamError)(nil)

func TestMarshalSeeOtherHost(t *testing.T) {
	for _, test := range []struct {
		ipaddr net.Addr
		xml    string
		err    bool
	}{
		// see-other-host errors should wrap IPv6 addresses in brackets.
		{&net.IPAddr{IP: net.ParseIP("::1")}, `<stream:error><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">[::1]</see-other-host></stream:error>`, false},
		{&net.IPAddr{IP: net.ParseIP("127.0.0.1")}, `<stream:error><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">127.0.0.1</see-other-host></stream:error>`, false},
	} {
		soh := SeeOtherHostError(test.ipaddr)
		xb, err := xml.Marshal(soh)
		switch xbs := string(xb); {
		case test.err && err == nil:
			t.Error("Expected marshaling SeeOtherHost error for address `%v` to fail", test.ipaddr)
			continue
		case !test.err && err != nil:
			t.Error(err)
			continue
		case err != nil:
			continue
		case xbs != test.xml:
			t.Logf("Expected `%s` but got `%s`", test.xml, xbs)
			t.Fail()
		}
	}
}

func TestUnmarshal(t *testing.T) {
	for _, test := range []struct {
		xml string
		se  StreamError
		err bool
	}{
		{
			`<stream:error><restricted-xml xmlns="urn:ietf:params:xml:ns:xmpp-streams"></restricted-xml></stream:error>`,
			RestrictedXML, false,
		},
		{
			`<stream:error></a>`,
			RestrictedXML, true,
		},
	} {
		s := StreamError{}
		err := xml.Unmarshal([]byte(test.xml), &s)
		switch {
		case test.err && err == nil:
			t.Errorf("Expected unmarshaling error for `%v` to fail", test.xml)
			continue
		case !test.err && err != nil:
			t.Error(err)
			continue
		case err != nil:
			continue
		case s.Err != test.se.Err || string(s.InnerXML) != string(test.se.InnerXML):
			t.Errorf("Expected `%#v` but got `%#v`", test.se, s)
		}
	}
}

func TestErrorReturnsErr(t *testing.T) {
	if RestrictedXML.Error() != "restricted-xml" {
		t.Error("Error should return the name of the err")
	}
}

D streamerror_test.go => streamerror_test.go +0 -63
@@ 1,63 0,0 @@
// Copyright 2015 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package xmpp

import (
	"bytes"
	"encoding/xml"
	"net"
	"testing"
)

var _ error = (*StreamError)(nil)
var _ error = StreamError{}
var _ xml.Marshaler = (*StreamError)(nil)
var _ xml.Marshaler = StreamError{}
var _ xml.Unmarshaler = (*StreamError)(nil)

// see-other-host errors should wrap IPv6 addresses in brackets.
func TestMarshalSeeOtherHostV6(t *testing.T) {
	ipaddr := net.IPAddr{IP: net.ParseIP("::1")}
	soh := SeeOtherHostError(&ipaddr)
	xb, err := xml.Marshal(soh)
	if err != nil {
		t.Log(err)
		t.FailNow()
	}

	if xbs := string(xb); xbs != `<stream:error><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">[::1]</see-other-host></stream:error>` {
		t.Logf("Expected [::1] but got %s", xbs)
		t.Fail()
	}
}

// Stream errors should be marshalable and unmarshalable.
func TestUnmarshalMarshalSteamError(t *testing.T) {
	b := []byte(`<stream:error>
	<restricted-xml xmlns="urn:ietf:params:xml:ns:xmpp-streams">a</restricted-xml>
</stream:error>`)
	mb := bytes.NewBuffer(b)
	d := xml.NewDecoder(mb)
	s := &StreamError{}
	err := d.Decode(s)
	if err != nil {
		t.Log(err)
		t.FailNow()
	}
	if s.Error() != "restricted-xml" {
		t.Logf("Expected restricted-xml but got %+v\n", s)
		t.FailNow()
	}

	xb, err := xml.MarshalIndent(s, "", "\t")
	if err != nil {
		t.Log(err)
		t.FailNow()
	}
	if string(b) != string(xb) {
		t.Logf("Expected %s but got %s", string(b), string(xb))
		t.Fail()
	}
}