~samwhited/xmpp

9f862df11007620fe671db6cac0a7b2be068e31d — Sam Whited 5 years ago 4b4f51d
Change stream features list API to use xml encoder

Fixes #11
7 files changed, 56 insertions(+), 21 deletions(-)

M bind.go
M bind_test.go
M features.go
M sasl.go
M sasl_test.go
M starttls.go
M starttls_test.go
M bind.go => bind.go +11 -3
@@ 27,9 27,17 @@ func BindResource() StreamFeature {
		Name:       xml.Name{Space: ns.Bind, Local: "bind"},
		Necessary:  Authn,
		Prohibited: Ready,
		List: func(ctx context.Context, w io.Writer) (bool, error) {
			_, err := fmt.Fprintf(w, `<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>`)
			return true, err
		List: func(ctx context.Context, e *xml.Encoder, start xml.StartElement) (req bool, err error) {
			req = true
			if err = e.EncodeToken(start); err != nil {
				return req, err
			}
			if err = e.EncodeToken(start.End()); err != nil {
				return req, err
			}

			err = e.Flush()
			return req, err
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {

M bind_test.go => bind_test.go +6 -2
@@ 10,19 10,23 @@ import (
	"encoding/xml"
	"strings"
	"testing"

	"mellium.im/xmpp/ns"
)

func TestBindList(t *testing.T) {
	buf := &bytes.Buffer{}
	bind := BindResource()
	req, err := bind.List(context.Background(), buf)
	e := xml.NewEncoder(buf)
	start := xml.StartElement{Name: xml.Name{Space: ns.Bind, Local: "bind"}}
	req, err := bind.List(context.Background(), e, start)
	if err != nil {
		t.Fatal(err)
	}
	if !req {
		t.Error("Bind must always be required")
	}
	if buf.String() != `<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/>` {
	if buf.String() != `<bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"></bind>` {
		t.Errorf("Got unexpected value for bind listing: `%s`", buf.String())
	}
}

M features.go => features.go +9 -3
@@ 35,8 35,12 @@ type StreamFeature struct {
	// set this to "Authn".
	Prohibited SessionState

	// Used to send the feature in a features list for server connections.
	List func(ctx context.Context, conn io.Writer) (req bool, err error)
	// Used to send the feature in a features list for server connections. The
	// start element will have a name that matches the features name and should be
	// used as the outermost tag in the stream (but also may be ignored). List
	// implementations that call e.EncodeToken directly need to call e.Flush when
	// finished to ensure that the XML is written to the underlying writer.
	List func(ctx context.Context, e *xml.Encoder, start xml.StartElement) (req bool, err error)

	// Used to parse the feature that begins with the given xml start element
	// (which should have a Name that matches this stream feature's Name).


@@ 74,7 78,9 @@ func writeStreamFeatures(ctx context.Context, conn *Conn) (n int, req int, err e
		// are set.
		if (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
			var r bool
			r, err = feature.List(ctx, conn)
			r, err = feature.List(ctx, conn.out.e, xml.StartElement{
				Name: feature.Name,
			})
			if err != nil {
				return
			}

M sasl.go => sasl.go +10 -7
@@ 33,13 33,13 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
		Name:       xml.Name{Space: ns.SASL, Local: "mechanisms"},
		Necessary:  Secure,
		Prohibited: Authn,
		List: func(ctx context.Context, conn io.Writer) (req bool, err error) {
		List: func(ctx context.Context, e *xml.Encoder, start xml.StartElement) (req bool, err error) {
			req = true
			_, err = fmt.Fprint(conn, `<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>`)
			if err != nil {
			if err = e.EncodeToken(start); err != nil {
				return
			}

			startMechanism := xml.StartElement{Name: xml.Name{Space: "", Local: "mechanism"}}
			for _, m := range mechanisms {
				select {
				case <-ctx.Done():


@@ 47,17 47,20 @@ func SASL(mechanisms ...sasl.Mechanism) StreamFeature {
				default:
				}

				if _, err = fmt.Fprint(conn, `<mechanism>`); err != nil {
				if err = e.EncodeToken(startMechanism); err != nil {
					return
				}
				if err = xml.EscapeText(conn, []byte(m.Name)); err != nil {
				if err = e.EncodeToken(xml.CharData(m.Name)); err != nil {
					return
				}
				if _, err = fmt.Fprint(conn, `</mechanism>`); err != nil {
				if err = e.EncodeToken(startMechanism.End()); err != nil {
					return
				}
			}
			_, err = fmt.Fprint(conn, `</mechanisms>`)
			if err = e.EncodeToken(start.End()); err != nil {
				return
			}
			err = e.Flush()
			return
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {

M sasl_test.go => sasl_test.go +3 -1
@@ 26,8 26,10 @@ func TestSASLPanicsNoMechanisms(t *testing.T) {

func TestSASLList(t *testing.T) {
	var b bytes.Buffer
	e := xml.NewEncoder(&b)
	start := xml.StartElement{Name: xml.Name{Space: ns.SASL, Local: "mechanisms"}}
	s := SASL(sasl.Plain, sasl.ScramSha256)
	req, err := s.List(context.Background(), &b)
	req, err := s.List(context.Background(), e, start)
	switch {
	case err != nil:
		t.Fatal(err)

M starttls.go => starttls.go +14 -4
@@ 31,13 31,23 @@ func StartTLS(required bool) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Local: "starttls", Space: ns.StartTLS},
		Prohibited: Secure,
		List: func(ctx context.Context, conn io.Writer) (req bool, err error) {
		List: func(ctx context.Context, e *xml.Encoder, start xml.StartElement) (req bool, err error) {
			if err = e.EncodeToken(start); err != nil {
				return required, err
			}
			if required {
				_, err = fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'><required/></starttls>`)
				startRequired := xml.StartElement{Name: xml.Name{Space: "", Local: "required"}}
				if err = e.EncodeToken(startRequired); err != nil {
					return required, err
				}
				if err = e.EncodeToken(startRequired.End()); err != nil {
					return required, err
				}
			}
			if err = e.EncodeToken(start.End()); err != nil {
				return required, err
			}
			_, err = fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
			return
			return required, e.Flush()
		},
		Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) {
			parsed := struct {

M starttls_test.go => starttls_test.go +3 -1
@@ 24,7 24,9 @@ func TestStartTLSList(t *testing.T) {
	for _, req := range []bool{true, false} {
		stls := StartTLS(req)
		var b bytes.Buffer
		r, err := stls.List(context.Background(), &b)
		e := xml.NewEncoder(&b)
		start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}
		r, err := stls.List(context.Background(), e, start)
		switch {
		case err != nil:
			t.Fatal(err)