~samwhited/xmpp

c4172a3d65f08e835fb2280440148b47b270dbb9 — Sam Whited 3 years ago 6223470
all: make *Session an xml.TokenReader

This is less efficient than returning one when wraping it in an
xml.Decoder because the underlying decoder cannot be used, but it makes
the API nicer to use.
8 files changed, 18 insertions(+), 22 deletions(-)

M bind.go
M compress/compression.go
M features.go
M ibr2/ibr2.go
M sasl.go
M sasl2/sasl.go
M session.go
M starttls.go
M bind.go => bind.go +1 -1
@@ 79,7 79,7 @@ func bind(server func(*jid.JID, string) (*jid.JID, error)) StreamFeature {
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			e := session.Encoder()
			d := xml.NewTokenDecoder(session.TokenReader())
			d := xml.NewTokenDecoder(session)

			// Handle the server side of resource binding if we're on the receiving
			// end of the connection.

M compress/compression.go => compress/compression.go +1 -1
@@ 88,7 88,7 @@ func New(methods ...Method) xmpp.StreamFeature {
		},
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
			conn := session.Conn()
			d := xml.NewTokenDecoder(session.TokenReader())
			d := xml.NewTokenDecoder(session)

			// If we're a server.
			if (session.State() & xmpp.Received) == xmpp.Received {

M features.go => features.go +2 -2
@@ 84,7 84,7 @@ func negotiateFeatures(ctx context.Context, s *Session) (mask SessionState, rw i

	if !server {
		// Read a new startstream:features token.
		t, err = s.TokenReader().Token()
		t, err = s.Token()
		if err != nil {
			return mask, nil, err
		}


@@ 115,7 115,7 @@ func negotiateFeatures(ctx context.Context, s *Session) (mask SessionState, rw i

		if server {
			// Read a new feature to negotiate.
			t, err = s.TokenReader().Token()
			t, err = s.Token()
			if err != nil {
				return mask, nil, err
			}

M ibr2/ibr2.go => ibr2/ibr2.go +3 -4
@@ 145,7 145,6 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,

		var tok xml.Token
		e := session.Encoder()
		r := session.TokenReader()

		if server {
			for _, c := range challenges {


@@ 170,7 169,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,

				// Decode the clients response
				var cancel bool
				cancel, err = decodeClientResp(ctx, r, c.Receive)
				cancel, err = decodeClientResp(ctx, session, c.Receive)
				if err != nil || cancel {
					return
				}


@@ 179,7 178,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
		}

		// If we're the client, decode the challenge.
		tok, err = r.Token()
		tok, err = session.Token()
		if err != nil {
			return
		}


@@ 210,7 209,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
				continue
			}

			err = c.Receive(ctx, false, r, &start)
			err = c.Receive(ctx, false, session, &start)
			if err != nil {
				return
			}

M sasl.go => sasl.go +1 -1
@@ 123,7 123,7 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				return mask, nil, err
			}

			d := xml.NewTokenDecoder(session.TokenReader())
			d := xml.NewTokenDecoder(session)

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.

M sasl2/sasl.go => sasl2/sasl.go +4 -6
@@ 143,19 143,17 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
				return mask, nil, err
			}

			r := session.TokenReader()

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.
			if !more {
				tok, err := r.Token()
				tok, err := session.Token()
				if err != nil {
					return mask, nil, err
				}
				if t, ok := tok.(xml.StartElement); ok {
					// TODO: Handle the additional data that could be returned if
					// success?
					_, _, err := decodeSASLChallenge(r, t, false)
					_, _, err := decodeSASLChallenge(session, t, false)
					if err != nil {
						return mask, nil, err
					}


@@ 171,13 169,13 @@ func SASL(mechanisms ...sasl.Mechanism) xmpp.StreamFeature {
					return mask, nil, ctx.Err()
				default:
				}
				tok, err := r.Token()
				tok, err := session.Token()
				if err != nil {
					return mask, nil, err
				}
				var challenge []byte
				if t, ok := tok.(xml.StartElement); ok {
					challenge, success, err = decodeSASLChallenge(r, t, true)
					challenge, success, err = decodeSASLChallenge(session, t, true)
					if err != nil {
						return mask, nil, err
					}

M session.go => session.go +5 -6
@@ 190,10 190,9 @@ func (s *Session) Conn() io.ReadWriter {
	return s.rw
}

// TokenReader returns the XML token reader that was used to negotiate the
// latest stream.
func (s *Session) TokenReader() xmlstream.TokenReader {
	return s.in.d
// Token satisfies the xml.TokenReader interface for Session.
func (s *Session) Token() (xml.Token, error) {
	return s.in.d.Token()
}

// Encoder returns the XML encoder that was used to negotiate the latest stream.


@@ 260,7 259,7 @@ func (s *Session) handleInputStream(handler Handler) error {
			return nil
		default:
		}
		tok, err := s.TokenReader().Token()
		tok, err := s.Token()
		if err != nil {
			select {
			case <-s.in.ctx.Done():


@@ 275,7 274,7 @@ func (s *Session) handleInputStream(handler Handler) error {
		case xml.StartElement:
			if t.Name.Local == "error" && t.Name.Space == ns.Stream {
				e := stream.Error{}
				err = xml.NewTokenDecoder(s.TokenReader()).DecodeElement(&e, &t)
				err = xml.NewTokenDecoder(s).DecodeElement(&e, &t)
				if err != nil {
					return err
				}

M starttls.go => starttls.go +1 -1
@@ 65,7 65,7 @@ func StartTLS(required bool) StreamFeature {

			config := session.Config()
			state := session.State()
			d := xml.NewTokenDecoder(session.TokenReader())
			d := xml.NewTokenDecoder(session)

			// Fetch or create a TLSConfig to use.
			var tlsconf *tls.Config