~samwhited/xmpp

16fcfed077a257f695ebd100e188e0ceaa5049c7 — Sam Whited 2 months ago 6026cbc
internal/stream: remove s2s bool from stream send

Previously we set the namespace depending on whether a server-to-server
value was set. However, the namespace is set in the stream and may be
other values (such as the Jabber Component namespace) as well, so just
leave this alone and let the negotiator handle it.

Signed-off-by: Sam Whited <sam@samwhited.com>
M blocklist/blocking_test.go => blocklist/blocking_test.go +1 -1
@@ 190,7 190,7 @@ func (errReadWriter) Read([]byte) (int, error) {
}

func TestErroredDoesNotPanic(t *testing.T) {
	s := xmpptest.NewSession(0, errReadWriter{})
	s := xmpptest.NewClientSession(0, errReadWriter{})
	iter := blocklist.Fetch(context.Background(), s)
	if iter.Next() {
		t.Errorf("expected false from call to next")

M internal/stream/stream.go => internal/stream/stream.go +1 -8
@@ 26,14 26,7 @@ import (
// is much faster than encoding.
// Afterwards, clear the StreamRestartRequired bit and set the output stream
// information.
func Send(rw io.ReadWriter, streamData *stream.Info, s2s, ws bool, version stream.Version, lang string, to, from, id string) error {
	switch s2s {
	case true:
		streamData.XMLNS = ns.Server
	case false:
		streamData.XMLNS = ns.Client
	}

func Send(rw io.ReadWriter, streamData *stream.Info, ws bool, version stream.Version, lang, to, from, id string) error {
	streamData.ID = id
	b := bufio.NewWriter(rw)
	var err error

M internal/stream/stream_test.go => internal/stream/stream_test.go +8 -2
@@ 12,6 12,7 @@ import (
	"testing"

	"mellium.im/xmpp/internal/decl"
	"mellium.im/xmpp/internal/ns"
	intstream "mellium.im/xmpp/internal/stream"
	"mellium.im/xmpp/stream"
)


@@ 36,7 37,12 @@ func TestSendNewS2S(t *testing.T) {
				ids = "abc"
			}
			out := stream.Info{}
			err := intstream.Send(&b, &out, tc.s2s, false, stream.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", ids)
			if tc.s2s {
				out.XMLNS = ns.Server
			} else {
				out.XMLNS = ns.Client
			}
			err := intstream.Send(&b, &out, false, stream.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", ids)

			str := b.String()
			if !strings.HasPrefix(str, decl.XMLHeader) {


@@ 86,7 92,7 @@ func TestSendNewS2SReturnsWriteErr(t *testing.T) {
	}{
		nopReader{},
		errWriter{},
	}, &out, true, false, stream.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", "abc")
	}, &out, false, stream.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", "abc")
	if err != io.ErrUnexpectedEOF {
		t.Errorf("Expected errWriterErr (%s) but got `%s`", io.ErrUnexpectedEOF, err)
	}

M internal/xmpptest/features.go => internal/xmpptest/features.go +1 -1
@@ 66,7 66,7 @@ func RunFeatureTests(t *testing.T, tcs []FeatureTestCase) {
			}

			buf.Reset()
			s := NewSession(tc.State, struct {
			s := NewClientSession(tc.State, struct {
				io.Reader
				io.Writer
			}{

M internal/xmpptest/session.go => internal/xmpptest/session.go +22 -11
@@ 22,7 22,7 @@ import (
// pops the first token (likely <stream:stream>) but does not perform any
// validation on the token, transmit any data over the wire, or perform any
// other session negotiation.
func NopNegotiator(state xmpp.SessionState) xmpp.Negotiator {
func NopNegotiator(state xmpp.SessionState, streamNS string) xmpp.Negotiator {
	return func(ctx context.Context, in, out *stream.Info, s *xmpp.Session, data interface{}) (xmpp.SessionState, io.ReadWriter, interface{}, error) {
		// Pop the stream start token.
		rc := s.TokenReader()


@@ 32,24 32,35 @@ func NopNegotiator(state xmpp.SessionState) xmpp.Negotiator {
		if err != nil {
			return state | xmpp.Ready, nil, nil, err
		}
		out.XMLNS = streamNS
		err = intstream.Send(struct {
			io.Reader
			io.Writer
		}{
			Writer: io.Discard,
		}, out, s.State()&xmpp.S2S == xmpp.S2S, false, stream.DefaultVersion, "", "example.net", "test@example.net", "123")
		}, out, false, stream.DefaultVersion, "", "example.net", "test@example.net", "123")

		return state | xmpp.Ready, nil, nil, err
	}
}

// NewSession returns a new client-to-client XMPP session with the state bits
// set to finalState|xmpp.Ready, the origin JID set to "test@example.net" and
// the location JID set to "example.net".
// NewClientSession returns a new client-to-client XMPP session with the state
// bits set to finalState|xmpp.Ready, the origin JID set to "test@example.net"
// and the location JID set to "example.net".
//
// NewSession panics on error for ease of use in testing, where a panic is
// NewClientSession panics on error for ease of use in testing, where a panic is
// acceptable.
func NewSession(finalState xmpp.SessionState, rw io.ReadWriter) *xmpp.Session {
func NewClientSession(finalState xmpp.SessionState, rw io.ReadWriter) *xmpp.Session {
	return newSession(finalState, rw, ns.Client)
}

// NewServerSession is like NewClientSession except that the stream uses the
// server-to-server namespace.
func NewServerSession(finalState xmpp.SessionState, rw io.ReadWriter) *xmpp.Session {
	return newSession(finalState, rw, ns.Server)
}

func newSession(finalState xmpp.SessionState, rw io.ReadWriter, streamNS string) *xmpp.Session {
	location := jid.MustParse("example.net")
	origin := jid.MustParse("test@example.net")



@@ 65,14 76,14 @@ func NewSession(finalState xmpp.SessionState, rw io.ReadWriter) *xmpp.Session {
			io.Writer
		}{
			Reader: io.MultiReader(
				strings.NewReader(`<stream:stream from="`+from.String()+`" to="`+to.String()+`" id="123" version="1.0" xmlns="`+ns.Client+`" xmlns:stream="`+stream.NS+`">`),
				strings.NewReader(`<stream:stream from="`+from.String()+`" to="`+to.String()+`" id="123" version="1.0" xmlns="`+streamNS+`" xmlns:stream="`+stream.NS+`">`),
				rw,
				strings.NewReader(`</stream:stream>`),
			),
			Writer: rw,
		},
		0,
		NopNegotiator(finalState),
		NopNegotiator(finalState, streamNS),
	)
	if err != nil {
		panic(err)


@@ 148,8 159,8 @@ func NewClientServer(opts ...Option) *ClientServer {
	}

	clientConn, serverConn := net.Pipe()
	cs.Client = NewSession(cs.clientState, clientConn)
	cs.Server = NewSession(cs.serverState, serverConn)
	cs.Client = NewClientSession(cs.clientState, clientConn)
	cs.Server = NewServerSession(cs.serverState, serverConn)
	/* #nosec */
	go cs.Client.Serve(cs.clientHandler)
	/* #nosec */

M internal/xmpptest/session_test.go => internal/xmpptest/session_test.go +1 -1
@@ 19,7 19,7 @@ import (
func TestNewSession(t *testing.T) {
	state := xmpp.Secure | xmpp.InputStreamClosed
	buf := new(bytes.Buffer)
	s := xmpptest.NewSession(state, buf)
	s := xmpptest.NewClientSession(state, buf)

	if mask := s.State(); mask != state|xmpp.Ready {
		t.Errorf("Got invalid state value: want=%d, got=%d", state, mask)

M mux/mux_test.go => mux/mux_test.go +1 -1
@@ 544,7 544,7 @@ func TestFallback(t *testing.T) {
		Reader: strings.NewReader(`<iq xmlns="jabber:client" to="romeo@example.com" from="juliet@example.com" id="123"><test/></iq>`),
		Writer: buf,
	}
	s := xmpptest.NewSession(0, rw)
	s := xmpptest.NewClientSession(0, rw)

	r := s.TokenReader()
	defer r.Close()

M negotiator.go => negotiator.go +13 -2
@@ 10,6 10,7 @@ import (
	"io"

	"mellium.im/xmpp/internal/attr"
	"mellium.im/xmpp/internal/ns"
	intstream "mellium.im/xmpp/internal/stream"
	"mellium.im/xmpp/internal/wskey"
	"mellium.im/xmpp/jid"


@@ 158,7 159,12 @@ func negotiator(f func(*Session, *StreamConfig) StreamConfig) Negotiator {
				location = in.To
				origin = in.From

				err = intstream.Send(s.Conn(), out, s.State()&S2S == S2S, websocket, stream.DefaultVersion, cfg.Lang, origin.String(), location.String(), attr.RandomID())
				if s.State()&S2S == S2S {
					out.XMLNS = ns.Server
				} else {
					out.XMLNS = ns.Client
				}
				err = intstream.Send(s.Conn(), out, websocket, stream.DefaultVersion, cfg.Lang, origin.String(), location.String(), attr.RandomID())
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err


@@ 168,7 174,12 @@ func negotiator(f func(*Session, *StreamConfig) StreamConfig) Negotiator {
				// one in response.
				origin := s.LocalAddr()
				location := s.RemoteAddr()
				err = intstream.Send(s.Conn(), out, s.State()&S2S == S2S, websocket, stream.DefaultVersion, cfg.Lang, location.String(), origin.String(), "")
				if s.State()&S2S == S2S {
					out.XMLNS = ns.Server
				} else {
					out.XMLNS = ns.Client
				}
				err = intstream.Send(s.Conn(), out, websocket, stream.DefaultVersion, cfg.Lang, location.String(), origin.String(), "")
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err

M receipts/receipts_test.go => receipts/receipts_test.go +2 -2
@@ 147,7 147,7 @@ func TestClosedDoesNotPanic(t *testing.T) {
	h := &receipts.Handler{}

	bw := &bytes.Buffer{}
	s := xmpptest.NewSession(0, bw)
	s := xmpptest.NewClientSession(0, bw)
	ctx, cancel := context.WithCancel(context.Background())
	cancel()
	err := h.SendMessageElement(ctx, s, nil, stanza.Message{


@@ 188,7 188,7 @@ func TestRoundTrip(t *testing.T) {
	h := &receipts.Handler{}

	var req bytes.Buffer
	s := xmpptest.NewSession(0, &req)
	s := xmpptest.NewClientSession(0, &req)

	ctx, cancel := context.WithCancel(context.Background())
	cancel()

M roster/roster_test.go => roster/roster_test.go +1 -1
@@ 225,7 225,7 @@ func (errReadWriter) Read([]byte) (int, error) {
}

func TestErroredDoesNotPanic(t *testing.T) {
	s := xmpptest.NewSession(0, errReadWriter{})
	s := xmpptest.NewClientSession(0, errReadWriter{})
	iter := roster.Fetch(context.Background(), s)
	if iter.Next() {
		t.Errorf("expected false from call to next")

M session_test.go => session_test.go +29 -15
@@ 33,7 33,7 @@ func TestClosedInputStream(t *testing.T) {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			mask := xmpp.SessionState(i)
			buf := new(bytes.Buffer)
			s := xmpptest.NewSession(mask, buf)
			s := xmpptest.NewClientSession(mask, buf)
			r := s.TokenReader()
			defer r.Close()



@@ 53,7 53,7 @@ func TestClosedOutputStream(t *testing.T) {
		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
			mask := xmpp.SessionState(i)
			buf := new(bytes.Buffer)
			s := xmpptest.NewSession(mask, buf)
			s := xmpptest.NewClientSession(mask, buf)
			w := s.TokenWriter()
			defer w.Close()



@@ 182,6 182,7 @@ var serveTests = [...]struct {
	err          error
	errStringCmp bool
	state        xmpp.SessionState
	serverNS     bool
}{
	0: {
		in:  `<test></test>`,


@@ 343,9 344,10 @@ var serveTests = [...]struct {
			}.Wrap(nil))
			return err
		}),
		in:    `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,
		out:   `<iq xmlns="jabber:server" type="result" id="1234" from="test@example.net"></iq></stream:stream>`,
		state: xmpp.S2S,
		in:       `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,
		out:      `<iq xmlns="jabber:server" type="result" id="1234" from="test@example.net"></iq></stream:stream>`,
		state:    xmpp.S2S,
		serverNS: true,
	},
	16: {
		// S2S stanzas always have "from" set, unless it was already set.


@@ 357,9 359,10 @@ var serveTests = [...]struct {
			}.Wrap(nil))
			return err
		}),
		in:    `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,
		out:   `<iq xmlns="jabber:server" type="result" from="from@example.net" id="1234"></iq></stream:stream>`,
		state: xmpp.S2S,
		in:       `<iq type="get" id="1234"><unknownpayload xmlns="unknown"/></iq>`,
		out:      `<iq xmlns="jabber:server" type="result" from="from@example.net" id="1234"></iq></stream:stream>`,
		state:    xmpp.S2S,
		serverNS: true,
	},
}



@@ 368,13 371,24 @@ func TestServe(t *testing.T) {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			out := &bytes.Buffer{}
			in := strings.NewReader(tc.in)
			s := xmpptest.NewSession(tc.state, struct {
				io.Reader
				io.Writer
			}{
				Reader: in,
				Writer: out,
			})
			var s *xmpp.Session
			if tc.serverNS {
				s = xmpptest.NewServerSession(tc.state, struct {
					io.Reader
					io.Writer
				}{
					Reader: in,
					Writer: out,
				})
			} else {
				s = xmpptest.NewClientSession(tc.state, struct {
					io.Reader
					io.Writer
				}{
					Reader: in,
					Writer: out,
				})
			}

			err := s.Serve(tc.handler)
			switch {

M starttls_test.go => starttls_test.go +2 -2
@@ 123,7 123,7 @@ func (nopRWC) Close() error {
func TestNegotiateServer(t *testing.T) {
	stls := xmpp.StartTLS(&tls.Config{})
	var b bytes.Buffer
	c := xmpptest.NewSession(xmpp.Received, nopRWC{&b, &b})
	c := xmpptest.NewClientSession(xmpp.Received, nopRWC{&b, &b})
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
	switch {
	case err != nil:


@@ 162,7 162,7 @@ func TestNegotiateClient(t *testing.T) {
			stls := xmpp.StartTLS(&tls.Config{})
			r := strings.NewReader(strings.Join(test.responses, "\n"))
			var b bytes.Buffer
			c := xmpptest.NewSession(0, nopRWC{r, &b})
			c := xmpptest.NewClientSession(0, nopRWC{r, &b})
			mask, rw, err := stls.Negotiate(context.Background(), c, nil)
			switch {
			case test.err && err == nil: