M doc.go => doc.go +2 -2
@@ 85,12 85,12 @@
// The mellium.im/xmpp/stanza package contains functions and structs that aid in
// the construction of message, presence and info/query (IQ) elements which have
// special semantics in XMPP and are known as "stanzas".
-// These can be sent with the Send and SendElement methods.
+// These can be sent with the Send, SendElement, and SendIQ methods.
//
// // Send initial presence to let the server know we want to receive messages.
// _, err = session.Send(context.TODO(), stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))
//
-// For Send to correctly handle IQ responses, and to make the common case of
+// For SendIQ to correctly handle IQ responses, and to make the common case of
// polling for incoming XML on the input stream—and possibly writing to the
// output stream in response—easier, we need a long running goroutine.
// Session includes the Serve method for starting this processing.
M echobot_example_test.go => echobot_example_test.go +2 -2
@@ 49,7 49,7 @@ func Example_echobot() {
}()
// Send initial presence to let the server know we want to receive messages.
- _, err = s.Send(context.TODO(), stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))
+ err = s.Send(context.TODO(), stanza.WrapPresence(jid.JID{}, stanza.AvailablePresence, nil))
if err != nil {
log.Printf("Error sending initial presence: %q", err)
return
@@ 86,7 86,7 @@ func Example_echobot() {
return xml.CharData(msg.Body), io.EOF
}), xml.StartElement{Name: xml.Name{Local: "body"}}),
)
- _, err = s.Send(context.TODO(), reply)
+ err = s.Send(context.TODO(), reply)
if err != nil {
log.Printf("Error responding to message %q: %q", msg.ID, err)
}
M roster/roster.go => roster/roster.go +18 -6
@@ 107,9 107,14 @@ func Fetch(ctx context.Context, s *xmpp.Session) *Iter {
}
// FetchIQ is like Fetch but it allows you to customize the IQ.
+// Changing the type of the provided IQ has no effect.
func FetchIQ(ctx context.Context, iq stanza.IQ, s *xmpp.Session) *Iter {
+ if iq.Type != stanza.GetIQ {
+ iq.Type = stanza.GetIQ
+ }
rosterIQ := IQ{IQ: iq}
- r, err := s.Send(ctx, rosterIQ.TokenReader())
+ payload := rosterIQ.payload()
+ r, err := s.SendIQ(ctx, iq, payload)
if err != nil {
return &Iter{err: err}
}
@@ 179,18 184,25 @@ func (m itemMarshaler) Token() (xml.Token, error) {
// TokenReader returns a stream of XML tokens that match the IQ.
func (iq IQ) TokenReader() xml.TokenReader {
+ if iq.IQ.Type != stanza.GetIQ {
+ iq.IQ.Type = stanza.GetIQ
+ }
+
+ return stanza.WrapIQ(iq.IQ, iq.payload())
+}
+
+// Payload returns a stream of XML tokekns that match the roster query payload
+// without the IQ wrapper.
+func (iq IQ) payload() xml.TokenReader {
attrs := []xml.Attr{}
if iq.Query.Ver != "" {
attrs = append(attrs, xml.Attr{Name: xml.Name{Local: "version"}, Value: iq.Query.Ver})
}
- if iq.IQ.Type != stanza.GetIQ {
- iq.IQ.Type = stanza.GetIQ
- }
- return stanza.WrapIQ(iq.IQ, xmlstream.Wrap(
+ return xmlstream.Wrap(
itemMarshaler{items: iq.Query.Item},
xml.StartElement{Name: xml.Name{Local: "query", Space: NS}, Attr: attrs},
- ))
+ )
}
// WriteXML satisfies the xmlstream.WriterTo interface.
M send_test.go => send_test.go +83 -22
@@ 48,6 48,88 @@ var (
to = jid.MustParse("test@example.net")
)
+var sendIQTests = [...]struct {
+ iq stanza.IQ
+ payload xml.TokenReader
+ err error
+ writesBody bool
+ resp xml.TokenReader
+}{
+ 0: {
+ iq: stanza.IQ{ID: testIQID, Type: stanza.GetIQ},
+ writesBody: true,
+ resp: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ResultIQ}, nil),
+ },
+ 1: {
+ iq: stanza.IQ{ID: testIQID, Type: stanza.SetIQ},
+ writesBody: true,
+ resp: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}, nil),
+ },
+ 2: {
+ iq: stanza.IQ{Type: stanza.ResultIQ},
+ writesBody: true,
+ },
+ 3: {
+ iq: stanza.IQ{Type: stanza.ErrorIQ},
+ writesBody: true,
+ },
+}
+
+func TestSendIQ(t *testing.T) {
+ for i, tc := range sendIQTests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ br := &bytes.Buffer{}
+ bw := &bytes.Buffer{}
+ s := xmpptest.NewSession(0, struct {
+ io.Reader
+ io.Writer
+ }{
+ Reader: br,
+ Writer: bw,
+ })
+ if tc.resp != nil {
+ e := xml.NewEncoder(br)
+ _, err := xmlstream.Copy(e, tc.resp)
+ if err != nil {
+ t.Logf("error responding: %q", err)
+ }
+ err = e.Flush()
+ if err != nil {
+ t.Logf("error flushing after responding: %q", err)
+ }
+ }
+
+ go func() {
+ err := s.Serve(nil)
+ if err != nil && err != io.EOF {
+ panic(err)
+ }
+ }()
+
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
+ defer cancel()
+ resp, err := s.SendIQ(ctx, tc.iq, tc.payload)
+ if err != tc.err {
+ t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
+ }
+ if empty := bw.Len() != 0; tc.writesBody != empty {
+ t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
+ }
+ switch {
+ case resp == nil && tc.resp != nil:
+ t.Fatalf("Expected response, but got none")
+ case resp != nil && tc.resp == nil:
+ buf := &bytes.Buffer{}
+ _, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
+ if err != nil {
+ t.Fatalf("Error encoding unexpected response")
+ }
+ t.Fatalf("Did not expect response, but got: %s", buf.String())
+ }
+ })
+ }
+}
+
var sendTests = [...]struct {
r xml.TokenReader
err error
@@ 80,16 162,6 @@ var sendTests = [...]struct {
r: stanza.WrapIQ(stanza.IQ{Type: stanza.ErrorIQ}, nil),
writesBody: true,
},
- 6: {
- r: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.GetIQ}, nil),
- writesBody: true,
- resp: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ResultIQ}, nil),
- },
- 7: {
- r: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.SetIQ}, nil),
- writesBody: true,
- resp: stanza.WrapIQ(stanza.IQ{ID: testIQID, Type: stanza.ErrorIQ}, nil),
- },
}
func TestSend(t *testing.T) {
@@ 125,24 197,13 @@ func TestSend(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
defer cancel()
- resp, err := s.Send(ctx, tc.r)
+ err := s.Send(ctx, tc.r)
if err != tc.err {
t.Errorf("Unexpected error, want=%q, got=%q", tc.err, err)
}
if empty := bw.Len() != 0; tc.writesBody != empty {
t.Errorf("Unexpected body, want=%t, got=%t", tc.writesBody, empty)
}
- switch {
- case resp == nil && tc.resp != nil:
- t.Fatalf("Expected response, but got none")
- case resp != nil && tc.resp == nil:
- buf := &bytes.Buffer{}
- _, err := xmlstream.Copy(xml.NewEncoder(buf), resp)
- if err != nil {
- t.Fatalf("Error encoding unexpected response")
- }
- t.Fatalf("Did not expect response, but got: %s", buf.String())
- }
})
}
}
M session.go => session.go +76 -65
@@ 78,7 78,8 @@ type Session struct {
// The negotiated features (by namespace) for the current session.
negotiated map[string]struct{}
- sentIQs map[string]chan xmlstream.TokenReadCloser
+ sentIQMutex sync.Mutex
+ sentIQs map[string]chan xmlstream.TokenReadCloser
in struct {
internal.StreamInfo
@@ 364,11 365,13 @@ func (s *Session) handleInputStream(handler Handler) (err error) {
id = getID(start)
// If this is a response IQ (ie. an "error" or "result") check if we're
- // handling it as part of a SendElement call.
+ // handling it as part of a SendIQ call.
// If not, record this so that we can check if the user sends a response
// later.
if !iqNeedsResp(start.Attr) {
+ s.sentIQMutex.Lock()
c := s.sentIQs[id]
+ s.sentIQMutex.Unlock()
if c == nil {
goto noreply
}
@@ 550,8 553,8 @@ func (s *Session) SetCloseDeadline(t time.Time) error {
// Send transmits the first element read from the provided token reader.
//
-// For more information, see SendElement.
-func (s *Session) Send(ctx context.Context, r xml.TokenReader) (xmlstream.TokenReadCloser, error) {
+// Send is safe for concurrent use by multiple goroutines.
+func (s *Session) Send(ctx context.Context, r xml.TokenReader) error {
return s.SendElement(ctx, r, xml.StartElement{})
}
@@ 564,7 567,7 @@ func iqNeedsResp(attrs []xml.Attr) bool {
}
}
- return typ == "get" || typ == "set"
+ return typ == string(stanza.GetIQ) || typ == string(stanza.SetIQ)
}
func isIQ(name xml.Name) bool {
@@ 585,94 588,102 @@ func getID(start xml.StartElement) string {
return ""
}
-// SendElement transmits the first element read from the provided token reader
-// using start as the outermost tag in the encoding.
+// SendIQ is like Send or SendElement except that it wraps the payload in an
+// Info/Query (IQ) element and blocks until a response is received.
//
-// If the element is an info/query (IQ) stanza, Send blocks until a response is
-// received and then returns a reader from which the response can be read.
// If the input stream is not being processed (a call to Serve is not running),
-// SendElement may block forever.
-// If the provided context is closed before the response is received SendElement
-// immediately returns an error and any response received at a later time will
-// not be associated with the original request.
-// The response does not need to be consumed in its entirety, but it must be
-// closed before stream processing will resume.
-// If an error is returned, xml.TokenReader will be nil; the converse is not
+// SendIQ will never receive a response and will block until the provided
+// context is canceled.
+// If the response is non-nil, it does not need to be consumed in its entirety,
+// but it must be closed before stream processing will resume.
+// If the IQ type does not require a response—ie. it is a result or error IQ,
+// meaning that it is a response itself—SendIQ does not block and the response
+// is nil.
+//
+// If the context is closed before the response is received, SendIQ immediately
+// returns the context error.
+// Any response received at a later time will not be associated with the
+// original request but can still be handled by the Serve handler.
+//
+// If an error is returned, the response will be nil; the converse is not
// necessarily true.
+// SendIQ is safe for concurrent use by multiple goroutines.
+func (s *Session) SendIQ(ctx context.Context, iq stanza.IQ, payload xml.TokenReader) (xmlstream.TokenReadCloser, error) {
+ // We need to add an id to the IQ if one wasn't already set by the user so
+ // that we can use it to associate the response with the original query.
+ if iq.ID == "" {
+ iq.ID = internal.RandomID()
+ }
+ needsResp := iq.Type == stanza.GetIQ || iq.Type == stanza.SetIQ
+
+ var c chan xmlstream.TokenReadCloser
+ if needsResp {
+ c = make(chan xmlstream.TokenReadCloser)
+
+ s.sentIQMutex.Lock()
+ s.sentIQs[iq.ID] = c
+ s.sentIQMutex.Unlock()
+ defer func() {
+ s.sentIQMutex.Lock()
+ delete(s.sentIQs, iq.ID)
+ s.sentIQMutex.Unlock()
+ }()
+ }
+
+ err := s.Send(ctx, stanza.WrapIQ(iq, payload))
+ if err != nil {
+ return nil, err
+ }
+
+ // If this is not an IQ of type "set" or "get" we don't expect a response and
+ // merely transmit the information, so don't block.
+ if !needsResp {
+ return nil, nil
+ }
+
+ select {
+ case rr := <-c:
+ return rr, nil
+ case <-ctx.Done():
+ close(c)
+ return nil, ctx.Err()
+ }
+}
+
+// SendElement is like Send except that it uses start as the outermost tag in
+// the encoding.
//
// SendElement is safe for concurrent use by multiple goroutines.
-func (s *Session) SendElement(ctx context.Context, r xml.TokenReader, start xml.StartElement) (xmlstream.TokenReadCloser, error) {
+func (s *Session) SendElement(ctx context.Context, r xml.TokenReader, start xml.StartElement) error {
s.out.Lock()
defer s.out.Unlock()
if start.Name.Local == "" {
tok, err := r.Token()
if err != nil {
- return nil, err
+ return err
}
var ok bool
start, ok = tok.(xml.StartElement)
if !ok {
- return nil, errNotStart
+ return errNotStart
}
}
- // If this is not an IQ (or is an IQ that's not of type "set" or "get") we
- // don't expect a response and merely transmit the information.
- if !isIQ(start.Name) || !iqNeedsResp(start.Attr) {
- err := s.EncodeToken(start)
- if err != nil {
- return nil, err
- }
- _, err = xmlstream.Copy(s, xmlstream.Inner(r))
- if err != nil {
- return nil, err
- }
- err = s.EncodeToken(start.End())
- if err != nil {
- return nil, err
- }
- return nil, s.Flush()
- }
-
- // We need to add an id to the IQ if one wasn't already set by the user so
- // that we can use it to associate the response with the original query.
- id := getID(start)
- if id == "" {
- id = internal.RandomID()
- start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "id"}, Value: id})
- }
-
- c := make(chan xmlstream.TokenReadCloser)
- s.sentIQs[id] = c
-
err := s.EncodeToken(start)
if err != nil {
- return nil, err
+ return err
}
_, err = xmlstream.Copy(s, xmlstream.Inner(r))
if err != nil {
- return nil, err
+ return err
}
err = s.EncodeToken(start.End())
if err != nil {
- return nil, err
- }
- err = s.Flush()
- if err != nil {
- return nil, err
- }
-
- select {
- case rr := <-c:
- delete(s.sentIQs, id)
- return rr, nil
- case <-ctx.Done():
- delete(s.sentIQs, id)
- close(c)
- return nil, ctx.Err()
+ return err
}
+ return s.Flush()
}
// closeInputStream immediately marks the input stream as closed and cancels any