@@ 17,6 17,7 @@ import (
// TokenReader returns a reader for the XML encoding of v.
func TokenReader(v interface{}) (xml.TokenReader, error) {
+ // If the payload to marshal is already a TokenReader, just return it.
if r, ok := v.(xml.TokenReader); ok {
return r, nil
}
@@ 5,9 5,13 @@
package marshal_test
import (
+ "bytes"
"encoding/xml"
+ "fmt"
+ "reflect"
"testing"
+ "mellium.im/xmlstream"
"mellium.im/xmpp/internal/marshal"
"mellium.im/xmpp/stanza"
)
@@ 25,7 29,66 @@ func (wf *testWriteFlusher) Flush() error {
return nil
}
-func TestFlushes(t *testing.T) {
+type noopTokenWriter struct{}
+
+func (noopTokenWriter) EncodeToken(xml.Token) error { return nil }
+
+type errTokenWriter struct{}
+
+func (errTokenWriter) EncodeToken(xml.Token) error { return bytes.ErrTooLarge }
+
+var simpleIn = struct {
+ XMLName xml.Name `xml:"space local"`
+}{}
+
+var encodeTestCases = [...]struct {
+ w xmlstream.TokenWriter
+ v interface{}
+ err error
+ errType error
+}{
+ 0: {
+ w: nil,
+ v: struct{}{},
+ errType: &xml.UnsupportedTypeError{},
+ },
+ 1: {
+ w: noopTokenWriter{},
+ v: simpleIn,
+ },
+ 2: {
+ w: errTokenWriter{},
+ v: simpleIn,
+ err: bytes.ErrTooLarge,
+ },
+}
+
+func TestEncode(t *testing.T) {
+ for i, tc := range encodeTestCases {
+ t.Run(fmt.Sprintf("%d/EncodeXML", i), func(t *testing.T) {
+ err := marshal.EncodeXML(tc.w, tc.v)
+ switch {
+ case tc.err != nil && err != tc.err:
+ t.Errorf("unexpected error: want=%v, got=%v", tc.err, err)
+ case tc.errType != nil && reflect.TypeOf(err) != reflect.TypeOf(tc.errType):
+ t.Errorf("error of wrong type: want=%T, got=%T", tc.errType, err)
+ }
+ })
+ t.Run(fmt.Sprintf("%d/EncodeXMLElement", i), func(t *testing.T) {
+ err := marshal.EncodeXMLElement(tc.w, tc.v, xml.StartElement{
+ Name: xml.Name{Space: "space", Local: "local"},
+ })
+ switch {
+ case tc.err != nil && err != tc.err:
+ t.Errorf("unexpected error: want=%v, got=%v", tc.err, err)
+ case tc.errType != nil && reflect.TypeOf(err) != reflect.TypeOf(tc.errType):
+ t.Errorf("error of wrong type: want=%T, got=%T", tc.errType, err)
+ }
+ })
+ }
+}
+
+func TestFlusher(t *testing.T) {
t.Run("EncodeXML", func(t *testing.T) {
f := &testWriteFlusher{}
if err := marshal.EncodeXML(f, 1); err != nil {
@@ 50,7 113,6 @@ func TestFlushes(t *testing.T) {
}
func TestMarshalTokenReader(t *testing.T) {
- // If the payload to marshal is already a TokenReader, just return it.
r := stanza.IQ{}.Wrap(nil)
rr, err := marshal.TokenReader(r)
if err != nil {