~samwhited/xmpp

3c3aee0ab83ab7aac5c40407b55a1c13ecc0ae48 — Sam Whited 3 years ago 6df3623
all: use xmlstream.TokenReader

See #14
M bind.go => bind.go +4 -3
@@ 9,6 9,7 @@ import (
	"encoding/xml"
	"io"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"


@@ 71,15 72,15 @@ func bind(server func(*jid.JID, string) (*jid.JID, error)) StreamFeature {
			err = e.Flush()
			return req, err
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
		Parse: func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
				XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
			}{}
			return true, nil, d.DecodeElement(&parsed, start)
			return true, nil, xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			e := session.Encoder()
			d := session.Decoder()
			d := xml.NewTokenDecoder(session.TokenReader())

			// Handle the server side of resource binding if we're on the receiving
			// end of the connection.

M compress/compression.go => compress/compression.go +7 -6
@@ 1,6 1,6 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package compress implements XEP-0138: Stream Compression and XEP-0229: Stream
// Compression with LZW.


@@ 13,6 13,7 @@ import (
	"fmt"
	"io"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/stream"
)


@@ 67,12 68,12 @@ func New(methods ...Method) xmpp.StreamFeature {
			}
			return false, e.Flush()
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
		Parse: func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			listed := struct {
				XMLName xml.Name `xml:"http://jabber.org/features/compress compression"`
				Methods []string `xml:"http://jabber.org/features/compress method"`
			}{}
			if err := d.DecodeElement(&listed, start); err != nil {
			if err := xml.NewTokenDecoder(r).DecodeElement(&listed, start); err != nil {
				return false, nil, err
			}



@@ 87,6 88,7 @@ func New(methods ...Method) xmpp.StreamFeature {
		},
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
			conn := session.Conn()
			d := xml.NewTokenDecoder(session.TokenReader())

			// If we're a server.
			if (session.State() & xmpp.Received) == xmpp.Received {


@@ 94,7 96,7 @@ func New(methods ...Method) xmpp.StreamFeature {
					XMLName xml.Name `xml:"http://jabber.org/protocol/compress compress"`
					Method  string   `xml:"method"`
				}{}
				if err = session.Decoder().Decode(&clientSelected); err != nil {
				if err = d.Decode(&clientSelected); err != nil {
					return
				}



@@ 146,7 148,6 @@ func New(methods ...Method) xmpp.StreamFeature {
				return
			}

			d := session.Decoder()
			tok, err := d.Token()
			if err != nil {
				return mask, nil, err

M features.go => features.go +7 -6
@@ 1,6 1,6 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp



@@ 9,6 9,7 @@ import (
	"encoding/xml"
	"io"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/stream"
)


@@ 47,7 48,7 @@ type StreamFeature struct {
	// Returns whether or not the feature is required, and any data that will be
	// needed if the feature is selected for negotiation (eg. the list of
	// mechanisms if the feature was SASL).
	Parse func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (req bool, data interface{}, err error)
	Parse func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (req bool, data interface{}, err error)

	// A function that will take over the session temporarily while negotiating
	// the feature. The "mask" SessionState represents the state bits that should


@@ 83,7 84,7 @@ func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rw io.ReadW

	if !server {
		// Read a new startstream:features token.
		t, err = s.Decoder().Token()
		t, err = s.TokenReader().Token()
		if err != nil {
			return done, nil, err
		}


@@ 115,7 116,7 @@ func (s *Session) negotiateFeatures(ctx context.Context) (done bool, rw io.ReadW

		if server {
			// Read a new feature to negotiate.
			t, err = s.Decoder().Token()
			t, err = s.TokenReader().Token()
			if err != nil {
				return done, nil, err
			}


@@ 289,7 290,7 @@ parsefeatures:
				continue parsefeatures
			}
			// If the feature is not one we support, skip it.
			if err := s.in.d.Skip(); err != nil {
			if err := xmlstream.Skip(s.in.d); err != nil {
				return nil, err
			}
		case xml.EndElement:

M ibr2/challenge.go => ibr2/challenge.go +5 -3
@@ 1,12 1,14 @@
// Copyright 2017 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package ibr2

import (
	"context"
	"encoding/xml"

	"mellium.im/xmlstream"
)

// Challenge is an IBR challenge.


@@ 23,5 25,5 @@ type Challenge struct {

	// Receive is used by the client to receive and decode the server's challenge
	// and by the server to receive and decode the clients response.
	Receive func(ctx context.Context, server bool, d *xml.Decoder, start *xml.StartElement) error
	Receive func(ctx context.Context, server bool, r xmlstream.TokenReader, start *xml.StartElement) error
}

M ibr2/ibr2.go => ibr2/ibr2.go +13 -11
@@ 1,6 1,6 @@
// Copyright 2017 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package ibr2 implements the Extensible In-Band Registration ProtoXEP.
//


@@ 15,6 15,7 @@ import (
	"errors"
	"io"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/stream"
)


@@ 77,12 78,13 @@ func listFunc(challenges ...Challenge) func(context.Context, *xml.Encoder, xml.S
	}
}

func parseFunc(challenges ...Challenge) func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (req bool, supported interface{}, err error) {
	return func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
func parseFunc(challenges ...Challenge) func(context.Context, xmlstream.TokenReader, *xml.StartElement) (req bool, supported interface{}, err error) {
	return func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
		// Parse the list of challenge types sent down by the server.
		parsed := struct {
			Challenges []string `xml:"urn:xmpp:register:0 challenge"`
		}{}
		d := xml.NewTokenDecoder(r)
		err := d.DecodeElement(&parsed, start)
		if err != nil {
			return false, false, err


@@ 105,9 107,9 @@ func parseFunc(challenges ...Challenge) func(ctx context.Context, d *xml.Decoder
	}
}

func decodeClientResp(ctx context.Context, d *xml.Decoder, decode func(ctx context.Context, server bool, d *xml.Decoder, start *xml.StartElement) error) (cancel bool, err error) {
func decodeClientResp(ctx context.Context, r xmlstream.TokenReader, decode func(ctx context.Context, server bool, r xmlstream.TokenReader, start *xml.StartElement) error) (cancel bool, err error) {
	var tok xml.Token
	tok, err = d.Token()
	tok, err = r.Token()
	if err != nil {
		return
	}


@@ 120,7 122,7 @@ func decodeClientResp(ctx context.Context, d *xml.Decoder, decode func(ctx conte
		cancel = true
		return
	case start.Name.Local == "response" && start.Name.Space == NS:
		err = decode(ctx, true, d, &start)
		err = decode(ctx, true, r, &start)
		if err != nil {
			return
		}


@@ 143,7 145,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,

		var tok xml.Token
		e := session.Encoder()
		d := session.Decoder()
		r := session.TokenReader()

		if server {
			for _, c := range challenges {


@@ 168,7 170,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,

				// Decode the clients response
				var cancel bool
				cancel, err = decodeClientResp(ctx, d, c.Receive)
				cancel, err = decodeClientResp(ctx, r, c.Receive)
				if err != nil || cancel {
					return
				}


@@ 177,7 179,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
		}

		// If we're the client, decode the challenge.
		tok, err = d.Token()
		tok, err = r.Token()
		if err != nil {
			return
		}


@@ 208,7 210,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
				continue
			}

			err = c.Receive(ctx, false, d, &start)
			err = c.Receive(ctx, false, r, &start)
			if err != nil {
				return
			}

M ibr2/oob.go => ibr2/oob.go +5 -4
@@ 1,6 1,6 @@
// Copyright 2017 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package ibr2



@@ 8,6 8,7 @@ import (
	"context"
	"encoding/xml"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/oob"
)



@@ 22,14 23,14 @@ func OOB(data *oob.Data, f func(*oob.Data) error) Challenge {
		Send: func(ctx context.Context, e *xml.Encoder) error {
			return e.Encode(data)
		},
		Receive: func(ctx context.Context, server bool, d *xml.Decoder, start *xml.StartElement) error {
		Receive: func(ctx context.Context, server bool, r xmlstream.TokenReader, start *xml.StartElement) error {
			// The server does not receive a reply for this mechanism.
			if server {
				return nil
			}

			oob := &oob.Data{}
			err := d.Decode(oob)
			err := xml.NewTokenDecoder(r).Decode(oob)
			if err != nil {
				return err
			}

M sasl.go => sasl.go +6 -5
@@ 1,6 1,6 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp



@@ 13,6 13,7 @@ import (
	"io"

	"mellium.im/sasl"
	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/internal/saslerr"
	"mellium.im/xmpp/stream"


@@ 58,12 59,12 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
			}
			return req, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
		Parse: func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
				XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
				List    []string `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanism"`
			}{}
			err := d.DecodeElement(&parsed, start)
			err := xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {


@@ 122,7 123,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				return mask, nil, err
			}

			d := session.Decoder()
			d := xml.NewTokenDecoder(session.TokenReader())

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.

M sasl2/sasl.go => sasl2/sasl.go +12 -10
@@ 1,6 1,6 @@
// Copyright 2017 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

// Package sasl2 is an experimental implementation of XEP-0388: Extensible SASL
// Profile.


@@ 19,6 19,7 @@ import (
	"io"

	"mellium.im/sasl"
	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/saslerr"
	"mellium.im/xmpp/stream"


@@ 74,12 75,12 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
			}
			return req, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
		Parse: func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
				XMLName xml.Name `xml:"urn:xmpp:sasl:0 mechanisms"`
				List    []string `xml:"urn:xmpp:sasl:0 mechanism"`
			}{}
			err := d.DecodeElement(&parsed, start)
			err := xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
			return true, parsed.List, err
		},
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {


@@ 142,19 143,19 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
				return mask, nil, err
			}

			d := session.Decoder()
			r := session.TokenReader()

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.
			if !more {
				tok, err := d.Token()
				tok, err := r.Token()
				if err != nil {
					return mask, nil, err
				}
				if t, ok := tok.(xml.StartElement); ok {
					// TODO: Handle the additional data that could be returned if
					// success?
					_, _, err := decodeSASLChallenge(d, t, false)
					_, _, err := decodeSASLChallenge(r, t, false)
					if err != nil {
						return mask, nil, err
					}


@@ 170,13 171,13 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
					return mask, nil, ctx.Err()
				default:
				}
				tok, err := d.Token()
				tok, err := r.Token()
				if err != nil {
					return mask, nil, err
				}
				var challenge []byte
				if t, ok := tok.(xml.StartElement); ok {
					challenge, success, err = decodeSASLChallenge(d, t, true)
					challenge, success, err = decodeSASLChallenge(r, t, true)
					if err != nil {
						return mask, nil, err
					}


@@ 201,7 202,8 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
	}
}

func decodeSASLChallenge(d *xml.Decoder, start xml.StartElement, allowChallenge bool) (challenge []byte, success bool, err error) {
func decodeSASLChallenge(r xmlstream.TokenReader, start xml.StartElement, allowChallenge bool) (challenge []byte, success bool, err error) {
	d := xml.NewTokenDecoder(r)
	switch start.Name {
	case xml.Name{Space: NS, Local: "challenge"}:
		if !allowChallenge {

M session.go => session.go +9 -7
@@ 1,6 1,6 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp



@@ 11,6 11,7 @@ import (
	"net"
	"sync"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"


@@ 76,7 77,7 @@ type Session struct {
	in struct {
		sync.Mutex
		streamInfo
		d      *xml.Decoder
		d      xmlstream.TokenReader
		ctx    context.Context
		cancel context.CancelFunc
	}


@@ 139,8 140,9 @@ func (s *Session) Conn() io.ReadWriter {
	return s.rw
}

// Decoder returns the XML decoder that was used to negotiate the latest stream.
func (s *Session) Decoder() *xml.Decoder {
// TokenReader returns the XML token reader that was used to negotiate the
// latest stream.
func (s *Session) TokenReader() xmlstream.TokenReader {
	return s.in.d
}



@@ 208,7 210,7 @@ func (s *Session) handleInputStream(handler Handler) error {
			return nil
		default:
		}
		tok, err := s.Decoder().Token()
		tok, err := s.TokenReader().Token()
		if err != nil {
			select {
			case <-s.in.ctx.Done():


@@ 223,7 225,7 @@ func (s *Session) handleInputStream(handler Handler) error {
		case xml.StartElement:
			if t.Name.Local == "error" && t.Name.Space == ns.Stream {
				e := stream.Error{}
				err = s.Decoder().DecodeElement(&e, &t)
				err = xml.NewTokenDecoder(s.TokenReader()).DecodeElement(&e, &t)
				if err != nil {
					return err
				}

M starttls.go => starttls.go +6 -5
@@ 1,6 1,6 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp



@@ 13,6 13,7 @@ import (
	"io"
	"net"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/stream"
)


@@ 46,14 47,14 @@ func StartTLS(required bool) StreamFeature {
			}
			return required, e.EncodeToken(start.End())
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
		Parse: func(ctx context.Context, r xmlstream.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {
				XMLName  xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
				Required struct {
					XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls required"`
				}
			}{}
			err := d.DecodeElement(&parsed, start)
			err := xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
			return parsed.Required.XMLName.Local == "required" && parsed.Required.XMLName.Space == ns.StartTLS, nil, err
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {


@@ 64,7 65,7 @@ func StartTLS(required bool) StreamFeature {

			config := session.Config()
			state := session.State()
			d := session.Decoder()
			d := xml.NewTokenDecoder(session.TokenReader())

			// Fetch or create a TLSConfig to use.
			var tlsconf *tls.Config

M stream.go => stream.go +1 -1
@@ 129,7 129,7 @@ func expectNewStream(ctx context.Context, s *Session) error {
			switch {
			case tok.Name.Local == "error" && tok.Name.Space == ns.Stream:
				se := stream.Error{}
				if err := d.DecodeElement(&se, &tok); err != nil {
				if err := xml.NewTokenDecoder(d).DecodeElement(&se, &tok); err != nil {
					return err
				}
				return se