~rootmos/lua-hack

9bf62137d6ad83d58079dd53d08cfc28c09c00f3 — Gustav Behm 1 year, 4 months ago d7ec725
Test pad, ch, maj and the sigma functions
3 files changed, 155 insertions(+), 57 deletions(-)

M sha/ref.c
M sha/sha.lua
M sha/sha.test.lua
M sha/ref.c => sha/ref.c +112 -47
@@ 8,6 8,33 @@
#include <assert.h>
#include <string.h>

static void expected_number_of_arguments(lua_State* L, int n)
{
    char buf[128];
    int m = lua_gettop(L);
    if(m == n) return;

    snprintf(buf, sizeof(buf),
             "incorret number of arguments: expected %d but got %d",
             n, m);
    lua_pushstring(L, buf);
    lua_error(L);
}

static uint32_t uint32_arg(lua_State* L, int idx)
{
    lua_Integer i = luaL_checkinteger(L, idx);

    if(i < 0 || i > UINT32_MAX) {
        char buf[128];
        snprintf(buf, sizeof(buf), "argument %d is not a 32-bit unsigned integer", idx);
        lua_pushstring(L, buf);
        lua_error(L);
    }

    return i;
}

static int openssl_sha256(lua_State* L)
{
    int n = lua_gettop(L);


@@ 75,37 102,72 @@ static int rfc_sha256(lua_State* L)
    return 1;
}

static void expected_number_of_arguments(lua_State* L, int n)
// from sha-private.h
#define SHA_Ch(x,y,z)        (((x) & (y)) ^ ((~(x)) & (z)))
#define SHA_Maj(x,y,z)       (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))

static int ch(lua_State* L)
{
    char buf[128];
    int m = lua_gettop(L);
    if(m == n) return;
    uint32_t x = uint32_arg(L, 1);
    uint32_t y = uint32_arg(L, 2);
    uint32_t z = uint32_arg(L, 3);

    snprintf(buf, sizeof(buf),
             "incorret number of arguments: expected %d but got %d",
             n, m);
    lua_pushstring(L, buf);
    lua_error(L);
    lua_pushinteger(L, SHA_Ch(x, y, z));
    return 1;
}

static uint32_t uint32_arg(lua_State* L, int idx)
static int maj(lua_State* L)
{
    if(!lua_isinteger(L, idx)) {
        char buf[128];
        snprintf(buf, sizeof(buf), "argument %d is not an integer", idx);
        lua_pushstring(L, buf);
        lua_error(L);
    }
    uint32_t x = uint32_arg(L, 1);
    uint32_t y = uint32_arg(L, 2);
    uint32_t z = uint32_arg(L, 3);

    lua_Integer i = lua_tointeger(L, idx);
    if(i < 0 || i > UINT32_MAX) {
        char buf[128];
        snprintf(buf, sizeof(buf), "argument %d is not a 32-bit unsigned integer", idx);
        lua_pushstring(L, buf);
        lua_error(L);
    }
    lua_pushinteger(L, SHA_Maj(x, y, z));
    return 1;
}

    return i;
// from sha224-256.c
#define SHA256_SHR(bits,word)      ((word) >> (bits))
#define SHA256_ROTL(bits,word)                         \
  (((word) << (bits)) | ((word) >> (32-(bits))))
#define SHA256_ROTR(bits,word)                         \
  (((word) >> (bits)) | ((word) << (32-(bits))))

#define SHA256_SIGMA0(word)   \
  (SHA256_ROTR( 2,word) ^ SHA256_ROTR(13,word) ^ SHA256_ROTR(22,word))
#define SHA256_SIGMA1(word)   \
  (SHA256_ROTR( 6,word) ^ SHA256_ROTR(11,word) ^ SHA256_ROTR(25,word))
#define SHA256_sigma0(word)   \
  (SHA256_ROTR( 7,word) ^ SHA256_ROTR(18,word) ^ SHA256_SHR( 3,word))
#define SHA256_sigma1(word)   \
  (SHA256_ROTR(17,word) ^ SHA256_ROTR(19,word) ^ SHA256_SHR(10,word))

static int sha256_bsig0(lua_State* L)
{
    uint32_t x = uint32_arg(L, 1);
    lua_pushinteger(L, SHA256_SIGMA0(x));
    return 1;
}

static int sha256_bsig1(lua_State* L)
{
    uint32_t x = uint32_arg(L, 1);
    lua_pushinteger(L, SHA256_SIGMA1(x));
    return 1;
}

static int sha256_ssig0(lua_State* L)
{
    uint32_t x = uint32_arg(L, 1);
    lua_pushinteger(L, SHA256_sigma0(x));
    return 1;
}

static int sha256_ssig1(lua_State* L)
{
    uint32_t x = uint32_arg(L, 1);
    lua_pushinteger(L, SHA256_sigma1(x));
    return 1;
}

static int uint32_shr(lua_State* L)


@@ 113,14 175,9 @@ static int uint32_shr(lua_State* L)
    expected_number_of_arguments(L, 2);

    uint32_t x = uint32_arg(L, 1);
    lua_Integer n = luaL_checkinteger(L, 2);

    if(!lua_isinteger(L, 2)) {
        lua_pushliteral(L, "invalid type");
        lua_error(L);
    }
    lua_Integer n = lua_tointeger(L, 2);

    uint32_t r = x >> n;
    uint32_t r = SHA256_SHR(n, x);
    lua_pushinteger(L, r);

    return 1;


@@ 149,14 206,9 @@ static int uint32_rotr(lua_State* L)
    expected_number_of_arguments(L, 2);

    uint32_t x = uint32_arg(L, 1);
    lua_Integer n = luaL_checkinteger(L, 2);

    if(!lua_isinteger(L, 2)) {
        lua_pushliteral(L, "invalid type");
        lua_error(L);
    }
    lua_Integer n = lua_tointeger(L, 2);

    uint32_t r = (x >> n) | (x << (32-n));
    uint32_t r = SHA256_ROTR(n, x);
    lua_pushinteger(L, r);

    return 1;


@@ 167,14 219,9 @@ static int uint32_rotl(lua_State* L)
    expected_number_of_arguments(L, 2);

    uint32_t x = uint32_arg(L, 1);
    lua_Integer n = luaL_checkinteger(L, 2);

    if(!lua_isinteger(L, 2)) {
        lua_pushliteral(L, "invalid type");
        lua_error(L);
    }
    lua_Integer n = lua_tointeger(L, 2);

    uint32_t r = (x << n) | (x >> (32-n));
    uint32_t r = SHA256_ROTL(n, x);
    lua_pushinteger(L, r);

    return 1;


@@ 233,8 280,10 @@ static int uint32_to_bytes(lua_State* L)
static int uint32_add(lua_State* L)
{
    expected_number_of_arguments(L, 2);
    uint32_t x = (uint32_t)luaL_checkinteger(L, 1);
    uint32_t y = (uint32_t)luaL_checkinteger(L, 2);

    uint32_t x = uint32_arg(L, 1);
    uint32_t y = uint32_arg(L, 2);

    uint32_t sum = x + y;
    lua_pushinteger(L, sum);
    return 1;


@@ 259,6 308,22 @@ int luaopen_ref_rfc(lua_State* L)
    lua_newtable(L);
    lua_pushcfunction(L, rfc_sha256);
    lua_setfield(L, -2, "sha256");

    lua_pushcfunction(L, ch);
    lua_setfield(L, -2, "ch");

    lua_pushcfunction(L, maj);
    lua_setfield(L, -2, "maj");

    lua_pushcfunction(L, sha256_bsig0);
    lua_setfield(L, -2, "sha256_bsig0");
    lua_pushcfunction(L, sha256_bsig1);
    lua_setfield(L, -2, "sha256_bsig1");
    lua_pushcfunction(L, sha256_ssig0);
    lua_setfield(L, -2, "sha256_ssig0");
    lua_pushcfunction(L, sha256_ssig1);
    lua_setfield(L, -2, "sha256_ssig1");

    return 1;
}


M sha/sha.lua => sha/sha.lua +15 -10
@@ 1,4 1,7 @@
local M = {}
local M = {
   ch = function(x, y, z) return (x & y) ~ ( ~x & z) end,
   maj = function(x, y, z) return (x & y) ~ (x & z) ~ (y & z) end,
}

local u32 = require("uint32")
local rotr = u32.rotr


@@ 25,8 28,6 @@ M.sha256 = {
      0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
      0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
   },
   ch = function(x, y, z) return (x & y) ~ ( (~x) & z) end,
   maj = function(x, y, z) return (x & y) ~ (x & z) ~ (y & z) end,
   bsig0 = function(x) return rotr(x, 2) ~ rotr(x, 13) ~ rotr(x, 22) end,
   bsig1 = function(x) return rotr(x, 6) ~ rotr(x, 11) ~ rotr(x, 25) end,
   ssig0 = function(x) return rotr(x, 7) ~ rotr(x, 18) ~ shr(x, 3) end,


@@ 70,6 71,7 @@ local function sha256_process_blocks(ctx)
      u32.from_bytes(ctx.buf:sub(57,60)),
      u32.from_bytes(ctx.buf:sub(61,64)),
   }
   assert(#W == 16)

   for i = 17,64 do
      W[i] = sum(


@@ 79,6 81,7 @@ local function sha256_process_blocks(ctx)
         W[i-16]
      )
   end
   assert(#W == 64)

   local a, b, c, d, e, f, g, h = table.unpack(ctx.H)



@@ 86,14 89,14 @@ local function sha256_process_blocks(ctx)
      local T1 = sum(
         h,
         M.sha256.bsig1(e),
         M.sha256.ch(e, f, g),
         M.ch(e, f, g),
         M.sha256.K[t],
         W[t]
      )

      local T2 = sum(
         M.sha256.bsig0(a),
         M.sha256.maj(a, b, c)
         M.maj(a, b, c)
      )

      h = g


@@ 125,21 128,23 @@ function M.sha256.update(ctx, bs)
   sha256_process_blocks(ctx)
end

local function sha256_pad(ctx)
function M.sha256.pad(buf, l)
   local k = 0
   while (ctx.L*8 + 1 + k) ~= 448 do
   while (l*8 + 1 + k) ~= 448 do
      k = k + 1
   end
   assert(k % 8 == 7)
   k = k - 7
   local n = k // 8
   ctx.buf = ctx.buf .. string.char(0x80) .. string.rep("\0", n)
   ctx.buf = ctx.buf .. "\0\0\0\0" .. u32.to_bytes(ctx.L)
   buf = buf .. string.char(0x80) .. string.rep("\0", n)
   buf = buf .. "\0\0\0\0" .. u32.to_bytes(l*8)
   assert((#buf % 64) == 0)
   return buf
end

function M.sha256.finalize(ctx)
   sha256_process_blocks(ctx)
   sha256_pad(ctx)
   ctx.buf = M.sha256.pad(ctx.buf, ctx.L)
   sha256_process_blocks(ctx)
   assert(#ctx.buf == 0)


M sha/sha.test.lua => sha/sha.test.lua +28 -0
@@ 3,6 3,8 @@ local ref = require("ref")
local L = require("sha")
local H = require("hex")

local N = 100

function test_sha256_empty()
    local ctx = L.sha256.init()
    local digest = L.sha256.finalize(ctx)


@@ 15,4 17,30 @@ function test_sha256_foo()
    lu.assertEquals(H.encode(digest), "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae")
end

function test_sha256_pad()
    msg = H.decode("6162636465")
    padded = L.sha256.pad(msg, #msg)
    lu.assertEquals(padded, H.decode("61626364658000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000028"))
end

function test_ch_and_maj()
    for _ = 1,N do
        local x = ref.uint32.fresh()
        local y = ref.uint32.fresh()
        local z = ref.uint32.fresh()
        lu.assertEquals(L.ch(x, y, z), ref.rfc.ch(x, y, z))
        lu.assertEquals(L.maj(x, y, z), ref.rfc.maj(x, y, z))
    end
end

function test_sigma()
    for _ = 1,N do
        local x = ref.uint32.fresh()
        lu.assertEquals(L.sha256.bsig0(x), ref.rfc.sha256_bsig0(x))
        lu.assertEquals(L.sha256.bsig1(x), ref.rfc.sha256_bsig1(x))
        lu.assertEquals(L.sha256.ssig0(x), ref.rfc.sha256_ssig0(x))
        lu.assertEquals(L.sha256.ssig1(x), ref.rfc.sha256_ssig1(x))
    end
end

os.exit(lu.LuaUnit.run())