// Copyright 2021 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.
package delay_test
import (
"encoding/xml"
"strconv"
"strings"
"testing"
"time"
"mellium.im/xmlstream"
"mellium.im/xmpp/delay"
"mellium.im/xmpp/jid"
)
var (
_ xml.Marshaler = delay.Delay{}
_ xmlstream.Marshaler = delay.Delay{}
_ xmlstream.WriterTo = delay.Delay{}
_ xml.Unmarshaler = (*delay.Delay)(nil)
)
var insertTestCases = [...]struct {
in string
out string
stanza bool
}{
0: {},
1: {
stanza: true,
in: ``,
out: `foo`,
},
2: {
stanza: true,
in: `test`,
out: `foofootest`,
},
3: {
stanza: true,
in: ``,
out: ``,
},
4: {
in: ``,
out: `foo`,
},
5: {
in: `test`,
out: `foofootest`,
},
6: {
in: ``,
out: `foo`,
},
}
func TestInsert(t *testing.T) {
for i, tc := range insertTestCases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
d := delay.Delay{From: jid.MustParse("me@example.net"), Time: time.Time{}, Reason: "foo"}
var r xml.TokenReader
if tc.stanza {
stanzaDelayer := delay.Stanza(d)
r = stanzaDelayer(xml.NewDecoder(strings.NewReader(tc.in)))
} else {
r = delay.Insert(d)(xml.NewDecoder(strings.NewReader(tc.in)))
}
// Prevent duplicate xmlns attributes. See https://mellium.im/issue/75
r = xmlstream.RemoveAttr(func(start xml.StartElement, attr xml.Attr) bool {
return (start.Name.Local == "message" || start.Name.Local == "iq") && attr.Name.Local == "xmlns"
})(r)
var buf strings.Builder
e := xml.NewEncoder(&buf)
_, err := xmlstream.Copy(e, r)
if err != nil {
t.Fatalf("error encoding: %v", err)
}
if err = e.Flush(); err != nil {
t.Fatalf("error flushing: %v", err)
}
if out := buf.String(); tc.out != out {
t.Errorf("wrong output:\nwant=%s,\n got=%s", tc.out, out)
}
})
}
}
var marshalTests = [...]struct {
unmarshal bool // true if we should only unmarshal for this test.
in delay.Delay
out string
}{
0: {
out: ``,
},
1: {
in: delay.Delay{From: jid.MustParse("me@example.net")},
out: ``,
},
2: {
in: delay.Delay{Time: time.Time{}.Add(24 * time.Hour)},
out: ``,
},
3: {
in: delay.Delay{Reason: "foo"},
out: `foo`,
},
4: {
in: delay.Delay{From: jid.MustParse("me@example.net"), Time: time.Time{}.Add(24 * time.Hour), Reason: "foo"},
out: `foo`,
},
5: {
unmarshal: true,
in: delay.Delay{Time: time.Time{}.Add(24 * time.Hour), Reason: "foo"},
out: `foo`,
},
}
func TestMarshal(t *testing.T) {
for i, tc := range marshalTests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
if !tc.unmarshal {
b, err := xml.Marshal(tc.in)
if err != nil {
t.Fatalf("unexpected error marshaling: %v", err)
}
if out := string(b); out != tc.out {
t.Fatalf("wrong value:\nwant=%v,\n got=%v", tc.out, out)
}
}
d := delay.Delay{}
err := xml.Unmarshal([]byte(tc.out), &d)
if err != nil {
t.Fatalf("error unmarshaling: %v", err)
}
if !d.From.Equal(tc.in.From) {
t.Errorf("wrong from JID: want=%v, got=%v", tc.in.From, d.From)
}
if !d.Time.Equal(tc.in.Time) {
t.Errorf("wrong timestamp: want=%v, got=%v", tc.in.Time, d.Time)
}
if d.Reason != tc.in.Reason {
t.Errorf("wrong chardata: want=%q, got=%q", tc.in.Reason, d.Reason)
}
})
}
}