~samwhited/xmpp

29342212a2a277c13054669f7a14f9955f733de1 — Sam Whited 4 years ago 521af43
XEP-0106: Create JID Escaping text transformer

Fixes #24
5 files changed, 363 insertions(+), 127 deletions(-)

M jid/benchmark_test.go
A jid/escape.go
A jid/escape_test.go
M jid/jid.go
M jid/jid_test.go
M jid/benchmark_test.go => jid/benchmark_test.go +8 -2
@@ 60,13 60,19 @@ func BenchmarkString(b *testing.B) {
}

func BenchmarkEscape(b *testing.B) {
	src := []byte(escape)
	dst := make([]byte, len(src)+18)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_ = Escape(escape)
		Escape.Transform(dst, src, true)
	}
}

func BenchmarkUnescape(b *testing.B) {
	src := []byte(allescaped)
	dst := make([]byte, len(src)/3)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_ = Unescape(allescaped)
		Unescape.Transform(dst, src, true)
	}
}

A jid/escape.go => jid/escape.go +243 -0
@@ 0,0 1,243 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package jid

import (
	"bytes"

	"golang.org/x/text/transform"
)

// Transformer implements the transform.Transformer and
// transform.SpanningTransformer interfaces.
type Transformer struct {
	t transform.SpanningTransformer
}

// Reset implements the transform.Transformer interface.
func (t Transformer) Reset() { t.t.Reset() }

// Transform implements the transform.Transformer interface.
func (t Transformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	return t.t.Transform(dst, src, atEOF)
}

// Span implements the transform.SpanningTransformer interface.
func (t Transformer) Span(src []byte, atEOF bool) (n int, err error) {
	return t.t.Span(src, atEOF)
}

// Bytes returns a new byte slice with the result of applying t to b.
func (t Transformer) Bytes(b []byte) []byte {
	b, _, _ = transform.Bytes(t, b)
	return b
}

// String returns a string with the result of applying t to s.
func (t Transformer) String(s string) string {
	s, _, _ = transform.String(t, s)
	return s
}

var (
	// Escape is a transform that maps escapable runes to their escaped form as
	// defined in XEP-0106: JID Escaping.
	Escape Transformer = Transformer{escapeMapping{}}

	// Unescape is a transform that maps escapable runes to their escaped form as
	// defined in XEP-0106: JID Escaping.
	Unescape Transformer = Transformer{unescapeMapping{}}
)

const escape = ` "&'/:<>@\`

type escapeMapping struct {
	transform.NopResetter
}

func (escapeMapping) Span(src []byte, atEOF bool) (n int, err error) {
	switch idx := bytes.IndexAny(src, escape); idx {
	case -1:
		return len(src), nil
	default:
		return idx, transform.ErrEndOfSpan
	}
}

func (t escapeMapping) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	for nSrc < len(src) {
		idx := bytes.IndexAny(src[nSrc:], escape)
		switch idx {
		case -1:
			n := copy(dst[nDst:], src[nSrc:])
			nDst += n
			nSrc += n
			if nSrc < len(src) {
				return nDst, nSrc, transform.ErrShortDst
			}
		default:
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx-nSrc {
				return nDst, nSrc, transform.ErrShortDst
			}
			c := src[nSrc]
			n = copy(dst[nDst:], []byte{
				'\\',
				"0123456789abcdef"[c>>4],
				"0123456789abcdef"[c&15],
			})
			nDst += n
			nSrc += 1
			if n != 3 {
				return nDst, nSrc, transform.ErrShortDst
			}
		}
	}
	return
}

type unescapeMapping struct {
	transform.NopResetter
}

// TODO: Be more specific. Only check if it's the starting character in any
//       valid escape sequence.

func ishex(c byte) bool {
	switch {
	case '0' <= c && c <= '9':
		return true
	case 'a' <= c && c <= 'f':
		return true
	case 'A' <= c && c <= 'F':
		return true
	}
	return false
}

// I just wrote these all out because it's a lot faster and not likely to
// change; is it really worth the confusing logic though?
func shouldUnescape(s []byte) bool {
	return (s[0] == '2' && (s[1] == '0' || s[1] == '2' || s[1] == '6' || s[1] == '7' || s[1] == 'f' || s[1] == 'F')) || (s[0] == '3' && (s[1] == 'a' || s[1] == 'A' || s[1] == 'c' || s[1] == 'C' || s[1] == 'e' || s[1] == 'E')) || (s[0] == '4' && s[1] == '0') || (s[0] == '5' && (s[1] == 'c' || s[1] == 'C'))
}

func unhex(c byte) byte {
	switch {
	case '0' <= c && c <= '9':
		return c - '0'
	case 'a' <= c && c <= 'f':
		return c - 'a' + 10
	case 'A' <= c && c <= 'F':
		return c - 'A' + 10
	}
	return 0
}

func (unescapeMapping) Span(src []byte, atEOF bool) (n int, err error) {
	for n < len(src) {
		if src[n] != '\\' {
			n++
			continue
		}

		switch n {
		case len(src) - 1:
			// The last character is the escape char.
			if atEOF {
				return len(src), nil
			}
			return n, transform.ErrShortSrc
		case len(src) - 2:
			if atEOF || !ishex(src[n+1]) {
				return len(src), nil
			}
			return n, transform.ErrShortSrc
		}

		if shouldUnescape(src[n+1 : n+3]) {
			// unhex(s[n+1])<<4 | unhex(s[n+2])
			return n, transform.ErrEndOfSpan
		}
		n++
	}
	return
}

func (t unescapeMapping) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	const all = `\20\22\26\27\2f\3a\3c\3e\40\5c`
	for nSrc < len(src) {
		idx := bytes.IndexRune(src[nSrc:], '\\')

		switch {
		case idx == -1 || (idx == len(src[nSrc:])-1 && atEOF):
			// No unescape sequence exists, or the escape sequence is at the end but
			// there aren't enough following characters to make it valid, so copy to
			// the end.
			n := copy(dst[nDst:], src[nSrc:])
			nDst += n
			nSrc += n
			if nSrc < len(src) {
				return nDst, nSrc, transform.ErrShortDst
			}
			return
		case idx == len(src[nSrc:])-1:
			// The last character is the escape char and this isn't the EOF
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			return nDst, nSrc, transform.ErrShortSrc
		case idx == len(src[nSrc:])-2:
			if atEOF || !ishex(src[nSrc+idx+1]) {
				n := copy(dst[nDst:], src[nSrc:])
				nDst += n
				nSrc += n
				if nSrc < len(src) {
					return nDst, nSrc, transform.ErrShortDst
				}
				return
			}
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			return nDst, nSrc, transform.ErrShortSrc
		}

		if shouldUnescape(src[nSrc+idx+1 : nSrc+idx+3]) {
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			if n == 0 {
				n++
			}
			n = copy(dst[nDst:], []byte{
				unhex(src[nSrc+n])<<4 | unhex(src[nSrc+n+1]),
			})
			nDst += n
			nSrc += 3
			if n != 1 {
				return nDst, nSrc, transform.ErrShortDst
			}
			continue
		}
		n := copy(dst[nDst:], src[nSrc:nSrc+idx+1])
		nDst += n
		nSrc += n
		if n != idx+1 {
			return nDst, nSrc, transform.ErrShortDst
		}
	}
	return
}

A jid/escape_test.go => jid/escape_test.go +112 -0
@@ 0,0 1,112 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package jid

import (
	"fmt"
	"testing"

	"golang.org/x/text/transform"
)

var _ transform.SpanningTransformer = (*escapeMapping)(nil)

const allescaped = `\20\22\26\27\2f\3a\3c\3e\40\5c`

var escapeTestCases = [...]struct {
	unescaped, escaped string
	atEOF              bool
	span               int
	err, spanErr       error
}{
	0: {escape, allescaped, true, 0, nil, transform.ErrEndOfSpan},
	1: {escape, allescaped, false, 0, nil, transform.ErrEndOfSpan},
	2: {`nothingtodohere`, `nothingtodohere`, true, 15, nil, nil},
	3: {`nothingtodohere`, `nothingtodohere`, false, 15, nil, nil},
	4: {"", "", true, 0, nil, nil},
	5: {"", "", false, 0, nil, nil},
	6: {`a `, `a\20`, true, 1, nil, transform.ErrEndOfSpan},
}

var unescapeTestCases = [...]struct {
	escaped, unescaped string
	atEOF              bool
	span               int
	err, spanErr       error
}{
	0: {allescaped, escape, true, 0, nil, transform.ErrEndOfSpan},
	1: {`a\20`, `a `, true, 1, nil, transform.ErrEndOfSpan},
	2: {`a\`, `a\`, true, 2, nil, nil},
	3: {`a\`, `a`, false, 1, transform.ErrShortSrc, transform.ErrShortSrc},
	4: {`nothingtodohere`, `nothingtodohere`, true, 15, nil, nil},
	5: {`nothingtodohere`, `nothingtodohere`, false, 15, nil, nil},
	6: {`a\a\20`, `a\a `, false, 3, nil, transform.ErrEndOfSpan},
	7: {`aa\2`, `aa\2`, true, 4, nil, nil},
	8: {`aa\2`, `aa`, false, 2, transform.ErrShortSrc, transform.ErrShortSrc},
}

func TestUnescape(t *testing.T) {
	for i, tc := range unescapeTestCases {
		t.Run(fmt.Sprintf("Transform/%d", i), func(t *testing.T) {
			buf := make([]byte, 100)
			switch nDst, _, err := Unescape.Transform(buf, []byte(tc.escaped), tc.atEOF); {
			case err != tc.err:
				t.Errorf("Unexpected error, got=%v, want=%v", err, tc.err)
			case string(buf[:nDst]) != tc.unescaped:
				t.Errorf("Unescaped localpart should be `%s` but got: `%s`", tc.unescaped, string(buf[:nDst]))
			}
		})
		t.Run(fmt.Sprintf("Span/%d", i), func(t *testing.T) {
			switch n, err := Unescape.Span([]byte(tc.escaped), tc.atEOF); {
			case err != tc.spanErr:
				t.Errorf("Unexpected error, got=%v, want=%v", err, tc.spanErr)
			case n != tc.span:
				t.Errorf("Unexpected span, got=%d, want=%d", n, tc.span)
			}
		})
	}
}

func TestEscape(t *testing.T) {
	for i, tc := range escapeTestCases {
		t.Run(fmt.Sprintf("Transform/%d", i), func(t *testing.T) {
			switch e, _, err := transform.String(Escape, tc.unescaped); {
			case err != tc.err:
				t.Errorf("Unexpected error, got=%v, want=%v", err, tc.err)
			case e != tc.escaped:
				t.Errorf("Escaped localpart should be `%s` but got: `%s`", tc.escaped, e)
			}
		})
		t.Run(fmt.Sprintf("Span/%d", i), func(t *testing.T) {
			switch n, err := Escape.Span([]byte(tc.unescaped), tc.atEOF); {
			case err != tc.spanErr:
				t.Errorf("Unexpected error, got=%v, want=%v", err, tc.spanErr)
			case n != tc.span:
				t.Errorf("Unexpected span, got=%d, want=%d", n, tc.span)
			}
		})
	}
}

// TODO: Malloc tests may be flakey under GCC until it improves its escape
//       analysis.

func TestEscapeMallocs(t *testing.T) {
	src := []byte(escape)
	dst := make([]byte, len(src)+18)

	if n := testing.AllocsPerRun(1000, func() { Escape.Transform(dst, src, true) }); n > 0 {
		t.Errorf("got %f allocs, want 0", n)
	}
}

func TestUnescapeMallocs(t *testing.T) {
	src := []byte(allescaped)
	dst := make([]byte, len(src)/3)

	if n := testing.AllocsPerRun(1000, func() { Unescape.Transform(dst, src, true) }); n > 0 {
		t.Errorf("got %f allocs, want 0", n)
	}
}

M jid/jid.go => jid/jid.go +0 -94
@@ 16,100 16,6 @@ import (
	"golang.org/x/text/secure/precis"
)

const escape = ` "&'/:<>@\`

func shouldEscape(c byte) bool {
	return c == ' ' || c == '"' || c == '&' || c == '\'' || c == '/' || c == ':' || c == '<' || c == '>' || c == '@' || c == '\\'
}

// I just wrote these all out because it's a lot faster and not likely to
// change; is it really worth the confusing logic though?
func shouldUnescape(s string) bool {
	return (s[0] == '2' && (s[1] == '0' || s[1] == '2' || s[1] == '6' || s[1] == '7' || s[1] == 'f' || s[1] == 'F')) || (s[0] == '3' && (s[1] == 'a' || s[1] == 'A' || s[1] == 'c' || s[1] == 'C' || s[1] == 'e' || s[1] == 'E')) || (s[0] == '4' && s[1] == '0') || (s[0] == '5' && (s[1] == 'c' || s[1] == 'C'))
}

func unhex(c byte) byte {
	switch {
	case '0' <= c && c <= '9':
		return c - '0'
	case 'a' <= c && c <= 'f':
		return c - 'a' + 10
	case 'A' <= c && c <= 'F':
		return c - 'A' + 10
	}
	return 0
}

// BUG(ssw): Unescape does not fail on invalid escape codes.

// Unescape returns an unescaped version of the specified localpart using the
// escaping mechanism defined in XEP-0106: JID Escaping. It only unescapes
// sequences documented in XEP-0106 and does not guarantee that the resulting
// localpart is well formed.
func Unescape(s string) string {
	n := 0
	for i := 0; i < len(s); i++ {
		if len(s) < i+3 {
			break
		}
		if s[i] == '\\' && shouldUnescape(s[i+1:i+3]) {
			n++
			i += 2
		}
	}

	if n == 0 {
		return s
	}

	t := make([]byte, len(s)-2*n)
	j := 0
	for i := 0; i < len(s); i++ {
		if s[i] == '\\' && len(s) > i+2 && shouldUnescape(s[i+1:i+3]) {
			t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
			i += 2
		} else {
			t[j] = s[i]
		}
		j++
	}
	return string(t)
}

// Escape returns an escaped version of the specified localpart using the
// escaping mechanism defined in XEP-0106: JID Escaping. It is not applied
// by any of the JID methods, and must be applied manually before constructing a
// JID.
func Escape(s string) string {
	count := 0
	for i := 0; i < len(s); i++ {
		c := s[i]
		if shouldEscape(c) {
			count++
		}
	}

	if count == 0 {
		return s
	}

	t := make([]byte, len(s)+2*count)
	j := 0
	for i := 0; i < len(s); i++ {
		switch c := s[i]; {
		case shouldEscape(c):
			t[j] = '\\'
			t[j+1] = "0123456789abcdef"[c>>4]
			t[j+2] = "0123456789abcdef"[c&15]
			j += 3
		default:
			t[j] = s[i]
			j++
		}
	}
	return string(t)
}

// JID represents an XMPP address (Jabber ID) comprising a localpart,
// domainpart, and resourcepart. All parts of a JID are guaranteed to be valid
// UTF-8 and will be represented in their canonical form which gives comparison

M jid/jid_test.go => jid/jid_test.go +0 -31
@@ 172,37 172,6 @@ func TestCopy(t *testing.T) {
	}
}

const allescaped = `\20\22\26\27\2f\3a\3c\3e\40\5c`

func TestEscape(t *testing.T) {
	for _, test := range []struct {
		unescaped, escaped string
	}{
		{escape, allescaped},
		{`nothingtodohere`, `nothingtodohere`},
		{"", ""},
	} {
		if e := Escape(test.unescaped); e != test.escaped {
			t.Errorf("Escaped localpart should be `%s` but got: `%s`", test.escaped, e)
		}
	}
}

func TestUnescape(t *testing.T) {
	for _, test := range []struct {
		escaped, unescaped string
	}{
		{allescaped, escape},
		{`\20\3c\3C\aa\\\`, ` <<\aa\\\`},
		{"nothingtodohere", "nothingtodohere"},
		{"", ""},
	} {
		if u := Unescape(test.escaped); u != test.unescaped {
			t.Errorf("Unescaped localpart should be `%s` but got: `%s`", test.unescaped, u)
		}
	}
}

func TestMarshalXML(t *testing.T) {
	// Test default marshaling
	j := MustParse("feste@shakespeare.lit")