M internal/marshal/encode.go => internal/marshal/encode.go +4 -0
@@ 17,6 17,10 @@ import (
// TokenReader returns a reader for the XML encoding of v.
func TokenReader(v interface{}) (xml.TokenReader, error) {
+ if r, ok := v.(xml.TokenReader); ok {
+ return r, nil
+ }
+
var b bytes.Buffer
err := xml.NewEncoder(&b).Encode(v)
if err != nil {
M internal/marshal/encode_test.go => internal/marshal/encode_test.go +13 -0
@@ 9,6 9,7 @@ import (
"testing"
"mellium.im/xmpp/internal/marshal"
+ "mellium.im/xmpp/stanza"
)
type testWriteFlusher struct {
@@ 47,3 48,15 @@ func TestFlushes(t *testing.T) {
}
})
}
+
+func TestMarshalTokenReader(t *testing.T) {
+ // If the payload to marshal is already a TokenReader, just return it.
+ r := stanza.IQ{}.Wrap(nil)
+ rr, err := marshal.TokenReader(r)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if r != rr {
+ t.Errorf("got different xml.TokenReader out: want=%v, got=%v", r, rr)
+ }
+}