~samwhited/xmpp

723a3da2ab202d15651ed8414465670b138cf488 — Sam Whited 4 years ago 1ec4438
internal: move stream send/recv logic
3 files changed, 84 insertions(+), 86 deletions(-)

R stream.go => internal/stream.go
R stream_test.go => internal/stream_test.go
M session.go
R stream.go => internal/stream.go +35 -72
@@ 2,7 2,7 @@
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp
package internal

import (
	"context"


@@ 11,58 11,29 @@ import (
	"io"

	"golang.org/x/text/language"
	"mellium.im/xmpp/internal"

	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stream"
)

const (
	xmlHeader = `<?xml version="1.0" encoding="UTF-8"?>`
	XMLHeader = `<?xml version="1.0" encoding="UTF-8"?>`
)

func negotiator(ctx context.Context, s *Session, doRestart interface{}) (mask SessionState, rw io.ReadWriter, restartNext interface{}, err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	if rst, ok := doRestart.(bool); ok && rst {
		if (s.state & Received) == Received {
			// If we're the receiving entity wait for a new stream, then send one in
			// response.
			if err = expectNewStream(ctx, s); err != nil {
				return mask, nil, false, err
			}
			if err = sendNewStream(s, s.config, internal.RandomID()); err != nil {
				return mask, nil, false, err
			}
		} else {
			// If we're the initiating entity, send a new stream and then wait for
			// one in response.
			if err = sendNewStream(s, s.config, ""); err != nil {
				return mask, nil, false, err
			}
			if err = expectNewStream(ctx, s); err != nil {
				return mask, nil, false, err
			}
		}
	}

	mask, rw, err = negotiateFeatures(ctx, s)
	return mask, rw, rw != nil, err
}

type streamInfo struct {
type StreamInfo struct {
	to      *jid.JID
	from    *jid.JID
	id      string
	version internal.Version
	version Version
	xmlns   string
	lang    language.Tag
}

// This MUST only return stream errors.
// TODO: Is the above true? Just make it return a StreamError?
func streamFromStartElement(s xml.StartElement) (streamInfo, error) {
	streamData := streamInfo{}
func streamFromStartElement(s xml.StartElement) (StreamInfo, error) {
	streamData := StreamInfo{}
	for _, attr := range s.Attr {
		switch attr.Name {
		case xml.Name{Space: "", Local: "to"}:


@@ 101,14 72,9 @@ func streamFromStartElement(s xml.StartElement) (streamInfo, error) {
// because we can guarantee well-formedness of the XML with a print in this case
// and printing is much faster than encoding. Afterwards, clear the
// StreamRestartRequired bit and set the output stream information.
func sendNewStream(s *Session, cfg *Config, id string) error {
	streamData := streamInfo{
		to:      cfg.Location,
		from:    cfg.Origin,
		lang:    cfg.Lang,
		version: cfg.Version,
	}
	switch cfg.S2S {
func SendNewStream(rw io.ReadWriter, s2s bool, version Version, lang language.Tag, location, origin, id string) (StreamInfo, error) {
	streamData := StreamInfo{}
	switch s2s {
	case true:
		streamData.xmlns = ns.Server
	case false:


@@ 122,36 88,34 @@ func sendNewStream(s *Session, cfg *Config, id string) error {
		id = ` id='` + id + `' `
	}

	_, err := fmt.Fprintf(s.Conn(),
		xmlHeader+`<stream:stream%sto='%s' from='%s' version='%s' xml:lang='%s' xmlns='%s' xmlns:stream='http://etherx.jabber.org/streams'>`,
	_, err := fmt.Fprintf(rw,
		XMLHeader+`<stream:stream%sto='%s' from='%s' version='%s' xml:lang='%s' xmlns='%s' xmlns:stream='http://etherx.jabber.org/streams'>`,
		id,
		cfg.Location.String(),
		cfg.Origin.String(),
		cfg.Version,
		cfg.Lang,
		location,
		origin,
		version,
		lang,
		streamData.xmlns,
	)
	if err != nil {
		return err
		return streamData, err
	}

	s.out.streamInfo = streamData
	return nil
	return streamData, nil
}

func expectNewStream(ctx context.Context, s *Session) error {
func ExpectNewStream(ctx context.Context, d xml.TokenReader, recv bool) (streamData StreamInfo, err error) {
	var foundHeader bool

	d := s.in.d
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
			return streamData, ctx.Err()
		default:
		}
		t, err := d.Token()
		if err != nil {
			return err
			return streamData, err
		}
		switch tok := t.(type) {
		case xml.StartElement:


@@ 159,40 123,39 @@ func expectNewStream(ctx context.Context, s *Session) error {
			case tok.Name.Local == "error" && tok.Name.Space == ns.Stream:
				se := stream.Error{}
				if err := xml.NewTokenDecoder(d).DecodeElement(&se, &tok); err != nil {
					return err
					return streamData, err
				}
				return se
				return streamData, se
			case tok.Name.Local != "stream":
				return stream.BadFormat
				return streamData, stream.BadFormat
			case tok.Name.Space != ns.Stream:
				return stream.InvalidNamespace
				return streamData, stream.InvalidNamespace
			}

			streamData, err := streamFromStartElement(tok)
			streamData, err = streamFromStartElement(tok)
			switch {
			case err != nil:
				return err
			case streamData.version != internal.DefaultVersion:
				return stream.UnsupportedVersion
				return streamData, err
			case streamData.version != DefaultVersion:
				return streamData, stream.UnsupportedVersion
			}

			if (s.state&Received) != Received && streamData.id == "" {
			if !recv && streamData.id == "" {
				// if we are the initiating entity and there is no stream ID…
				return stream.BadFormat
				return streamData, stream.BadFormat
			}
			s.in.streamInfo = streamData
			return nil
			return streamData, nil
		case xml.ProcInst:
			// TODO: If version or encoding are declared, validate XML 1.0 and UTF-8
			if !foundHeader && tok.Target == "xml" {
				foundHeader = true
				continue
			}
			return stream.RestrictedXML
			return streamData, stream.RestrictedXML
		case xml.EndElement:
			return stream.NotWellFormed
			return streamData, stream.NotWellFormed
		default:
			return stream.RestrictedXML
			return streamData, stream.RestrictedXML
		}
	}
}

R stream_test.go => internal/stream_test.go +12 -12
@@ 1,8 1,8 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package xmpp
package internal_test

import (
	"bytes"


@@ 11,11 11,12 @@ import (
	"strings"
	"testing"

	"mellium.im/xmpp/jid"
	"golang.org/x/text/language"

	"mellium.im/xmpp/internal"
)

func TestSendNewS2S(t *testing.T) {
	config := NewClientConfig(jid.MustParse("test@example.net"))
	for i, tc := range []struct {
		s2s    bool
		id     bool


@@ 34,14 35,13 @@ func TestSendNewS2S(t *testing.T) {
			if tc.id {
				ids = "abc"
			}
			config.S2S = tc.s2s
			err := sendNewStream(&Session{rw: &b}, config, ids)
			_, err := internal.SendNewStream(&b, tc.s2s, internal.Version{1, 0}, language.Und, "example.net", "test@example.net", ids)

			str := b.String()
			if !strings.HasPrefix(str, xmlHeader) {
			if !strings.HasPrefix(str, internal.XMLHeader) {
				t.Errorf("Expected string to start with XML header but got: %s", str)
			}
			str = strings.TrimPrefix(str, xmlHeader)
			str = strings.TrimPrefix(str, internal.XMLHeader)

			switch {
			case err != tc.err:


@@ 78,14 78,14 @@ func (nopReader) Read(p []byte) (n int, err error) {
}

func TestSendNewS2SReturnsWriteErr(t *testing.T) {
	config := NewClientConfig(jid.MustParse("test@example.net"))
	if err := sendNewStream(&Session{rw: struct {
	_, err := internal.SendNewStream(struct {
		io.Reader
		io.Writer
	}{
		nopReader{},
		errWriter{},
	}}, config, "abc"); err != io.ErrUnexpectedEOF {
	}, true, internal.Version{1, 0}, language.Und, "example.net", "test@example.net", "abc")
	if err != io.ErrUnexpectedEOF {
		t.Errorf("Expected errWriterErr (%s) but got `%s`", io.ErrUnexpectedEOF, err)
	}
}

M session.go => session.go +37 -2
@@ 12,6 12,7 @@ import (
	"sync"

	"mellium.im/xmlstream"
	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/internal/ns"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"


@@ 76,14 77,14 @@ type Session struct {

	in struct {
		sync.Mutex
		streamInfo
		internal.StreamInfo
		d      xmlstream.TokenReader
		ctx    context.Context
		cancel context.CancelFunc
	}
	out struct {
		sync.Mutex
		streamInfo
		internal.StreamInfo
		e *xml.Encoder
	}
}


@@ 311,3 312,37 @@ func (s *Session) handleInputStream(handler Handler) error {
		}
	}
}

func negotiator(ctx context.Context, s *Session, doRestart interface{}) (mask SessionState, rw io.ReadWriter, restartNext interface{}, err error) {
	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	if rst, ok := doRestart.(bool); ok && rst {
		if (s.state & Received) == Received {
			// If we're the receiving entity wait for a new stream, then send one in
			// response.

			s.in.StreamInfo, err = internal.ExpectNewStream(ctx, s.in.d, s.State()&Received == Received)
			if err != nil {
				return mask, nil, false, err
			}
			s.out.StreamInfo, err = internal.SendNewStream(s.Conn(), s.config.S2S, s.config.Version, s.config.Lang, s.config.Location.String(), s.config.Origin.String(), internal.RandomID())
			if err != nil {
				return mask, nil, false, err
			}
		} else {
			// If we're the initiating entity, send a new stream and then wait for
			// one in response.
			s.out.StreamInfo, err = internal.SendNewStream(s.Conn(), s.config.S2S, s.config.Version, s.config.Lang, s.config.Location.String(), s.config.Origin.String(), "")
			if err != nil {
				return mask, nil, false, err
			}
			s.in.StreamInfo, err = internal.ExpectNewStream(ctx, s.in.d, s.State()&Received == Received)
			if err != nil {
				return mask, nil, false, err
			}
		}
	}

	mask, rw, err = negotiateFeatures(ctx, s)
	return mask, rw, rw != nil, err
}