From be25ce7b23f96712a76b65c28360b2f1de2514e6 Mon Sep 17 00:00:00 2001 From: Sam Whited Date: Thu, 18 Jul 2019 16:11:06 -0500 Subject: [PATCH] all: rework API for SendIQ and SendIQElement --- roster/roster.go | 2 +- send_test.go | 141 +++++++++++++++++++++++++++++++++++------------ session.go | 107 +++++++++++++++++++++++++---------- 3 files changed, 186 insertions(+), 64 deletions(-) diff --git a/roster/roster.go b/roster/roster.go index 777be7e..d5eabf6 100644 --- a/roster/roster.go +++ b/roster/roster.go @@ -114,7 +114,7 @@ func FetchIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session) *Iter { } rosterIQ := IQ{IQ: iq} payload := rosterIQ.payload() - r, err := s.SendIQ(ctx, iq, payload) + r, err := s.SendIQElement(ctx, payload, iq) if err != nil { return &Iter{err: err} } diff --git a/send_test.go b/send_test.go index 7093bbe..09bcbb3 100644 --- a/send_test.go +++ b/send_test.go @@ -11,6 +11,7 @@ import ( "errors" "io" "strconv" + "strings" "testing" "time" @@ -66,11 +67,11 @@ var sendIQTests = [...]struct { resp: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}, nil), }, 2: { - iq: stanza.IQ{Type: stanza.ResultIQ}, + iq: stanza.IQ{Type: stanza.ResultIQ, ID: testIQID}, writesBody: true, }, 3: { - iq: stanza.IQ{Type: stanza.ErrorIQ}, + iq: stanza.IQ{Type: stanza.ErrorIQ, ID: testIQID}, writesBody: true, }, } @@ -79,14 +80,6 @@ func TestSendIQ(t *testing.T) { for i, tc := range sendIQTests { t.Run(strconv.Itoa(i), func(t *testing.T) { br := &bytes.Buffer{} - bw := &bytes.Buffer{} - s := xmpptest.NewSession(0, struct { - io.Reader - io.Writer - }{ - Reader: br, - Writer: bw, - }) if tc.resp != nil { e := xml.NewEncoder(br) _, err := xmlstream.Copy(e, tc.resp) @@ -99,33 +92,108 @@ func TestSendIQ(t *testing.T) { } } - go func() { - err := s.Serve(nil) - if err != nil && err != io.EOF { - panic(err) + t.Run("SendIQElement", func(t *testing.T) { + bw := &bytes.Buffer{} + s := xmpptest.NewSession(0, struct { + io.Reader + io.Writer + }{ + Reader: strings.NewReader(br.String()), + Writer: bw, + }) + defer func() { + if err := s.Close(); err != nil { + t.Errorf("Error closing session: %q", err) + } + }() + + go func() { + err := s.Serve(nil) + if err != nil && err != io.EOF { + panic(err) + } + }() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + + resp, err := s.SendIQElement(ctx, tc.payload, tc.iq) + if err != tc.err { + t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err) } - }() + if resp != nil { + defer func() { + if err := resp.Close(); err != nil { + t.Errorf("Error closing response: %q", err) + } + }() + } + if empty := bw.Len() != 0; tc.writesBody != empty { + t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty) + } + switch { + case resp == nil && tc.resp != nil: + t.Errorf("Expected response, but got none") + case resp != nil && tc.resp == nil: + buf := &bytes.Buffer{} + _, err := xmlstream.Copy(xml.NewEncoder(buf), resp) + if err != nil { + t.Errorf("Error encoding unexpected response") + } + t.Errorf("Did not expect response, but got: %s", buf.String()) + } + }) + t.Run("SendIQ", func(t *testing.T) { + bw := &bytes.Buffer{} + s := xmpptest.NewSession(0, struct { + io.Reader + io.Writer + }{ + Reader: strings.NewReader(br.String()), + Writer: bw, + }) + defer func() { + if err := s.Close(); err != nil { + t.Errorf("Error closing session: %q", err) + } + }() - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) - defer cancel() - resp, err := s.SendIQ(ctx, tc.iq, tc.payload) - if err != tc.err { - t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err) - } - if empty := bw.Len() != 0; tc.writesBody != empty { - t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty) - } - switch { - case resp == nil && tc.resp != nil: - t.Fatalf("Expected response, but got none") - case resp != nil && tc.resp == nil: - buf := &bytes.Buffer{} - _, err := xmlstream.Copy(xml.NewEncoder(buf), resp) - if err != nil { - t.Fatalf("Error encoding unexpected response") + go func() { + err := s.Serve(nil) + if err != nil && err != io.EOF { + panic(err) + } + }() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + + resp, err := s.SendIQ(ctx, stanza.WrapIQ(tc.iq, tc.payload)) + if err != tc.err { + t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err) } - t.Fatalf("Did not expect response, but got: %s", buf.String()) - } + if resp != nil { + defer func() { + if err := resp.Close(); err != nil { + t.Errorf("Error closing response: %q", err) + } + }() + } + if empty := bw.Len() != 0; tc.writesBody != empty { + t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty) + } + switch { + case resp == nil && tc.resp != nil: + t.Errorf("Expected response, but got none") + case resp != nil && tc.resp == nil: + buf := &bytes.Buffer{} + _, err := xmlstream.Copy(xml.NewEncoder(buf), resp) + if err != nil { + t.Errorf("Error encoding unexpected response") + } + t.Errorf("Did not expect response, but got: %s", buf.String()) + } + }) }) } } @@ -176,6 +244,11 @@ func TestSend(t *testing.T) { Reader: br, Writer: bw, }) + defer func() { + if err := s.Close(); err != nil { + t.Errorf("Error closing session: %q", err) + } + }() if tc.resp != nil { e := xml.NewEncoder(br) _, err := xmlstream.Copy(e, tc.resp) diff --git a/session.go b/session.go index 40b68f2..d526f98 100644 --- a/session.go +++ b/session.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "encoding/xml" "errors" + "fmt" "io" "net" "sync" @@ -235,7 +236,7 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW // the input stream is closed by the remote entity as above, or the deadline set // by SetCloseDeadline is reached in which case a timeout error is returned. func (s *Session) Serve(h Handler) error { - return s.handleInputStream(h) + return handleInputStream(s, h) } // sendError transmits an error on the session. If the error is not a standard @@ -246,6 +247,10 @@ func (s *Session) sendError(err error) (e error) { s.out.Lock() defer s.out.Unlock() + if s.state&OutputStreamClosed == OutputStreamClosed { + return err + } + switch typErr := err.(type) { case stream.Error: if _, e = typErr.WriteXML(s); e != nil { @@ -288,7 +293,7 @@ func (r iqResponder) Close() error { return nil } -func (s *Session) handleInputStream(handler Handler) (err error) { +func handleInputStream(s *Session, handler Handler) (err error) { if handler == nil { handler = nopHandler{} } @@ -367,7 +372,7 @@ func (s *Session) handleInputStream(handler Handler) (err error) { var id string var needsResp bool if isIQ(start.Name) { - id = getID(start) + _, id = getID(start) // If this is a response IQ (ie. an "error" or "result") check if we're // handling it as part of a SendIQ call. @@ -449,7 +454,7 @@ type responseChecker struct { func (rw *responseChecker) EncodeToken(t xml.Token) error { switch tok := t.(type) { case xml.StartElement: - id := getID(tok) + _, id := getID(tok) if rw.level < 1 && isIQ(tok.Name) && id == rw.id && !iqNeedsResp(tok.Attr) { rw.wroteResp = true } @@ -457,6 +462,7 @@ func (rw *responseChecker) EncodeToken(t xml.Token) error { case xml.EndElement: rw.level-- } + return rw.TokenWriter.EncodeToken(t) } @@ -591,17 +597,18 @@ func isStanza(name xml.Name) bool { (name.Space == "" || name.Space == ns.Client || name.Space == ns.Server) } -func getID(start xml.StartElement) string { - for _, attr := range start.Attr { +func getID(start xml.StartElement) (int, string) { + for i, attr := range start.Attr { if attr.Name.Local == "id" { - return attr.Value + return i, attr.Value } } - return "" + return -1, "" } -// SendIQ is like Send or SendElement except that it wraps the payload in an -// Info/Query (IQ) element and blocks until a response is received. +// SendIQ is like Send except that it returns an error if the first token read +// from the stream is not an Info/Query (IQ) start element and blocks until a +// response is received. // // If the input stream is not being processed (a call to Serve is not running), // SendIQ will never receive a response and will block until the provided @@ -609,8 +616,8 @@ func getID(start xml.StartElement) string { // If the response is non-nil, it does not need to be consumed in its entirety, // but it must be closed before stream processing will resume. // If the IQ type does not require a response—ie. it is a result or error IQ, -// meaning that it is a response itself—SendIQ does not block and the response -// is nil. +// meaning that it is a response itself—SendIQElemnt does not block and the +// response is nil. // // If the context is closed before the response is received, SendIQ immediately // returns the context error. @@ -620,7 +627,46 @@ func getID(start xml.StartElement) string { // If an error is returned, the response will be nil; the converse is not // necessarily true. // SendIQ is safe for concurrent use by multiple goroutines. -func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) { +func (s *Session) SendIQ(ctx context.Context, r xml.TokenReader) (xmlstream.TokenReadCloser, error) { + tok, err := r.Token() + if err != nil { + return nil, err + } + start, ok := tok.(xml.StartElement) + if !ok { + return nil, fmt.Errorf("expected IQ start element, got %T", start) + } + if !isIQ(start.Name) { + return nil, fmt.Errorf("expected start element to be an IQ") + } + + // If there's no ID, add one. + idx, id := getID(start) + if idx == -1 { + idx = len(start.Attr) + start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: ""}) + } + if id == "" { + id = internal.RandomID() + start.Attr[idx].Value = id + } + + // If this an IQ of type "set" or "get" we expect a response. + if iqNeedsResp(start.Attr) { + return s.sendResp(ctx, id, xmlstream.Wrap(r, start)) + } + + // If this is an IQ of type result or error, we don't expect a response so + // just send it normally. + return nil, s.SendElement(ctx, r, start) +} + +// SendIQElement is like SendIQ except that it wraps the payload in an +// Info/Query (IQ) element and blocks until a response is received. +// For more information, see SendIQ. +// +// SendIQElement is safe for concurrent use by multiple goroutines. +func (s *Session) SendIQElement(ctx context.Context, payload xml.TokenReader, iq stanza.IQ) (xmlstream.TokenReadCloser, error) { // We need to add an id to the IQ if one wasn't already set by the user so // that we can use it to associate the response with the original query. if iq.ID == "" { @@ -628,31 +674,33 @@ func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenRea } needsResp := iq.Type == stanza.GetIQ || iq.Type == stanza.SetIQ - var c chan xmlstream.TokenReadCloser + // If this an IQ of type "set" or "get" we expect a response. if needsResp { - c = make(chan xmlstream.TokenReadCloser) + return s.sendResp(ctx, iq.ID, stanza.WrapIQ(iq, payload)) + } + + // If this is an IQ of type result or error, we don't expect a response so + // just send it normally. + return nil, s.Send(ctx, stanza.WrapIQ(iq, payload)) +} +func (s *Session) sendResp(ctx context.Context, id string, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) { + c := make(chan xmlstream.TokenReadCloser) + + s.sentIQMutex.Lock() + s.sentIQs[id] = c + s.sentIQMutex.Unlock() + defer func() { s.sentIQMutex.Lock() - s.sentIQs[iq.ID] = c + delete(s.sentIQs, id) s.sentIQMutex.Unlock() - defer func() { - s.sentIQMutex.Lock() - delete(s.sentIQs, iq.ID) - s.sentIQMutex.Unlock() - }() - } + }() - err := s.Send(ctx, stanza.WrapIQ(iq, payload)) + err := s.Send(ctx, payload) if err != nil { return nil, err } - // If this is not an IQ of type "set" or "get" we don't expect a response and - // merely transmit the information, so don't block. - if !needsResp { - return nil, nil - } - select { case rr := <-c: return rr, nil @@ -741,6 +789,7 @@ func stanzaAddID(w tokenWriteFlusher) tokenWriteFlusher { case xml.EndElement: depth-- } + return w.EncodeToken(t) }, flush: w.Flush, -- 2.30.2