~samwhited/xmpp

11eadb30f1eff7c52470d82608334b7ce9b5139c — Sam Whited 1 year, 9 months ago d0cd96c
Revert "all: pass Session directly to handler"

This reverts commit 7e76defde884af7cd63b3d6bb7065694db076b8f.
M echobot_example_test.go => echobot_example_test.go +2 -4
@@ 55,10 55,8 @@ func Example_echobot() {
		return
	}

	s.Serve(xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
		r := s.TokenReader()
		defer r.Close()
		d := xml.NewTokenDecoder(r)
	s.Serve(xmpp.HandlerFunc(func(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(t)

		// Ignore anything that's not a message. In a real system we'd want to at
		// least respond to IQs.

M examples/echobot/echo.go => examples/echobot/echo.go +3 -5
@@ 63,10 63,8 @@ func echo(addr, pass string, xmlIn, xmlOut io.Writer, logger, debug *log.Logger)
		return fmt.Errorf("Error sending initial presence: %w", err)
	}

	return s.Serve(xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
		r := s.TokenReader()
		defer r.Close()
		d := xml.NewTokenDecoder(r)
	return s.Serve(xmpp.HandlerFunc(func(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(t)

		// Ignore anything that's not a message. In a real system we'd want to at
		// least respond to IQs.


@@ 98,7 96,7 @@ func echo(addr, pass string, xmlIn, xmlOut io.Writer, logger, debug *log.Logger)
			}), xml.StartElement{Name: xml.Name{Local: "body"}}),
		)
		debug.Printf("Replying to message %q from %s with body %q", msg.ID, msg.From.Bare(), msg.Body)
		err = s.Send(context.TODO(), reply)
		_, err = xmlstream.Copy(t, reply)
		if err != nil {
			logger.Printf("Error responding to message %q: %q", msg.ID, err)
		}

M handler.go => handler.go +6 -4
@@ 6,20 6,22 @@ package xmpp

import (
	"encoding/xml"

	"mellium.im/xmlstream"
)

// A Handler triggers events or responds to incoming elements in an XML stream.
type Handler interface {
	HandleXMPP(s *Session, start *xml.StartElement) error
	HandleXMPP(t xmlstream.TokenReadWriter, start *xml.StartElement) error
}

// The HandlerFunc type is an adapter to allow the use of ordinary functions as
// XMPP handlers.
// If f is a function with the appropriate signature, HandlerFunc(f) is a
// Handler that calls f.
type HandlerFunc func(s *Session, start *xml.StartElement) error
type HandlerFunc func(t xmlstream.TokenReadWriter, start *xml.StartElement) error

// HandleXMPP calls f(t, start).
func (f HandlerFunc) HandleXMPP(s *Session, start *xml.StartElement) error {
	return f(s, start)
func (f HandlerFunc) HandleXMPP(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
	return f(t, start)
}

A handler_test.go => handler_test.go +35 -0
@@ 0,0 1,35 @@
package xmpp_test

import (
	"encoding/xml"
	"errors"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
)

var errHandlerFuncSentinal = errors.New("handler test")

type sentinalReadWriter struct{}

func (sentinalReadWriter) Token() (xml.Token, error)   { return nil, nil }
func (sentinalReadWriter) EncodeToken(xml.Token) error { return nil }

func TestHandlerFunc(t *testing.T) {
	s := &xml.StartElement{}
	var f xmpp.HandlerFunc = func(r xmlstream.TokenReadWriter, start *xml.StartElement) error {
		if _, ok := r.(sentinalReadWriter); !ok {
			t.Errorf("HandleXMPP did not pass reader to HandlerFunc")
		}
		if start != s {
			t.Errorf("HandleXMPP did not pass start token to HandlerFunc")
		}
		return errHandlerFuncSentinal
	}

	err := f.HandleXMPP(sentinalReadWriter{}, s)
	if err != errHandlerFuncSentinal {
		t.Errorf("HandleXMPP did not return handlerfunc error, got %q", err)
	}
}

M mux/mux.go => mux/mux.go +7 -12
@@ 32,7 32,7 @@ type ServeMux struct {
	patterns map[xml.Name]xmpp.Handler
}

func fallback(s *xmpp.Session, start *xml.StartElement) error {
func fallback(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
	if start.Name.Local != "iq" {
		return nil
	}


@@ 84,13 84,8 @@ func fallback(s *xmpp.Session, start *xml.StartElement) error {
		Type:      stanza.Cancel,
		Condition: stanza.FeatureNotImplemented,
	}
	w := s.TokenWriter()
	defer w.Close()
	_, err := xmlstream.Copy(w, xmlstream.Wrap(e.TokenReader(), *start))
	if err != nil {
		return err
	}
	return w.Flush()
	_, err := xmlstream.Copy(t, xmlstream.Wrap(e.TokenReader(), *start))
	return err
}

// New allocates and returns a new ServeMux.


@@ 118,23 113,23 @@ func (m *ServeMux) Handler(name xml.Name) (h xmpp.Handler, ok bool) {

// HandleXMPP dispatches the request to the handler whose pattern most closely
// matches start.Name.
func (m *ServeMux) HandleXMPP(s *xmpp.Session, start *xml.StartElement) error {
func (m *ServeMux) HandleXMPP(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
	h, ok := m.Handler(start.Name)
	if ok {
		return h.HandleXMPP(s, start)
		return h.HandleXMPP(t, start)
	}

	n := start.Name
	n.Space = ""
	h, ok = m.Handler(n)
	if ok {
		return h.HandleXMPP(s, start)
		return h.HandleXMPP(t, start)
	}

	n = start.Name
	n.Local = ""
	h, _ = m.Handler(n)
	return h.HandleXMPP(s, start)
	return h.HandleXMPP(t, start)
}

// Option configures a ServeMux.

M mux/mux_test.go => mux/mux_test.go +16 -4
@@ 13,6 13,7 @@ import (
	"strings"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/internal/xmpptest"


@@ 21,11 22,11 @@ import (

var passTest = errors.New("mux_test: PASSED")

var passHandler xmpp.HandlerFunc = func(*xmpp.Session, *xml.StartElement) error {
var passHandler xmpp.HandlerFunc = func(xmlstream.TokenReadWriter, *xml.StartElement) error {
	return passTest
}

var failHandler xmpp.HandlerFunc = func(*xmpp.Session, *xml.StartElement) error {
var failHandler xmpp.HandlerFunc = func(xmlstream.TokenReadWriter, *xml.StartElement) error {
	return errors.New("mux_test: FAILED")
}



@@ 78,7 79,7 @@ var testCases = [...]struct {
func TestMux(t *testing.T) {
	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			err := tc.m.HandleXMPP(&xmpp.Session{}, &xml.StartElement{Name: tc.p})
			err := tc.m.HandleXMPP(nil, &xml.StartElement{Name: tc.p})
			if err != passTest {
				t.Fatalf("unexpected error: `%v'", err)
			}


@@ 104,10 105,21 @@ func TestFallback(t *testing.T) {
		t.Fatalf("Bad start token read: `%v'", err)
	}
	start := tok.(xml.StartElement)
	err = mux.New().HandleXMPP(s, &start)
	w := s.TokenWriter()
	defer w.Close()
	err = mux.New().HandleXMPP(struct {
		xml.TokenReader
		xmlstream.TokenWriter
	}{
		TokenReader: r,
		TokenWriter: w,
	}, &start)
	if err != nil {
		t.Errorf("Unexpected error: `%v'", err)
	}
	if err := w.Flush(); err != nil {
		t.Errorf("Unexpected error flushing token writer: %q", err)
	}

	const expected = `<iq to="juliet@example.com" from="romeo@example.com" id="123" type="error"><error type="cancel"><feature-not-implemented xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></feature-not-implemented></error></iq>`
	if buf.String() != expected {

M session.go => session.go +36 -41
@@ 280,7 280,7 @@ func (s *Session) sendError(err error) (e error) {

type nopHandler struct{}

func (nopHandler) HandleXMPP(_ *Session, _ *xml.StartElement) error {
func (nopHandler) HandleXMPP(_ xmlstream.TokenReadWriter, _ *xml.StartElement) error {
	return nil
}



@@ 409,45 409,44 @@ func handleInputStream(s *Session, handler Handler) (err error) {

	noreply:

		rw := &responseChecker{
			twf:         s.out.e,
			TokenReader: xmlstream.Inner(s.in.d),
			id:          id,
		}
		// Make a copy of the session and set its output stream to the response
		// checker. This means that HandleXMPP will see the state bits as they were
		// when it was first called and will not recieve updates, and that we don't
		// have to take a lock to ensure that nothing else reads or writes through
		// the responseChecker. Instead, the lock will only be taken if something
		// tries to read/write XML from inside the handler.
		ss := *s
		ss.out.e = rw
		ss.in.d = rw
		if err = handler.HandleXMPP(&ss, &start); err != nil {
			return s.sendError(err)
		}
		err = func() error {
			r := s.TokenReader()
			w := s.TokenWriter()
			defer r.Close()
			defer w.Close()

		// If the user did not write a response to an IQ, send a default one.
		if needsResp && !rw.wroteResp {
			_, err := xmlstream.Copy(s.out.e, stanza.WrapIQ(stanza.IQ{
				ID:   id,
				Type: stanza.ErrorIQ,
			}, stanza.Error{
				Type:      stanza.Cancel,
				Condition: stanza.ServiceUnavailable,
			}.TokenReader()))
			if err != nil {
			rw := &responseChecker{
				TokenReader: xmlstream.Inner(r),
				TokenWriter: w,
				id:          id,
			}
			if err := handler.HandleXMPP(rw, &start); err != nil {
				return err
			}
		}

		if err := s.out.e.Flush(); err != nil {
			return err
		}
			// If the user did not write a response to an IQ, send a default one.
			if needsResp && !rw.wroteResp {
				_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
					ID:   id,
					Type: stanza.ErrorIQ,
				}, stanza.Error{
					Type:      stanza.Cancel,
					Condition: stanza.ServiceUnavailable,
				}.TokenReader()))
				if err != nil {
					return err
				}
			}

			if err := w.Flush(); err != nil {
				return err
			}

		// Advance to the end of the current element before attempting to read the
		// next.
		_, err = xmlstream.Copy(discard, rw)
			// Advance to the end of the current element before attempting to read the
			// next.
			_, err = xmlstream.Copy(discard, rw)
			return err
		}()
		if err != nil {
			return s.sendError(err)
		}


@@ 455,17 454,13 @@ func handleInputStream(s *Session, handler Handler) (err error) {
}

type responseChecker struct {
	twf tokenWriteFlusher
	xml.TokenReader
	xmlstream.TokenWriter
	id        string
	wroteResp bool
	level     int
}

func (rw *responseChecker) Flush() error {
	return rw.twf.Flush()
}

func (rw *responseChecker) EncodeToken(t xml.Token) error {
	switch tok := t.(type) {
	case xml.StartElement:


@@ 478,7 473,7 @@ func (rw *responseChecker) EncodeToken(t xml.Token) error {
		rw.level--
	}

	return rw.twf.EncodeToken(t)
	return rw.TokenWriter.EncodeToken(t)
}

// Feature checks if a feature with the given namespace was advertised

M session_test.go => session_test.go +12 -20
@@ 158,10 158,8 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	3: {
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
				ID:   "1234",
				Type: stanza.ResultIQ,
			}, nil))


@@ 172,10 170,8 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	4: {
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
				ID:   "wrongid",
				Type: stanza.ResultIQ,
			}, nil))


@@ 186,10 182,8 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	5: {
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
				ID:   "1234",
				Type: stanza.ErrorIQ,
			}, nil))


@@ 200,10 194,8 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	6: {
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
				ID:   "1234",
				Type: stanza.GetIQ,
			}, nil))


@@ 214,7 206,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	7: {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value != "" {
					panic("expected attr to be normalized")


@@ 227,7 219,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	8: {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value == "" {
					panic("expected attr not to be normalized")


@@ 240,7 232,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	9: {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value == "" {
					panic("expected attr not to be normalized")


@@ 253,7 245,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	10: {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value == "" {
					panic("expected attr not to be normalized")