~samwhited/xmpp

e8839b44be95cb2e23353c5e4662fce9d6b492b3 — Sam Whited 1 year, 11 months ago d554e50
internal/stream: new package for stream parsing

Also update all uses of old internal identifiers that were moved into
internal/stream.
7 files changed, 53 insertions(+), 47 deletions(-)

M internal/decl/decl.go
R internal/{stream.go => stream/stream.go}
R internal/{stream_test.go => stream/stream_test.go}
R internal/{version.go => stream/version.go}
R internal/{version_test.go => stream/version_test.go}
M negotiator.go
M session.go
M internal/decl/decl.go => internal/decl/decl.go +6 -0
@@ 9,6 9,12 @@ import (
	"encoding/xml"
)

const (
	// XMLHeader is an XML header like the one in encoding/xml but without a
	// newline at the end.
	XMLHeader = `<?xml version="1.0" encoding="UTF-8"?>`
)

type skipper struct {
	r       xml.TokenReader
	started bool

R internal/stream.go => internal/stream/stream.go +3 -8
@@ 2,7 2,8 @@
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package internal
// Package stream contains internal stream parsing and handling behavior.
package stream // import "mellium.im/xmpp/internal/stream"

import (
	"bufio"


@@ 17,12 18,6 @@ import (
	"mellium.im/xmpp/stream"
)

const (
	// XMLHeader is an XML header like the one in encoding/xml but without a
	// newline at the end.
	XMLHeader = `<?xml version="1.0" encoding="UTF-8"?>`
)

// StreamInfo contains metadata extracted from a stream start token.
type StreamInfo struct {
	to      *jid.JID


@@ 96,7 91,7 @@ func SendNewStream(rw io.ReadWriter, s2s bool, version Version, lang string, loc

	b := bufio.NewWriter(rw)
	_, err := fmt.Fprintf(b,
		XMLHeader+`<stream:stream%sto='%s' from='%s' version='%s' `,
		decl.XMLHeader+`<stream:stream%sto='%s' from='%s' version='%s' `,
		id,
		location,
		origin,

R internal/stream_test.go => internal/stream/stream_test.go +8 -7
@@ 2,7 2,7 @@
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package internal_test
package stream_test

import (
	"bytes"


@@ 11,7 11,8 @@ import (
	"strings"
	"testing"

	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/internal/decl"
	"mellium.im/xmpp/internal/stream"
)

func TestSendNewS2S(t *testing.T) {


@@ 33,13 34,13 @@ func TestSendNewS2S(t *testing.T) {
			if tc.id {
				ids = "abc"
			}
			_, err := internal.SendNewStream(&b, tc.s2s, internal.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", ids)
			_, err := stream.SendNewStream(&b, tc.s2s, stream.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", ids)

			str := b.String()
			if !strings.HasPrefix(str, internal.XMLHeader) {
			if !strings.HasPrefix(str, decl.XMLHeader) {
				t.Errorf("Expected string to start with XML header but got: %s", str)
			}
			str = strings.TrimPrefix(str, internal.XMLHeader)
			str = strings.TrimPrefix(str, decl.XMLHeader)

			switch {
			case err != tc.err:


@@ 76,13 77,13 @@ func (nopReader) Read(p []byte) (n int, err error) {
}

func TestSendNewS2SReturnsWriteErr(t *testing.T) {
	_, err := internal.SendNewStream(struct {
	_, err := stream.SendNewStream(struct {
		io.Reader
		io.Writer
	}{
		nopReader{},
		errWriter{},
	}, true, internal.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", "abc")
	}, true, stream.Version{Major: 1, Minor: 0}, "und", "example.net", "test@example.net", "abc")
	if err != io.ErrUnexpectedEOF {
		t.Errorf("Expected errWriterErr (%s) but got `%s`", io.ErrUnexpectedEOF, err)
	}

R internal/version.go => internal/stream/version.go +1 -1
@@ 2,7 2,7 @@
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package internal
package stream

import (
	"encoding/xml"

R internal/version_test.go => internal/stream/version_test.go +27 -25
@@ 2,36 2,38 @@
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package internal
package stream_test

import (
	"encoding/xml"
	"fmt"
	"testing"

	"mellium.im/xmpp/internal/stream"
)

// Compile time interface checks.
var _ fmt.Stringer = &Version{}
var _ fmt.Stringer = Version{}
var _ xml.MarshalerAttr = &Version{}
var _ xml.MarshalerAttr = Version{}
var _ xml.UnmarshalerAttr = (*Version)(nil)
var _ fmt.Stringer = &stream.Version{}
var _ fmt.Stringer = stream.Version{}
var _ xml.MarshalerAttr = &stream.Version{}
var _ xml.MarshalerAttr = stream.Version{}
var _ xml.UnmarshalerAttr = (*stream.Version)(nil)

// Strings must parse correctly.
func TestParseVersion(t *testing.T) {
	for _, data := range []struct {
		vs        string
		v         Version
		v         stream.Version
		shouldErr bool
	}{
		{"1.0", Version{1, 0}, false},
		{"1.0.0", Version{}, true},
		{"A.1", Version{}, true},
		{"1.a", Version{}, true},
		{"1.0xA", Version{}, true},
		{"", Version{}, true},
		{"1.0", stream.Version{1, 0}, false},
		{"1.0.0", stream.Version{}, true},
		{"A.1", stream.Version{}, true},
		{"1.a", stream.Version{}, true},
		{"1.0xA", stream.Version{}, true},
		{"", stream.Version{}, true},
	} {
		v, err := ParseVersion(data.vs)
		v, err := stream.ParseVersion(data.vs)
		switch {
		case data.shouldErr && err == nil:
			t.Logf("Version '%s' should fail with an error when parsed.", data.vs)


@@ 57,21 59,21 @@ func TestMustParseVersionPanics(t *testing.T) {
			t.Error("Expected MustParseVersion to panic when given invalid version")
		}
	}()
	MustParseVersion("a.0")
	stream.MustParseVersion("a.0")
}

func TestCompareVersion(t *testing.T) {
	for _, data := range []struct {
		v1, v2 Version
		v1, v2 stream.Version
		less   bool
	}{
		{Version{}, Version{}, false},
		{MustParseVersion("1.0"), MustParseVersion("1.1"), true},
		{MustParseVersion("1.1"), MustParseVersion("1.0"), false},
		{MustParseVersion("1.0"), MustParseVersion("2.0"), true},
		{MustParseVersion("2.0"), MustParseVersion("1.0"), false},
		{MustParseVersion("1.5"), MustParseVersion("2.0"), true},
		{MustParseVersion("2.0"), MustParseVersion("1.5"), false},
		{stream.Version{}, stream.Version{}, false},
		{stream.MustParseVersion("1.0"), stream.MustParseVersion("1.1"), true},
		{stream.MustParseVersion("1.1"), stream.MustParseVersion("1.0"), false},
		{stream.MustParseVersion("1.0"), stream.MustParseVersion("2.0"), true},
		{stream.MustParseVersion("2.0"), stream.MustParseVersion("1.0"), false},
		{stream.MustParseVersion("1.5"), stream.MustParseVersion("2.0"), true},
		{stream.MustParseVersion("2.0"), stream.MustParseVersion("1.5"), false},
	} {
		if data.v1.Less(data.v2) != data.less {
			if data.less {


@@ 88,7 90,7 @@ func TestMarshalVersion(t *testing.T) {
	for _, data := range []string{
		"1.0", "0.1", "10.0", "0.10",
	} {
		switch s2, _ := MustParseVersion(data).MarshalXMLAttr(n); {
		switch s2, _ := stream.MustParseVersion(data).MarshalXMLAttr(n); {
		case s2.Value != data:
			t.Errorf("Expected %s to parse and stringify to itself but got %s", data, s2)
		case s2.Name != n:


@@ 107,7 109,7 @@ func TestUnmarshalVersion(t *testing.T) {
		{xml.Attr{Value: "2.0"}, "2.0", false},
		{xml.Attr{Name: xml.Name{Space: "", Local: "Whatever"}, Value: "0.9"}, "0.9", false},
	} {
		v2 := Version{}
		v2 := stream.Version{}
		err := v2.UnmarshalXMLAttr(data.attr)
		switch {
		case data.err && err == nil:

M negotiator.go => negotiator.go +5 -4
@@ 10,6 10,7 @@ import (
	"net"

	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/internal/stream"
)

// teeConn is a net.Conn that also copies reads and writes to the provided


@@ 135,12 136,12 @@ func negotiator(cfg StreamConfig) Negotiator {
				// If we're the receiving entity wait for a new stream, then send one in
				// response.

				s.in.StreamInfo, err = internal.ExpectNewStream(ctx, s.in.d, s.State()&Received == Received)
				s.in.StreamInfo, err = stream.ExpectNewStream(ctx, s.in.d, s.State()&Received == Received)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				s.out.StreamInfo, err = internal.SendNewStream(s.Conn(), cfg.S2S, internal.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), internal.RandomID())
				s.out.StreamInfo, err = stream.SendNewStream(s.Conn(), cfg.S2S, stream.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), internal.RandomID())
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err


@@ 148,12 149,12 @@ func negotiator(cfg StreamConfig) Negotiator {
			} else {
				// If we're the initiating entity, send a new stream and then wait for
				// one in response.
				s.out.StreamInfo, err = internal.SendNewStream(s.Conn(), cfg.S2S, internal.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), "")
				s.out.StreamInfo, err = stream.SendNewStream(s.Conn(), cfg.S2S, stream.DefaultVersion, cfg.Lang, s.location.String(), s.origin.String(), "")
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err
				}
				s.in.StreamInfo, err = internal.ExpectNewStream(ctx, s.in.d, s.State()&Received == Received)
				s.in.StreamInfo, err = stream.ExpectNewStream(ctx, s.in.d, s.State()&Received == Received)
				if err != nil {
					nState.doRestart = false
					return mask, nil, nState, err

M session.go => session.go +3 -2
@@ 21,6 21,7 @@ import (
	"mellium.im/xmpp/internal"
	"mellium.im/xmpp/internal/marshal"
	"mellium.im/xmpp/internal/ns"
	intstream "mellium.im/xmpp/internal/stream"
	"mellium.im/xmpp/jid"
	"mellium.im/xmpp/stanza"
	"mellium.im/xmpp/stream"


@@ 88,14 89,14 @@ type Session struct {
	sentIQs     map[string]chan xmlstream.TokenReadCloser

	in struct {
		internal.StreamInfo
		intstream.StreamInfo
		d      xml.TokenReader
		ctx    context.Context
		cancel context.CancelFunc
		sync.Locker
	}
	out struct {
		internal.StreamInfo
		intstream.StreamInfo
		e tokenWriteFlusher
		sync.Locker
	}