// Copyright 2021 The Mellium Contributors. // Use of this source code is governed by the BSD 2-clause // license that can be found in the LICENSE file. package delay_test import ( "encoding/xml" "strconv" "strings" "testing" "time" "mellium.im/xmlstream" "mellium.im/xmpp/delay" "mellium.im/xmpp/jid" ) var ( _ xml.Marshaler = delay.Delay{} _ xmlstream.Marshaler = delay.Delay{} _ xmlstream.WriterTo = delay.Delay{} _ xml.Unmarshaler = (*delay.Delay)(nil) ) var insertTestCases = [...]struct { in string out string stanza bool }{ 0: {}, 1: { stanza: true, in: ``, out: `foo`, }, 2: { stanza: true, in: `test`, out: `foofootest`, }, 3: { stanza: true, in: ``, out: ``, }, 4: { in: ``, out: `foo`, }, 5: { in: `test`, out: `foofootest`, }, 6: { in: ``, out: `foo`, }, } func TestInsert(t *testing.T) { for i, tc := range insertTestCases { t.Run(strconv.Itoa(i), func(t *testing.T) { d := delay.Delay{From: jid.MustParse("me@example.net"), Time: time.Time{}, Reason: "foo"} var r xml.TokenReader if tc.stanza { stanzaDelayer := delay.Stanza(d) r = stanzaDelayer(xml.NewDecoder(strings.NewReader(tc.in))) } else { r = delay.Insert(d)(xml.NewDecoder(strings.NewReader(tc.in))) } // Prevent duplicate xmlns attributes. See https://mellium.im/issue/75 r = xmlstream.RemoveAttr(func(start xml.StartElement, attr xml.Attr) bool { return (start.Name.Local == "message" || start.Name.Local == "iq") && attr.Name.Local == "xmlns" })(r) var buf strings.Builder e := xml.NewEncoder(&buf) _, err := xmlstream.Copy(e, r) if err != nil { t.Fatalf("error encoding: %v", err) } if err = e.Flush(); err != nil { t.Fatalf("error flushing: %v", err) } if out := buf.String(); tc.out != out { t.Errorf("wrong output:\nwant=%s,\n got=%s", tc.out, out) } }) } } var marshalTests = [...]struct { unmarshal bool // true if we should only unmarshal for this test. in delay.Delay out string }{ 0: { out: ``, }, 1: { in: delay.Delay{From: jid.MustParse("me@example.net")}, out: ``, }, 2: { in: delay.Delay{Time: time.Time{}.Add(24 * time.Hour)}, out: ``, }, 3: { in: delay.Delay{Reason: "foo"}, out: `foo`, }, 4: { in: delay.Delay{From: jid.MustParse("me@example.net"), Time: time.Time{}.Add(24 * time.Hour), Reason: "foo"}, out: `foo`, }, 5: { unmarshal: true, in: delay.Delay{Time: time.Time{}.Add(24 * time.Hour), Reason: "foo"}, out: `foo`, }, } func TestMarshal(t *testing.T) { for i, tc := range marshalTests { t.Run(strconv.Itoa(i), func(t *testing.T) { if !tc.unmarshal { b, err := xml.Marshal(tc.in) if err != nil { t.Fatalf("unexpected error marshaling: %v", err) } if out := string(b); out != tc.out { t.Fatalf("wrong value:\nwant=%v,\n got=%v", tc.out, out) } } d := delay.Delay{} err := xml.Unmarshal([]byte(tc.out), &d) if err != nil { t.Fatalf("error unmarshaling: %v", err) } if !d.From.Equal(tc.in.From) { t.Errorf("wrong from JID: want=%v, got=%v", tc.in.From, d.From) } if !d.Time.Equal(tc.in.Time) { t.Errorf("wrong timestamp: want=%v, got=%v", tc.in.Time, d.Time) } if d.Reason != tc.in.Reason { t.Errorf("wrong chardata: want=%q, got=%q", tc.in.Reason, d.Reason) } }) } }