const builtin = @import("builtin");
const std = @import("std");
const fs = std.fs;
const io = std.io;
const math = std.math;
const mem = std.mem;
const meta = std.meta;
const warn = std.debug.warn;
pub const ReadError = error{Overflow};
pub const Reader = struct {
arena: std.heap.ArenaAllocator,
const Self = @This();
pub fn init(allocator: *mem.Allocator) Self {
return .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
}
pub fn deinit(self: *Self) void {
self.arena.deinit();
}
pub fn read(self: *Self, comptime T: type, reader: anytype) !T {
return switch (@typeInfo(T)) {
.Int => self.readInt(T, reader),
.Float => self.readFloat(T, reader),
.Bool => self.readBool(reader),
.Struct => if (comptime isHashMap(T))
self.readHashMap(T, reader)
else
self.readStruct(T, reader),
.Enum => self.readEnum(T, reader),
.Optional => self.readOptional(T, reader),
.Array => self.readArray(T, reader),
.Union => self.readUnion(T, reader),
.Pointer => self.readPointer(T, reader),
else => @compileError("unsupported type " ++ @typeName(T)),
};
}
fn readAllowVoid(self: *Self, comptime T: type, reader: anytype) !T {
return switch (T) {
void => {},
else => self.read(T, reader),
};
}
fn readVarInt(self: *Self, reader: anytype) !i64 {
const ux = try self.readVarUint(reader);
var x = @intCast(i64, ux >> 1);
if (ux & 1 != 0)
x = -1 ^ x;
return x;
}
fn readVarUint(self: *Self, reader: anytype) !u64 {
var x: u64 = 0;
var s: u6 = 0;
var i: usize = 0;
while (true) : (i += 1) {
const b = try reader.readByte();
if (b < 0x80) {
// if (9 < i or i == 9 and 1 < b)
if (i == 9 and 1 < b)
return ReadError.Overflow;
return x | @as(u64, b & 0x7f) << s;
}
x |= @as(u64, b & 0x7f) << s;
if (@addWithOverflow(@TypeOf(s), s, 7, &s))
return ReadError.Overflow;
}
return 0;
}
fn readInt(self: *Self, comptime T: type, reader: anytype) !T {
const type_info = @typeInfo(T);
return switch (type_info.Int.bits) {
8, 16, 32, 64 => reader.readIntLittle(T),
else => @compileError("unsupported integer type " ++ @typeName(T)),
};
}
fn readFloat(self: *Self, comptime T: type, reader: anytype) !T {
const bits = @typeInfo(T).Float.bits;
return switch (bits) {
32, 64 => @bitCast(T, try reader.readIntLittle(meta.Int(.unsigned, bits))),
else => @compileError("unsupported float type " ++ @typeName(T)),
};
}
fn readBool(self: *Self, reader: anytype) !bool {
return 0 != try reader.readByte();
}
fn readStruct(self: *Self, comptime T: type, reader: anytype) !T {
const ti = @typeInfo(T).Struct;
var s: T = undefined;
if (ti.fields.len < 1)
@compileError("structs must have 1 or more fields");
inline for (ti.fields) |f|
@field(s, f.name) = try self.read(f.field_type, reader);
return s;
}
fn readEnum(self: *Self, comptime T: type, reader: anytype) !T {
const TT = @TagType(T);
return meta.intToEnum(T, @intCast(TT, try self.readVarUint(reader)));
}
fn readOptional(self: *Self, comptime T: type, reader: anytype) !T {
if (0x0 == try reader.readByte())
return null;
const type_info = @typeInfo(T);
return @as(T, try self.read(type_info.Optional.child, reader));
}
fn readArray(self: *Self, comptime T: type, reader: anytype) !T {
const ti = @typeInfo(T).Array;
var buf = mem.zeroes([ti.len]ti.child);
var i: usize = 0;
if (ti.len < 1)
@compileError("array length must be at least 1");
while (i < ti.len) : (i += 1)
buf[i] = try self.read(ti.child, reader);
return buf;
}
fn readUnion(self: *Self, comptime T: type, reader: anytype) !T {
const ti = @typeInfo(T).Union;
if (ti.tag_type == null)
@compileError("only tagged unions are supported");
const tag = try self.readVarUint(reader);
inline for (ti.fields) |f| {
if (tag == @enumToInt(@field(ti.tag_type.?, f.name))) {
const v = try self.readAllowVoid(f.field_type, reader);
return @unionInit(T, f.name, v);
}
}
@panic("malformed union");
}
fn readPointer(self: *Self, comptime T: type, reader: anytype) !T {
const ti = @typeInfo(T).Pointer;
if (ti.size != .Slice)
@compileError("slices are the only supported pointer type");
var len = try self.readVarUint(reader);
var i: usize = 0;
var buf = try self.arena.allocator.alloc(ti.child, len);
while (i != len) : (i += 1)
buf[i] = try self.read(ti.child, reader);
return buf;
}
fn readHashMap(self: *Self, comptime T: type, reader: anytype) !T {
const K = HashMapKeyType(T);
const V = HashMapValueType(T);
if (comptime !isValidHashMapKeyType(K))
@compileError("unsupported hashmap key type " ++ @typeName(K));
var i = @intCast(u32, try self.readVarUint(reader));
var map = T.init(&self.arena.allocator);
if (i != 0) {
try map.ensureCapacity(i);
while (i != 0) : (i -= 1) {
const key = try self.read(K, reader);
const val = try self.read(V, reader);
_ = map.putAssumeCapacity(key, val);
}
}
return map;
}
};
pub const Writer = struct {
const Self = @This();
pub fn init() Self {
return .{};
}
pub fn deinit(self: *Self) void {}
pub fn write(self: *Self, value: anytype, writer: anytype) !void {
const T = @TypeOf(value);
return switch (@typeInfo(T)) {
.Int => self.writeInt(T, value, writer),
.Float => self.writeFloat(T, value, writer),
.Bool => self.writeBool(value, writer),
.Struct => if (comptime isHashMap(T))
self.writeHashMap(value, writer)
else
self.writeStruct(value, writer),
.Enum => self.writeEnum(value, writer),
.Optional => self.writeOptional(value, writer),
.Array => self.writeArray(value, writer),
.Union => self.writeUnion(value, writer),
.Pointer => self.writePointer(value, writer),
else => @compileError("unsupported type " ++ @typeName(T)),
};
}
fn writeAllowVoid(self: *Self, value: anytype, writer: anytype) !void {
return switch (@TypeOf(value)) {
void => {},
else => self.write(value, writer),
};
}
fn writeVarUint(self: *Self, value: u64, writer: anytype) !void {
var x = value;
while (0x80 <= x) {
try writer.writeByte(@truncate(u8, x) | 0x80);
x >>= 7;
}
try writer.writeByte(@truncate(u8, x));
}
fn writeVarInt(self: *Self, value: i64, writer: anytype) !void {
var ux = @bitCast(u64, value) << 1;
if (value < 0)
ux = math.maxInt(u64) ^ ux;
return self.writeVarUint(ux, writer);
}
fn writeInt(self: *Self, comptime T: type, value: T, writer: anytype) !void {
const type_info = @typeInfo(T);
return switch (type_info.Int.bits) {
8, 16, 32, 64 => writer.writeIntLittle(T, value),
else => @compileError("unsupported integer type " ++ @typeName(T)),
};
}
fn writeFloat(self: *Self, comptime T: type, value: T, writer: anytype) !void {
const type_info = @typeInfo(T);
return switch (type_info.Float.bits) {
32 => writer.writeIntLittle(u32, @bitCast(u32, value)),
64 => writer.writeIntLittle(u64, @bitCast(u64, value)),
else => @compileError("unsupported float type " ++ @typeName(T)),
};
}
fn writeBool(self: *Self, value: bool, writer: anytype) !void {
try writer.writeByte(@boolToInt(value));
}
fn writeStruct(self: *Self, value: anytype, writer: anytype) !void {
const ti = @typeInfo(@TypeOf(value)).Struct;
if (ti.fields.len < 1)
@compileError("structs must have 1 or more fields");
inline for (ti.fields) |f|
try self.write(@field(value, f.name), writer);
}
fn writeEnum(self: *Self, value: anytype, writer: anytype) !void {
try self.writeVarUint(@enumToInt(value), writer);
}
fn writeOptional(self: *Self, value: anytype, writer: anytype) !void {
if (value) |val| {
try writer.writeByte(@boolToInt(true));
try self.write(val, writer);
} else
try writer.writeByte(@boolToInt(false));
}
fn writeArray(self: *Self, value: anytype, writer: anytype) !void {
const ti = @typeInfo(@TypeOf(value)).Array;
if (ti.len < 1)
@compileError("array length must be at least 1");
for (value) |v|
try self.write(v, writer);
}
fn writePointer(self: *Self, value: anytype, writer: anytype) !void {
const ti = @typeInfo(@TypeOf(value)).Pointer;
if (ti.size != .Slice)
@compileError("slices are the only supported pointer type");
try self.writeVarUint(value.len, writer);
for (value) |v|
try self.write(v, writer);
}
fn writeHashMap(self: *Self, value: anytype, writer: anytype) !void {
const T = @TypeOf(value);
const K = HashMapKeyType(T);
if (comptime !isValidHashMapKeyType(K))
@compileError("unsupported hashmap key type " ++ @typeName(K));
try self.writeVarUint(value.count(), writer);
if (@hasDecl(T, "items")) {
for (value.items()) |entry| {
try self.write(entry.key, writer);
try self.write(entry.value, writer);
}
} else {
var it = value.iterator();
while (it.next()) |kv| {
try self.write(kv.key, writer);
try self.write(kv.value, writer);
}
}
}
fn writeUnion(self: *Self, value: anytype, writer: anytype) !void {
const T = @TypeOf(value);
const ti = @typeInfo(T).Union;
if (ti.tag_type) |TT| {
const tag = @enumToInt(value);
try self.writeVarUint(tag, writer);
inline for (ti.fields) |f| {
if (value == @field(TT, f.name))
try self.writeAllowVoid(@field(value, f.name), writer);
}
} else
@compileError("only tagged unions are supported");
}
};
// Horrible hack.
fn isHashMap(comptime T: type) bool {
// These are the only parts of the HashMap API that are used.
// `HashMapKeyType` and `HashMapValueType` add further constraints.
const has1 = @hasDecl(T, "iterator") or @hasDecl(T, "items");
const has2 = @hasDecl(T, "putAssumeCapacity") and @hasDecl(T, "ensureCapacity");
return has1 and has2;
}
fn HashMapKeyType(comptime T: type) type {
return HashMapType(T, "key");
}
fn HashMapValueType(comptime T: type) type {
return HashMapType(T, "value");
}
fn HashMapType(comptime T: type, comptime field_name: []const u8) type {
const fields = blk: {
if (@hasDecl(T, "KV")) {
break :blk @typeInfo(T.KV).Struct.fields;
} else if (@hasDecl(T, "Entry")) {
break :blk @typeInfo(T.Entry).Struct.fields;
} else
@compileError("unsupported Zig version");
};
inline for (fields) |f|
if (comptime mem.eql(u8, f.name, field_name))
return f.field_type;
}
fn isValidHashMapKeyType(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Int, .Float, .Bool, .Enum => true,
// Strings are allowed, but we don't quite have strings.
.Pointer => |p| p.size == .Slice,
else => false,
};
}
test "read variable uint" {
const x = try Reader.init(std.testing.allocator).readVarUint(io.fixedBufferStream("\x2a").inStream());
std.testing.expectEqual(x, 42);
const y = try Reader.init(std.testing.allocator).readVarUint(io.fixedBufferStream("\x80\x02").inStream());
std.testing.expectEqual(y, 0x100);
}
test "read variable uint overflow 1" {
const buf = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\x02";
const x = Reader.init(std.testing.allocator).readVarUint(io.fixedBufferStream(buf).inStream());
std.testing.expectError(ReadError.Overflow, x);
}
test "read variable uint overflow 2" {
const buf = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00";
const x = Reader.init(std.testing.allocator).readVarUint(io.fixedBufferStream(buf).inStream());
std.testing.expectError(ReadError.Overflow, x);
}
test "write variable uint" {
var buf: [4]u8 = undefined;
var fbs = io.fixedBufferStream(&buf);
try Writer.init().writeVarUint(42, fbs.outStream());
std.testing.expectEqual(fbs.getWritten()[0], 42);
try Writer.init().writeVarUint(0x100, fbs.outStream());
std.testing.expectEqual(fbs.getWritten()[1], 128);
std.testing.expectEqual(fbs.getWritten()[2], 2);
}
test "write variable int" {
var buf: [4]u8 = undefined;
var fbs = io.fixedBufferStream(&buf);
try Writer.init().writeVarInt(42, fbs.outStream());
std.testing.expectEqual(fbs.getWritten()[0], 42 << 1);
}
test "round trip variable uint" {
var buf: [4]u8 = undefined;
var fbs = io.fixedBufferStream(&buf);
try Writer.init().writeVarUint(0x10000, fbs.outStream());
const res = try Reader.init(std.testing.allocator).readVarUint(io.fixedBufferStream(fbs.getWritten()).inStream());
std.testing.expectEqual(res, 0x10000);
}
test "round trip variable int 1" {
var buf: [4]u8 = undefined;
var fbs = io.fixedBufferStream(&buf);
try Writer.init().writeVarInt(-0x10000, fbs.outStream());
const res = try Reader.init(std.testing.allocator).readVarInt(io.fixedBufferStream(fbs.getWritten()).inStream());
std.testing.expectEqual(res, -0x10000);
}
test "round trip variable int 2" {
var buf: [4]u8 = undefined;
var fbs = io.fixedBufferStream(&buf);
try Writer.init().writeVarInt(0x10000, fbs.outStream());
const res = try Reader.init(std.testing.allocator).readVarInt(io.fixedBufferStream(fbs.getWritten()).inStream());
std.testing.expectEqual(res, 0x10000);
}