~rootmos/lua-hack

5a0b4e0120220cfeab22898785be0c92fe6eeec5 — Gustav Behm 1 year, 1 month ago 26c8215
Add a rope-buffer structure
4 files changed, 126 insertions(+), 36 deletions(-)

A buffer.lua
A buffer.test.lua
M sha/init.lua
M sha/test.lua
A buffer.lua => buffer.lua +68 -0
@@ 0,0 1,68 @@
local M = {}

local mt = {
    __len = function(st) return st.remaining end,
}

function M.init(bs)
    local st = setmetatable({
        i = 1,
        j = 0,
        k = 0,
        total = 0,
        consumed = 0,
        remaining = 0,
        append = M.append,
        consume = M.consume,
    }, mt)

    if bs then
        st:append(bs)
    end

    return st
end

function M.append(st, bs)
    st.j = st.j + 1
    st[st.j] = bs
    st.total = st.total + #bs
    st.remaining = st.remaining + #bs
    return st
end

function M.consume(st, n)
    if st.remaining < n then
        return nil
    end

    local t = {}
    while n > 0 do
        local b = st[st.i]
        local r = #b - st.k
        if n <= r then
            table.insert(t, b:sub(st.k+1, st.k+n))
            st.k = st.k + n
            st.remaining = st.remaining - n

            if r == n then
                st[st.i] = nil
                st.i = st.i + 1
                st.k = 0
            end

            break
        else
            table.insert(t, b:sub(st.k+1))
            st[st.i] = nil
            st.i = st.i + 1
            st.k = 0
            st.remaining = st.remaining - r
            n = n - r
        end
    end

    return table.concat(t)
end

return M

A buffer.test.lua => buffer.test.lua +40 -0
@@ 0,0 1,40 @@
local lu = require("luaunit")
local B = require("buffer")

function test_empty()
    local b = B.init()
    lu.assertEquals(b:consume(0), "")
end

function test_foobar_1()
    local b = B.init("foobar")
    lu.assertEquals(b:consume(2), "fo")
    lu.assertEquals(b:consume(2), "ob")
    lu.assertEquals(b:consume(2), "ar")
    lu.assertNil(b:consume(2))
end

function test_foobar_2()
    local b = B.init("foo"):append("bar")
    lu.assertEquals(b:consume(2), "fo")
    lu.assertEquals(b:consume(2), "ob")
    lu.assertEquals(b:consume(2), "ar")
    lu.assertNil(b:consume(2))
end

function test_foobar_3()
    local b = B.init("foo"):append("bar")
    lu.assertEquals(b:consume(3), "foo")
    lu.assertEquals(b:consume(3), "bar")
    lu.assertNil(b:consume(1))
end

function test_foobar_4()
    local b = B.init("foo")
    lu.assertEquals(b:consume(3), "foo")
    b:append("bar")
    lu.assertEquals(b:consume(3), "bar")
    lu.assertNil(b:consume(1))
end

os.exit(lu.LuaUnit.run())

M sha/init.lua => sha/init.lua +18 -30
@@ 1,5 1,7 @@
local M = {}

local B = require("buffer")

local u32 = require("uint32")
local rotr = u32.rotr
local rotl = u32.rotl


@@ 58,13 60,14 @@ M.sha256 = setmetatable({
})

local function sha256_process_blocks(ctx)
   if #ctx.buf < 64 then
   local buf = ctx.buf:consume(64)
   if buf == nil then
      return
   end

   local W = {}
   for t = 1,16 do
      W[t] = u32.from_bytes(ctx.buf:sub(t*4-3,t*4))
      W[t] = u32.from_bytes(buf:sub(t*4-3,t*4))
   end
   assert(#W == 16)



@@ 110,35 113,29 @@ local function sha256_process_blocks(ctx)
   ctx.H[7] = add(g, ctx.H[7])
   ctx.H[8] = add(h, ctx.H[8])

   ctx.buf = ctx.buf:sub(65)

   return sha256_process_blocks(ctx)
end

function M.sha256.update(ctx, bs)
   ctx.buf = ctx.buf .. bs
   ctx.L = ctx.L + #bs
   ctx.buf:append(bs)
   sha256_process_blocks(ctx)
   return ctx
end

function M.sha256.pad(buf, l)
local function pad(buf)
   local l = buf.total
   local k = 0
   while (l*8 + 1 + k) % 512 ~= 448 do
      k = k + 1
   end
   assert(k % 8 == 7)
   k = k - 7
   local n = k // 8
   buf = buf .. string.char(0x80) .. string.rep("\0", n)
   buf = buf .. "\0\0\0\0" .. u32.to_bytes(u32(l*8)) -- TODO: support large messages
   assert((#buf % 64) == 0)
   return buf
   buf:append(string.char(0x80) .. string.rep("\0", n))
   buf:append("\0\0\0\0" .. u32.to_bytes(u32(l*8))) -- TODO: support large messages
end

function M.sha256.finalize(ctx)
   sha256_process_blocks(ctx)
   ctx.buf = M.sha256.pad(ctx.buf, ctx.L)
   pad(ctx.buf)
   sha256_process_blocks(ctx)
   assert(#ctx.buf == 0)



@@ 156,11 153,9 @@ function M.sha256.init(buf)
      H[k] = v
   end

   local buf = buf or ""
   local ctx = {
      H = H,
      buf = buf,
      L = #buf,
      buf = B.init(buf),
      update = M.sha256.update,
      finalize = M.sha256.finalize,
   }


@@ 205,13 200,14 @@ for t = 1,80 do
end

local function sha1_process_blocks(ctx)
   if #ctx.buf < 64 then
   local buf = ctx.buf:consume(64)
   if buf == nil then
      return
   end

   local W = {}
   for t = 1,16 do
      W[t] = u32.from_bytes(ctx.buf:sub(t*4-3,t*4))
      W[t] = u32.from_bytes(buf:sub(t*4-3,t*4))
   end
   assert(#W == 16)



@@ 244,23 240,17 @@ local function sha1_process_blocks(ctx)
   ctx.H[4] = add(ctx.H[4], D)
   ctx.H[5] = add(ctx.H[5], E)

   ctx.buf = ctx.buf:sub(65)

   return sha1_process_blocks(ctx)
end

function M.sha1.update(ctx, bs)
   ctx.buf = ctx.buf .. bs
   ctx.L = ctx.L + #bs
   ctx.buf:append(bs)
   sha1_process_blocks(ctx)
   return ctx
end

M.sha1.pad = M.sha256.pad

function M.sha1.finalize(ctx)
   sha1_process_blocks(ctx)
   ctx.buf = M.sha1.pad(ctx.buf, ctx.L)
   pad(ctx.buf)
   sha1_process_blocks(ctx)
   assert(#ctx.buf == 0)



@@ 278,11 268,9 @@ function M.sha1.init(buf)
      H[k] = v
   end

   local buf = buf or ""
   local ctx = {
      H = H,
      buf = buf,
      L = #buf,
      buf = B.init(buf),
      update = M.sha1.update,
      finalize = M.sha1.finalize,
   }

M sha/test.lua => sha/test.lua +0 -6
@@ 26,12 26,6 @@ function test_sha256_foobar_2()
    lu.assertEquals(H.encode(digest), "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2")
end

function test_sha256_pad()
    msg = H.decode("6162636465")
    padded = L.sha256.pad(msg, #msg)
    lu.assertEquals(padded, H.decode("61626364658000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000028"))
end

function test_sha256_fresh()
    for _ = 1,N do
        local bs0 = F.bytestring()