~samwhited/xmpp

b864a86168a3544477a680516d837f92360d4022 — Sam Whited 5 years ago 9ca315a
Return encoding errors from sasl errors
2 files changed, 35 insertions(+), 20 deletions(-)

M internal/saslerr/errors.go
M internal/saslerr/errors_test.go
M internal/saslerr/errors.go => internal/saslerr/errors.go +20 -13
@@ 6,10 6,6 @@
// defined by RFC 6120 §6.5.
package saslerr // import "mellium.im/xmpp/internal/saslerr"

// TODO(ssw): I think these errors should really be created via code generation
//            in case more are added in the future and so that we can store them
//            in a more efficient way that doesn't require a giant switch.

import (
	"encoding/xml"



@@ 51,16 47,22 @@ func (f Failure) Error() string {
}

// MarshalXML satisfies the xml.Marshaler interface for a Failure.
func (f Failure) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
func (f Failure) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
	failure := xml.StartElement{
		Name: xml.Name{Space: `urn:ietf:params:xml:ns:xmpp-sasl`, Local: "failure"},
	}
	e.EncodeToken(failure)
	if err = e.EncodeToken(failure); err != nil {
		return
	}
	condition := xml.StartElement{
		Name: xml.Name{Space: "", Local: string(f.Condition)},
	}
	e.EncodeToken(condition)
	e.EncodeToken(condition.End())
	if err = e.EncodeToken(condition); err != nil {
		return
	}
	if err = e.EncodeToken(condition.End()); err != nil {
		return
	}
	if f.Text != "" {
		text := xml.StartElement{
			Name: xml.Name{Space: "", Local: "text"},


@@ 71,12 73,17 @@ func (f Failure) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
				},
			},
		}
		e.EncodeToken(text)
		e.EncodeToken(xml.CharData(f.Text))
		e.EncodeToken(text.End())
		if err = e.EncodeToken(text); err != nil {
			return
		}
		if err = e.EncodeToken(xml.CharData(f.Text)); err != nil {
			return
		}
		if err = e.EncodeToken(text.End()); err != nil {
			return
		}
	}
	e.EncodeToken(failure.End())
	return nil
	return e.EncodeToken(failure.End())
}

// UnmarshalXML satisfies the xml.Unmarshaler interface for a Failure. If

M internal/saslerr/errors_test.go => internal/saslerr/errors_test.go +15 -7
@@ 38,9 38,10 @@ func TestErrorTextOrCondition(t *testing.T) {
}

func TestMarshalCondition(t *testing.T) {
	for _, test := range []struct {
	for i, test := range []struct {
		Failure   Failure
		Marshaled string
		err       bool
	}{
		{
			Failure{


@@ 49,15 50,22 @@ func TestMarshalCondition(t *testing.T) {
				Lang:      language.BrazilianPortuguese,
			},
			`<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"><mechanism-too-weak></mechanism-too-weak><text xml:lang="pt-BR">Test</text></failure>`,
			false,
		},
		{Failure{Condition: IncorrectEncoding}, `<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"><incorrect-encoding></incorrect-encoding></failure>`},
		{Failure{Condition: Aborted, Lang: language.Polish}, `<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"><aborted></aborted></failure>`},
		{Failure{Condition: IncorrectEncoding}, `<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"><incorrect-encoding></incorrect-encoding></failure>`, false},
		{Failure{Condition: Aborted, Lang: language.Polish}, `<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"><aborted></aborted></failure>`, false},
	} {
		b, err := xml.Marshal(test.Failure)
		if err != nil {
			t.Fatal(err)
		}
		if string(b) != test.Marshaled {
		switch {
		case test.err && err == nil:
			t.Errorf("Expected error when marshaling condition %d", i)
			continue
		case !test.err && err != nil:
			t.Error(err)
			continue
		case err != nil:
			continue
		case string(b) != test.Marshaled:
			t.Errorf("Expected %s but got %s", test.Marshaled, b)
		}
	}