~samwhited/xmpp

3aee0cc5a4163a287c4c1cfc7795199e884561ea — Sam Whited 2 years ago 5da0271
Revert "all: new session XML read/write API"

This reverts commit 5265955b9c79963ea0c1b9f392db0e56b6a242a8.
M bind.go => bind.go +3 -9
@@ 103,13 103,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
			return true, nil, xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			rc := session.TokenReader()
			/* #nosec */
			defer rc.Close()
			d := xml.NewTokenDecoder(rc)
			wc := session.TokenWriter()
			/* #nosec */
			defer wc.Close()
			d := xml.NewTokenDecoder(session)

			// Handle the server side of resource binding if we're on the receiving
			// end of the connection.


@@ 162,7 156,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
					resp.Bind = bindPayload{JID: j}
				}

				_, err = resp.WriteXML(wc)
				_, err = resp.WriteXML(session)
				return mask, nil, err
			}



@@ 177,7 171,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
					Resource: session.origin.Resourcepart(),
				},
			}
			_, err = req.WriteXML(wc)
			_, err = req.WriteXML(session)
			if err != nil {
				return mask, nil, err
			}

M compress/compression.go => compress/compression.go +1 -4
@@ 85,10 85,7 @@ func New(methods ...Method) xmpp.StreamFeature {
		},
		Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
			conn := session.Conn()
			rc := session.TokenReader()
			/* #nosec */
			defer rc.Close()
			d := xml.NewTokenDecoder(rc)
			d := xml.NewTokenDecoder(session)

			// If we're a server.
			if (session.State() & xmpp.Received) == xmpp.Received {

M doc.go => doc.go +6 -7
@@ 73,18 73,17 @@
// required or may be received out of order.
// This is accomplished with two XML streams: an input stream and an output
// stream.
// To receive XML on the input stream, Session has a TokenReader function which
// returns a value that can be wrapped with xml.NewTokenDecoder.
// To send XML on the output stream, Session has the TokenWriter method.
// To receive XML on the input stream, Session implements the xml.TokenReader
// interface defined in encoding/xml; this allows session to be wrapped with
// xml.NewTokenDecoder.
// To send XML on the output stream, Session has an EncodeToken and Flush
// method like the "mellium.im/xmlstream".TokenWriter interface.
// The mellium.im/xmpp/stanza package contains functions and structs that aid in
// the construction of message, presence and IQ elements which have special
// semantics in XMPP and are known as "stanzas".
//
//     // Send initial presence to let the server know we want to receive messages.
//     wc := session.TokenWriter()
//     defer wc.Close()
//
//     _, err = xmlstream.Copy(wc, stanza.WrapPresence(nil, stanza.AvailablePresence, nil))
//     _, err = xmlstream.Copy(session, stanza.WrapPresence(nil, stanza.AvailablePresence, nil))
//
// To make the common case of polling for incoming XML on the input stream—and
// possibly writing to the output stream in response—easier, Session includes

M echobot_example_test.go => echobot_example_test.go +2 -4
@@ 49,13 49,11 @@ func Example_echobot() {
	}()

	// Send initial presence to let the server know we want to receive messages.
	wc := s.TokenWriter()
	_, err = xmlstream.Copy(wc, stanza.WrapPresence(nil, stanza.AvailablePresence, nil))
	_, err = xmlstream.Copy(s, stanza.WrapPresence(nil, stanza.AvailablePresence, nil))
	if err != nil {
		log.Printf("Error sending initial presence: %q", err)
		return
	}
	wc.Close()

	s.Serve(xmpp.HandlerFunc(func(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(t)


@@ 88,7 86,7 @@ func Example_echobot() {
				return xml.CharData(msg.Body), io.EOF
			}), xml.StartElement{Name: xml.Name{Local: "body"}}),
		)
		_, err = xmlstream.Copy(t, reply)
		_, err = xmlstream.Copy(s, reply)
		if err != nil {
			log.Printf("Error responding to mid-%s: %q", msg.ID, err)
		}

M features.go => features.go +4 -4
@@ 103,7 103,7 @@ func negotiateFeatures(ctx context.Context, s *Session, first bool, features []S
	var doStartTLS bool
	if !server {
		// Read a new start stream:features token.
		t, err = s.in.d.Token()
		t, err = s.Token()
		if err != nil {
			return mask, nil, err
		}


@@ 152,7 152,7 @@ func negotiateFeatures(ctx context.Context, s *Session, first bool, features []S

		if server {
			// Read a new feature to negotiate.
			t, err = s.in.d.Token()
			t, err = s.Token()
			if err != nil {
				return mask, nil, err
			}


@@ 252,7 252,7 @@ func getFeature(name xml.Name, features []StreamFeature) (feature StreamFeature,

func writeStreamFeatures(ctx context.Context, s *Session, features []StreamFeature) (list *streamFeaturesList, err error) {
	start := xml.StartElement{Name: xml.Name{Space: "", Local: "stream:features"}}
	if err = s.out.e.EncodeToken(start); err != nil {
	if err = s.EncodeToken(start); err != nil {
		return
	}



@@ 283,7 283,7 @@ func writeStreamFeatures(ctx context.Context, s *Session, features []StreamFeatu
			list.total++
		}
	}
	if err = s.out.e.EncodeToken(start.End()); err != nil {
	if err = s.EncodeToken(start.End()); err != nil {
		return
	}
	if err = s.Flush(); err != nil {

M go.mod => go.mod +1 -1
@@ 6,5 6,5 @@ require (
	golang.org/x/net v0.0.0-20180216171745-136a25c244d3
	golang.org/x/text v0.0.0-20180208041248-4e4a3210bb54
	mellium.im/sasl v0.1.1
	mellium.im/xmlstream v0.12.3
	mellium.im/xmlstream v0.12.1
)

M go.sum => go.sum +0 -2
@@ 20,5 20,3 @@ mellium.im/xmlstream v0.12.0 h1:0GPPdFc9L8QAA/sV5xw0JqpVrfzWq9MkMaYExqU5ET0=
mellium.im/xmlstream v0.12.0/go.mod h1:EadHCpZVaEeBmMJ276+Rw8hEBZWcKjv/y0njL1k5zCU=
mellium.im/xmlstream v0.12.1 h1:sa2cmyWsV62j/iNN4nAMFrRNRNXil2GivcjgRxpn+n8=
mellium.im/xmlstream v0.12.1/go.mod h1:PDwtcQeiAyF7FrUBRWnVO91nIskvhOFYzOeftngiM7Q=
mellium.im/xmlstream v0.12.3 h1:gm7mpQCUPeyi2WcfxMN2n6PoFNdqkDFvK3pYNITYjtc=
mellium.im/xmlstream v0.12.3/go.mod h1:PDwtcQeiAyF7FrUBRWnVO91nIskvhOFYzOeftngiM7Q=

M ibr2/ibr2.go => ibr2/ibr2.go +9 -15
@@ 135,12 135,6 @@ func decodeClientResp(ctx context.Context, r xml.TokenReader, decode func(ctx co
func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session, interface{}) (xmpp.SessionState, io.ReadWriter, error) {
	return func(ctx context.Context, session *xmpp.Session, supported interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
		server := (session.State() & xmpp.Received) == xmpp.Received
		wc := session.TokenWriter()
		/* #nosec */
		defer wc.Close()
		rc := session.TokenReader()
		/* #nosec */
		defer rc.Close()

		if !server && !supported.(bool) {
			// We don't support some of the challenge types advertised by the server.


@@ 155,15 149,15 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
			for _, c := range challenges {
				// Send the challenge.
				start := challengeStart(c.Type)
				err = wc.EncodeToken(start)
				err = session.EncodeToken(start)
				if err != nil {
					return
				}
				err = c.Send(ctx, wc)
				err = c.Send(ctx, session)
				if err != nil {
					return
				}
				err = wc.EncodeToken(start.End())
				err = session.EncodeToken(start.End())
				if err != nil {
					return
				}


@@ 174,7 168,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,

				// Decode the clients response
				var cancel bool
				cancel, err = decodeClientResp(ctx, rc, c.Receive)
				cancel, err = decodeClientResp(ctx, session, c.Receive)
				if err != nil || cancel {
					return
				}


@@ 183,7 177,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
		}

		// If we're the client, decode the challenge.
		tok, err = rc.Token()
		tok, err = session.Token()
		if err != nil {
			return
		}


@@ 214,7 208,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
				continue
			}

			err = c.Receive(ctx, false, rc, &start)
			err = c.Receive(ctx, false, session, &start)
			if err != nil {
				return
			}


@@ 222,16 216,16 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
			respStart := xml.StartElement{
				Name: xml.Name{Local: "response"},
			}
			if err = wc.EncodeToken(respStart); err != nil {
			if err = session.EncodeToken(respStart); err != nil {
				return
			}
			if c.Respond != nil {
				err = c.Respond(ctx, wc)
				err = c.Respond(ctx, session)
				if err != nil {
					return
				}
			}
			if err = wc.EncodeToken(respStart.End()); err != nil {
			if err = session.EncodeToken(respStart.End()); err != nil {
				return
			}


M sasl.go => sasl.go +1 -4
@@ 130,10 130,7 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature
				return mask, nil, err
			}

			rc := session.TokenReader()
			/* #nosec */
			defer rc.Close()
			d := xml.NewTokenDecoder(rc)
			d := xml.NewTokenDecoder(session)

			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.

M sasl2/sasl.go => sasl2/sasl.go +4 -7
@@ 94,9 94,6 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFe
			}

			conn := session.Conn()
			rc := session.TokenReader()
			/* #nosec */
			defer rc.Close()

			// Select a mechanism, preferring the client order.
			var selected sasl.Mechanism


@@ 156,14 153,14 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFe
			// If we're already done after the first step, decode the <success/> or
			// <failure/> before we exit.
			if !more {
				tok, err := rc.Token()
				tok, err := session.Token()
				if err != nil {
					return mask, nil, err
				}
				if t, ok := tok.(xml.StartElement); ok {
					// TODO: Handle the additional data that could be returned if
					// success?
					_, _, err := decodeSASLChallenge(rc, t, false)
					_, _, err := decodeSASLChallenge(session, t, false)
					if err != nil {
						return mask, nil, err
					}


@@ 179,13 176,13 @@ func SASL(identity, password string, mechanisms ...sasl.Mechanism) xmpp.StreamFe
					return mask, nil, ctx.Err()
				default:
				}
				tok, err := rc.Token()
				tok, err := session.Token()
				if err != nil {
					return mask, nil, err
				}
				var challenge []byte
				if t, ok := tok.(xml.StartElement); ok {
					challenge, success, err = decodeSASLChallenge(rc, t, true)
					challenge, success, err = decodeSASLChallenge(session, t, true)
					if err != nil {
						return mask, nil, err
					}

M session.go => session.go +72 -144
@@ 83,7 83,6 @@ type Session struct {
		cancel context.CancelFunc
	}
	out struct {
		sync.Mutex
		internal.StreamInfo
		e xmlstream.TokenWriter
	}


@@ 104,36 103,6 @@ type Session struct {
// (encoders, decoders, etc.) will be reset.
type Negotiator func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, cache interface{}, err error)

type stateCheckReader struct {
	s *Session
}

func (r stateCheckReader) Read(p []byte) (int, error) {
	r.s.slock.RLock()
	defer r.s.slock.RUnlock()

	if r.s.state&InputStreamClosed == InputStreamClosed {
		return 0, ErrInputStreamClosed
	}

	return r.s.conn.Read(p)
}

type stateCheckWriter struct {
	s *Session
}

func (w stateCheckWriter) Write(p []byte) (int, error) {
	w.s.slock.RLock()
	defer w.s.slock.RUnlock()

	if w.s.state&OutputStreamClosed == OutputStreamClosed {
		return 0, ErrOutputStreamClosed
	}

	return w.s.conn.Write(p)
}

// NegotiateSession creates an XMPP session using a custom negotiate function.
// Calling NegotiateSession with a nil Negotiator panics.
//


@@ 149,8 118,8 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
		features:   make(map[string]interface{}),
		negotiated: make(map[string]struct{}),
	}
	s.in.d = xml.NewDecoder(stateCheckReader{s: s})
	s.out.e = xml.NewEncoder(stateCheckWriter{s: s})
	s.in.d = xml.NewDecoder(s.conn)
	s.out.e = xml.NewEncoder(s.conn)
	s.in.ctx, s.in.cancel = context.WithCancel(context.Background())

	// If rw was already a *tls.Conn, go ahead and mark the connection as secure


@@ 177,8 146,8 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
				delete(s.negotiated, k)
			}
			s.conn = newConn(rw, s.conn)
			s.in.d = xml.NewDecoder(stateCheckReader{s: s})
			s.out.e = xml.NewEncoder(stateCheckWriter{s: s})
			s.in.d = xml.NewDecoder(s.conn)
			s.out.e = xml.NewEncoder(s.conn)
		}
		s.state |= mask
	}


@@ 284,6 253,9 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// If the input stream is closed, Serve returns.
// Serve does not close the output stream.
func (s *Session) Serve(h Handler) error {
	s.in.Lock()
	defer s.in.Unlock()

	return s.handleInputStream(h)
}



@@ 292,13 264,9 @@ func (s *Session) Serve(h Handler) error {
// If an error is returned (the original error or a different one), it has not
// been handled fully and must be handled by the caller.
func (s *Session) sendError(err error) (e error) {
	wc := s.TokenWriter()
	/* #nosec */
	defer wc.Close()

	switch typErr := err.(type) {
	case stream.Error:
		if _, e = typErr.WriteXML(wc); e != nil {
		if _, e = typErr.WriteXML(s); e != nil {
			return e
		}
		if e = s.Close(); e != nil {


@@ 312,7 280,7 @@ func (s *Session) sendError(err error) (e error) {
	//     The error condition is not one of those defined by the other
	//     conditions in this list; this error condition SHOULD NOT be used
	//     except in conjunction with an application-specific condition.
	if _, e = stream.UndefinedCondition.WriteXML(wc); e != nil {
	if _, e = stream.UndefinedCondition.WriteXML(s); e != nil {
		return e
	}
	return err


@@ 334,78 302,60 @@ func (s *Session) handleInputStream(handler Handler) (err error) {
			return s.in.ctx.Err()
		default:
		}
		tok, err := s.Token()
		// TODO: If this is a network issue we should return it, if not we should
		// handle it.
		if err != nil {
			return s.sendError(err)
		}

		err = func() error {
			rc := s.TokenReader()
			/* #nosec */
			defer rc.Close()

			// TODO: should we change handle to not pass a writer? Why lock this if we
			// don't know it's necessary? Maybe pass the session and the locked reader
			// instead so they can aquire the write lock only if necessary.
			wc := s.TokenWriter()
			/* #nosec */
			defer wc.Close()

			tok, err := rc.Token()
			// TODO: If this is a network issue we should return it, if not we should
			// handle it.
			if err != nil {
				return s.sendError(err)
			}

			var start xml.StartElement
			switch t := tok.(type) {
			case xml.StartElement:
				start = t
			case xml.EndElement:
				if t.Name.Space == ns.Stream && t.Name.Local == "stream" {
					return nil
				}
				// If this is a stream level end element but not </stream:stream>,
				// something is really weird…
				return s.sendError(stream.BadFormat)
			default:
				// If this isn't a start element, the stream is in a bad state.
				return s.sendError(stream.BadFormat)
		var start xml.StartElement
		switch t := tok.(type) {
		case xml.StartElement:
			start = t
		case xml.EndElement:
			if t.Name.Space == ns.Stream && t.Name.Local == "stream" {
				return nil
			}
			// If this is a stream level end element but not </stream:stream>,
			// something is really weird…
			return s.sendError(stream.BadFormat)
		default:
			// If this isn't a start element, the stream is in a bad state.
			return s.sendError(stream.BadFormat)
		}

			rw := struct {
				xml.TokenReader
				xmlstream.TokenWriter
			}{
				TokenReader: xmlstream.Inner(rc),
				TokenWriter: wc,
			}
		rw := struct {
			xml.TokenReader
			xmlstream.TokenWriter
		}{
			TokenReader: xmlstream.Inner(s),
			TokenWriter: s,
		}

			// Handle stream errors and unknown stream namespaced tokens first, before
			// delegating to the normal handler.
			if start.Name.Space == ns.Stream {
				switch start.Name.Local {
				case "error":
					// TODO: Unmarshal the error and return it.
					return nil
				default:
					return s.sendError(stream.UnsupportedStanzaType)
				}
		// Handle stream errors and unknown stream namespaced tokens first, before
		// delegating to the normal handler.
		if start.Name.Space == ns.Stream {
			switch start.Name.Local {
			case "error":
				// TODO: Unmarshal the error and return it.
				return nil
			default:
				return s.sendError(stream.UnsupportedStanzaType)
			}
		}

			if err = handler.HandleXMPP(rw, &start); err != nil {
				return s.sendError(err)
			}
			// Advance to the end of the current element before attempting to read the
			// next.
			//
			// TODO: Error handling should be the same here as it would be for the
			// rest of this loop.
			_, err = xmlstream.Copy(discard, rw)
			if err != nil {
				return s.sendError(err)
			}
			return nil
		}()
		if err = handler.HandleXMPP(rw, &start); err != nil {
			return s.sendError(err)
		}
		// Advance to the end of the current element before attempting to read the
		// next.
		//
		// TODO: Error handling should be the same here as it would be for the rest
		// of this loop.
		_, err = xmlstream.Copy(discard, rw)
		if err != nil {
			return err
			return s.sendError(err)
		}
	}
}


@@ 428,50 378,26 @@ func (s *Session) Conn() net.Conn {
	return s.conn
}

type readCloser struct {
	xml.TokenReader
	c func() error
}

func (c readCloser) Close() error {
	return c.c()
}

type writeCloser struct {
	xmlstream.TokenWriter
	c func() error
}

func (c writeCloser) Close() error {
	return c.c()
}

// TokenReader locks the underlying XML input stream and returns a value that
// has exclusive read access to it until Close is called.
func (s *Session) TokenReader() xmlstream.TokenReadCloser {
	s.in.Lock()
// Token satisfies the xml.TokenReader interface for Session.
func (s *Session) Token() (xml.Token, error) {
	s.slock.RLock()
	defer s.slock.RUnlock()

	return readCloser{
		TokenReader: s.in.d,
		c: func() error {
			s.in.Unlock()
			return nil
		},
	if s.state&InputStreamClosed == InputStreamClosed {
		return nil, ErrInputStreamClosed
	}
	return s.in.d.Token()
}

// TokenWriter locks the underlying XML output stream and returns a value that
// has exclusive write access to it until Close is called.
func (s *Session) TokenWriter() xmlstream.TokenWriteCloser {
	s.out.Lock()
// EncodeToken satisfies the xmlstream.TokenWriter interface.
func (s *Session) EncodeToken(t xml.Token) error {
	s.slock.RLock()
	defer s.slock.RUnlock()

	return writeCloser{
		TokenWriter: s.out.e,
		c: func() error {
			s.out.Unlock()
			return nil
		},
	if s.state&OutputStreamClosed == OutputStreamClosed {
		return ErrOutputStreamClosed
	}
	return s.out.e.EncodeToken(t)
}

// Flush satisfies the xmlstream.TokenWriter interface.


@@ 506,6 432,8 @@ func (s *Session) Close() error {
// State returns the current state of the session. For more information, see the
// SessionState type.
func (s *Session) State() SessionState {
	s.slock.RLock()
	defer s.slock.RUnlock()
	return s.state
}


M session_test.go => session_test.go +5 -4
@@ 27,9 27,8 @@ func TestClosedInputStream(t *testing.T) {
			mask := xmpp.SessionState(i)
			buf := new(bytes.Buffer)
			s := xmpptest.NewSession(mask, buf)
			rc := s.TokenReader()

			_, err := rc.Token()
			_, err := s.Token()
			switch {
			case mask&xmpp.InputStreamClosed == xmpp.InputStreamClosed && err != xmpp.ErrInputStreamClosed:
				t.Errorf("Unexpected error: want=`%v', got=`%v'", xmpp.ErrInputStreamClosed, err)


@@ 46,9 45,11 @@ func TestClosedOutputStream(t *testing.T) {
			mask := xmpp.SessionState(i)
			buf := new(bytes.Buffer)
			s := xmpptest.NewSession(mask, buf)
			wc := s.TokenWriter()

			if err := wc.EncodeToken(xml.CharData("chartoken")); err != nil {
			switch err := s.EncodeToken(xml.CharData("chartoken")); {
			case mask&xmpp.OutputStreamClosed == xmpp.OutputStreamClosed && err != xmpp.ErrOutputStreamClosed:
				t.Errorf("Unexpected error: want=`%v', got=`%v'", xmpp.ErrOutputStreamClosed, err)
			case mask&xmpp.OutputStreamClosed == 0 && err != nil:
				t.Errorf("Unexpected error: `%v'", err)
			}
			switch err := s.Flush(); {

M starttls.go => starttls.go +1 -4
@@ 53,10 53,7 @@ func StartTLS(required bool, cfg *tls.Config) StreamFeature {
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			conn := session.Conn()
			state := session.State()
			rc := session.TokenReader()
			/* #nosec */
			defer rc.Close()
			d := xml.NewTokenDecoder(rc)
			d := xml.NewTokenDecoder(session)

			// If no TLSConfig was specified, use a default config.
			if cfg == nil {