~samwhited/xmpp

978b4142ff34aae1cb821cb527b9067e2571fa7a — Sam Whited 5 months ago d4187a5 126_verify_valid_tokens
internal/stream: limit XML token types

Make sure that we don't allow comments, proc insts, or directives
anywhere in the stream.

Signed-off-by: Sam Whited <sam@samwhited.com>
4 files changed, 83 insertions(+), 28 deletions(-)

M internal/stream/reader.go
M internal/stream/reader_test.go
M sasl_test.go
M session_test.go
M internal/stream/reader.go => internal/stream/reader.go +7 -1
@@ 7,6 7,7 @@ package stream
import (
	"encoding/xml"
	"errors"
	"fmt"
	"io"

	"mellium.im/xmpp/stream"


@@ 63,8 64,13 @@ func (r reader) Token() (xml.Token, error) {
		// If this is a stream level end element but not </stream:stream>,
		// something is really weird…
		return nil, stream.BadFormat
	case xml.CharData:
		// Pass chardata through. We ensure that any chardata at the top level of
		// the stream is only whitespace elsewhere.
		return tok, err
	}
	return tok, err
	// Other XML tokens are forbidden.
	return tok, fmt.Errorf("invalid token type: %T", tok)
}

// Reader returns a token reader that handles stream level tokens on an already

M internal/stream/reader_test.go => internal/stream/reader_test.go +41 -19
@@ 5,7 5,6 @@
package stream_test

import (
	"bytes"
	"encoding/xml"
	"errors"
	"reflect"


@@ 20,10 19,11 @@ import (
)

var readerTestCases = [...]struct {
	in      string
	skip    int
	err     error
	errType error
	in        string
	skip      int
	err       error
	errType   error
	errStrCmp bool
}{
	0: {},
	1: {


@@ 83,6 83,32 @@ var readerTestCases = [...]struct {
		in:   `<stream:stream xmlns:stream='http://etherx.jabber.org/streams'></stream:stream>`,
		skip: 1,
	},
	12: {
		in:        `<message><!-- Test --></message></stream:stream>`,
		err:       errors.New("invalid token type: xml.Comment"),
		errStrCmp: true,
	},
	13: {
		in:        `<iq><!dir></iq></stream:stream>`,
		err:       errors.New("invalid token type: xml.Directive"),
		errStrCmp: true,
	},
	14: {
		in:        `<iq><!dir></iq>`,
		err:       errors.New("invalid token type: xml.Directive"),
		errStrCmp: true,
	},
	15: {
		in:        `<iq><?xml?></iq>`,
		err:       errors.New("invalid token type: xml.ProcInst"),
		errStrCmp: true,
	},
	16: {
		// Chardata is checked elsewhere because it is valid as long as it's not at
		// the top level, or is at the top level but is only whitespace and this
		// reader doesn't track the nesting level.
		in: `aaa<iq></iq>`,
	},
}

func TestReader(t *testing.T) {


@@ 106,6 132,16 @@ func TestReader(t *testing.T) {
			}
			_, err := xmlstream.Copy(e, stream.Reader(d))
			switch {
			case tc.errStrCmp:
				if err == nil {
					err = errors.New("nil")
				}
				if tc.err == nil {
					tc.err = errors.New("nil")
				}
				if tc.err.Error() != err.Error() {
					t.Errorf("unexpected error string: want=%q, got=%q", tc.err, err)
				}
			case tc.errType != nil:
				if reflect.TypeOf(tc.errType) != reflect.TypeOf(err) {
					t.Errorf("unexpected error type: want=%T, got=%T", tc.err, err)


@@ 138,17 174,3 @@ func TestBadFormat(t *testing.T) {
	}
	t.Logf("output: %q", out.String())
}

func TestDisallowedTokenType(t *testing.T) {
	comment := xml.Comment("foo")
	toks := &xmpptest.Tokens{comment}
	r := stream.Reader(toks)
	tok, err := r.Token()
	if err != nil {
		t.Errorf("unexpected error: %v", err)
	}

	if c, ok := tok.(xml.Comment); !ok || !bytes.Equal(c, comment) {
		t.Errorf("expected unknown token type to be passed through: want=%#v, got=%#v", comment, tok)
	}
}

M sasl_test.go => sasl_test.go +11 -8
@@ 8,6 8,7 @@ import (
	"bytes"
	"context"
	"encoding/xml"
	"errors"
	"strings"
	"testing"



@@ 147,10 148,11 @@ var saslTestCases = [...]xmpptest.FeatureTestCase{
		Err:     xmpp.ErrUnexpectedPayload,
	},
	1: {
		Feature: xmpp.SASL("", "", sasl.Plain),
		In:      `<!-- not a start element -->`,
		Out:     `<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3QA</auth>`,
		Err:     xmpp.ErrUnexpectedPayload,
		Feature:   xmpp.SASL("", "", sasl.Plain),
		In:        `<!-- not a start element -->`,
		Out:       `<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3QA</auth>`,
		Err:       errors.New("invalid token type: xml.Comment"),
		ErrStrCmp: true,
	},
	2: {
		Feature: xmpp.SASL("", "", sasl.Plain),


@@ 175,10 177,11 @@ var saslTestCases = [...]xmpptest.FeatureTestCase{
		Err:     xmpp.ErrUnexpectedPayload,
	},
	5: {
		State:   xmpp.Received,
		Feature: xmpp.SASLServer(panicPerms, sasl.Plain),
		In:      `<!-- not a start element -->`,
		Err:     xmpp.ErrUnexpectedPayload,
		State:     xmpp.Received,
		Feature:   xmpp.SASLServer(panicPerms, sasl.Plain),
		In:        `<!-- not a start element -->`,
		Err:       errors.New("invalid token type: xml.Comment"),
		ErrStrCmp: true,
	},
	6: {
		// TODO: can the client send failure?

M session_test.go => session_test.go +24 -0
@@ 359,6 359,30 @@ var serveTests = [...]struct {
		out:   `<iq xmlns="jabber:server" type="result" from="from@example.net" id="1234"></iq></stream:stream>`,
		state: xmpp.S2S,
	},
	17: {
		in:           `<?xml version="1.0" encoding="UTF-8"?>`,
		out:          `</stream:stream>`,
		err:          errors.New("invalid token type: xml.ProcInst"),
		errStringCmp: true,
	},
	18: {
		in:           `<!-- Test -->`,
		out:          `</stream:stream>`,
		err:          errors.New("invalid token type: xml.Comment"),
		errStringCmp: true,
	},
	19: {
		in:           `<!dir>`,
		out:          `</stream:stream>`,
		err:          errors.New("invalid token type: xml.Directive"),
		errStringCmp: true,
	},
	20: {
		in:           `<iq xmlns="jabber:client"><!-- Test --></iq>`,
		out:          `</stream:stream>`,
		err:          errors.New("invalid token type: xml.Comment"),
		errStringCmp: true,
	},
}

func TestServe(t *testing.T) {