~samwhited/xmpp

7b9310cb11358835c951d5f6a41be72dfc76cd49 — Sam Whited 3 years ago 379c0dc
xmpp: use new JID API
M bind.go => bind.go +7 -16
@@ 39,7 39,7 @@ func BindResource() StreamFeature {
// requested by the client (or an empty string if a specific resource was not
// requested). Resources generated by the server function should be random to
// prevent certain security issues related to guessing resourceparts.
func BindCustom(server func(*jid.JID, string) (*jid.JID, error)) StreamFeature {
func BindCustom(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
	return bind(server)
}



@@ 67,29 67,20 @@ func (biq *bindIQ) WriteXML(w xmlstream.TokenWriter) (n int, err error) {
}

type bindPayload struct {
	Resource string   `xml:"resource,omitempty"`
	JID      *jid.JID `xml:"jid,omitempty"`
	Resource string  `xml:"resource,omitempty"`
	JID      jid.JID `xml:"jid,omitempty"`
}

func (bp bindPayload) TokenReader() xml.TokenReader {
	if bp.JID != nil {
		return xmlstream.Wrap(
			xmlstream.ReaderFunc(func() (xml.Token, error) {
				return xml.CharData(bp.JID.String()), io.EOF
			}),
			xml.StartElement{Name: xml.Name{Local: "jid"}},
		)
	}

	return xmlstream.Wrap(
		xmlstream.ReaderFunc(func() (xml.Token, error) {
			return xml.CharData(bp.Resource), io.EOF
			return xml.CharData(bp.JID.String()), io.EOF
		}),
		xml.StartElement{Name: xml.Name{Local: "resource"}},
		xml.StartElement{Name: xml.Name{Local: "jid"}},
	)
}

func bind(server func(*jid.JID, string) (*jid.JID, error)) StreamFeature {
func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Space: ns.Bind, Local: "bind"},
		Necessary:  Authn,


@@ 137,7 128,7 @@ func bind(server func(*jid.JID, string) (*jid.JID, error)) StreamFeature {

				iqid := internal.GetAttr(start.Attr, "id")

				var j *jid.JID
				var j jid.JID
				if server != nil {
					j, err = server(session.RemoteAddr(), resReq.Bind.Resource)
				} else {

M component/component.go => component/component.go +3 -3
@@ 29,14 29,14 @@ const (

// NewClientSession initiates an XMPP session on the given io.ReadWriter using
// the component protocol.
func NewClientSession(ctx context.Context, addr *jid.JID, secret []byte, rw io.ReadWriter) (*xmpp.Session, error) {
func NewClientSession(ctx context.Context, addr jid.JID, secret []byte, rw io.ReadWriter) (*xmpp.Session, error) {
	addr = addr.Domain()
	return xmpp.NegotiateSession(ctx, addr, addr, rw, Negotiator(addr, secret, false))
}

// AcceptSession accepts an XMPP session on the given io.ReadWriter using the
// component protocol.
//func AcceptSession(ctx context.Context, addr *jid.JID, secret []byte, rw io.ReadWriter) (*xmpp.Session, error) {
//func AcceptSession(ctx context.Context, addr jid.JID, secret []byte, rw io.ReadWriter) (*xmpp.Session, error) {
//	return xmpp.NegotiateSession(ctx, nil, rw, Negotiator(addr, secret, true))
//}



@@ 44,7 44,7 @@ func NewClientSession(ctx context.Context, addr *jid.JID, secret []byte, rw io.R
// protocol connection on the provided io.ReadWriter.
//
// It currently only supports the client side of the component protocol.
func Negotiator(addr *jid.JID, secret []byte, recv bool) xmpp.Negotiator {
func Negotiator(addr jid.JID, secret []byte, recv bool) xmpp.Negotiator {
	return func(ctx context.Context, s *xmpp.Session, _ interface{}) (mask xmpp.SessionState, _ io.ReadWriter, _ interface{}, err error) {
		d := xml.NewDecoder(s.Conn())


M dial.go => dial.go +4 -4
@@ 111,7 111,7 @@ func (c *Conn) Write(b []byte) (int, error) {
// Network may be any of the network types supported by net.Dial, but you almost
// certainly want to use one of the tcp connection types ("tcp", "tcp4", or
// "tcp6").
func DialClient(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
func DialClient(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
	var d Dialer
	return d.Dial(ctx, network, addr)
}


@@ 120,7 120,7 @@ func DialClient(ctx context.Context, network string, addr *jid.JID) (*Conn, erro
// server-to-server connection (s2s).
//
// For more info see the DialClient function.
func DialServer(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
func DialServer(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
	d := Dialer{
		S2S: true,
	}


@@ 147,11 147,11 @@ type Dialer struct {
// Dial discovers and connects to the address on the named network.
//
// For a description of the arguments see the DialClient function.
func (d *Dialer) Dial(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
func (d *Dialer) Dial(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
	return d.dial(ctx, network, addr)
}

func (d *Dialer) dial(ctx context.Context, network string, addr *jid.JID) (*Conn, error) {
func (d *Dialer) dial(ctx context.Context, network string, addr jid.JID) (*Conn, error) {
	if d.NoLookup {
		p, err := internal.LookupPort(network, connType(d.S2S))
		if err != nil {

M ping/ping.go => ping/ping.go +1 -1
@@ 18,7 18,7 @@ const NS = `urn:xmpp:ping`

// IQ returns an xml.TokenReader that outputs a new IQ stanza with a ping
// payload.
func IQ(to *jid.JID) xml.TokenReader {
func IQ(to jid.JID) xml.TokenReader {
	start := xml.StartElement{Name: xml.Name{Local: "ping", Space: NS}}
	return stanza.WrapIQ(&stanza.IQ{
		To:   to,

M session.go => session.go +7 -10
@@ 64,8 64,8 @@ type Session struct {
	state SessionState
	slock sync.RWMutex

	origin   *jid.JID
	location *jid.JID
	origin   jid.JID
	location jid.JID

	// The stream feature namespaces advertised for the current streams.
	features map[string]interface{}


@@ 105,7 105,7 @@ type Negotiator func(ctx context.Context, session *Session, data interface{}) (m
// Calling NegotiateSession with a nil Negotiator panics.
//
// For more information see the Negotiator type.
func NegotiateSession(ctx context.Context, location, origin *jid.JID, rw io.ReadWriter, negotiate Negotiator) (*Session, error) {
func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadWriter, negotiate Negotiator) (*Session, error) {
	if negotiate == nil {
		panic("xmpp: attempted to negotiate session with nil negotiator")
	}


@@ 187,7 187,7 @@ func stanzaAddID(w xmlstream.TokenWriter) xmlstream.TokenWriter {
// If the provided context is canceled before stream negotiation is complete an
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func NewClientSession(ctx context.Context, origin *jid.JID, lang string, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
func NewClientSession(ctx context.Context, origin jid.JID, lang string, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return NegotiateSession(ctx, origin.Domain(), origin, rw, negotiator(false, lang, features))
}



@@ 196,7 196,7 @@ func NewClientSession(ctx context.Context, origin *jid.JID, lang string, rw io.R
// If the provided context is canceled before stream negotiation is complete an
// error is returned.
// After stream negotiation if the context is canceled it has no effect.
func NewServerSession(ctx context.Context, location, origin *jid.JID, lang string, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
func NewServerSession(ctx context.Context, location, origin jid.JID, lang string, rw io.ReadWriter, features ...StreamFeature) (*Session, error) {
	return NegotiateSession(ctx, location, origin, rw, negotiator(true, lang, features))
}



@@ 401,21 401,18 @@ func (s *Session) State() SessionState {

// LocalAddr returns the Origin address for initiated connections, or the
// Location for received connections.
func (s *Session) LocalAddr() *jid.JID {
func (s *Session) LocalAddr() jid.JID {
	s.slock.RLock()
	defer s.slock.RUnlock()
	if (s.state & Received) == Received {
		return s.location
	}
	if s.origin != nil {
		return s.origin
	}
	return s.origin
}

// RemoteAddr returns the Location address for initiated connections, or the
// Origin address for received connections.
func (s *Session) RemoteAddr() *jid.JID {
func (s *Session) RemoteAddr() jid.JID {
	s.slock.RLock()
	defer s.slock.RUnlock()
	if (s.state & Received) == Received {

M session_test.go => session_test.go +2 -1
@@ 19,6 19,7 @@ import (

	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/xmpptest"
	"mellium.im/xmpp/jid"
)

func TestClosedInputStream(t *testing.T) {


@@ 102,7 103,7 @@ func TestNegotiator(t *testing.T) {
				Reader: rand.New(rand.NewSource(99)),
				Writer: ioutil.Discard,
			}
			_, err := xmpp.NegotiateSession(context.Background(), nil, nil, rw, tc.negotiator)
			_, err := xmpp.NegotiateSession(context.Background(), jid.JID{}, jid.JID{}, rw, tc.negotiator)
			if err != tc.err {
				t.Errorf("Unexpected error: want=%v, got=%v", tc.err, err)
			}

M stanza/error.go => stanza/error.go +4 -4
@@ 250,7 250,7 @@ const (
// unmarshalable as XML.
type Error struct {
	XMLName   xml.Name
	By        *jid.JID
	By        jid.JID
	Type      ErrorType
	Condition Condition
	Lang      language.Tag


@@ 275,8 275,8 @@ func (se Error) TokenReader() xml.TokenReader {
	if string(se.Type) != "" {
		start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: string(se.Type)})
	}
	if se.By != nil {
		a, _ := se.By.MarshalXMLAttr(xml.Name{Space: "", Local: "by"})
	a, _ := se.By.MarshalXMLAttr(xml.Name{Space: "", Local: "by"})
	if a.Value != "" {
		start.Attr = append(start.Attr, a)
	}



@@ 333,7 333,7 @@ func (se *Error) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
			XMLName xml.Name
		} `xml:",any"`
		Type ErrorType `xml:"type,attr"`
		By   *jid.JID  `xml:"by,attr"`
		By   jid.JID   `xml:"by,attr"`
		Text []struct {
			Lang string `xml:"http://www.w3.org/XML/1998/namespace lang,attr"`
			Data string `xml:",chardata"`

M stanza/error_test.go => stanza/error_test.go +5 -4
@@ 7,6 7,7 @@ package stanza
import (
	"encoding/xml"
	"fmt"
	"reflect"
	"testing"

	"golang.org/x/text/language"


@@ 90,7 91,7 @@ func TestUnmarshalStanzaError(t *testing.T) {
		9: {`<error type="auth"><remote-server-timeout xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></remote-server-timeout><text xmlns="urn:ietf:params:xml:ns:xmpp-stanzas" xml:lang="en">test</text><text xmlns="urn:ietf:params:xml:ns:xmpp-stanzas" xml:lang="es">Spanish</text></error>`,
			language.LatinAmericanSpanish, Error{Type: Auth, Condition: RemoteServerTimeout, Text: "Spanish", Lang: language.Spanish}, false},
		10: {`<error by=""><remote-server-not-found xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></remote-server-not-found></error>`,
			language.Und, Error{By: &jid.JID{}, Condition: RemoteServerNotFound}, false},
			language.Und, Error{By: jid.JID{}, Condition: RemoteServerNotFound}, false},
		11: {`<error><other xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></other></error>`,
			language.Und, Error{Condition: Condition("other")}, false},
		12: {`<error><recipient-unavailable xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></recipient-unavailable><text xmlns="urn:ietf:params:xml:ns:xmpp-stanzas" xml:lang="ac-u">test</text></error>`,


@@ 100,8 101,8 @@ func TestUnmarshalStanzaError(t *testing.T) {
			se2 := Error{Lang: data.lang}
			err := xml.Unmarshal([]byte(data.xml), &se2)
			j1, j2 := data.se.By, se2.By
			data.se.By = nil
			se2.By = nil
			data.se.By = jid.JID{}
			se2.By = jid.JID{}
			switch {
			case data.err && err == nil:
				t.Errorf("Expected an error when unmarshaling stanza error `%s`", data.xml)


@@ 115,7 116,7 @@ func TestUnmarshalStanzaError(t *testing.T) {
				// This case is included in the next one, but I wanted it to print
				// something nicer for languages…
				t.Errorf("Expected unmarshaled stanza error to have lang `%s` but got `%s`.", data.se.Lang, se2.Lang)
			case data.se != se2:
			case !reflect.DeepEqual(data.se, se2):
				t.Errorf("Expected unmarshaled stanza error:\n`%#v`\nbut got:\n`%#v`", data.se, se2)
			}
		})

M stanza/example_pingstream_test.go => stanza/example_pingstream_test.go +1 -1
@@ 16,7 16,7 @@ import (

// WrapPingIQ returns an xml.TokenReader that outputs a new IQ stanza with
// a ping payload.
func WrapPingIQ(to *jid.JID) xml.TokenReader {
func WrapPingIQ(to jid.JID) xml.TokenReader {
	start := xml.StartElement{Name: xml.Name{Local: "ping", Space: "urn:xmpp:ping"}}
	return stanza.WrapIQ(&stanza.IQ{To: to, Type: stanza.GetIQ}, xmlstream.Wrap(nil, start))
}

M stanza/iq.go => stanza/iq.go +10 -6
@@ 19,12 19,16 @@ func WrapIQ(iq *IQ, payload xml.TokenReader) xml.TokenReader {
	attr := []xml.Attr{
		{Name: xml.Name{Local: "type"}, Value: string(iq.Type)},
	}
	if iq.To != nil {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "to"}, Value: iq.To.String()})

	to, _ := iq.To.MarshalXMLAttr(xml.Name{Space: "", Local: "to"})
	if to.Value != "" {
		attr = append(attr, to)
	}
	if iq.From != nil {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "from"}, Value: iq.From.String()})
	from, _ := iq.From.MarshalXMLAttr(xml.Name{Space: "", Local: "from"})
	if from.Value != "" {
		attr = append(attr, from)
	}

	if iq.Lang != "" {
		attr = append(attr, xml.Attr{Name: xml.Name{Local: "lang", Space: ns.XML}, Value: iq.Lang})
	}


@@ 44,8 48,8 @@ func WrapIQ(iq *IQ, payload xml.TokenReader) xml.TokenReader {
type IQ struct {
	XMLName xml.Name `xml:"iq"`
	ID      string   `xml:"id,attr"`
	To      *jid.JID `xml:"to,attr,omitempty"`
	From    *jid.JID `xml:"from,attr,omitempty"`
	To      jid.JID  `xml:"to,attr,omitempty"`
	From    jid.JID  `xml:"from,attr,omitempty"`
	Lang    string   `xml:"http://www.w3.org/XML/1998/namespace lang,attr,omitempty"`
	Type    IQType   `xml:"type,attr"`
}

M stanza/message.go => stanza/message.go +3 -3
@@ 12,7 12,7 @@ import (
)

// WrapMessage wraps a payload in a message stanza.
func WrapMessage(to *jid.JID, typ MessageType, payload xml.TokenReader) xml.TokenReader {
func WrapMessage(to jid.JID, typ MessageType, payload xml.TokenReader) xml.TokenReader {
	return xmlstream.Wrap(payload, xml.StartElement{
		Name: xml.Name{Local: "message"},
		Attr: []xml.Attr{


@@ 29,8 29,8 @@ func WrapMessage(to *jid.JID, typ MessageType, payload xml.TokenReader) xml.Toke
type Message struct {
	XMLName xml.Name    `xml:"message"`
	ID      string      `xml:"id,attr"`
	To      *jid.JID    `xml:"to,attr"`
	From    *jid.JID    `xml:"from,attr"`
	To      jid.JID     `xml:"to,attr"`
	From    jid.JID     `xml:"from,attr"`
	Lang    string      `xml:"http://www.w3.org/XML/1998/namespace lang,attr,omitempty"`
	Type    MessageType `xml:"type,attr,omitempty"`
}

M stanza/presence.go => stanza/presence.go +3 -3
@@ 12,7 12,7 @@ import (
)

// WrapPresence wraps a payload in a presence stanza.
func WrapPresence(to *jid.JID, typ PresenceType, payload xml.TokenReader) xml.TokenReader {
func WrapPresence(to jid.JID, typ PresenceType, payload xml.TokenReader) xml.TokenReader {
	return xmlstream.Wrap(payload, xml.StartElement{
		Name: xml.Name{Local: "presence"},
		Attr: []xml.Attr{


@@ 29,8 29,8 @@ func WrapPresence(to *jid.JID, typ PresenceType, payload xml.TokenReader) xml.To
type Presence struct {
	XMLName xml.Name     `xml:"presence"`
	ID      string       `xml:"id,attr"`
	To      *jid.JID     `xml:"to,attr"`
	From    *jid.JID     `xml:"from,attr"`
	To      jid.JID      `xml:"to,attr"`
	From    jid.JID      `xml:"from,attr"`
	Lang    string       `xml:"http://www.w3.org/XML/1998/namespace lang,attr,omitempty"`
	Type    PresenceType `xml:"type,attr,omitempty"`
}

A stream/benchmark_test.go => stream/benchmark_test.go +16 -0
@@ 0,0 1,16 @@
package stream_test

import (
	"net"
	"testing"

	"mellium.im/xmpp/stream"
)

func BenchmarkSeeOtherHostError(b *testing.B) {
	ip := &net.IPAddr{IP: net.ParseIP("2001:db8::68")}
	b.ResetTimer()
	for n := 0; n < b.N; n++ {
		_ = stream.SeeOtherHostError(ip, nil)
	}
}