~samwhited/xmpp

ref: a2ffaf710170d692a9289c2b428a4efa90c8e703 xmpp/jid/escape.go -rw-r--r-- 5.8 KiB
a2ffaf71Sam Whited all: update deps 3 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.

package jid

import (
	"bytes"

	"golang.org/x/text/transform"
)

// Transformer implements the transform.Transformer and
// transform.SpanningTransformer interfaces.
type Transformer struct {
	t transform.SpanningTransformer
}

// Reset implements the transform.Transformer interface.
func (t Transformer) Reset() { t.t.Reset() }

// Transform implements the transform.Transformer interface.
func (t Transformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	return t.t.Transform(dst, src, atEOF)
}

// Span implements the transform.SpanningTransformer interface.
func (t Transformer) Span(src []byte, atEOF bool) (n int, err error) {
	return t.t.Span(src, atEOF)
}

// Bytes returns a new byte slice with the result of applying t to b.
func (t Transformer) Bytes(b []byte) []byte {
	b, _, _ = transform.Bytes(t, b)
	return b
}

// String returns a string with the result of applying t to s.
func (t Transformer) String(s string) string {
	s, _, _ = transform.String(t, s)
	return s
}

var (
	// Escape is a transform that maps escapable runes to their escaped form as
	// defined in XEP-0106: JID Escaping.
	Escape Transformer = Transformer{escapeMapping{}}

	// Unescape is a transform that maps valid escape sequences to their unescaped
	// form as defined in XEP-0106: JID Escaping.
	Unescape Transformer = Transformer{unescapeMapping{}}
)

const escape = ` "&'/:<>@\`

type escapeMapping struct {
	transform.NopResetter
}

func (escapeMapping) Span(src []byte, atEOF bool) (n int, err error) {
	switch idx := bytes.IndexAny(src, escape); idx {
	case -1:
		return len(src), nil
	default:
		return idx, transform.ErrEndOfSpan
	}
}

func (t escapeMapping) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	for nSrc < len(src) {
		idx := bytes.IndexAny(src[nSrc:], escape)
		switch idx {
		case -1:
			n := copy(dst[nDst:], src[nSrc:])
			nDst += n
			nSrc += n
			if nSrc < len(src) {
				return nDst, nSrc, transform.ErrShortDst
			}
		default:
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx-nSrc {
				return nDst, nSrc, transform.ErrShortDst
			}
			c := src[nSrc]
			n = copy(dst[nDst:], []byte{
				'\\',
				"0123456789abcdef"[c>>4],
				"0123456789abcdef"[c&15],
			})
			nDst += n
			nSrc++
			if n != 3 {
				return nDst, nSrc, transform.ErrShortDst
			}
		}
	}
	return
}

type unescapeMapping struct {
	transform.NopResetter
}

// TODO: Be more specific. Only check if it's the starting character in any
//       valid escape sequence.

func ishex(c byte) bool {
	switch {
	case '0' <= c && c <= '9':
		return true
	case 'a' <= c && c <= 'f':
		return true
	case 'A' <= c && c <= 'F':
		return true
	}
	return false
}

// I just wrote these all out because it's a lot faster and not likely to
// change; is it really worth the confusing logic though?
func shouldUnescape(s []byte) bool {
	return (s[0] == '2' && (s[1] == '0' || s[1] == '2' || s[1] == '6' || s[1] == '7' || s[1] == 'f' || s[1] == 'F')) || (s[0] == '3' && (s[1] == 'a' || s[1] == 'A' || s[1] == 'c' || s[1] == 'C' || s[1] == 'e' || s[1] == 'E')) || (s[0] == '4' && s[1] == '0') || (s[0] == '5' && (s[1] == 'c' || s[1] == 'C'))
}

func unhex(c byte) byte {
	switch {
	case '0' <= c && c <= '9':
		return c - '0'
	case 'a' <= c && c <= 'f':
		return c - 'a' + 10
	case 'A' <= c && c <= 'F':
		return c - 'A' + 10
	}
	return 0
}

func (unescapeMapping) Span(src []byte, atEOF bool) (n int, err error) {
	for n < len(src) {
		if src[n] != '\\' {
			n++
			continue
		}

		switch n {
		case len(src) - 1:
			// The last character is the escape char.
			if atEOF {
				return len(src), nil
			}
			return n, transform.ErrShortSrc
		case len(src) - 2:
			if atEOF || !ishex(src[n+1]) {
				return len(src), nil
			}
			return n, transform.ErrShortSrc
		}

		if shouldUnescape(src[n+1 : n+3]) {
			// unhex(s[n+1])<<4 | unhex(s[n+2])
			return n, transform.ErrEndOfSpan
		}
		n++
	}
	return
}

func (t unescapeMapping) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
	for nSrc < len(src) {
		idx := bytes.IndexRune(src[nSrc:], '\\')

		switch {
		case idx == -1 || (idx == len(src[nSrc:])-1 && atEOF):
			// No unescape sequence exists, or the escape sequence is at the end but
			// there aren't enough following characters to make it valid, so copy to
			// the end.
			n := copy(dst[nDst:], src[nSrc:])
			nDst += n
			nSrc += n
			if nSrc < len(src) {
				return nDst, nSrc, transform.ErrShortDst
			}
			return
		case idx == len(src[nSrc:])-1:
			// The last character is the escape char and this isn't the EOF
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			return nDst, nSrc, transform.ErrShortSrc
		case idx == len(src[nSrc:])-2:
			if atEOF || !ishex(src[nSrc+idx+1]) {
				n := copy(dst[nDst:], src[nSrc:])
				nDst += n
				nSrc += n
				if nSrc < len(src) {
					return nDst, nSrc, transform.ErrShortDst
				}
				return
			}
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			return nDst, nSrc, transform.ErrShortSrc
		}

		if shouldUnescape(src[nSrc+idx+1 : nSrc+idx+3]) {
			n := copy(dst[nDst:], src[nSrc:nSrc+idx])
			nDst += n
			nSrc += n
			if n != idx {
				return nDst, nSrc, transform.ErrShortDst
			}
			if n == 0 {
				n++
			}
			n = copy(dst[nDst:], []byte{
				unhex(src[nSrc+n])<<4 | unhex(src[nSrc+n+1]),
			})
			nDst += n
			nSrc += 3
			if n != 1 {
				return nDst, nSrc, transform.ErrShortDst
			}
			continue
		}
		n := copy(dst[nDst:], src[nSrc:nSrc+idx+1])
		nDst += n
		nSrc += n
		if n != idx+1 {
			return nDst, nSrc, transform.ErrShortDst
		}
	}
	return
}