~samwhited/xmpp

74518cc1fa0621ade0b31aac25634add39bd6995 — Sam Whited 2 years ago 5ae11bf
all: add new API to make writing tokens safe
7 files changed, 74 insertions(+), 28 deletions(-)

M CHANGELOG.md
M bind.go
M doc.go
M features.go
M ibr2/ibr2.go
M session.go
M session_test.go
M CHANGELOG.md => CHANGELOG.md +1 -0
@@ 10,6 10,7 @@ All notable changes to this project will be documented in this file.
- dial: use underlying net.Dialer's DNS Resolver in Dialer.
- stanza: change API of `WrapIQ` and `WrapPresence` to not abuse pointers
- xmpp: add new `SendIQ` API and remove response from `Send` and `SendElement`
- xmpp: new API for writing custom tokens to a session

### Fixed


M bind.go => bind.go +4 -2
@@ 101,6 101,8 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
		},
		Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
			d := xml.NewTokenDecoder(session)
			w := session.TokenWriter()
			defer w.Close()

			// Handle the server side of resource binding if we're on the receiving
			// end of the connection.


@@ 153,7 155,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
					resp.Bind = bindPayload{JID: j}
				}

				_, err = resp.WriteXML(session)
				_, err = resp.WriteXML(w)
				if err != nil {
					return mask, nil, err
				}


@@ 171,7 173,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
					Resource: session.origin.Resourcepart(),
				},
			}
			_, err = req.WriteXML(session)
			_, err = req.WriteXML(w)
			if err != nil {
				return mask, nil, err
			}

M doc.go => doc.go +9 -6
@@ 76,16 76,18 @@
// To receive XML on the input stream, Session implements the xml.TokenReader
// interface defined in encoding/xml; this allows session to be wrapped with
// xml.NewTokenDecoder.
// To send XML on the output stream, Session has an EncodeToken method like the
// "mellium.im/xmlstream".TokenWriter interface.
// To send XML on the output stream, Session has a TokenEncoder method that
// returns a token encoder that holds a lock on the output stream until it is
// closed.
// The session may also buffer writes and has a Flush method which will write
// any buffered XML to the underlying connection.
//
// However, writing individual XML tokens can be tedious and error prone.
// The mellium.im/xmpp/stanza package contains functions and structs that aid in
// the construction of message, presence and info/query (IQ) elements which have
// Writing individual XML tokens can be tedious and error prone.
// The stanza package contains functions and structs that aid in the
// construction of message, presence and info/query (IQ) elements which have
// special semantics in XMPP and are known as "stanzas".
// These can be sent with the Send, SendElement, and SendIQ methods.
// These can be sent with the Send, SendElement, SendIQ, and SendIQElement
// methods.
//
//     // Send initial presence to let the server know we want to receive messages.
//     _, err = session.Send(context.TODO(), stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))


@@ 119,6 121,7 @@
//         }
//     }))
//
//
// Be Advised
//
// This API is unstable and subject to change.

M features.go => features.go +4 -2
@@ 252,7 252,9 @@ func getFeature(name xml.Name, features []StreamFeature) (feature StreamFeature,

func writeStreamFeatures(ctx context.Context, s *Session, features []StreamFeature) (list *streamFeaturesList, err error) {
	start := xml.StartElement{Name: xml.Name{Space: "", Local: "stream:features"}}
	if err = s.EncodeToken(start); err != nil {
	w := s.TokenWriter()
	defer w.Close()
	if err = w.EncodeToken(start); err != nil {
		return
	}



@@ 283,7 285,7 @@ func writeStreamFeatures(ctx context.Context, s *Session, features []StreamFeatu
			list.total++
		}
	}
	if err = s.EncodeToken(start.End()); err != nil {
	if err = w.EncodeToken(start.End()); err != nil {
		return list, err
	}
	if err = s.Flush(); err != nil {

M ibr2/ibr2.go => ibr2/ibr2.go +8 -6
@@ 133,6 133,8 @@ func decodeClientResp(ctx context.Context, r xml.TokenReader, decode func(ctx co
func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session, interface{}) (xmpp.SessionState, io.ReadWriter, error) {
	return func(ctx context.Context, session *xmpp.Session, supported interface{}) (mask xmpp.SessionState, rw io.ReadWriter, err error) {
		server := (session.State() & xmpp.Received) == xmpp.Received
		w := session.TokenWriter()
		defer w.Close()

		if !server && !supported.(bool) {
			// We don't support some of the challenge types advertised by the server.


@@ 147,15 149,15 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
			for _, c := range challenges {
				// Send the challenge.
				start := challengeStart(c.Type)
				err = session.EncodeToken(start)
				err = w.EncodeToken(start)
				if err != nil {
					return
				}
				err = c.Send(ctx, session)
				err = c.Send(ctx, w)
				if err != nil {
					return
				}
				err = session.EncodeToken(start.End())
				err = w.EncodeToken(start.End())
				if err != nil {
					return
				}


@@ 214,16 216,16 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
			respStart := xml.StartElement{
				Name: xml.Name{Local: "response"},
			}
			if err = session.EncodeToken(respStart); err != nil {
			if err = w.EncodeToken(respStart); err != nil {
				return
			}
			if c.Respond != nil {
				err = c.Respond(ctx, session)
				err = c.Respond(ctx, w)
				if err != nil {
					return
				}
			}
			if err = session.EncodeToken(respStart.End()); err != nil {
			if err = w.EncodeToken(respStart.End()); err != nil {
				return
			}


M session.go => session.go +45 -11
@@ 254,7 254,7 @@ func (s *Session) sendError(err error) (e error) {

	switch typErr := err.(type) {
	case stream.Error:
		if _, e = typErr.WriteXML(s); e != nil {
		if _, e = typErr.WriteXML(s.out.e); e != nil {
			return e
		}
		if e = s.closeSession(); e != nil {


@@ 268,7 268,7 @@ func (s *Session) sendError(err error) (e error) {
	//     The error condition is not one of those defined by the other
	//     conditions in this list; this error condition SHOULD NOT be used
	//     except in conjunction with an application-specific condition.
	if _, e = stream.UndefinedCondition.WriteXML(s); e != nil {
	if _, e = stream.UndefinedCondition.WriteXML(s.out.e); e != nil {
		return e
	}
	return err


@@ 407,7 407,7 @@ func handleInputStream(s *Session, handler Handler) (err error) {

		rw := &responseChecker{
			TokenReader: xmlstream.Inner(s),
			TokenWriter: s,
			TokenWriter: s.out.e,
			id:          id,
		}
		if err = handler.HandleXMPP(rw, &start); err != nil {


@@ 416,7 416,7 @@ func handleInputStream(s *Session, handler Handler) (err error) {

		// If the user did not write a response to an IQ, send a default one.
		if needsResp && !rw.wroteResp {
			_, err := xmlstream.Copy(s, stanza.WrapIQ(stanza.IQ{
			_, err := xmlstream.Copy(s.out.e, stanza.WrapIQ(stanza.IQ{
				ID:   id,
				Type: stanza.ErrorIQ,
			}, stanza.Error{


@@ 493,12 493,46 @@ func (s *Session) Token() (xml.Token, error) {
	return s.in.d.Token()
}

// EncodeToken satisfies the xmlstream.TokenWriter interface.
func (s *Session) EncodeToken(t xml.Token) error {
	if s.state&OutputStreamClosed == OutputStreamClosed {
type lockWriteCloser struct {
	w   *Session
	err error
	m   *sync.Mutex
}

func (lwc *lockWriteCloser) EncodeToken(t xml.Token) error {
	if lwc.err != nil {
		return lwc.err
	}

	if lwc.w.state&OutputStreamClosed == OutputStreamClosed {
		return ErrOutputStreamClosed
	}
	return s.out.e.EncodeToken(t)

	return lwc.w.out.e.EncodeToken(t)
}

func (lwc *lockWriteCloser) Close() error {
	if lwc.err != nil {
		return nil
	}
	lwc.err = io.EOF
	lwc.m.Unlock()
	return nil
}

// TokenWriter returns a new xmlstream.TokenWriteCloser that can be used to
// write raw XML tokens to the session.
// All other writes and future calls to TokenWriter will block until the Close
// method is called.
// After the TokenWriteCloser has been closed, any future writes will return
// io.EOF.
func (s *Session) TokenWriter() xmlstream.TokenWriteCloser {
	s.out.Lock()

	return &lockWriteCloser{
		m: &s.out.Mutex,
		w: s,
	}
}

// Flush satisfies the xmlstream.TokenWriter interface.


@@ 598,15 632,15 @@ func (s *Session) SendElement(ctx context.Context, r xml.TokenReader, start xml.
		}
	}

	err := s.EncodeToken(start)
	err := s.out.e.EncodeToken(start)
	if err != nil {
		return err
	}
	_, err = xmlstream.Copy(s, xmlstream.Inner(r))
	_, err = xmlstream.Copy(s.out.e, xmlstream.Inner(r))
	if err != nil {
		return err
	}
	err = s.EncodeToken(start.End())
	err = s.out.e.EncodeToken(start.End())
	if err != nil {
		return err
	}

M session_test.go => session_test.go +3 -1
@@ 48,8 48,10 @@ func TestClosedOutputStream(t *testing.T) {
			mask := xmpp.SessionState(i)
			buf := new(bytes.Buffer)
			s := xmpptest.NewSession(mask, buf)
			w := s.TokenWriter()
			defer w.Close()

			switch err := s.EncodeToken(xml.CharData("chartoken")); {
			switch err := w.EncodeToken(xml.CharData("chartoken")); {
			case mask&xmpp.OutputStreamClosed == xmpp.OutputStreamClosed && err != xmpp.ErrOutputStreamClosed:
				t.Errorf("Unexpected error: want=`%v', got=`%v'", xmpp.ErrOutputStreamClosed, err)
			case mask&xmpp.OutputStreamClosed == 0 && err != nil: