@@ 22,6 22,10 @@ func TokenReader(v interface{}) (xml.TokenReader, error) {
return r, nil
}
+ return tokenDecoder(v)
+}
+
+func tokenDecoder(v interface{}) (*xml.Decoder, error) {
var b bytes.Buffer
err := xml.NewEncoder(&b).Encode(v)
if err != nil {
@@ 30,6 34,15 @@ func TokenReader(v interface{}) (xml.TokenReader, error) {
return xml.NewDecoder(&b), nil
}
+// rawTokenReader maps a decoders RawToken method onto its Token method.
+type rawTokenReader struct {
+ *xml.Decoder
+}
+
+func (r rawTokenReader) Token() (xml.Token, error) {
+ return r.RawToken()
+}
+
// EncodeXML writes the XML encoding of v to the stream.
//
// See the documentation for xml.Marshal for details about the conversion of Go
@@ 38,11 51,11 @@ func TokenReader(v interface{}) (xml.TokenReader, error) {
// If the stream is an xmlstream.Flusher, EncodeXML calls Flush before
// returning.
func EncodeXML(w xmlstream.TokenWriter, v interface{}) error {
- r, err := TokenReader(v)
+ d, err := tokenDecoder(v)
if err != nil {
return err
}
- _, err = xmlstream.Copy(w, r)
+ _, err = xmlstream.Copy(w, rawTokenReader{Decoder: d})
if err != nil {
return err
}
@@ 62,11 75,11 @@ func EncodeXML(w xmlstream.TokenWriter, v interface{}) error {
// If the stream is an xmlstream.Flusher, EncodeXMLElement calls Flush before
// returning.
func EncodeXMLElement(w xmlstream.TokenWriter, v interface{}, start xml.StartElement) error {
- r, err := TokenReader(v)
+ d, err := tokenDecoder(v)
if err != nil {
return err
}
- _, err = xmlstream.Copy(w, xmlstream.Wrap(r, start))
+ _, err = xmlstream.Copy(w, rawTokenReader{Decoder: d})
if err != nil {
return err
}
@@ 122,3 122,24 @@ func TestMarshalTokenReader(t *testing.T) {
t.Errorf("got different xml.TokenReader out: want=%v, got=%v", r, rr)
}
}
+
+func TestTokenDecoder(t *testing.T) {
+ r := stanza.IQ{}
+ _, err := marshal.TokenReader(r)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+}
+
+func TestEncodeXMLNS(t *testing.T) {
+ var buf bytes.Buffer
+ e := xml.NewEncoder(&buf)
+ err := marshal.EncodeXML(e, simpleIn)
+ if err != nil {
+ t.Errorf("unexpected error encoding: %v", err)
+ }
+ const expected = `<local xmlns="space"></local>`
+ if s := buf.String(); s != expected {
+ t.Errorf("wrong output: want=%s, got=%s", expected, s)
+ }
+}