~mna/luafn

211cf875927372f0b6fcf6e290c2b0fe640fce56 — Martin Angers 4 months ago 286d80d
Add zip/unzip, collectarray, collectkv, pack, select, tests\
M .luacheckrc => .luacheckrc +1 -1
@@ 1,2 1,2 @@
std = 'lua53'
std = 'lua54'
files['*_test.lua'].allow_defined_top = true

M fn.lua => fn.lua +148 -21
@@ 1,5 1,12 @@
local M = {}

-- Runs a single step of the iterator, returning all the values it produced as
-- an array with field "n" set. The control variable should be modified
-- accordingly by the caller.
local function step_iterator(it, inv, ctl)
	return table.pack(it(inv, ctl))
end

-- Return an iterator that generates all integers starting at from
-- and ending at to, inclusive.
function M.fromto(from, to)


@@ 60,16 67,16 @@ function M.pipe(...)
  end
end

-- Filter iterator it by keeping only items that satisfy predicate p.
-- Filter iterator "it" by keeping only items that satisfy predicate p.
-- Return a new iterator that applies the filter.
-- If it is nil, returns a partially-applied function with the predicate
-- If "it" is nil, returns a partially-applied function with the predicate
-- set.
function M.filter(p, it, inv, ctl)
  if it == nil then return M.partial(M.filter, p) end

  return function()
    while true do
      local res = table.pack(it(inv, ctl))
      local res = step_iterator(it, inv, ctl)
      ctl = res[1]
      if ctl == nil then return nil end



@@ 84,13 91,13 @@ end
-- returned values instead of the original ones. Note that returning
-- nil from f as first value end the iterator.
-- Return a new iterator that applies the map.
-- If it is nil, returns a partially-applied function with the map
-- If "it" is nil, returns a partially-applied function with the map
-- function set.
function M.map(f, it, inv, ctl)
  if it == nil then return M.partial(M.map, f) end

  return function()
    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return nil end



@@ 101,7 108,7 @@ end
-- Reduce iterator it by calling fn on each iteration with the
-- accumulator cumul and all values returned for this iteration.
-- Return the final value of the accumulator.
-- If it is nil, returns a partially-applied function with the
-- If "it" is nil, returns a partially-applied function with the
-- reduce function and, if provided, the accumulator value.
function M.reduce(f, cumul, it, inv, ctl)
  if it == nil then


@@ 110,7 117,7 @@ function M.reduce(f, cumul, it, inv, ctl)
  end

  while true do
    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return cumul end
    cumul = f(cumul, table.unpack(res, 1, res.n))


@@ 119,7 126,7 @@ end

-- Take the first n results of iterator it.
-- Return a new iterator that takes at most those first n results.
-- If it is nil, returns a partially-applied function with the n
-- If "it" is nil, returns a partially-applied function with the n
-- value set.
function M.taken(n, it, inv, ctl)
  if it == nil then return M.partial(M.taken, n) end


@@ 128,7 135,7 @@ function M.taken(n, it, inv, ctl)
    if n <= 0 then return nil end

    n = n - 1
    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return nil end
    return table.unpack(res, 1, res.n)


@@ 138,7 145,7 @@ end
-- Take the iterator's it results while the predicate p returns true.
-- The predicate is called with the values of each iteration.
-- Return a new iterator that applies the take while condition.
-- If it is nil, returns a partially-applied function with the predicate
-- If "it" is nil, returns a partially-applied function with the predicate
-- p set.
function M.takewhile(p, it, inv, ctl)
  if it == nil then return M.partial(M.takewhile, p) end


@@ 147,7 154,7 @@ function M.takewhile(p, it, inv, ctl)
  return function()
    if stop then return nil end

    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return nil end



@@ 160,7 167,7 @@ end

-- Skip the first n results of iterator it.
-- Return a new iterator that skips those first n results.
-- If it is nil, returns a partially-applied function with the n
-- If "it" is nil, returns a partially-applied function with the n
-- value set.
function M.skipn(n, it, inv, ctl)
  if it == nil then return M.partial(M.skipn, n) end


@@ 171,7 178,7 @@ function M.skipn(n, it, inv, ctl)
      n = n - 1
      if ctl == nil then return nil end
    end
    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return nil end
    return table.unpack(res, 1, res.n)


@@ 181,7 188,7 @@ end
-- Skip the iterator's it results while the predicate p returns true.
-- The predicate is called with the values of each iteration.
-- Return a new iterator that applies the skip while condition.
-- If it is nil, returns a partially-applied function with the predicate
-- If "it" is nil, returns a partially-applied function with the predicate
-- p set.
function M.skipwhile(p, it, inv, ctl)
  if it == nil then return M.partial(M.skipwhile, p) end


@@ 189,7 196,7 @@ function M.skipwhile(p, it, inv, ctl)
  local skipping = true
  return function()
    while skipping do
      local res = table.pack(it(inv, ctl))
      local res = step_iterator(it, inv, ctl)
      ctl = res[1]
      if ctl == nil then return nil end
      if not p(table.unpack(res, 1, res.n)) then


@@ 198,7 205,7 @@ function M.skipwhile(p, it, inv, ctl)
      end
    end

    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return nil end
    return table.unpack(res, 1, res.n)


@@ 210,14 217,14 @@ end
-- iteration that returned true and all its values.
-- It returns false as the only value if the iteration is completed
-- without p returning true.
-- If it is nil, returns a partially-applied function with the predicate
-- If "it" is nil, returns a partially-applied function with the predicate
-- p set.
function M.any(p, it, inv, ctl)
  if it == nil then return M.partial(M.any, p) end

  local ix = 0
  while true do
    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return false end
    ix = ix + 1


@@ 233,14 240,14 @@ end
-- iteration that returned false and all its values.
-- It returns true as the only value if the iteration is completed without
-- p returning false.
-- If it is nil, returns a partially-applied function with the predicate
-- If "it" is nil, returns a partially-applied function with the predicate
-- p set.
function M.all(p, it, inv, ctl)
  if it == nil then return M.partial(M.all, p) end

  local ix = 0
  while true do
    local res = table.pack(it(inv, ctl))
    local res = step_iterator(it, inv, ctl)
    ctl = res[1]
    if ctl == nil then return true end
    ix = ix + 1


@@ 271,7 278,7 @@ function M.concat(...)
      if t == nil then return nil end

      local it, inv, ctl = t[1], t[2], t[3]
      local res = table.pack(it(inv, ctl))
      local res = step_iterator(it, inv, ctl)
      t[3] = res[1]

      if t[3] == nil then


@@ 283,6 290,126 @@ function M.concat(...)
  end
end

-- Zip returns an iterator that returns the first value of all iterators at
-- each iteration step. All iterators are iterated together at the same time,
-- and iteration ends when the first iterator ends. If other iterators end
-- earlier, the nil value is returned for this iterator.
--
-- The arguments must be provided as a list of tables, each table an array
-- containing the iterator tuple (see documentation for concat for more
-- details).
function M.zip(...)
  local its = table.pack(...)

  return function()
		local res = {n=0}
		for i, it in ipairs(its) do
			local step = {}

			if not it.done then
				step = step_iterator(it[1], it[2], it[3])
				it[3] = step[1]
				if it[3] == nil then
					if i == 1 then return nil end
					it.done = true
				end
			end

			res.n = res.n + 1
			res[res.n] = step[1]
		end
		return table.unpack(res, 1, res.n)
  end
end

-- Unzip takes a single iterator and returns a new iterator that produces a
-- single value on each iteration. The original iterator advances only when
-- all its returned values for a given step have been returned as single-value
-- iteration steps. Note that any nil value in the values returned by the
-- original iterator will stop the new iterator early, as nil are possible
-- only when not the first return value in a Lua iterator (otherwise they
-- indicate the end of iteration).
function M.unzip(it, inv, ctl)
	local remain = {n=0}
	return function()
		while true do
			if remain.n > 0 then
				remain.n = remain.n - 1
				return table.remove(remain, 1)
			end

			local res = step_iterator(it, inv, ctl)
			ctl = res[1]
			if ctl == nil then return nil end
			remain = res
		end
	end
end

-- Select takes a single iterator and returns a new iterator that produces
-- the value(s) of the original iterator specified by n, which can be:
--     * A number, indicating the 1-based index of the value to select
--     * An array, indicating the 1-based indices of the values to select,
--       returned in the array's order.
--     * A function that will receive the original iterator's values and
--       return its returned values instead (same as map).
--
-- It is a specialized form of map. If "it" is nil, returns a partially-applied
-- function with "n" set.
function M.select(n, it, inv, ctl)
	local typ = type(n)
	if typ == 'function' then
		return M.map(n, it, inv, ctl)
	end
	if typ == 'number' then
		n = {n}
	end
	return M.map(function(...)
		local res = {}
		for _, ix in ipairs(n) do
			table.insert(res, (select(ix, ...)))
		end
		return table.unpack(res, 1, #n)
	end, it, inv, ctl)
end

-- Collects the first value of the iterator into an array, appending to t on
-- each iteration. If t is nil, a new table is created. If "it" is nil, returns
-- a partially-applied function with "t" set. This function consumes the
-- iterator and returns t, it is a special case of reduce.
--
-- To collect multiple values from the iterator in an array, pipe from pack,
-- select or map.
function M.collectarray(t, it, inv, ctl)
	return M.reduce(function(cumul, v)
		table.insert(cumul, v)
		return cumul
	end, t or {}, it, inv, ctl)
end

-- Collects the first two values of the iterator in a table, the first value
-- being used as the key and the second as the value. If t is nil, a new table
-- is created. If "it" is nil, returns a partially-applied function with "t"
-- set. This function consumes the iterator and returns t, it is a special case
-- of reduce.
--
-- To rearrange order of the iterator's values, see select.
function M.collectkv(t, it, inv, ctl)
	return M.reduce(function(cumul, k, v)
		cumul[k] = v
		return cumul
	end, t or {}, it, inv, ctl)
end

-- Packs takes an iterator and returns a new iterator that packs all values
-- from the original iterator into an array and returns that array as iteration
-- value instead. It is a specialized form of map.
function M.pack(it, inv, ctl)
	return M.map(function(...)
		return table.pack(...)
	end, it, inv, ctl)
end

-- Callmethod calls the method m on table t, passing the args
-- an any additional arguments received after t. The args
-- parameter is treated as a "packed" table, it is unpacked when

M fn_test.lua => fn_test.lua +6 -0
@@ 3,18 3,24 @@ local lu = require 'luaunit'
TestAll = require 'tests.all'
TestAny = require 'tests.any'
TestCallMethod = require 'tests.callmethod'
TestCollectArray = require 'tests.collectarray'
TestCollectKV = require 'tests.collectkv'
TestConcat = require 'tests.concat'
TestFilter = require 'tests.filter'
TestFromTo = require 'tests.fromto'
TestMap = require 'tests.map'
TestPack = require 'tests.pack'
TestPartial = require 'tests.partial'
TestPartialTrail = require 'tests.partialtrail'
TestPipe = require 'tests.pipe'
TestReduce = require 'tests.reduce'
TestSelect = require 'tests.select'
TestSkipn = require 'tests.skipn'
TestSkipWhile = require 'tests.skipwhile'
TestTaken = require 'tests.taken'
TestTakeWhile = require 'tests.takewhile'
TestUnzip = require 'tests.unzip'
TestUseCases = require 'tests.usecases'
TestZip = require 'tests.zip'

os.exit(lu.LuaUnit.run())

A tests/collectarray.lua => tests/collectarray.lua +30 -0
@@ 0,0 1,30 @@
local lu = require 'luaunit'
local fn = require 'fn'

local M = {}

function M.test_empty()
	local got = fn.collectarray(nil, fn.fromto(5, 4))
	lu.assertEquals(got, {})
end

function M.test_simple()
	local got = fn.collectarray(nil, fn.fromto(1, 4))
	lu.assertEquals(got, {1, 2, 3, 4})
end

function M.test_mult_piped()
	local ar = {'a', 'b', 'c'}
	local pfn = fn.pipe(fn.collectarray({}))
	local got = pfn(ipairs(ar))
	lu.assertEquals(got, {1, 2, 3})
end

function M.test_append()
	local ar = {'a', 'b', 'c'}
	local pfn = fn.pipe(fn.collectarray({10, 20}))
	local got = pfn(ipairs(ar))
	lu.assertEquals(got, {10, 20, 1, 2, 3})
end

return M

A tests/collectkv.lua => tests/collectkv.lua +39 -0
@@ 0,0 1,39 @@
local lu = require 'luaunit'
local fn = require 'fn'

local M = {}

function M.test_empty()
	local got = fn.collectkv(nil, fn.fromto(5, 4))
	lu.assertEquals(got, {})
end

function M.test_simple_noval()
	local got = fn.collectkv(nil, fn.fromto(1, 4))
	lu.assertEquals(got, {})
end

function M.test_simple()
	local t = {x=1, y=2, z=3}
	local got = fn.collectkv(nil, pairs(t))
	lu.assertEquals(got, {x=1, y=2, z=3})
end

function M.test_many()
	local s = "from=world, to=Lua, last=val"
	local want = {from = '=', to = '=', last = '='}
	local got = fn.collectkv({}, string.gmatch(s, "(%w+)(=)(%w+)"))
	lu.assertEquals(got, want)
end

function M.test_pipe_append()
	local s = "from=world, to=Lua, last=val"
	local t = {other = 1}
	local pfn = fn.pipe(fn.select({3, 1}), fn.collectkv(t))
	local got = pfn(string.gmatch(s, "(%w+)(=)(%w+)"))

	local want = {world = 'from', Lua = 'to', val = 'last', other = 1}
	lu.assertEquals(got, want)
end

return M

A tests/pack.lua => tests/pack.lua +22 -0
@@ 0,0 1,22 @@
local lu = require 'luaunit'
local fn = require 'fn'

local M = {}

function M.test_empty()
	local got = fn.collectarray({}, fn.pack(fn.fromto(5, 4)))
	lu.assertEquals(got, {})
end

function M.test_simple()
	local got = fn.collectarray({}, fn.pack(fn.fromto(1, 4)))
	lu.assertEquals(got, {{1, n=1}, {2, n=1}, {3, n=1}, {4, n=1}})
end

function M.test_many()
	local ar = {1, 2}
	local got = fn.collectarray({}, fn.pack(ipairs(ar)))
	lu.assertEquals(got, {{1, 1, n=2}, {2, 2, n=2}})
end

return M

A tests/select.lua => tests/select.lua +50 -0
@@ 0,0 1,50 @@
local lu = require 'luaunit'
local fn = require 'fn'

local M = {}

function M.test_empty()
	local got = fn.collectarray({}, fn.select(2, fn.fromto(5, 4)))
	lu.assertEquals(got, {})
end

function M.test_single_select_outbounds()
	local got = fn.collectarray({}, fn.select(2, fn.fromto(1, 4)))
	lu.assertEquals(got, {})
end

function M.test_select_one()
	local s = "from=world, to=Lua, last=val"

	local want = {'world', 'Lua', 'val'}
	local got = fn.collectarray({}, fn.select(3, string.gmatch(s, "(%w+)(=)(%w+)")))
	lu.assertEquals(got, want)
end

function M.test_select_negative()
	local s = "from=world, to=Lua, last=val"

	local want = {'=', '=', '='}
	local got = fn.collectarray({}, fn.select(-2, string.gmatch(s, "(%w+)(=)(%w+)")))
	lu.assertEquals(got, want)
end

function M.test_select_many()
	local s = "from=world, to=Lua, last=val"

	local want = {{'from', 'world', n=2}, {'to', 'Lua', n=2}, {'last', 'val', n=2}}
	local got = fn.collectarray({}, fn.pack(fn.select({1, 3}, string.gmatch(s, "(%w+)(=)(%w+)"))))
	lu.assertEquals(got, want)
end

function M.test_select_function()
	local s = "from=world, to=Lua, last=val"

	local want = {'morf', 'ot', 'tsal'}
	local got = fn.collectarray({}, fn.select(function(v)
		return string.reverse(v)
	end, string.gmatch(s, "(%w+)(=)(%w+)")))
	lu.assertEquals(got, want)
end

return M

A tests/unzip.lua => tests/unzip.lua +31 -0
@@ 0,0 1,31 @@
local testing = require 'tests.testing'
local lu = require 'luaunit'
local fn = require 'fn'

local M = {}

function M.test_empty()
  testing.forstats({count = 0, sum = 0},
    fn.unzip(fn.fromto(5, 4)))
end

function M.test_simple()
  testing.forstats({count = 4, sum = 10},
    fn.unzip(fn.fromto(1, 4)))
end

function M.test_dual()
	local want = {1, 'a', 2, 'b', 3, 'c'}
	local got = fn.collectarray({}, fn.unzip(ipairs({'a', 'b', 'c'})))
	lu.assertEquals(got, want)
end

function M.test_multiple()
	local s = "from=world, to=Lua, last=val"

	local want = {'from', '=', 'world', 'to', '=', 'Lua', 'last', '=', 'val'}
	local got = fn.collectarray({}, fn.unzip(string.gmatch(s, "(%w+)(=)(%w+)")))
	lu.assertEquals(got, want)
end

return M

A tests/zip.lua => tests/zip.lua +45 -0
@@ 0,0 1,45 @@
local testing = require 'tests.testing'
local lu = require 'luaunit'
local fn = require 'fn'

local M = {}

function M.test_none()
  testing.forstats({count = 0, sum = 0},
    fn.zip())
end

function M.test_empty()
  testing.forstats({count = 0, sum = 0},
    fn.zip({fn.fromto(5, 4)}))
end

function M.test_single()
  testing.forstats({count = 4, sum = 10},
    fn.zip({fn.fromto(1, 4)}))
end

function M.test_dual()
	local want = {{1, 10, n=2}, {2, 11, n=2}, {3, 12, n=2}, {4, 13, n=2}}
  local pipefn = fn.pipe(
		fn.zip,
		fn.pack,
		fn.collectarray({})
	)
	local got = pipefn({fn.fromto(1, 4)}, {fn.fromto(10, 20)})
	lu.assertEquals(got, want)
end

function M.test_multiple()
	local want = {{1, 10, 100, n=3}, {2, nil, 101, n=3}, {3, nil, nil, n=3}}
  local pipefn = fn.pipe(
		fn.zip,
		fn.pack,
		fn.collectarray({})
	)
	local got = pipefn({fn.fromto(1, 3)}, {fn.fromto(10, 10)}, {fn.fromto(100, 101)})
	lu.assertEquals(got, want)
end

return M