// Copyright 2017 The Mellium Contributors. // Use of this source code is governed by the BSD 2-clause // license that can be found in the LICENSE file. package xmpp_test import ( "bytes" "context" "encoding/xml" "errors" "fmt" "io" "math" "net" "strconv" "strings" "testing" "mellium.im/xmlstream" "mellium.im/xmpp" intstream "mellium.im/xmpp/internal/stream" "mellium.im/xmpp/internal/xmpptest" "mellium.im/xmpp/jid" "mellium.im/xmpp/stanza" "mellium.im/xmpp/stream" ) var _ fmt.Stringer = xmpp.SessionState(0) func TestClosedInputStream(t *testing.T) { for i := 0; i <= math.MaxUint8; i++ { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { mask := xmpp.SessionState(i) buf := new(bytes.Buffer) s := xmpptest.NewSession(mask, buf) r := s.TokenReader() defer r.Close() _, err := r.Token() switch { case mask&xmpp.InputStreamClosed == xmpp.InputStreamClosed && err != xmpp.ErrInputStreamClosed: t.Errorf("Unexpected error: want=`%v', got=`%v'", xmpp.ErrInputStreamClosed, err) case mask&xmpp.InputStreamClosed == 0 && err != io.EOF: t.Errorf("Unexpected error: `%v'", err) } }) } } func TestClosedOutputStream(t *testing.T) { for i := 0; i <= math.MaxUint8; i++ { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { mask := xmpp.SessionState(i) buf := new(bytes.Buffer) s := xmpptest.NewSession(mask, buf) w := s.TokenWriter() defer w.Close() switch err := w.EncodeToken(xml.CharData("chartoken")); { case mask&xmpp.OutputStreamClosed == xmpp.OutputStreamClosed && err != xmpp.ErrOutputStreamClosed: t.Errorf("Unexpected error: want=`%v', got=`%v'", xmpp.ErrOutputStreamClosed, err) case mask&xmpp.OutputStreamClosed == 0 && err != nil: t.Errorf("Unexpected error: `%v'", err) } switch err := w.Flush(); { case mask&xmpp.OutputStreamClosed == xmpp.OutputStreamClosed && err != xmpp.ErrOutputStreamClosed: t.Errorf("Unexpected error flushing: want=`%v', got=`%v'", xmpp.ErrOutputStreamClosed, err) case mask&xmpp.OutputStreamClosed == 0 && err != nil: t.Errorf("Unexpected error: `%v'", err) } }) } } func TestNilNegotiatorPanics(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected panic, did not get one") } }() xmpp.NewSession(context.Background(), jid.JID{}, jid.JID{}, nil, 0, nil) } var errTestNegotiate = errors.New("a test error") func errNegotiator(ctx context.Context, _, _ *stream.Info, session *xmpp.Session, data interface{}) (mask xmpp.SessionState, rw io.ReadWriter, cache interface{}, err error) { err = errTestNegotiate return mask, rw, cache, err } type negotiateTestCase struct { negotiator xmpp.Negotiator in string out string location jid.JID origin jid.JID err error initialState xmpp.SessionState finalState xmpp.SessionState } var readyFeature = xmpp.StreamFeature{ Name: xml.Name{Space: "urn:example", Local: "ready"}, Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) { _, err := d.Token() return false, nil, err }, Negotiate: func(ctx context.Context, session *xmpp.Session, data interface{}) (xmpp.SessionState, io.ReadWriter, error) { return xmpp.Ready, nil, nil }, } var negotiateTests = [...]negotiateTestCase{ 0: {negotiator: errNegotiator, err: errTestNegotiate}, 1: { negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{ Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature { return []xmpp.StreamFeature{xmpp.StartTLS(nil)} }, }), in: ``, out: ``, err: errors.New("XML syntax error on line 1: unexpected EOF"), }, 2: { negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{}), in: ``, out: ``, err: errors.New("xmpp: features advertised out of order"), }, 3: { negotiator: xmpp.NewNegotiator(xmpp.StreamConfig{ Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature { return []xmpp.StreamFeature{readyFeature} }, }), in: ``, out: ``, initialState: xmpp.S2S, finalState: xmpp.Ready | xmpp.S2S, }, } func TestNegotiator(t *testing.T) { for i, tc := range negotiateTests { t.Run(strconv.Itoa(i), func(t *testing.T) { buf := &bytes.Buffer{} rw := struct { io.Reader io.Writer }{ Reader: strings.NewReader(tc.in), Writer: buf, } session, err := xmpp.NewSession(context.Background(), tc.location, tc.origin, rw, tc.initialState, tc.negotiator) if ((err == nil || tc.err == nil) && (err != nil || tc.err != nil)) && err.Error() != tc.err.Error() { t.Errorf("unexpected error: want=%q, got=%q", tc.err, err) } if out := buf.String(); out != tc.out { t.Errorf("unexpected output:\nwant=%q,\n got=%q", tc.out, out) } if s := session.State(); s != tc.finalState { t.Errorf("unexpected state: want=%v, got=%v", tc.finalState, s) } }) } } const invalidIQ = `` var failHandler xmpp.HandlerFunc = func(r xmlstream.TokenReadEncoder, t *xml.StartElement) error { return errors.New("session_test: FAILED") } var serveTests = [...]struct { handler xmpp.Handler out string in string err error errStringCmp bool state xmpp.SessionState }{ 0: { in: ``, out: ``, }, 1: { in: `a`, out: ``, err: errors.New("xmpp: unexpected stream-level chardata"), errStringCmp: true, }, 2: { in: ``, out: invalidIQ + ``, }, 3: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { _, err := xmlstream.Copy(rw, stanza.IQ{ ID: "1234", Type: stanza.ResultIQ, }.Wrap(nil)) return err }), in: ``, out: ``, }, 4: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { _, err := xmlstream.Copy(rw, stanza.IQ{ ID: "wrongid", Type: stanza.ResultIQ, }.Wrap(nil)) return err }), in: ``, out: `` + invalidIQ + ``, }, 5: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { _, err := xmlstream.Copy(rw, stanza.IQ{ ID: "1234", Type: stanza.ErrorIQ, }.Wrap(nil)) return err }), in: ``, out: ``, }, 6: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { _, err := xmlstream.Copy(rw, stanza.IQ{ ID: "1234", Type: stanza.GetIQ, }.Wrap(nil)) return err }), in: ``, out: `` + invalidIQ + ``, }, 7: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { for _, attr := range start.Attr { if attr.Name.Local == "from" && attr.Value != "" { panic("expected attr to be normalized") } } return nil }), in: ``, out: ``, }, 8: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { for _, attr := range start.Attr { if attr.Name.Local == "from" && attr.Value == "" { panic("expected attr not to be normalized") } } return nil }), in: ``, out: ``, }, 9: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { for _, attr := range start.Attr { if attr.Name.Local == "from" && attr.Value == "" { panic("expected attr not to be normalized") } } return nil }), in: ``, out: ``, }, 10: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { for _, attr := range start.Attr { if attr.Name.Local == "from" && attr.Value == "" { panic("expected attr not to be normalized") } } return nil }), in: ``, out: ``, }, 11: { handler: failHandler, in: "\n\t \r\n \t ", out: ``, }, 12: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { if start.Name.Space == stream.NS || start.Name.Space == "stream" { return fmt.Errorf("handler should never receive stream namespaced elements but got %v", start) } return nil }), in: ``, out: ``, err: stream.NotWellFormed, }, 13: { handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { if start.Name.Space == stream.NS || start.Name.Space == "stream" { return fmt.Errorf("handler should never receive stream namespaced elements but got %v", start) } return nil }), in: ``, out: ``, err: intstream.ErrUnknownStreamElement, }, 14: { // Regression test to ensure that we can't advance beyond the end of the // current element and that the close element is included in the stream. handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { if start.Name.Local == "b" { return nil } err := rw.EncodeToken(*start) if err != nil { return err } _, err = xmlstream.Copy(rw, rw) return err }), in: `test`, out: `test`, }, 15: { // S2S stanzas always have "from" set if not already set. handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { _, err := xmlstream.Copy(rw, stanza.IQ{ ID: "1234", Type: stanza.ResultIQ, }.Wrap(nil)) return err }), in: ``, out: ``, state: xmpp.S2S, }, 16: { // S2S stanzas always have "from" set, unless it was already set. handler: xmpp.HandlerFunc(func(rw xmlstream.TokenReadEncoder, start *xml.StartElement) error { _, err := xmlstream.Copy(rw, stanza.IQ{ ID: "1234", From: jid.MustParse("from@example.net"), Type: stanza.ResultIQ, }.Wrap(nil)) return err }), in: ``, out: ``, state: xmpp.S2S, }, } func TestServe(t *testing.T) { for i, tc := range serveTests { t.Run(strconv.Itoa(i), func(t *testing.T) { out := &bytes.Buffer{} in := strings.NewReader(tc.in) s := xmpptest.NewSession(tc.state, struct { io.Reader io.Writer }{ Reader: in, Writer: out, }) err := s.Serve(tc.handler) switch { case tc.errStringCmp && err.Error() != tc.err.Error(): t.Errorf("unexpected error: want=%v, got=%v", tc.err, err) case !tc.errStringCmp && !errors.Is(err, tc.err): t.Errorf("unexpected error: want=%v, got=%v", tc.err, err) } if s := out.String(); s != tc.out { t.Errorf("unexpected output:\nwant=%s,\n got=%s", tc.out, s) } if l := in.Len(); l != 0 { t.Errorf("did not finish read, %d bytes left", l) } }) } } func errorStartTLS(err error) xmpp.StreamFeature { startTLS := xmpp.StartTLS(nil) startTLS.Negotiate = func(ctx context.Context, session *xmpp.Session, data interface{}) (xmpp.SessionState, io.ReadWriter, error) { session.Encode(ctx, err) return 0, nil, nil } return startTLS } func TestNegotiateStreamError(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() clientConn, serverConn := net.Pipe() clientJID := jid.MustParse("me@example.net") semaphore := make(chan struct{}) go func() { defer close(semaphore) _, err := xmpp.ReceiveSession(ctx, serverConn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{ Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature { return []xmpp.StreamFeature{errorStartTLS(stream.Conflict)} }, })) if err != nil { t.Logf("error receiving session: %v", err) } }() _, err := xmpp.NewSession(ctx, clientJID, clientJID.Bare(), clientConn, 0, xmpp.NewNegotiator(xmpp.StreamConfig{ Features: func(*xmpp.Session, ...xmpp.StreamFeature) []xmpp.StreamFeature { return []xmpp.StreamFeature{xmpp.StartTLS(nil)} }, })) if !errors.Is(err, stream.Conflict) { t.Errorf("unexpected client err: want=%v, got=%v", stream.Conflict, err) } <-semaphore }