~rootmos/lua-hack

7bae9e11ac1e6e6b097f18989100df0051b02469 — Gustav Behm 1 year, 10 months ago 46091e3 sha1
Add the SHA1 implementation
3 files changed, 176 insertions(+), 22 deletions(-)

M sha/.k
M sha/init.lua
M sha/test.lua
M sha/.k => sha/.k +0 -8
@@ 1,11 1,3 @@
export LUA_PATH="../?.lua;$LUA_PATH"

export LUA="/home/gustav/build/lua/lua-5.4.4/src/lua"

test() {
    make && ../test
}

go() {
    make && lua go.lua
}

M sha/init.lua => sha/init.lua +140 -14
@@ 2,11 2,14 @@ local M = {}

local u32 = require("uint32")
local rotr = u32.rotr
local rotl = u32.rotl
local shr = u32.shr
local add = u32.add
local sum = u32.sum
local xor = u32.xor
local band = u32.band
local bnot = u32.bnot
local bor = u32.bor

function M.ch(x, y, z)
   return xor(band(x, y), band(bnot(x), z))


@@ 16,6 19,10 @@ function M.maj(x, y, z)
   return xor(xor(band(x, y), band(x, z)), band(y, z))
end

function M.parity(x, y, z)
   return xor(xor(x, y), z)
end

M.sha256 = setmetatable({
   H0 = {
      u32(0x6a09,0xe667), u32(0xbb67,0xae85), u32(0x3c6e,0xf372), u32(0xa54f,0xf53a),


@@ 82,29 89,26 @@ local function sha256_process_blocks(ctx)
         W[t]
      )

      local T2 = sum(
         M.sha256.bsig0(a),
         M.maj(a, b, c)
      )
      local T2 = add(M.sha256.bsig0(a), M.maj(a, b, c))

      h = g
      g = f
      f = e
      e = sum(d, T1)
      e = add(d, T1)
      d = c
      c = b
      b = a
      a = sum(T1, T2)
      a = add(T1, T2)
   end

   ctx.H[1] = sum(a, ctx.H[1])
   ctx.H[2] = sum(b, ctx.H[2])
   ctx.H[3] = sum(c, ctx.H[3])
   ctx.H[4] = sum(d, ctx.H[4])
   ctx.H[5] = sum(e, ctx.H[5])
   ctx.H[6] = sum(f, ctx.H[6])
   ctx.H[7] = sum(g, ctx.H[7])
   ctx.H[8] = sum(h, ctx.H[8])
   ctx.H[1] = add(a, ctx.H[1])
   ctx.H[2] = add(b, ctx.H[2])
   ctx.H[3] = add(c, ctx.H[3])
   ctx.H[4] = add(d, ctx.H[4])
   ctx.H[5] = add(e, ctx.H[5])
   ctx.H[6] = add(f, ctx.H[6])
   ctx.H[7] = add(g, ctx.H[7])
   ctx.H[8] = add(h, ctx.H[8])

   ctx.buf = ctx.buf:sub(65)



@@ 164,4 168,126 @@ function M.sha256.init(buf)
   return ctx
end

M.sha1 = setmetatable({
   H0 = {
      u32(0x6745, 0x2301),
      u32(0xefcd, 0xab89),
      u32(0x98ba, 0xdcfe),
      u32(0x1032, 0x5476),
      u32(0xc3d2, 0xe1f0),
   },
}, {
   __call = function(L, ...)
      local ctx = L.init()
      for _, bs in ipairs({...}) do
         ctx:update(bs)
      end
      return ctx:finalize()
   end,
})

M.sha1.f = {}
M.sha1.K = {}
for t = 1,80 do
   if 1 <= t and t <= 20 then
      M.sha1.f[t] = M.ch
      M.sha1.K[t] = u32(0x5a82, 0x7999)
   elseif 21 <= t and t <= 40 then
      M.sha1.f[t] = M.parity
      M.sha1.K[t] = u32(0x6ed9, 0xeba1)
   elseif 41 <= t and t <= 60 then
      M.sha1.f[t] = M.maj
      M.sha1.K[t] = u32(0x8f1b, 0xbcdc)
   elseif 61 <= t and t <= 80 then
      M.sha1.f[t] = M.parity
      M.sha1.K[t] = u32(0xca62, 0xc1d6)
   end
end

local function sha1_process_blocks(ctx)
   if #ctx.buf < 64 then
      return
   end

   local W = {}
   for t = 1,16 do
      W[t] = u32.from_bytes(ctx.buf:sub(t*4-3,t*4))
   end
   assert(#W == 16)

   for t = 17,80 do
      W[t] = rotl(xor(xor(xor(W[t-3], W[t-8]), W[t-14]), W[t-16]), 1)
   end
   assert(#W == 80)

   local A, B, C, D, E = table.unpack(ctx.H)

   for t = 1,80 do
      local TEMP = sum(
         rotl(A, 5),
         M.sha1.f[t](B, C, D),
         E,
         W[t],
         M.sha1.K[t]
      )

      E = D
      D = C
      C = rotl(B, 30)
      B = A
      A = TEMP
   end

   ctx.H[1] = add(ctx.H[1], A)
   ctx.H[2] = add(ctx.H[2], B)
   ctx.H[3] = add(ctx.H[3], C)
   ctx.H[4] = add(ctx.H[4], D)
   ctx.H[5] = add(ctx.H[5], E)

   ctx.buf = ctx.buf:sub(65)

   return sha1_process_blocks(ctx)
end

function M.sha1.update(ctx, bs)
   ctx.buf = ctx.buf .. bs
   ctx.L = ctx.L + #bs
   sha1_process_blocks(ctx)
   return ctx
end

M.sha1.pad = M.sha256.pad

function M.sha1.finalize(ctx)
   sha1_process_blocks(ctx)
   ctx.buf = M.sha1.pad(ctx.buf, ctx.L)
   sha1_process_blocks(ctx)
   assert(#ctx.buf == 0)

   local digest = ""
   for i = 1,5 do
      digest = digest .. u32.to_bytes(ctx.H[i])
   end

   return digest
end

function M.sha1.init(buf)
   H = {}
   for k, v in pairs(M.sha1.H0) do
      H[k] = v
   end

   local buf = buf or ""
   local ctx = {
      H = H,
      buf = buf,
      L = #buf,
      update = M.sha1.update,
      finalize = M.sha1.finalize,
   }
   sha1_process_blocks(ctx)
   return ctx
end

return M

M sha/test.lua => sha/test.lua +36 -0
@@ 70,4 70,40 @@ if require("bits") == 64 then
    end
end

function test_sha1_empty()
    local digest = L.sha1.init():finalize()
    lu.assertEquals(H.encode(digest), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
end

function test_sha1_foo()
    local digest = L.sha1.init("foo"):finalize()
    lu.assertEquals(H.encode(digest), "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
end

function test_sha1_foobar_1()
    local digest = L.sha1.init("foo"):update("bar"):finalize()
    lu.assertEquals(H.encode(digest), "8843d7f92416211de9ebb963ff4ce28125932878")
end

function test_sha1_foobar_2()
    local digest = L.sha1("foo", "bar")
    lu.assertEquals(H.encode(digest), "8843d7f92416211de9ebb963ff4ce28125932878")
end

function test_sha1_fresh()
    for _ = 1,N do
        local bs0 = F.bytestring()
        local bs1 = F.bytestring()
        local bs2 = F.bytestring()

        local ctx = L.sha1.init(bs0)
        ctx:update(bs1)
        ctx:update(bs2)
        local digest = ctx:finalize()

        lu.assertEquals(H.encode(digest), H.encode(ref.openssl.sha1(bs0, bs1, bs2)))
        lu.assertEquals(H.encode(digest), H.encode(ref.rfc.sha1(bs0, bs1, bs2)))
    end
end

os.exit(lu.LuaUnit.run())