M roster/roster.go => roster/roster.go +1 -1
@@ 114,7 114,7 @@ func FetchIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session) *Iter {
}
rosterIQ := IQ{IQ: iq}
payload := rosterIQ.payload()
- r, err := s.SendIQ(ctx, iq, payload)
+ r, err := s.SendIQElement(ctx, payload, iq)
if err != nil {
return &Iter{err: err}
}
M send_test.go => send_test.go +107 -34
@@ 11,6 11,7 @@ import (
"errors"
"io"
"strconv"
+ "strings"
"testing"
"time"
@@ 66,11 67,11 @@ var sendIQTests = [...]struct {
resp: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}, nil),
},
2: {
- iq: stanza.IQ{Type: stanza.ResultIQ},
+ iq: stanza.IQ{Type: stanza.ResultIQ, ID: testIQID},
writesBody: true,
},
3: {
- iq: stanza.IQ{Type: stanza.ErrorIQ},
+ iq: stanza.IQ{Type: stanza.ErrorIQ, ID: testIQID},
writesBody: true,
},
}
@@ 79,14 80,6 @@ func TestSendIQ(t *testing.T) {
for i, tc := range sendIQTests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
br := &bytes.Buffer{}
- bw := &bytes.Buffer{}
- s := xmpptest.NewSession(0, struct {
- io.Reader
- io.Writer
- }{
- Reader: br,
- Writer: bw,
- })
if tc.resp != nil {
e := xml.NewEncoder(br)
_, err := xmlstream.Copy(e, tc.resp)
@@ 99,33 92,108 @@ func TestSendIQ(t *testing.T) {
}
}
- go func() {
- err := s.Serve(nil)
- if err != nil && err != io.EOF {
- panic(err)
+ t.Run("SendIQElement", func(t *testing.T) {
+ bw := &bytes.Buffer{}
+ s := xmpptest.NewSession(0, struct {
+ io.Reader
+ io.Writer
+ }{
+ Reader: strings.NewReader(br.String()),
+ Writer: bw,
+ })
+ defer func() {
+ if err := s.Close(); err != nil {
+ t.Errorf("Error closing session: %q", err)
+ }
+ }()
+
+ go func() {
+ err := s.Serve(nil)
+ if err != nil && err != io.EOF {
+ panic(err)
+ }
+ }()
+
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
+ defer cancel()
+
+ resp, err := s.SendIQElement(ctx, tc.payload, tc.iq)
+ if err != tc.err {
+ t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
}
- }()
+ if resp != nil {
+ defer func() {
+ if err := resp.Close(); err != nil {
+ t.Errorf("Error closing response: %q", err)
+ }
+ }()
+ }
+ if empty := bw.Len() != 0; tc.writesBody != empty {
+ t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
+ }
+ switch {
+ case resp == nil && tc.resp != nil:
+ t.Errorf("Expected response, but got none")
+ case resp != nil && tc.resp == nil:
+ buf := &bytes.Buffer{}
+ _, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
+ if err != nil {
+ t.Errorf("Error encoding unexpected response")
+ }
+ t.Errorf("Did not expect response, but got: %s", buf.String())
+ }
+ })
+ t.Run("SendIQ", func(t *testing.T) {
+ bw := &bytes.Buffer{}
+ s := xmpptest.NewSession(0, struct {
+ io.Reader
+ io.Writer
+ }{
+ Reader: strings.NewReader(br.String()),
+ Writer: bw,
+ })
+ defer func() {
+ if err := s.Close(); err != nil {
+ t.Errorf("Error closing session: %q", err)
+ }
+ }()
- ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
- defer cancel()
- resp, err := s.SendIQ(ctx, tc.iq, tc.payload)
- if err != tc.err {
- t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
- }
- if empty := bw.Len() != 0; tc.writesBody != empty {
- t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
- }
- switch {
- case resp == nil && tc.resp != nil:
- t.Fatalf("Expected response, but got none")
- case resp != nil && tc.resp == nil:
- buf := &bytes.Buffer{}
- _, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
- if err != nil {
- t.Fatalf("Error encoding unexpected response")
+ go func() {
+ err := s.Serve(nil)
+ if err != nil && err != io.EOF {
+ panic(err)
+ }
+ }()
+
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
+ defer cancel()
+
+ resp, err := s.SendIQ(ctx, stanza.WrapIQ(tc.iq, tc.payload))
+ if err != tc.err {
+ t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
}
- t.Fatalf("Did not expect response, but got: %s", buf.String())
- }
+ if resp != nil {
+ defer func() {
+ if err := resp.Close(); err != nil {
+ t.Errorf("Error closing response: %q", err)
+ }
+ }()
+ }
+ if empty := bw.Len() != 0; tc.writesBody != empty {
+ t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
+ }
+ switch {
+ case resp == nil && tc.resp != nil:
+ t.Errorf("Expected response, but got none")
+ case resp != nil && tc.resp == nil:
+ buf := &bytes.Buffer{}
+ _, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
+ if err != nil {
+ t.Errorf("Error encoding unexpected response")
+ }
+ t.Errorf("Did not expect response, but got: %s", buf.String())
+ }
+ })
})
}
}
@@ 176,6 244,11 @@ func TestSend(t *testing.T) {
Reader: br,
Writer: bw,
})
+ defer func() {
+ if err := s.Close(); err != nil {
+ t.Errorf("Error closing session: %q", err)
+ }
+ }()
if tc.resp != nil {
e := xml.NewEncoder(br)
_, err := xmlstream.Copy(e, tc.resp)
M session.go => session.go +78 -29
@@ 9,6 9,7 @@ import (
"crypto/tls"
"encoding/xml"
"errors"
+ "fmt"
"io"
"net"
"sync"
@@ 235,7 236,7 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// the input stream is closed by the remote entity as above, or the deadline set
// by SetCloseDeadline is reached in which case a timeout error is returned.
func (s *Session) Serve(h Handler) error {
- return s.handleInputStream(h)
+ return handleInputStream(s, h)
}
// sendError transmits an error on the session. If the error is not a standard
@@ 246,6 247,10 @@ func (s *Session) sendError(err error) (e error) {
s.out.Lock()
defer s.out.Unlock()
+ if s.state&OutputStreamClosed == OutputStreamClosed {
+ return err
+ }
+
switch typErr := err.(type) {
case stream.Error:
if _, e = typErr.WriteXML(s); e != nil {
@@ 288,7 293,7 @@ func (r iqResponder) Close() error {
return nil
}
-func (s *Session) handleInputStream(handler Handler) (err error) {
+func handleInputStream(s *Session, handler Handler) (err error) {
if handler == nil {
handler = nopHandler{}
}
@@ 367,7 372,7 @@ func (s *Session) handleInputStream(handler Handler) (err error) {
var id string
var needsResp bool
if isIQ(start.Name) {
- id = getID(start)
+ _, id = getID(start)
// If this is a response IQ (ie. an "error" or "result") check if we're
// handling it as part of a SendIQ call.
@@ 449,7 454,7 @@ type responseChecker struct {
func (rw *responseChecker) EncodeToken(t xml.Token) error {
switch tok := t.(type) {
case xml.StartElement:
- id := getID(tok)
+ _, id := getID(tok)
if rw.level < 1 && isIQ(tok.Name) && id == rw.id && !iqNeedsResp(tok.Attr) {
rw.wroteResp = true
}
@@ 457,6 462,7 @@ func (rw *responseChecker) EncodeToken(t xml.Token) error {
case xml.EndElement:
rw.level--
}
+
return rw.TokenWriter.EncodeToken(t)
}
@@ 591,17 597,18 @@ func isStanza(name xml.Name) bool {
(name.Space == "" || name.Space == ns.Client || name.Space == ns.Server)
}
-func getID(start xml.StartElement) string {
- for _, attr := range start.Attr {
+func getID(start xml.StartElement) (int, string) {
+ for i, attr := range start.Attr {
if attr.Name.Local == "id" {
- return attr.Value
+ return i, attr.Value
}
}
- return ""
+ return -1, ""
}
-// SendIQ is like Send or SendElement except that it wraps the payload in an
-// Info/Query (IQ) element and blocks until a response is received.
+// SendIQ is like Send except that it returns an error if the first token read
+// from the stream is not an Info/Query (IQ) start element and blocks until a
+// response is received.
//
// If the input stream is not being processed (a call to Serve is not running),
// SendIQ will never receive a response and will block until the provided
@@ 609,8 616,8 @@ func getID(start xml.StartElement) string {
// If the response is non-nil, it does not need to be consumed in its entirety,
// but it must be closed before stream processing will resume.
// If the IQ type does not require a response—ie. it is a result or error IQ,
-// meaning that it is a response itself—SendIQ does not block and the response
-// is nil.
+// meaning that it is a response itself—SendIQElemnt does not block and the
+// response is nil.
//
// If the context is closed before the response is received, SendIQ immediately
// returns the context error.
@@ 620,7 627,46 @@ func getID(start xml.StartElement) string {
// If an error is returned, the response will be nil; the converse is not
// necessarily true.
// SendIQ is safe for concurrent use by multiple goroutines.
-func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) {
+func (s *Session) SendIQ(ctx context.Context, r xml.TokenReader) (xmlstream.TokenReadCloser, error) {
+ tok, err := r.Token()
+ if err != nil {
+ return nil, err
+ }
+ start, ok := tok.(xml.StartElement)
+ if !ok {
+ return nil, fmt.Errorf("expected IQ start element, got %T", start)
+ }
+ if !isIQ(start.Name) {
+ return nil, fmt.Errorf("expected start element to be an IQ")
+ }
+
+ // If there's no ID, add one.
+ idx, id := getID(start)
+ if idx == -1 {
+ idx = len(start.Attr)
+ start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: ""})
+ }
+ if id == "" {
+ id = internal.RandomID()
+ start.Attr[idx].Value = id
+ }
+
+ // If this an IQ of type "set" or "get" we expect a response.
+ if iqNeedsResp(start.Attr) {
+ return s.sendResp(ctx, id, xmlstream.Wrap(r, start))
+ }
+
+ // If this is an IQ of type result or error, we don't expect a response so
+ // just send it normally.
+ return nil, s.SendElement(ctx, r, start)
+}
+
+// SendIQElement is like SendIQ except that it wraps the payload in an
+// Info/Query (IQ) element and blocks until a response is received.
+// For more information, see SendIQ.
+//
+// SendIQElement is safe for concurrent use by multiple goroutines.
+func (s *Session) SendIQElement(ctx context.Context, payload xml.TokenReader, iq stanza.IQ) (xmlstream.TokenReadCloser, error) {
// We need to add an id to the IQ if one wasn't already set by the user so
// that we can use it to associate the response with the original query.
if iq.ID == "" {
@@ 628,31 674,33 @@ func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenRea
}
needsResp := iq.Type == stanza.GetIQ || iq.Type == stanza.SetIQ
- var c chan xmlstream.TokenReadCloser
+ // If this an IQ of type "set" or "get" we expect a response.
if needsResp {
- c = make(chan xmlstream.TokenReadCloser)
+ return s.sendResp(ctx, iq.ID, stanza.WrapIQ(iq, payload))
+ }
+
+ // If this is an IQ of type result or error, we don't expect a response so
+ // just send it normally.
+ return nil, s.Send(ctx, stanza.WrapIQ(iq, payload))
+}
+func (s *Session) sendResp(ctx context.Context, id string, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) {
+ c := make(chan xmlstream.TokenReadCloser)
+
+ s.sentIQMutex.Lock()
+ s.sentIQs[id] = c
+ s.sentIQMutex.Unlock()
+ defer func() {
s.sentIQMutex.Lock()
- s.sentIQs[iq.ID] = c
+ delete(s.sentIQs, id)
s.sentIQMutex.Unlock()
- defer func() {
- s.sentIQMutex.Lock()
- delete(s.sentIQs, iq.ID)
- s.sentIQMutex.Unlock()
- }()
- }
+ }()
- err := s.Send(ctx, stanza.WrapIQ(iq, payload))
+ err := s.Send(ctx, payload)
if err != nil {
return nil, err
}
- // If this is not an IQ of type "set" or "get" we don't expect a response and
- // merely transmit the information, so don't block.
- if !needsResp {
- return nil, nil
- }
-
select {
case rr := <-c:
return rr, nil
@@ 741,6 789,7 @@ func stanzaAddID(w tokenWriteFlusher) tokenWriteFlusher {
case xml.EndElement:
depth--
}
+
return w.EncodeToken(t)
},
flush: w.Flush,