~samwhited/xmpp

be25ce7b23f96712a76b65c28360b2f1de2514e6 — Sam Whited 1 year, 7 months ago 3a09fa2
all: rework API for SendIQ and SendIQElement
3 files changed, 186 insertions(+), 64 deletions(-)

M roster/roster.go
M send_test.go
M session.go
M roster/roster.go => roster/roster.go +1 -1
@@ 114,7 114,7 @@ func FetchIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session) *Iter {
	}
	rosterIQ := IQ{IQ: iq}
	payload := rosterIQ.payload()
	r, err := s.SendIQ(ctx, iq, payload)
	r, err := s.SendIQElement(ctx, payload, iq)
	if err != nil {
		return &Iter{err: err}
	}

M send_test.go => send_test.go +107 -34
@@ 11,6 11,7 @@ import (
	"errors"
	"io"
	"strconv"
	"strings"
	"testing"
	"time"



@@ 66,11 67,11 @@ var sendIQTests = [...]struct {
		resp:       stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}, nil),
	},
	2: {
		iq:         stanza.IQ{Type: stanza.ResultIQ},
		iq:         stanza.IQ{Type: stanza.ResultIQ, ID: testIQID},
		writesBody: true,
	},
	3: {
		iq:         stanza.IQ{Type: stanza.ErrorIQ},
		iq:         stanza.IQ{Type: stanza.ErrorIQ, ID: testIQID},
		writesBody: true,
	},
}


@@ 79,14 80,6 @@ func TestSendIQ(t *testing.T) {
	for i, tc := range sendIQTests {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			br := &bytes.Buffer{}
			bw := &bytes.Buffer{}
			s := xmpptest.NewSession(0, struct {
				io.Reader
				io.Writer
			}{
				Reader: br,
				Writer: bw,
			})
			if tc.resp != nil {
				e := xml.NewEncoder(br)
				_, err := xmlstream.Copy(e, tc.resp)


@@ 99,33 92,108 @@ func TestSendIQ(t *testing.T) {
				}
			}

			go func() {
				err := s.Serve(nil)
				if err != nil && err != io.EOF {
					panic(err)
			t.Run("SendIQElement", func(t *testing.T) {
				bw := &bytes.Buffer{}
				s := xmpptest.NewSession(0, struct {
					io.Reader
					io.Writer
				}{
					Reader: strings.NewReader(br.String()),
					Writer: bw,
				})
				defer func() {
					if err := s.Close(); err != nil {
						t.Errorf("Error closing session: %q", err)
					}
				}()

				go func() {
					err := s.Serve(nil)
					if err != nil && err != io.EOF {
						panic(err)
					}
				}()

				ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
				defer cancel()

				resp, err := s.SendIQElement(ctx, tc.payload, tc.iq)
				if err != tc.err {
					t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
				}
			}()
				if resp != nil {
					defer func() {
						if err := resp.Close(); err != nil {
							t.Errorf("Error closing response: %q", err)
						}
					}()
				}
				if empty := bw.Len() != 0; tc.writesBody != empty {
					t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
				}
				switch {
				case resp == nil && tc.resp != nil:
					t.Errorf("Expected response, but got none")
				case resp != nil && tc.resp == nil:
					buf := &bytes.Buffer{}
					_, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
					if err != nil {
						t.Errorf("Error encoding unexpected response")
					}
					t.Errorf("Did not expect response, but got: %s", buf.String())
				}
			})
			t.Run("SendIQ", func(t *testing.T) {
				bw := &bytes.Buffer{}
				s := xmpptest.NewSession(0, struct {
					io.Reader
					io.Writer
				}{
					Reader: strings.NewReader(br.String()),
					Writer: bw,
				})
				defer func() {
					if err := s.Close(); err != nil {
						t.Errorf("Error closing session: %q", err)
					}
				}()

			ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
			defer cancel()
			resp, err := s.SendIQ(ctx, tc.iq, tc.payload)
			if err != tc.err {
				t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
			}
			if empty := bw.Len() != 0; tc.writesBody != empty {
				t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
			}
			switch {
			case resp == nil && tc.resp != nil:
				t.Fatalf("Expected response, but got none")
			case resp != nil && tc.resp == nil:
				buf := &bytes.Buffer{}
				_, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
				if err != nil {
					t.Fatalf("Error encoding unexpected response")
				go func() {
					err := s.Serve(nil)
					if err != nil && err != io.EOF {
						panic(err)
					}
				}()

				ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
				defer cancel()

				resp, err := s.SendIQ(ctx, stanza.WrapIQ(tc.iq, tc.payload))
				if err != tc.err {
					t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
				}
				t.Fatalf("Did not expect response, but got: %s", buf.String())
			}
				if resp != nil {
					defer func() {
						if err := resp.Close(); err != nil {
							t.Errorf("Error closing response: %q", err)
						}
					}()
				}
				if empty := bw.Len() != 0; tc.writesBody != empty {
					t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
				}
				switch {
				case resp == nil && tc.resp != nil:
					t.Errorf("Expected response, but got none")
				case resp != nil && tc.resp == nil:
					buf := &bytes.Buffer{}
					_, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
					if err != nil {
						t.Errorf("Error encoding unexpected response")
					}
					t.Errorf("Did not expect response, but got: %s", buf.String())
				}
			})
		})
	}
}


@@ 176,6 244,11 @@ func TestSend(t *testing.T) {
				Reader: br,
				Writer: bw,
			})
			defer func() {
				if err := s.Close(); err != nil {
					t.Errorf("Error closing session: %q", err)
				}
			}()
			if tc.resp != nil {
				e := xml.NewEncoder(br)
				_, err := xmlstream.Copy(e, tc.resp)

M session.go => session.go +78 -29
@@ 9,6 9,7 @@ import (
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"
	"io"
	"net"
	"sync"


@@ 235,7 236,7 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// the input stream is closed by the remote entity as above, or the deadline set
// by SetCloseDeadline is reached in which case a timeout error is returned.
func (s *Session) Serve(h Handler) error {
	return s.handleInputStream(h)
	return handleInputStream(s, h)
}

// sendError transmits an error on the session. If the error is not a standard


@@ 246,6 247,10 @@ func (s *Session) sendError(err error) (e error) {
	s.out.Lock()
	defer s.out.Unlock()

	if s.state&OutputStreamClosed == OutputStreamClosed {
		return err
	}

	switch typErr := err.(type) {
	case stream.Error:
		if _, e = typErr.WriteXML(s); e != nil {


@@ 288,7 293,7 @@ func (r iqResponder) Close() error {
	return nil
}

func (s *Session) handleInputStream(handler Handler) (err error) {
func handleInputStream(s *Session, handler Handler) (err error) {
	if handler == nil {
		handler = nopHandler{}
	}


@@ 367,7 372,7 @@ func (s *Session) handleInputStream(handler Handler) (err error) {
		var id string
		var needsResp bool
		if isIQ(start.Name) {
			id = getID(start)
			_, id = getID(start)

			// If this is a response IQ (ie. an "error" or "result") check if we're
			// handling it as part of a SendIQ call.


@@ 449,7 454,7 @@ type responseChecker struct {
func (rw *responseChecker) EncodeToken(t xml.Token) error {
	switch tok := t.(type) {
	case xml.StartElement:
		id := getID(tok)
		_, id := getID(tok)
		if rw.level < 1 && isIQ(tok.Name) && id == rw.id && !iqNeedsResp(tok.Attr) {
			rw.wroteResp = true
		}


@@ 457,6 462,7 @@ func (rw *responseChecker) EncodeToken(t xml.Token) error {
	case xml.EndElement:
		rw.level--
	}

	return rw.TokenWriter.EncodeToken(t)
}



@@ 591,17 597,18 @@ func isStanza(name xml.Name) bool {
		(name.Space == "" || name.Space == ns.Client || name.Space == ns.Server)
}

func getID(start xml.StartElement) string {
	for _, attr := range start.Attr {
func getID(start xml.StartElement) (int, string) {
	for i, attr := range start.Attr {
		if attr.Name.Local == "id" {
			return attr.Value
			return i, attr.Value
		}
	}
	return ""
	return -1, ""
}

// SendIQ is like Send or SendElement except that it wraps the payload in an
// Info/Query (IQ) element and blocks until a response is received.
// SendIQ is like Send except that it returns an error if the first token read
// from the stream is not an Info/Query (IQ) start element and blocks until a
// response is received.
//
// If the input stream is not being processed (a call to Serve is not running),
// SendIQ will never receive a response and will block until the provided


@@ 609,8 616,8 @@ func getID(start xml.StartElement) string {
// If the response is non-nil, it does not need to be consumed in its entirety,
// but it must be closed before stream processing will resume.
// If the IQ type does not require a response—ie. it is a result or error IQ,
// meaning that it is a response itself—SendIQ does not block and the response
// is nil.
// meaning that it is a response itself—SendIQElemnt does not block and the
// response is nil.
//
// If the context is closed before the response is received, SendIQ immediately
// returns the context error.


@@ 620,7 627,46 @@ func getID(start xml.StartElement) string {
// If an error is returned, the response will be nil; the converse is not
// necessarily true.
// SendIQ is safe for concurrent use by multiple goroutines.
func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) {
func (s *Session) SendIQ(ctx context.Context, r xml.TokenReader) (xmlstream.TokenReadCloser, error) {
	tok, err := r.Token()
	if err != nil {
		return nil, err
	}
	start, ok := tok.(xml.StartElement)
	if !ok {
		return nil, fmt.Errorf("expected IQ start element, got %T", start)
	}
	if !isIQ(start.Name) {
		return nil, fmt.Errorf("expected start element to be an IQ")
	}

	// If there's no ID, add one.
	idx, id := getID(start)
	if idx == -1 {
		idx = len(start.Attr)
		start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: ""})
	}
	if id == "" {
		id = internal.RandomID()
		start.Attr[idx].Value = id
	}

	// If this an IQ of type "set" or "get" we expect a response.
	if iqNeedsResp(start.Attr) {
		return s.sendResp(ctx, id, xmlstream.Wrap(r, start))
	}

	// If this is an IQ of type result or error, we don't expect a response so
	// just send it normally.
	return nil, s.SendElement(ctx, r, start)
}

// SendIQElement is like SendIQ except that it wraps the payload in an
// Info/Query (IQ) element and blocks until a response is received.
// For more information, see SendIQ.
//
// SendIQElement is safe for concurrent use by multiple goroutines.
func (s *Session) SendIQElement(ctx context.Context, payload xml.TokenReader, iq stanza.IQ) (xmlstream.TokenReadCloser, error) {
	// We need to add an id to the IQ if one wasn't already set by the user so
	// that we can use it to associate the response with the original query.
	if iq.ID == "" {


@@ 628,31 674,33 @@ func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenRea
	}
	needsResp := iq.Type == stanza.GetIQ || iq.Type == stanza.SetIQ

	var c chan xmlstream.TokenReadCloser
	// If this an IQ of type "set" or "get" we expect a response.
	if needsResp {
		c = make(chan xmlstream.TokenReadCloser)
		return s.sendResp(ctx, iq.ID, stanza.WrapIQ(iq, payload))
	}

	// If this is an IQ of type result or error, we don't expect a response so
	// just send it normally.
	return nil, s.Send(ctx, stanza.WrapIQ(iq, payload))
}

func (s *Session) sendResp(ctx context.Context, id string, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) {
	c := make(chan xmlstream.TokenReadCloser)

	s.sentIQMutex.Lock()
	s.sentIQs[id] = c
	s.sentIQMutex.Unlock()
	defer func() {
		s.sentIQMutex.Lock()
		s.sentIQs[iq.ID] = c
		delete(s.sentIQs, id)
		s.sentIQMutex.Unlock()
		defer func() {
			s.sentIQMutex.Lock()
			delete(s.sentIQs, iq.ID)
			s.sentIQMutex.Unlock()
		}()
	}
	}()

	err := s.Send(ctx, stanza.WrapIQ(iq, payload))
	err := s.Send(ctx, payload)
	if err != nil {
		return nil, err
	}

	// If this is not an IQ of type "set" or "get" we don't expect a response and
	// merely transmit the information, so don't block.
	if !needsResp {
		return nil, nil
	}

	select {
	case rr := <-c:
		return rr, nil


@@ 741,6 789,7 @@ func stanzaAddID(w tokenWriteFlusher) tokenWriteFlusher {
			case xml.EndElement:
				depth--
			}

			return w.EncodeToken(t)
		},
		flush: w.Flush,