~turminal/hare

8eb08f665aae1f499b0e16a53c55b2996f856ca7 — Bor Grošelj Simić 6 months ago c0432c6
encoding::base64: rewrite decoder

It uses io::readall now instead of doing repeated reads on its own,
doesn't use static variables for storing state between reads and handles
short/partial reads much better.

References: https://todo.sr.ht/~sircmpwn/hare/819
Signed-off-by: Bor Grošelj Simić <bgs@turminal.net>
1 files changed, 136 insertions(+), 149 deletions(-)

M encoding/base64/base64.ha
M encoding/base64/base64.ha => encoding/base64/base64.ha +136 -149
@@ 52,7 52,6 @@ export type encoder = struct {
	iavail: size,
	obuf: [4]u8,
	oavail: size,
	err: (void | io::error),
};

const encoder_vtable: io::vtable = io::vtable {


@@ 73,7 72,6 @@ export fn newencoder(
		stream = &encoder_vtable,
		out = out,
		enc = enc,
		err = void,
		...
	};
};


@@ 83,13 81,6 @@ fn encode_writer(
	in: const []u8
) (size | io::error) = {
	let s = s: *encoder;
	match(s.err) {
	case let err: io::error =>
		s.err = void;
		return err;
	case void => void;
	};

	let i = 0z;
	for (i < len(in)) {
		let b = s.ibuf[..];


@@ 111,7 102,6 @@ fn encode_writer(
			if (i == 0) {
				return e;
			};
			s.err = e;
			return i;
		case void => void;
		};


@@ 153,13 143,6 @@ fn encode_closer(s: *io::stream) (void | io::error) = {
	let finished = false;
	defer if (finished) clear(s);

	match (s.err) {
	case let e: io::error =>
		s.err = void;
		return e;
	case void => void;
	};

	if (s.oavail > 0) {
		for (s.oavail > 0) {
			writeavail(s)?;


@@ 282,9 265,11 @@ export type decoder = struct {
	stream: io::stream,
	in: io::handle,
	enc: *encoding,
	avail: []u8, // leftover decoded output
	obuf: [3]u8, // leftover decoded output
	ibuf: [4]u8,
	iavail: u8,
	oavail: u8,
	pad: bool, // if padding was seen in a previous read
	state: (void | io::EOF | io::error),
};

const decoder_vtable: io::vtable = io::vtable {


@@ 303,7 288,6 @@ export fn newdecoder(
		stream = &decoder_vtable,
		in = in,
		enc = enc,
		state = void,
		...
	};
};


@@ 313,101 297,91 @@ fn decode_reader(
	out: []u8
) (size | io::EOF | io::error) = {
	let s = s: *decoder;
	let n = 0z;
	let l = len(out);
	match(s.state) {
	case let err: (io::EOF | io ::error) =>
		return err;
	case void => void;
	if (len(out) == 0) {
		return 0z;
	};
	if (len(s.avail) > 0) {
		n += if (l < len(s.avail)) l else len(s.avail);
		out[..n] = s.avail[0..n];
		s.avail = s.avail[n..];
		if (l == n) {
			return n;
	let n = 0z;
	if (s.oavail != 0) {
		if (len(out) <= s.oavail) {
			out[..] = s.obuf[..len(out)];
			s.obuf[..len(s.obuf) - len(out)] = s.obuf[len(out)..];
			s.oavail = s.oavail - len(out): u8;
			return len(out);
		};
	};
	static let buf: [os::BUFSZ]u8 = [0...];
	static let obuf: [os::BUFSZ / 4 * 3]u8 = [0...];
	const nn = ((l - n) / 3 + 1) * 4; // 4 extra bytes may be read.
	let nr = 0z;
	for (nr < nn) {
		match (io::read(s.in, buf[nr..])) {
		case let n: size =>
			if (n == 0) {
				break;
			};
			nr += n;
		case io::EOF =>
			s.state = io::EOF;
			break;
		case let err: io::error =>
			s.state = err;
		n = s.oavail;
		s.oavail = 0;
		out[..n] = s.obuf[..n];
		out = out[n..];
	};
	let buf: [os::BUFSZ]u8 = [0...];
	buf[..s.iavail] = s.ibuf[..s.iavail];

	let want = encodedsize(len(out));
	let nr = s.iavail: size;
	let lim = if (want > len(buf)) len(buf) else want;
	match (io::readall(s.in, buf[s.iavail..lim])) {
	case let n: size =>
		nr += n;
	case io::EOF =>
		return if (s.iavail != 0) errors::invalid
			else if (n != 0) n
			else io::EOF;
	case let err: io::error =>
		if (!(err is io::underread)) {
			return err;
		};
		nr += err: io::underread;
	};
	if (nr % 4 != 0) {
		s.state = errors::invalid;
	if (s.pad) {
		return errors::invalid;
	};
	if (nr == 0) { // io::EOF already set
		return n;
	s.iavail = nr: u8 % 4;
	s.ibuf[..s.iavail] = buf[nr - s.iavail..nr];
	nr -= s.iavail;
	if (nr == 0) {
		return 0z;
	};
	// Validating read buffer
	let valid = true;
	let np = 0; // Number of padding chars.
	let p = true; // Pad allowed in buf
	for (let i = nr; i > 0; i -= 1) {
		const ch = buf[i - 1];
		if (ch >= 128) {
			return errors::invalid;
		};
		if (ch == PADDING) {
			if(s.pad || !p) {
				valid = false;
				break;
			};
			np += 1;
		} else {
			if (s.enc.decmap[ch] == -1) {
				valid = false;
				break;
	let np = 0z; // Number of padding chars.
	for (let i = 0z; i < nr; i += 1) {
		if (buf[i] == PADDING) {
			for (i + np < nr; np += 1) {
				if (np > 2 || buf[i + np] != PADDING) {
					return errors::invalid;
				};
			};
			// Disallow padding on seeing a non-padding char
			p = false;
			s.pad = true;
			break;
		};
	};
	valid = valid && np <= 2;
	if (np > 0) {
		s.pad = true;
	};
	if (!valid) {
		s.state = errors::invalid;
		return errors::invalid;
	};
	for (let i = 0z; i < nr; i += 1) {
		if (buf[i] >= 128) {
		if (!ascii::valid(buf[i]: u32: rune) || s.enc.decmap[buf[i]] == -1) {
			return errors::invalid;
		};
		buf[i] = s.enc.decmap[buf[i]];
	};
	for (let i = 0z, j = 0z; i < nr) {
		obuf[j] = buf[i] << 2 | buf[i + 1] >> 4;
		obuf[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2;
		obuf[j + 2] = buf[i + 2] << 6 | buf[i + 3];

	if (nr / 4 * 3 - np < len(out)) {
		out = out[..nr / 4 * 3 - np];
	};
	let i = 0z, j = 0z;
	nr -= 4;
	for (i < nr) {
		out[j    ] = buf[i    ] << 2 | buf[i + 1] >> 4;
		out[j + 1] = buf[i + 1] << 4 | buf[i + 2] >> 2;
		out[j + 2] = buf[i + 2] << 6 | buf[i + 3];

		i += 4;
		j += 3;
	};
	// Removing bytes added due to padding.
	//                         0  1  2 // np
	static const npr: [3]u8 = [0, 1, 2]; // bytes to discard
	const navl = nr / 4 * 3 - npr[np];
	const rem = if(l - n < navl) l - n else navl;
	out[n..n + rem] = obuf[..rem];
	s.avail = obuf[rem..navl];
	return n + rem;
	s.obuf = [
		buf[i    ] << 2 | buf[i + 1] >> 4,
		buf[i + 1] << 4 | buf[i + 2] >> 2,
		buf[i + 2] << 6 | buf[i + 3],
	];
	out[j..] = s.obuf[..len(out) - j];
	s.oavail = (len(s.obuf) - (len(out) - j)): u8;
	s.obuf[..s.oavail] = s.obuf[len(s.obuf) - s.oavail..];
	s.oavail -= np: u8;
	return n + len(out);
};

// Decodes a byte slice of ASCII-encoded base 64 data, using the given encoding,


@@ 416,15 390,22 @@ export fn decodeslice(
	enc: *encoding,
	in: []u8,
) ([]u8 | errors::invalid) = {
	let in = memio::fixed(in);
	let decoder = newdecoder(enc, &in);
	let out = memio::dynamic();
	match (io::copy(&out, &decoder)) {
	if (len(in) == 0) {
		return [];
	};
	if (len(in) % 4 != 0) {
		return errors::invalid;
	};
	let ins = memio::fixed(in);
	let decoder = newdecoder(enc, &ins);
	let out = alloc([0u8...], decodedsize(len(in)));
	let outs = memio::fixed(out);
	match (io::copy(&outs, &decoder)) {
	case io::error =>
		io::close(&out)!;
		free(out);
		return errors::invalid;
	case size =>
		return memio::buffer(&out);
	case let sz: size =>
		return memio::buffer(&outs)[..sz];
	};
};



@@ 441,15 422,8 @@ export fn decode(
	enc: *encoding,
	buf: []u8,
) (size | io::EOF | io::error) = {
	const enc = newdecoder(enc, in);
	match (io::readall(&enc, buf)) {
	case let ret: (size | io::EOF) =>
		io::close(&enc)?;
		return ret;
	case let err: io::error =>
		io::close(&enc): void;
		return err;
	};
	const dec = newdecoder(enc, in);
	return io::readall(&dec, buf);
};

@test fn decode() void = {


@@ 463,54 437,67 @@ export fn decode(
		("Zm9vYmE=", "fooba", &std_encoding),
		("Zm9vYmFy", "foobar", &std_encoding),
	];
	for (let i = 0z; i < len(cases); i += 1) {
		let in = memio::fixed(strings::toutf8(cases[i].0));
		let decoder = newdecoder(cases[i].2, &in);
		let decb: []u8 = io::drain(&decoder)!;
		defer free(decb);
		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));

		// Testing decodestr should cover decodeslice too
		let decb = decodestr(cases[i].2, cases[i].0) as []u8;
		defer free(decb);
		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
	};
	// Repeat of the above, but with a larger buffer
	for (let i = 0z; i < len(cases); i += 1) {
		let in = memio::fixed(strings::toutf8(cases[i].0));
		let decoder = newdecoder(cases[i].2, &in);
		let decb: []u8 = io::drain(&decoder)!;
		defer free(decb);
		assert(bytes::equal(decb, strings::toutf8(cases[i].1)));
	};

	const invalid: [_]str = [
	const invalid: [_](str, *encoding) = [
		// invalid padding
		"=", "==", "===", "=====", "======",
		// invalid characters
		"@Zg=", "êg=", "êg==", "$3d==", "%3d==", "[==", "!",
		// data after padding is encountered
		"Zg==Zg==", "Zm8=Zm8=",
		("=", &std_encoding),
		("==", &std_encoding),
	        ("===", &std_encoding),
	        ("=====", &std_encoding),
	        ("======", &std_encoding),
	        // invalid characters
	        ("@Zg=", &std_encoding),
	        ("ê==", &std_encoding),
	        ("êg==", &std_encoding),
		("$3d==", &std_encoding),
		("%3d==", &std_encoding),
		("[==", &std_encoding),
		("!", &std_encoding),
	        // data after padding is encountered
	        ("Zg===", &std_encoding),
	        ("Zg====", &std_encoding),
	        ("Zg==Zg==", &std_encoding),
	        ("Zm8=Zm8=", &std_encoding),
	];
	const encodings: [_]*encoding = [&std_encoding, &url_encoding];
	for (let i = 0z; i < len(invalid); i += 1) {
		for (let enc = 0z; enc < 2; enc += 1) {
			let in = memio::fixed(strings::toutf8(invalid[i]));
			let decoder = newdecoder(encodings[enc], &in);
			let buf: [1]u8 = [0...];
	let buf: [12]u8 = [0...];
	for (let bufsz = 1z; bufsz <= 12; bufsz += 1) {
		for (let (input, expected, encoding) .. cases) {
			let in = memio::fixed(strings::toutf8(input));
			let decoder = newdecoder(encoding, &in);
			let buf = buf[..bufsz];
			let decb: []u8 = [];
			defer free(decb);
			for (true) match (io::read(&decoder, buf)!) {
			case let z: size =>
				if (z > 0) {
					append(decb, buf[..z]...);
				};
			case io::EOF =>
				break;
			};
			assert(bytes::equal(decb, strings::toutf8(expected)));

			// Testing decodestr should cover decodeslice too
			let decb = decodestr(encoding, input) as []u8;
			defer free(decb);
			assert(bytes::equal(decb, strings::toutf8(expected)));
		};

		for (let (input, encoding) .. invalid) {
			let in = memio::fixed(strings::toutf8(input));
			let decoder = newdecoder(encoding, &in);
			let buf = buf[..bufsz];
			let valid = false;
			for (true) match(io::read(&decoder, buf)) {
			case errors::invalid =>
				break;
			case size =>
				valid = true;
				void;
			case io::EOF =>
				break;
				abort();
			};
			assert(valid == false, "valid is not false");

			// Testing decodestr should cover decodeslice too
			assert(decodestr(encodings[enc], invalid[i]) is errors::invalid);
			assert(decodestr(encoding, input) is errors::invalid);
		};
	};
};