~samwhited/xmpp

7e76defde884af7cd63b3d6bb7065694db076b8f — Sam Whited 2 years ago bf9593a
all: pass Session directly to handler
8 files changed, 71 insertions(+), 92 deletions(-)

M echobot_example_test.go
M examples/echobot/echo.go
M handler.go
D handler_test.go
M mux/mux.go
M mux/mux_test.go
M session.go
M session_test.go
M echobot_example_test.go => echobot_example_test.go +2 -2
@@ 55,8 55,8 @@ func Example_echobot() {
		return
	}

	s.Serve(xmpp.HandlerFunc(func(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(t)
	s.Serve(xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(s)

		// Ignore anything that's not a message. In a real system we'd want to at
		// least respond to IQs.

M examples/echobot/echo.go => examples/echobot/echo.go +2 -2
@@ 63,8 63,8 @@ func echo(addr, pass string, xmlIn, xmlOut io.Writer, logger, debug *log.Logger)
		return fmt.Errorf("Error sending initial presence: %w", err)
	}

	return s.Serve(xmpp.HandlerFunc(func(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(t)
	return s.Serve(xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
		d := xml.NewTokenDecoder(s)

		// Ignore anything that's not a message. In a real system we'd want to at
		// least respond to IQs.

M handler.go => handler.go +4 -6
@@ 6,22 6,20 @@ package xmpp

import (
	"encoding/xml"

	"mellium.im/xmlstream"
)

// A Handler triggers events or responds to incoming elements in an XML stream.
type Handler interface {
	HandleXMPP(t xmlstream.TokenReadWriter, start *xml.StartElement) error
	HandleXMPP(s *Session, start *xml.StartElement) error
}

// The HandlerFunc type is an adapter to allow the use of ordinary functions as
// XMPP handlers.
// If f is a function with the appropriate signature, HandlerFunc(f) is a
// Handler that calls f.
type HandlerFunc func(t xmlstream.TokenReadWriter, start *xml.StartElement) error
type HandlerFunc func(s *Session, start *xml.StartElement) error

// HandleXMPP calls f(t, start).
func (f HandlerFunc) HandleXMPP(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
	return f(t, start)
func (f HandlerFunc) HandleXMPP(s *Session, start *xml.StartElement) error {
	return f(s, start)
}

D handler_test.go => handler_test.go +0 -35
@@ 1,35 0,0 @@
package xmpp_test

import (
	"encoding/xml"
	"errors"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
)

var errHandlerFuncSentinal = errors.New("handler test")

type sentinalReadWriter struct{}

func (sentinalReadWriter) Token() (xml.Token, error)   { return nil, nil }
func (sentinalReadWriter) EncodeToken(xml.Token) error { return nil }

func TestHandlerFunc(t *testing.T) {
	s := &xml.StartElement{}
	var f xmpp.HandlerFunc = func(r xmlstream.TokenReadWriter, start *xml.StartElement) error {
		if _, ok := r.(sentinalReadWriter); !ok {
			t.Errorf("HandleXMPP did not pass reader to HandlerFunc")
		}
		if start != s {
			t.Errorf("HandleXMPP did not pass start token to HandlerFunc")
		}
		return errHandlerFuncSentinal
	}

	err := f.HandleXMPP(sentinalReadWriter{}, s)
	if err != errHandlerFuncSentinal {
		t.Errorf("HandleXMPP did not return handlerfunc error, got %q", err)
	}
}

M mux/mux.go => mux/mux.go +8 -6
@@ 32,7 32,7 @@ type ServeMux struct {
	patterns map[xml.Name]xmpp.Handler
}

func fallback(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
func fallback(s *xmpp.Session, start *xml.StartElement) error {
	if start.Name.Local != "iq" {
		return nil
	}


@@ 84,7 84,9 @@ func fallback(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
		Type:      stanza.Cancel,
		Condition: stanza.FeatureNotImplemented,
	}
	_, err := xmlstream.Copy(t, xmlstream.Wrap(e.TokenReader(), *start))
	w := s.TokenWriter()
	defer w.Close()
	_, err := xmlstream.Copy(w, xmlstream.Wrap(e.TokenReader(), *start))
	return err
}



@@ 113,23 115,23 @@ func (m *ServeMux) Handler(name xml.Name) (h xmpp.Handler, ok bool) {

// HandleXMPP dispatches the request to the handler whose pattern most closely
// matches start.Name.
func (m *ServeMux) HandleXMPP(t xmlstream.TokenReadWriter, start *xml.StartElement) error {
func (m *ServeMux) HandleXMPP(s *xmpp.Session, start *xml.StartElement) error {
	h, ok := m.Handler(start.Name)
	if ok {
		return h.HandleXMPP(t, start)
		return h.HandleXMPP(s, start)
	}

	n := start.Name
	n.Space = ""
	h, ok = m.Handler(n)
	if ok {
		return h.HandleXMPP(t, start)
		return h.HandleXMPP(s, start)
	}

	n = start.Name
	n.Local = ""
	h, _ = m.Handler(n)
	return h.HandleXMPP(t, start)
	return h.HandleXMPP(s, start)
}

// Option configures a ServeMux.

M mux/mux_test.go => mux/mux_test.go +14 -21
@@ 13,28 13,22 @@ import (
	"strings"
	"testing"

	"mellium.im/xmlstream"
	"mellium.im/xmpp"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/internal/xmpptest"
	"mellium.im/xmpp/mux"
)

var passTest = errors.New("mux_test: PASSED")

var passHandler xmpp.HandlerFunc = func(xmlstream.TokenReadWriter, *xml.StartElement) error {
var passHandler xmpp.HandlerFunc = func(*xmpp.Session, *xml.StartElement) error {
	return passTest
}

var failHandler xmpp.HandlerFunc = func(xmlstream.TokenReadWriter, *xml.StartElement) error {
var failHandler xmpp.HandlerFunc = func(*xmpp.Session, *xml.StartElement) error {
	return errors.New("mux_test: FAILED")
}

type nopRW struct{}

func (nopRW) EncodeToken(xml.Token) error { return nil }
func (nopRW) Flush() error                { return nil }
func (nopRW) Token() (xml.Token, error)   { return nil, io.EOF }

var testCases = [...]struct {
	m *mux.ServeMux
	p xml.Name


@@ 84,7 78,7 @@ var testCases = [...]struct {
func TestMux(t *testing.T) {
	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			err := tc.m.HandleXMPP(nopRW{}, &xml.StartElement{Name: tc.p})
			err := tc.m.HandleXMPP(&xmpp.Session{}, &xml.StartElement{Name: tc.p})
			if err != passTest {
				t.Fatalf("unexpected error: `%v'", err)
			}


@@ 93,31 87,30 @@ func TestMux(t *testing.T) {
}

func TestFallback(t *testing.T) {
	d := xml.NewDecoder(strings.NewReader(`<iq to="romeo@example.com" from="juliet@example.com"><test/></iq>`))
	buf := new(bytes.Buffer)
	e := xml.NewEncoder(buf)
	buf := &bytes.Buffer{}
	rw := struct {
		xml.TokenReader
		xmlstream.TokenWriter
		io.Reader
		io.Writer
	}{
		TokenReader: d,
		TokenWriter: e,
		Reader: strings.NewReader(`<iq to="romeo@example.com" from="juliet@example.com" id="123"><test/></iq>`),
		Writer: buf,
	}
	s := xmpptest.NewSession(0, rw)

	tok, err := rw.Token()
	tok, err := s.Token()
	if err != nil {
		t.Fatalf("Bad start token read: `%v'", err)
	}
	start := tok.(xml.StartElement)
	err = mux.New().HandleXMPP(rw, &start)
	err = mux.New().HandleXMPP(s, &start)
	if err != nil {
		t.Errorf("Unexpected error: `%v'", err)
	}
	if err = e.Flush(); err != nil {
	if err = s.Flush(); err != nil {
		t.Errorf("Unexpected error: `%v'", err)
	}

	const expected = `<iq to="juliet@example.com" from="romeo@example.com" type="error"><error type="cancel"><feature-not-implemented xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></feature-not-implemented></error></iq>`
	const expected = `<iq to="juliet@example.com" from="romeo@example.com" id="123" type="error"><error type="cancel"><feature-not-implemented xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></feature-not-implemented></error></iq>`
	if buf.String() != expected {
		t.Errorf("Bad output:\nwant=`%v'\n got=`%v'", expected, buf.String())
	}

M session.go => session.go +21 -8
@@ 95,7 95,7 @@ type Session struct {
	out struct {
		internal.StreamInfo
		e tokenWriteFlusher
		sync.Mutex
		sync.Locker
	}
}



@@ 133,6 133,7 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
	if received {
		s.state |= Received
	}
	s.out.Locker = &sync.Mutex{}
	s.in.d = xml.NewDecoder(s.conn)
	s.out.e = xml.NewEncoder(s.conn)
	s.in.ctx, s.in.cancel = context.WithCancel(context.Background())


@@ 277,7 278,7 @@ func (s *Session) sendError(err error) (e error) {

type nopHandler struct{}

func (nopHandler) HandleXMPP(_ xmlstream.TokenReadWriter, _ *xml.StartElement) error {
func (nopHandler) HandleXMPP(_ *Session, _ *xml.StartElement) error {
	return nil
}



@@ 407,11 408,19 @@ func handleInputStream(s *Session, handler Handler) (err error) {
	noreply:

		rw := &responseChecker{
			twf:         s.out.e,
			TokenReader: xmlstream.Inner(s),
			TokenWriter: s.out.e,
			id:          id,
		}
		if err = handler.HandleXMPP(rw, &start); err != nil {
		// Make a copy of the session and set its output stream to the response
		// checker. This means that HandleXMPP will see the state bits as they were
		// when it was first called and will not recieve updates, and that we don't
		// have to take a lock to ensure that nothing else reads or writes through
		// the responseChecker. Instead, the lock will only be taken if something
		// tries to read/write XML from inside the handler.
		ss := *s
		ss.out.e = rw
		if err = handler.HandleXMPP(&ss, &start); err != nil {
			return s.sendError(err)
		}



@@ 446,13 455,17 @@ func handleInputStream(s *Session, handler Handler) (err error) {
}

type responseChecker struct {
	twf tokenWriteFlusher
	xml.TokenReader
	xmlstream.TokenWriter
	id        string
	wroteResp bool
	level     int
}

func (rw *responseChecker) Flush() error {
	return rw.twf.Flush()
}

func (rw *responseChecker) EncodeToken(t xml.Token) error {
	switch tok := t.(type) {
	case xml.StartElement:


@@ 465,7 478,7 @@ func (rw *responseChecker) EncodeToken(t xml.Token) error {
		rw.level--
	}

	return rw.TokenWriter.EncodeToken(t)
	return rw.twf.EncodeToken(t)
}

// Feature checks if a feature with the given namespace was advertised


@@ 496,7 509,7 @@ func (s *Session) Token() (xml.Token, error) {
type lockWriteCloser struct {
	w   *Session
	err error
	m   *sync.Mutex
	m   sync.Locker
}

func (lwc *lockWriteCloser) EncodeToken(t xml.Token) error {


@@ 530,7 543,7 @@ func (s *Session) TokenWriter() xmlstream.TokenWriteCloser {
	s.out.Lock()

	return &lockWriteCloser{
		m: &s.out.Mutex,
		m: s.out.Locker,
		w: s,
	}
}

M session_test.go => session_test.go +20 -12
@@ 156,8 156,10 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	3: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
				ID:   "1234",
				Type: stanza.ResultIQ,
			}, nil))


@@ 168,8 170,10 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	4: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
				ID:   "wrongid",
				Type: stanza.ResultIQ,
			}, nil))


@@ 180,8 184,10 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	5: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
				ID:   "1234",
				Type: stanza.ErrorIQ,
			}, nil))


@@ 192,8 198,10 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	6: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
			_, err := xmlstream.Copy(rw, stanza.WrapIQ(stanza.IQ{
		handler: xmpp.HandlerFunc(func(s *xmpp.Session, start *xml.StartElement) error {
			w := s.TokenWriter()
			defer w.Close()
			_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
				ID:   "1234",
				Type: stanza.GetIQ,
			}, nil))


@@ 204,7 212,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	7: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value != "" {
					panic("expected attr to be normalized")


@@ 217,7 225,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	8: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value == "" {
					panic("expected attr not to be normalized")


@@ 230,7 238,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	9: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value == "" {
					panic("expected attr not to be normalized")


@@ 243,7 251,7 @@ var serveTests = [...]struct {
		err: io.EOF,
	},
	10: {
		handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadWriter, start *xml.StartElement) error {
		handler: xmpp.HandlerFunc(func(_ *xmpp.Session, start *xml.StartElement) error {
			for _, attr := range start.Attr {
				if attr.Name.Local == "from" && attr.Value == "" {
					panic("expected attr not to be normalized")