~mna/tcheck

28bc9bb8f5650c0d996f5ed593a162de2c4429b5 — Martin Angers 3 years ago
initial commit
5 files changed, 170 insertions(+), 0 deletions(-)

A .gitignore
A .luacheckrc
A .luacov
A tcheck.lua
A test/main.lua
A  => .gitignore +11 -0
@@ 1,11 @@
# locally-installed lua modules
/lua_modules/

# output files generated by tools
*.out

# rocks packages
*.src.rock

# environment configuration
.env*

A  => .luacheckrc +3 -0
@@ 1,3 @@
std = 'lua53'
files['test/main.lua'].allow_defined_top = true
files['test/main.lua'].global = false

A  => .luacov +3 -0
@@ 1,3 @@
modules = {
  ["tcheck"] = "tcheck.lua"
}

A  => tcheck.lua +78 -0
@@ 1,78 @@
local typefn = {
  ['nil'] = type,
  ['number'] = type,
  ['string'] = type,
  ['boolean'] = type,
  ['table'] = type,
  ['function'] = type,
  ['thread'] = type,
  ['userdata'] = type,
  ['file'] = io.type,
  ['closed file'] = io.type,
  ['integer'] = math.type,
  ['float'] = math.type,
}

local function fieldtype(v)
  local mt = getmetatable(v)
  local t

  if mt then
    t = mt.__name
    if not t then
      t = mt.__type
    end
  end
  return t
end

local M = {}
setmetatable(M, {__call = function(m, ...) return m.check(...) end})

-- Checks that the values conform to the types provided as first argument.
-- The types argument can be a single string for a single value, or an array
-- of strings.
--
-- The type string can be any of the type built-in function values, any of the
-- math.type or io.type values, or a value to match with the __name or __type
-- field of the value's metadata.
--
-- To support multiple types for a value, the pipe '|' character is used to
-- separate different type names. As a special case, '*' can be specified
-- to mean any non-nil value, it cannot be combined with any other type.
--
-- Returns the array of types that did match for each value.
function M.check(types, ...)
  if type(types) == 'string' then
    types = {types}
  end

  -- do not validate more values than we have types for
  local matches = {}
  for i, ts in ipairs(types) do
    local v = select(i, ...)
    local ok = false

    if ts == '*' then
      ok = v ~= nil
      table.insert(matches, type(v))
    else
      for t in string.gmatch(ts, '([^|]+)') do
        local fn = typefn[t] or fieldtype
        local got = fn(v)
        if got == t then
          ok = true
          table.insert(matches, got)
          break
        end
      end
    end

    if not ok then
      error(string.format('bad argument #%d (%s expected, got %s)', i, ts, type(v)))
    end
  end
  return matches
end

return M

A  => test/main.lua +75 -0
@@ 1,75 @@
local lu = require 'luaunit'
local tcheck = require 'tcheck'

TestTcheck = {}
function TestTcheck.test_single_ok()
  local got = tcheck('string', 'a')
  lu.assertEquals(#got, 1)
  lu.assertEquals(got[1], 'string')
end

function TestTcheck.test_single_fail()
  lu.assertErrorMsgContains('bad argument #1', function()
    tcheck('string', 1)
  end)
end

function TestTcheck.test_many_ok()
  local got = tcheck({'string', 'number', 'boolean'}, 'a', 1.23, true)
  lu.assertEquals(#got, 3)
  lu.assertEquals(got, {'string', 'number', 'boolean'})
end

function TestTcheck.test_many_fail()
  lu.assertErrorMsgContains('bad argument #2', function()
    tcheck({'string', 'number', 'boolean'}, 'a', '4', true)
  end)
end

function TestTcheck.test_single_choice_ok()
  local got = tcheck('string|nil|number', 4)
  lu.assertEquals(#got, 1)
  lu.assertEquals(got[1], 'number')
end

function TestTcheck.test_single_choice_fail()
  lu.assertErrorMsgContains('bad argument #1', function()
    tcheck('string|nil|number', {})
  end)
end

function TestTcheck.test_many_extra()
  local got = tcheck('string|nil|number', 4, 'a', true)
  lu.assertEquals(#got, 1)
  lu.assertEquals(got[1], 'number')
end

function TestTcheck.test_any_ok()
  local got = tcheck('*', 4)
  lu.assertEquals(#got, 1)
  lu.assertEquals(got[1], 'number')
end

function TestTcheck.test_any_fail()
  lu.assertErrorMsgContains('bad argument #1', function()
    tcheck('*', nil)
  end)
end

local Class = {__name = 'Class'}
function TestTcheck.test_many_mt_name()
  local o = {}
  setmetatable(o, Class)
  local got = tcheck({'string|file|integer', 'boolean|Class'}, 3, o)
  lu.assertEquals(got, {'integer', 'Class'})
end

local Type = {__type = 'Type'}
function TestTcheck.test_many_mt_type()
  local o = {}
  setmetatable(o, Type)
  local got = tcheck({'string|file|integer', 'boolean|Class|Type'}, io.input(), o)
  lu.assertEquals(got, {'file', 'Type'})
end

os.exit(lu.LuaUnit.run())