~samwhited/xmpp

12826d6fdcabacd3e9d5b46d4c8201fd0239931a — Sam Whited 3 years ago 29e21e8
stream: new SAX like API

internal/ns: add required stream namespaces
3 files changed, 54 insertions(+), 26 deletions(-)

M internal/ns/ns.go
M stream/error.go
M stream/error_test.go
M internal/ns/ns.go => internal/ns/ns.go +1 -0
@@ 15,6 15,7 @@ const (
	Stanza   = "urn:ietf:params:xml:ns:xmpp-stanzas"
	StartTLS = "urn:ietf:params:xml:ns:xmpp-tls"
	Stream   = "http://etherx.jabber.org/streams"
	Streams  = "urn:ietf:params:xml:ns:xmpp-streams"
	WS       = "urn:ietf:params:xml:ns:xmpp-framing"
	XML      = "http://www.w3.org/XML/1998/namespace"
)

M stream/error.go => stream/error.go +50 -23
@@ 10,7 10,11 @@ package stream // import "mellium.im/xmpp/stream"

import (
	"encoding/xml"
	"io"
	"net"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
)

// A list of stream errors defined in RFC 6120 §4.9.3


@@ 149,7 153,7 @@ var (
// SeeOtherHostError returns a new see-other-host error with the given network
// address as the host. If the address appears to be a raw IPv6 address (eg.
// "::1"), the error wraps it in brackets ("[::1]").
func SeeOtherHostError(addr net.Addr) Error {
func SeeOtherHostError(addr net.Addr, payload xmlstream.TokenReader) Error {
	var cdata string

	// If the address looks like an IPv6 literal, wrap it in []


@@ 159,14 163,28 @@ func SeeOtherHostError(addr net.Addr) Error {
		cdata = addr.String()
	}

	return Error{"see-other-host", []byte(cdata)}
	if payload != nil {
		payload = xmlstream.MultiReader(
			xmlstream.ReaderFunc(func() (xml.Token, error) {
				return xml.CharData(cdata), io.EOF
			}),
			payload,
		)
	} else {
		payload = xmlstream.ReaderFunc(func() (xml.Token, error) {
			return xml.CharData(cdata), io.EOF
		})
	}

	return Error{Err: "see-other-host", innerXML: payload}
}

// A Error represents an unrecoverable stream-level error that may include
// character data or arbitrary inner XML.
type Error struct {
	Err      string
	InnerXML []byte
	Err string

	innerXML xmlstream.TokenReader `xml:"-"`
}

// Error satisfies the builtin error interface and returns the name of the


@@ 196,31 214,40 @@ func (s *Error) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
		return err
	}
	s.Err = se.Err.XMLName.Local
	s.InnerXML = se.Err.InnerXML
	// TODO: s.InnerXML = se.Err.InnerXML
	return nil
}

// MarshalXML satisfies the xml package's Marshaler interface and allows
// StreamError's to be correctly marshaled back into XML.
func (s Error) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
	return e.EncodeElement(
		struct {
			Err struct {
				XMLName  xml.Name
				InnerXML []byte `xml:",innerxml"`
			}
		}{
			struct {
				XMLName  xml.Name
				InnerXML []byte `xml:",innerxml"`
			}{
				XMLName:  xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-streams", Local: s.Err},
				InnerXML: s.InnerXML,
			},
		},
func (s Error) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
	return s.WriteXML(e, xml.StartElement{})
}

// WriteXML satisfies the xmlstream.Marshaler interface.
// It is like MarshalXML except it writes tokens to w.
func (s Error) WriteXML(w xmlstream.TokenWriter, _ xml.StartElement) error {
	_, err := xmlstream.Copy(w, s.TokenReader(nil))
	if err != nil {
		return err
	}
	return w.Flush()
}

// TokenReader returns a new xmlstream.TokenReader that returns an encoding of
// the error.
func (s Error) TokenReader(payload xmlstream.TokenReader) xmlstream.TokenReader {
	inner := xmlstream.Wrap(s.innerXML, xml.StartElement{Name: xml.Name{Local: s.Err, Space: ns.Streams}})
	if payload != nil {
		inner = xmlstream.MultiReader(
			inner,
			payload,
		)
	}
	return xmlstream.Wrap(
		inner,
		xml.StartElement{
			Name: xml.Name{Space: "", Local: "stream:error"},
			Attr: []xml.Attr{},
			Name: xml.Name{Local: "error", Space: ns.Stream},
		},
	)
}

M stream/error_test.go => stream/error_test.go +3 -3
@@ 27,14 27,14 @@ var marshalSeeOtherHostTests = [...]struct {
	err    bool
}{
	// see-other-host errors should wrap IPv6 addresses in brackets.
	0: {&net.IPAddr{IP: net.ParseIP("::1")}, `<stream:error><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">[::1]</see-other-host></stream:error>`, false},
	1: {&net.IPAddr{IP: net.ParseIP("127.0.0.1")}, `<stream:error><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">127.0.0.1</see-other-host></stream:error>`, false},
	0: {&net.IPAddr{IP: net.ParseIP("::1")}, `<error xmlns="http://etherx.jabber.org/streams"><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">[::1]</see-other-host></error>`, false},
	1: {&net.IPAddr{IP: net.ParseIP("127.0.0.1")}, `<error xmlns="http://etherx.jabber.org/streams"><see-other-host xmlns="urn:ietf:params:xml:ns:xmpp-streams">127.0.0.1</see-other-host></error>`, false},
}

func TestMarshalSeeOtherHost(t *testing.T) {
	for i, test := range marshalSeeOtherHostTests {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			soh := stream.SeeOtherHostError(test.ipaddr)
			soh := stream.SeeOtherHostError(test.ipaddr, nil)
			xb, err := xml.Marshal(soh)
			switch xbs := string(xb); {
			case test.err && err == nil: