~samwhited/xmpp

19734c657a478b3f4d0657eaea3c4476924a3f7a — Sam Whited 2 years ago 7e76def
all: add new API to make reading tokens safer
M bind.go => bind.go +3 -1
@@ 100,7 100,9 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
			return true, nil, xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			d := xml.NewTokenDecoder(session)
			r := session.TokenReader()
			defer r.Close()
			d := xml.NewTokenDecoder(r)
			w := session.TokenWriter()
			defer w.Close()


M compress/compression.go => compress/compression.go +3 -1
@@ 82,7 82,9 @@ func New(methods ...Method) xmpp.StreamFeature {
		},
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
			conn := session.Conn()
			d := xml.NewTokenDecoder(session)
			r := session.TokenReader()
			defer r.Close()
			d := xml.NewTokenDecoder(r)

			// If we're a server.
			if (session.State() & xmpp.Received) == xmpp.Received {

M echobot_example_test.go => echobot_example_test.go +3 -1
@@ 56,7 56,9 @@ func Example_echobot() {
	}

	s.Serve(xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(s)
		r := s.TokenReader()
		defer r.Close()
		d := xml.NewTokenDecoder(r)

		// Ignore anything that's not a message. In a real system we'd want to at
		// least respond to IQs.

M features.go => features.go +2 -2
@@ 103,7 103,7 @@ func negotiateFeatures(ctx context.Context, s *Session, first bool, features []S
	var doStartTLS bool
	if !server {
		// Read a new start stream:features token.
		t, err = s.Token()
		t, err = s.in.d.Token()
		if err != nil {
			return mask, nil, err
		}


@@ 152,7 152,7 @@ func negotiateFeatures(ctx context.Context, s *Session, first bool, features []S

		if server {
			// Read a new feature to negotiate.
			t, err = s.Token()
			t, err = s.in.d.Token()
			if err != nil {
				return mask, nil, err
			}

M ibr2/ibr2.go => ibr2/ibr2.go +5 -3
@@ 135,6 135,8 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
		server := (session.State() & xmpp.Received) == xmpp.Received
		w := session.TokenWriter()
		defer w.Close()
		r := session.TokenReader()
		defer r.Close()

		if !server && !supported.(bool) {
			// We don't support some of the challenge types advertised by the server.


@@ 168,7 170,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,

				// Decode the clients response
				var cancel bool
				cancel, err = decodeClientResp(ctx, session, c.Receive)
				cancel, err = decodeClientResp(ctx, r, c.Receive)
				if err != nil || cancel {
					return
				}


@@ 177,7 179,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
		}

		// If we're the client, decode the challenge.
		tok, err = session.Token()
		tok, err = r.Token()
		if err != nil {
			return
		}


@@ 208,7 210,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
				continue
			}

			err = c.Receive(ctx, false, session, &start)
			err = c.Receive(ctx, false, r, &start)
			if err != nil {
				return
			}

M mux/mux_test.go => mux/mux_test.go +3 -1
@@ 97,7 97,9 @@ func TestFallback(t *testing.T) {
	}
	s := xmpptest.NewSession(0, rw)

	tok, err := s.Token()
	r := s.TokenReader()
	defer r.Close()
	tok, err := r.Token()
	if err != nil {
		t.Fatalf("Bad start token read: `%v'", err)
	}

M sasl.go => sasl.go +3 -1
@@ 135,7 135,9 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
				return mask, nil, err
			}

			d := xml.NewTokenDecoder(session)
			r := session.TokenReader()
			defer r.Close()
			d := xml.NewTokenDecoder(r)

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.

M sasl2/sasl.go => sasl2/sasl.go +7 -4
@@ 155,17 155,20 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFe
				return mask, nil, err
			}

			r := session.TokenReader()
			defer r.Close()

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.
			if !more {
				tok, err := session.Token()
				tok, err := r.Token()
				if err != nil {
					return mask, nil, err
				}
				if t, ok := tok.(xml.StartElement); ok {
					// TODO: Handle the additional data that could be returned if
					// success?
					_, _, err := decodeSASLChallenge(session, t, false)
					_, _, err := decodeSASLChallenge(r, t, false)
					if err != nil {
						return mask, nil, err
					}


@@ 181,13 184,13 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFe
					return mask, nil, ctx.Err()
				default:
				}
				tok, err := session.Token()
				tok, err := r.Token()
				if err != nil {
					return mask, nil, err
				}
				var challenge []byte
				if t, ok := tok.(xml.StartElement); ok {
					challenge, success, err = decodeSASLChallenge(session, t, true)
					challenge, success, err = decodeSASLChallenge(r, t, true)
					if err != nil {
						return mask, nil, err
					}

M session.go => session.go +48 -12
@@ 91,6 91,7 @@ type Session struct {
		d      xml.TokenReader
		ctx    context.Context
		cancel context.CancelFunc
		sync.Locker
	}
	out struct {
		internal.StreamInfo


@@ 134,6 135,7 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
		s.state |= Received
	}
	s.out.Locker = &sync.Mutex{}
	s.in.Locker = &sync.Mutex{}
	s.in.d = xml.NewDecoder(s.conn)
	s.out.e = xml.NewEncoder(s.conn)
	s.in.ctx, s.in.cancel = context.WithCancel(context.Background())


@@ 317,7 319,7 @@ func handleInputStream(s *Session, handler Handler) (err error) {
			return s.in.ctx.Err()
		default:
		}
		tok, err := s.Token()
		tok, err := s.in.d.Token()
		if err != nil {
			// If this was a read timeout, don't try to send it. Just try to read
			// again.


@@ 390,12 392,12 @@ func handleInputStream(s *Session, handler Handler) (err error) {
				}

				c <- iqResponder{
					r: xmlstream.MultiReader(xmlstream.Token(start), xmlstream.Inner(s), xmlstream.Token(start.End())),
					r: xmlstream.MultiReader(xmlstream.Token(start), xmlstream.Inner(s.in.d), xmlstream.Token(start.End())),
					c: c,
				}
				<-c
				// Consume the rest of the stream before continuing the loop.
				_, err = xmlstream.Copy(discard, s)
				_, err = xmlstream.Copy(discard, s.in.d)
				if err != nil {
					return s.sendError(err)
				}


@@ 409,7 411,7 @@ func handleInputStream(s *Session, handler Handler) (err error) {

		rw := &responseChecker{
			twf:         s.out.e,
			TokenReader: xmlstream.Inner(s),
			TokenReader: xmlstream.Inner(s.in.d),
			id:          id,
		}
		// Make a copy of the session and set its output stream to the response


@@ 498,14 500,6 @@ func (s *Session) Conn() net.Conn {
	return s.conn
}

// Token satisfies the xml.TokenReader interface for Session.
func (s *Session) Token() (xml.Token, error) {
	if s.state&InputStreamClosed == InputStreamClosed {
		return nil, ErrInputStreamClosed
	}
	return s.in.d.Token()
}

type lockWriteCloser struct {
	w   *Session
	err error


@@ 533,6 527,33 @@ func (lwc *lockWriteCloser) Close() error {
	return nil
}

type lockReadCloser struct {
	s   *Session
	err error
	m   sync.Locker
}

func (lrc *lockReadCloser) Token() (xml.Token, error) {
	if lrc.err != nil {
		return nil, lrc.err
	}

	if lrc.s.state&InputStreamClosed == InputStreamClosed {
		return nil, ErrInputStreamClosed
	}

	return lrc.s.in.d.Token()
}

func (lrc *lockReadCloser) Close() error {
	if lrc.err != nil {
		return nil
	}
	lrc.err = io.EOF
	lrc.m.Unlock()
	return nil
}

// TokenWriter returns a new xmlstream.TokenWriteCloser that can be used to
// write raw XML tokens to the session.
// All other writes and future calls to TokenWriter will block until the Close


@@ 548,6 569,21 @@ func (s *Session) TokenWriter() xmlstream.TokenWriteCloser {
	}
}

// TokenReader returns a new xmlstream.TokenReadCloser that can be used to read
// raw XML tokens from the session.
// All other reads and future calls to TokenReader will block until the Close
// method is called.
// After the TokenReadCloser has been closed, any future reads will return
// io.EOF.
func (s *Session) TokenReader() xmlstream.TokenReadCloser {
	s.in.Lock()

	return &lockReadCloser{
		m: s.in.Locker,
		s: s,
	}
}

// Flush satisfies the xmlstream.TokenWriter interface.
func (s *Session) Flush() error {
	if s.state&OutputStreamClosed == OutputStreamClosed {

M session_test.go => session_test.go +3 -1
@@ 30,8 30,10 @@ func TestClosedInputStream(t *testing.T) {
			mask := xmpp.SessionState(i)
			buf := new(bytes.Buffer)
			s := xmpptest.NewSession(mask, buf)
			r := s.TokenReader()
			defer r.Close()

			_, err := s.Token()
			_, err := r.Token()
			switch {
			case mask&xmpp.InputStreamClosed == xmpp.InputStreamClosed && err != xmpp.ErrInputStreamClosed:
				t.Errorf("Unexpected error: want=`%v', got=`%v'", xmpp.ErrInputStreamClosed, err)

M starttls.go => starttls.go +3 -1
@@ 53,7 53,9 @@ func StartTLS(required bool, cfg *tls.Config) StreamFeature {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			conn := session.Conn()
			state := session.State()
			d := xml.NewTokenDecoder(session)
			r := session.TokenReader()
			defer r.Close()
			d := xml.NewTokenDecoder(r)

			// If no TLSConfig was specified, use a default config.
			if cfg == nil {

M starttls_test.go => starttls_test.go +12 -11
@@ 2,7 2,7 @@
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp
package xmpp_test

import (
	"bytes"


@@ 14,7 14,9 @@ import (
	"strings"
	"testing"

	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/internal/xmpptest"
)

// There is no room for variation on the starttls feature negotiation, so step


@@ 26,7 28,7 @@ func TestStartTLSList(t *testing.T) {
			name = "required"
		}
		t.Run(name, func(t *testing.T) {
			stls := StartTLS(req, nil)
			stls := xmpp.StartTLS(req, nil)
			var b bytes.Buffer
			e := xml.NewEncoder(&b)
			start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}


@@ 86,7 88,7 @@ func TestStartTLSList(t *testing.T) {
}

func TestStartTLSParse(t *testing.T) {
	stls := StartTLS(true, nil)
	stls := xmpp.StartTLS(true, nil)
	for i, test := range [...]struct {
		msg string
		req bool


@@ 126,9 128,9 @@ func (nopRWC) Close() error {
}

func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true, &tls.Config{})
	stls := xmpp.StartTLS(true, &tls.Config{})
	var b bytes.Buffer
	c := &Session{state: Received, conn: newConn(nopRWC{&b, &b}, nil)}
	c := xmpptest.NewSession(xmpp.Received, nopRWC{&b, &b})
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:


@@ 152,10 154,10 @@ func TestNegotiateClient(t *testing.T) {
		responses []string
		err       bool
		rw        bool
		state     SessionState
		state     xmpp.SessionState
	}{
		0: {[]string{`<proceed xmlns="badns"/>`}, true, false, Secure},
		1: {[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, true, Secure},
		0: {[]string{`<proceed xmlns="badns"/>`}, true, false, xmpp.Secure},
		1: {[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, true, xmpp.Secure},
		2: {[]string{`<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},
		3: {[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, false, false, 0},
		4: {[]string{`<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'></bad>`}, true, false, 0},


@@ 164,11 166,10 @@ func TestNegotiateClient(t *testing.T) {
		7: {[]string{`chardata not start element`}, true, false, 0},
	} {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			stls := StartTLS(true, &tls.Config{})
			stls := xmpp.StartTLS(true, &tls.Config{})
			r := strings.NewReader(strings.Join(test.responses, "\n"))
			var b bytes.Buffer
			c := &Session{conn: newConn(nopRWC{r, &b}, nil)}
			c.in.d = xml.NewDecoder(c.conn)
			c := xmpptest.NewSession(0, nopRWC{r, &b})
			mask, rw, err := stls.Negotiate(context.Background(), c, nil)
			switch {
			case test.err && err == nil: