~samwhited/xmpp

37830e60d0b2a2b046e94a905ddd3324eea13004 — Sam Whited 3 years ago 917f3ba textutil
jid: use textutil for escaping/unescaping

benchmark                        old ns/op     new ns/op     delta
BenchmarkEscapeTransform-4       64.7          172           +165.84%
BenchmarkUnescapeTransform-4     183           312           +70.49%
BenchmarkEscapeBytes-4           556           361           -35.07%
BenchmarkUnescapeBytes-4         257           388           +50.97%

benchmark                        old allocs     new allocs     delta
BenchmarkEscapeTransform-4       0              0              +0.00%
BenchmarkUnescapeTransform-4     0              0              +0.00%
BenchmarkEscapeBytes-4           3              3              +0.00%
BenchmarkUnescapeBytes-4         2              2              +0.00%

benchmark                        old bytes     new bytes     delta
BenchmarkEscapeTransform-4       0             0             +0.00%
BenchmarkUnescapeTransform-4     0             0             +0.00%
BenchmarkEscapeBytes-4           96            96            +0.00%
BenchmarkUnescapeBytes-4         48            48            +0.00%
3 files changed, 147 insertions(+), 214 deletions(-)

M jid/benchmark_test.go
M jid/escape.go
M jid/escape_test.go
M jid/benchmark_test.go => jid/benchmark_test.go +19 -3
@@ 71,8 71,8 @@ func BenchmarkString(b *testing.B) {
	}
}

func BenchmarkEscape(b *testing.B) {
	src := []byte(escape)
func BenchmarkEscapeTransform(b *testing.B) {
	src := []byte(EscapedChars)
	dst := make([]byte, len(src)+18)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {


@@ 80,7 80,7 @@ func BenchmarkEscape(b *testing.B) {
	}
}

func BenchmarkUnescape(b *testing.B) {
func BenchmarkUnescapeTransform(b *testing.B) {
	src := []byte(allescaped)
	dst := make([]byte, len(src)/3)
	b.ResetTimer()


@@ 88,3 88,19 @@ func BenchmarkUnescape(b *testing.B) {
		Unescape.Transform(dst, src, true)
	}
}

func BenchmarkEscapeBytes(b *testing.B) {
	src := []byte(EscapedChars)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_ = Escape.Bytes(src)
	}
}

func BenchmarkUnescapeBytes(b *testing.B) {
	src := []byte(allescaped)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_ = Unescape.Bytes(src)
	}
}

M jid/escape.go => jid/escape.go +124 -205
@@ 5,239 5,158 @@
package jid

import (
	"bytes"
	"unicode/utf8"

	"golang.org/x/text/transform"
	"github.com/mpvl/textutil"
)

// Transformer implements the transform.Transformer and
// transform.SpanningTransformer interfaces.
type Transformer struct {
	t transform.SpanningTransformer
}

// Reset implements the transform.Transformer interface.
func (t Transformer) Reset() { t.t.Reset() }

// Transform implements the transform.Transformer interface.
func (t Transformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	return t.t.Transform(dst, src, atEOF)
}

// Span implements the transform.SpanningTransformer interface.
func (t Transformer) Span(src []byte, atEOF bool) (n int, err error) {
	return t.t.Span(src, atEOF)
}

// Bytes returns a new byte slice with the result of applying t to b.
func (t Transformer) Bytes(b []byte) []byte {
	b, _, _ = transform.Bytes(t, b)
	return b
}

// String returns a string with the result of applying t to s.
func (t Transformer) String(s string) string {
	s, _, _ = transform.String(t, s)
	return s
}

var (
	// Escape is a transform that maps escapable runes to their escaped form as
	// defined in XEP-0106: JID Escaping.
	Escape Transformer = Transformer{escapeMapping{}}
	Escape = textutil.NewTransformerFromFunc(escape)

	// Unescape is a transform that maps valid escape sequences to their unescaped
	// form as defined in XEP-0106: JID Escaping.
	Unescape Transformer = Transformer{unescapeMapping{}}
	Unescape = textutil.NewTransformerFromFunc(unescape)
)

const escape = ` "&'/:<>@\`

type escapeMapping struct {
	transform.NopResetter
}

func (escapeMapping) Span(src []byte, atEOF bool) (n int, err error) {
	switch idx := bytes.IndexAny(src, escape); idx {
	case -1:
		return len(src), nil
// EscapedChars is a string composed of all the characters that will be escaped
// or unescaped by the transformers in this package (in no particular order).
const EscapedChars = ` "&'/:<>@\`

func escape(s textutil.State) {
	switch r, _ := s.ReadRune(); r {
	case ' ':
		s.WriteString(`\20`)
	case '"':
		s.WriteString(`\22`)
	case '&':
		s.WriteString(`\26`)
	case '\'':
		s.WriteString(`\27`)
	case '/':
		s.WriteString(`\2f`)
	case ':':
		s.WriteString(`\3a`)
	case '<':
		s.WriteString(`\3c`)
	case '>':
		s.WriteString(`\3e`)
	case '@':
		s.WriteString(`\40`)
	case '\\':
		s.WriteString(`\5c`)
	default:
		return idx, transform.ErrEndOfSpan
		s.WriteRune(r)
	}
}

func (t escapeMapping) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	for nSrc < len(src) {
		idx := bytes.IndexAny(src[nSrc:], escape)
		switch idx {
		case -1:
			n := copy(dst[nDst:], src[nSrc:])
			nDst += n
			nSrc += n
			if nSrc < len(src) {
				return nDst, nSrc, transform.ErrShortDst
			}
		default:
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx-nSrc {
				return nDst, nSrc, transform.ErrShortDst
			}
			c := src[nSrc]
			n = copy(dst[nDst:], []byte{
				'\\',
				"0123456789abcdef"[c>>4],
				"0123456789abcdef"[c&15],
			})
			nDst += n
			nSrc++
			if n != 3 {
				return nDst, nSrc, transform.ErrShortDst
			}
		}
	}
	return
}

type unescapeMapping struct {
	transform.NopResetter
}

// TODO: Be more specific. Only check if it's the starting character in any
//       valid escape sequence.

func ishex(c byte) bool {
	switch {
	case '0' <= c && c <= '9':
		return true
	case 'a' <= c && c <= 'f':
		return true
	case 'A' <= c && c <= 'F':
		return true
// fmt.Printf("% x", EscapedChars):
// 20 22 26 27 2f 3a 3c 3e 40 5c
func unescape(s textutil.State) {
	if r, _ := s.ReadRune(); r != '\\' {
		s.WriteRune(r)
		return
	}
	return false
}

// I just wrote these all out because it's a lot faster and not likely to
// change; is it really worth the confusing logic though?
func shouldUnescape(s []byte) bool {
	return (s[0] == '2' && (s[1] == '0' || s[1] == '2' || s[1] == '6' || s[1] == '7' || s[1] == 'f' || s[1] == 'F')) || (s[0] == '3' && (s[1] == 'a' || s[1] == 'A' || s[1] == 'c' || s[1] == 'C' || s[1] == 'e' || s[1] == 'E')) || (s[0] == '4' && s[1] == '0') || (s[0] == '5' && (s[1] == 'c' || s[1] == 'C'))
}

func unhex(c byte) byte {
	switch {
	case '0' <= c && c <= '9':
		return c - '0'
	case 'a' <= c && c <= 'f':
		return c - 'a' + 10
	case 'A' <= c && c <= 'F':
		return c - 'A' + 10
	}
	return 0
}
	// TODO: There's probably a better way to do this than generate a giant
	// switch/case tree.

func (unescapeMapping) Span(src []byte, atEOF bool) (n int, err error) {
	for n < len(src) {
		if src[n] != '\\' {
			n++
			continue
	r, n := s.ReadRune()
	switch r {
	case utf8.RuneError:
		if n == 0 {
			s.WriteRune('\\')
			return
		}

		switch n {
		case len(src) - 1:
			// The last character is the escape char.
			if atEOF {
				return len(src), nil
			}
			return n, transform.ErrShortSrc
		case len(src) - 2:
			if atEOF || !ishex(src[n+1]) {
				return len(src), nil
		s.WriteRune(r)
		return
	case '2':
		switch r2, n := s.ReadRune(); r2 {
		case utf8.RuneError:
			if n == 0 {
				s.WriteRune('\\')
				s.WriteRune(r)
				return
			}
			return n, transform.ErrShortSrc
		}

		if shouldUnescape(src[n+1 : n+3]) {
			// unhex(s[n+1])<<4 | unhex(s[n+2])
			return n, transform.ErrEndOfSpan
			s.WriteRune(r)
			s.WriteRune(r2)
			return
		case '0':
			s.WriteRune(' ')
		case '2':
			s.WriteRune('"')
		case '6':
			s.WriteRune('&')
		case '7':
			s.WriteRune('\'')
		case 'f':
			s.WriteRune('/')
		default:
			s.WriteRune('\\')
			s.WriteRune(r)
			s.WriteRune(r2)
		}
		n++
	}
	return
}

func (t unescapeMapping) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	const all = `\20\22\26\27\2f\3a\3c\3e\40\5c`
	for nSrc < len(src) {
		idx := bytes.IndexRune(src[nSrc:], '\\')

		switch {
		case idx == -1 || (idx == len(src[nSrc:])-1 && atEOF):
			// No unescape sequence exists, or the escape sequence is at the end but
			// there aren't enough following characters to make it valid, so copy to
			// the end.
			n := copy(dst[nDst:], src[nSrc:])
			nDst += n
			nSrc += n
			if nSrc < len(src) {
				return nDst, nSrc, transform.ErrShortDst
	case '3':
		switch r2, _ := s.ReadRune(); r2 {
		case utf8.RuneError:
			if n == 0 {
				s.WriteRune('\\')
				s.WriteRune(r)
				return
			}
			s.WriteRune(r)
			s.WriteRune(r2)
			return
		case idx == len(src[nSrc:])-1:
			// The last character is the escape char and this isn't the EOF
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			return nDst, nSrc, transform.ErrShortSrc
		case idx == len(src[nSrc:])-2:
			if atEOF || !ishex(src[nSrc+idx+1]) {
				n := copy(dst[nDst:], src[nSrc:])
				nDst += n
				nSrc += n
				if nSrc < len(src) {
					return nDst, nSrc, transform.ErrShortDst
				}
		case 'a':
			s.WriteRune(':')
		case 'c':
			s.WriteRune('<')
		case 'e':
			s.WriteRune('>')
		default:
			s.WriteRune('\\')
			s.WriteRune(r)
			s.WriteRune(r2)
		}
	case '4':
		r2, n := s.ReadRune()
		switch r2 {
		case utf8.RuneError:
			if n == 0 {
				s.WriteRune('\\')
				s.WriteRune(r)
				return
			}
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			return nDst, nSrc, transform.ErrShortSrc
			s.WriteRune(r)
			s.WriteRune(r2)
		case '0':
			s.WriteRune('@')
		default:
			s.WriteRune('\\')
			s.WriteRune(r)
			s.WriteRune(r2)
		}

		if shouldUnescape(src[nSrc+idx+1 : nSrc+idx+3]) {
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
	case '5':
		r2, _ := s.ReadRune()
		switch r2 {
		case utf8.RuneError:
			if n == 0 {
				n++
			}
			n = copy(dst[nDst:], []byte{
				unhex(src[nSrc+n])<<4 | unhex(src[nSrc+n+1]),
			})
			nDst += n
			nSrc += 3
			if n != 1 {
				return nDst, nSrc, transform.ErrShortDst
				s.WriteRune('\\')
				s.WriteRune(r)
				return
			}
			continue
		}
		n := copy(dst[nDst:], src[nSrc:nSrc+idx+1])
		nDst += n
		nSrc += n
		if n != idx+1 {
			return nDst, nSrc, transform.ErrShortDst
			s.WriteRune(r)
			s.WriteRune(r2)
		case 'c':
			s.WriteRune('\\')
		default:
			s.WriteRune('\\')
			s.WriteRune(r)
			s.WriteRune(r2)
		}
	default:
		s.WriteRune('\\')
		s.WriteRune(r)
	}
	return
}

M jid/escape_test.go => jid/escape_test.go +4 -6
@@ 11,8 11,6 @@ import (
	"golang.org/x/text/transform"
)

var _ transform.SpanningTransformer = (*escapeMapping)(nil)

const allescaped = `\20\22\26\27\2f\3a\3c\3e\40\5c`

var escapeTestCases = [...]struct {


@@ 21,8 19,8 @@ var escapeTestCases = [...]struct {
	span               int
	err, spanErr       error
}{
	0: {escape, allescaped, true, 0, nil, transform.ErrEndOfSpan},
	1: {escape, allescaped, false, 0, nil, transform.ErrEndOfSpan},
	0: {EscapedChars, allescaped, true, 0, nil, transform.ErrEndOfSpan},
	1: {EscapedChars, allescaped, false, 0, nil, transform.ErrEndOfSpan},
	2: {`nothingtodohere`, `nothingtodohere`, true, 15, nil, nil},
	3: {`nothingtodohere`, `nothingtodohere`, false, 15, nil, nil},
	4: {"", "", true, 0, nil, nil},


@@ 36,7 34,7 @@ var unescapeTestCases = [...]struct {
	span               int
	err, spanErr       error
}{
	0: {allescaped, escape, true, 0, nil, transform.ErrEndOfSpan},
	0: {allescaped, EscapedChars, true, 0, nil, transform.ErrEndOfSpan},
	1: {`a\20`, `a `, true, 1, nil, transform.ErrEndOfSpan},
	2: {`a\`, `a\`, true, 2, nil, nil},
	3: {`a\`, `a`, false, 1, transform.ErrShortSrc, transform.ErrShortSrc},


@@ 94,7 92,7 @@ func TestEscape(t *testing.T) {
//       analysis.

func TestEscapeMallocs(t *testing.T) {
	src := []byte(escape)
	src := []byte(EscapedChars)
	dst := make([]byte, len(src)+18)

	if n := testing.AllocsPerRun(1000, func() { Escape.Transform(dst, src, true) }); n > 0 {