~amirouche/python-okvs

a2e9bbcf834b25da5c2ad0b71c4b2b63ec77d5fc — Amirouche 1 year, 2 months ago e6f84f3 v0.1.0
initil version
4 files changed, 706 insertions(+), 0 deletions(-)

A okdb/__init__.py
A okdb/nstore2.py
A pyproject.toml
A tests.py
A okdb/__init__.py => okdb/__init__.py +272 -0
@@ 0,0 1,272 @@
import sqlite3
import struct


# https://sqlite.org/lang_createtable.html
SQL_MAKE_DB = """CREATE TABLE IF NOT EXISTS db (key BLOB, value BLOB);"""
# https://sqlite.org/lang_createindex.html
SQL_MAKE_DB_INDEX_ASC = """CREATE UNIQUE INDEX IF NOT EXISTS asc ON db (key ASC);"""
SQL_MAKE_DB_INDEX_DESC = """CREATE UNIQUE INDEX IF NOT EXISTS desc ON db (key DESC);"""


def open(path):
    cnx = sqlite3.connect(path)
    with cnx:
        cnx.execute(SQL_MAKE_DB)
        cnx.execute(SQL_MAKE_DB_INDEX_ASC)
        cnx.execute(SQL_MAKE_DB_INDEX_DESC)
    return cnx


def close(cnx):
    cnx.close()


def txn(cnx, func, *args, **kwargs):
    try:
        func(cnx, *args, **kwargs)
    except Exception:
        cnx.rollback()
        raise
    finally:
        cnx.commit()


SQL_QUERY_GET = """SELECT value FROM db WHERE key=?"""


def get(cnx, key):
    cursor = cnx.execute(SQL_QUERY_GET, (key,))
    out = cursor.fetchone()
    out = None if out is None else out[0]
    return out


SQL_QUERY_SET = """INSERT INTO db (key, value) VALUES(?, ?)"""


def set(cnx, key, value):
    cnx.execute(SQL_QUERY_SET, (key, value))


SQL_QUERY_E = """SELECT key FROM db WHERE key = ?"""


def e(cnx, key):
    cursor = cnx.execute(SQL_QUERY_E, (key,))
    out = cursor.fetchone()
    return out[0] if out is not None else None


SQL_QUERY_LT = """SELECT key FROM db WHERE key < ?"""


def lt(cnx, key):
    cursor = cnx.execute(SQL_QUERY_LT, (key,))
    out = cursor.fetchone()
    return out[0] if out is not None else None


SQL_QUERY_GT = """SELECT key FROM db WHERE key > ?"""


def gt(cnx, key):
    cursor = cnx.execute(SQL_QUERY_GT, (key,))
    out = cursor.fetchone()
    return out[0] if out is not None else None


def near(cnx, key):
    out = e(cnx, key)
    if out is not None:
        return 0, out
    out = lt(cnx, key)
    if out is not None:
        return -1, out
    out = gt(cnx, key)
    if out is not None:
        return 1, out
    return None, None


SQL_QUERY_ASC = """
SELECT key, value
FROM db WHERE key >= ? AND key < ?
ORDER BY key ASC
LIMIT ?
OFFSET ?
"""


SQL_QUERY_DESC = """
SELECT key, value
FROM db WHERE key >= ? AND key < ?
ORDER BY key DESC
LIMIT ?
OFFSET ?
"""


def query(cnx, key, other, limit=-1, offset=0):
    if key < other:
        query = SQL_QUERY_ASC
    else:
        query = SQL_QUERY_DESC
        key, other = other, key

    for item in cnx.execute(query, (key, other, limit, offset)):
        yield item


SQL_QUERY_COUNT = """SELECT COUNT(key) FROM db WHERE key >= ? AND key < ?"""


def count(cnx, key, other):
    cursor = cnx.execute(SQL_QUERY_COUNT, (key, other))
    out = cursor.fetchone()[0]
    return out


SQL_QUERY_SIZE = """SELECT SUM(length(key)) + SUM(length(value)) FROM db WHERE key >= ? AND key < ?"""


def size(cnx, key, other):
    cursor = cnx.execute(SQL_QUERY_SIZE, (key, other))
    out = cursor.fetchone()[0]
    return out


SQL_QUERY_DELETE = """DELETE FROM db WHERE key = ?"""
SQL_QUERY_DELETE_RANGE = """DELETE FROM db WHERE key >= ? AND key < ?"""


def delete(cnx, key, other=None):
    if other is None:
        cnx.execute(SQL_QUERY_DELETE, (key,))
    else:
        cnx.execute(SQL_QUERY_DELETE_RANGE, (key, other))


# pack and unpack

_size_limits = tuple((1 << (i * 8)) - 1 for i in range(9))

# Define type codes:
BYTES_CODE = 0x01
FALSE_CODE = 0x02
NESTED_CODE = 0x08
NULL_CODE = 0x00
INTEGER_NEGATIVE_CODE = 0x04
INTEGER_ZERO = 0x05
INTEGER_POSITIVE_CODE = 0x06
STRING_CODE = 0x07
TRUE_CODE = 0x03

INTEGER_MAX = struct.unpack('>Q', b'\xff' * 8)[0]


def _find_terminator(v, pos):
    # Finds the start of the next terminator [\x00]![\xff] or the end of v
    while True:
        pos = v.find(b"\x00", pos)
        if pos < 0:
            return len(v)
        if pos + 1 == len(v) or v[pos + 1:pos + 2] != b"\xff":
            return pos
        pos += 2


def _decode(v, pos):
    code = v[pos]
    if code == NULL_CODE:
        return None, pos + 1
    elif code == BYTES_CODE:
        end = _find_terminator(v, pos + 1)
        return v[pos + 1 : end].replace(b"\x00\xFF", b"\x00"), end + 1
    elif code == STRING_CODE:
        end = _find_terminator(v, pos + 1)
        return v[pos + 1 : end].replace(b"\x00\xFF", b"\x00").decode("utf-8"), end + 1
    elif code == INTEGER_ZERO:
        return 0, pos + 1
    elif code == INTEGER_NEGATIVE_CODE:
        end = pos + 1 + 8
        value = struct.unpack(">Q", v[pos + 1 : end])[0] - INTEGER_MAX
        return value, end
    elif code == INTEGER_POSITIVE_CODE:
        end = pos + 1 + 8
        value = struct.unpack(">Q", v[pos + 1 : end])[0]
        return value, end
    elif code == FALSE_CODE:
        return False, pos + 1
    elif code == TRUE_CODE:
        return True, pos + 1
    elif code == NESTED_CODE:
        ret = []
        end_pos = pos + 1
        while end_pos < len(v):
            if v[end_pos] == 0x00:
                if end_pos + 1 < len(v) and v[end_pos + 1] == 0xFF:
                    ret.append(None)
                    end_pos += 2
                else:
                    break
            else:
                val, end_pos = _decode(v, end_pos)
                ret.append(val)
        return tuple(ret), end_pos + 1
    else:
        raise ValueError("Unknown data type in DB: " + repr(v))


def _encode(value, nested=False):
    if value is None:
        if nested:
            return bytes((NULL_CODE, 0xFF))
        else:
            return bytes((NULL_CODE,))
    elif isinstance(value, bool):
        if value:
            return bytes((TRUE_CODE,))
        else:
            return bytes((FALSE_CODE,))
    elif isinstance(value, bytes):
        return bytes((BYTES_CODE,)) + value.replace(b"\x00", b"\x00\xFF") + b"\x00"
    elif isinstance(value, str):
        return (
            bytes((STRING_CODE,))
            + value.encode("utf-8").replace(b"\x00", b"\x00\xFF")
            + b"\x00"
        )
    elif value == 0:
        return bytes((INTEGER_ZERO,))
    elif isinstance(value, int):
        if value > 0:
            out = bytes((INTEGER_POSITIVE_CODE,)) + struct.pack('>Q', value)
            return out
        else:
            value = INTEGER_MAX + value
            out = bytes((INTEGER_NEGATIVE_CODE,)) + struct.pack('>Q', value)
            return out
    elif isinstance(value, (tuple, list)):
        child_bytes = list(map(lambda x: _encode(x, True), value))
        return b''.join([bytes((NESTED_CODE,))] + child_bytes + [bytes((0x00,))])
    else:
        raise ValueError("Unsupported data type: {}".format(type(value)))


def pack(t):
    return b"".join((_encode(x) for x in t))


def unpack(key):
    pos = 0
    res = []
    while pos < len(key):
        r, pos = _decode(key, pos)
        res.append(r)
    return tuple(res)


def next_prefix(x):
    x = x.rstrip(b"\xff")
    return x[:-1] + bytes((x[-1] + 1,))

A okdb/nstore2.py => okdb/nstore2.py +188 -0
@@ 0,0 1,188 @@
# Compute the minimal set of indices required to bind any n-pattern
# in one hop.
#
# Based on https://stackoverflow.com/a/55148433/140837
#
# Taken from hoply.
import itertools
from math import factorial

from immutables import Map

import okdb as db


bc = lambda n, k: factorial(n) // factorial(k) // factorial(n - k) if k < n else 0


def stringify(iterable):
    return "".join(str(x) for x in iterable)


def combinations(tab):
    out = []
    for i in range(1, len(tab) + 1):
        out.extend(stringify(x) for x in itertools.combinations(tab, i))
    assert len(out) == 2 ** len(tab) - 1
    return out


def ok(solutions, tab):
    """Check that SOLUTIONS of TAB is a correct solution"""
    cx = combinations(tab)

    px = [stringify(x) for x in itertools.permutations(tab)]

    for combination in cx:
        pcx = ["".join(x) for x in itertools.permutations(combination)]
        # check for existing solution
        for solution in solutions:
            if any(solution.startswith(p) for p in pcx):
                # yeah, there is an existing solution
                break
        else:
            print("failed with combination={}".format(combination))
            break
    else:
        return True
    return False


def _compute_indices(n):
    tab = list(range(n))
    cx = list(itertools.combinations(tab, n // 2))
    for c in cx:
        L = [(i, i in c) for i in tab]
        A = []
        B = []
        while True:
            for i in range(len(L) - 1):
                if (not L[i][1]) and L[i + 1][1]:
                    A.append(L[i + 1][0])
                    B.append(L[i][0])
                    L.remove((L[i + 1][0], True))
                    L.remove((L[i][0], False))
                    break
            else:
                break
        l = [i for (i, _) in L]
        yield tuple(A + l + B)


def compute_indices(n):
    return list(_compute_indices(n))


# Taken from hoply/hoply.py

class Variable:

    __slots__ = ("name",)

    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return "<var %r>" % self.name


var = Variable


def stringify(list):
    return "".join(str(x) for x in list)


def is_permutation_prefix(combination, index):
    index = stringify(index)
    out = any(index.startswith(stringify(x)) for x in itertools.permutations(combination))
    return out


def init(name, prefix, n):
    out = dict(
        type="dbext",
        subtype="nstore2",
        name=name,
        prefix=prefix,
        indices=compute_indices(n),
    )
    return out


def add(cnx, nstore2, *items, value=None):
    for subspace, index in enumerate(nstore2["indices"]):
        permutation = list(items[i] for i in index)
        key = nstore2["prefix"] + [subspace] + permutation
        db.set(cnx, db.pack(key), value)


def delete(cnx, nstore2, *items):
    for subspace, index in enumerate(nstore2['indices']):
        permutation = list(items[i] for i in index)
        key = nstore2['prefix'] + [subspace] + permutation
        db.delete(cnx, db.pack(key))


def get(cnx, nstore2, *items):
    subspace = 0
    key = nstore2["prefix"] + [subspace] + list(items)
    out = db.get(cnx, db.pack(key))
    return out


def _from(cnx, nstore2, *pattern, seed=Map()):  # seed is immutable
    variable = tuple(isinstance(x, Variable) for x in pattern)
    # find the first index suitable for the query
    combination = tuple(x for x in range(len(pattern)) if not variable[x])
    for subspace, index in enumerate(nstore2["indices"]):
        if is_permutation_prefix(combination, index):
            break
    else:
        raise Exception("Oops! Mathematics failed!")
    # `index` variable holds the permutation suitable for the
    # query. `subspace` is the "prefix" of that index.
    prefix = list(pattern[i] for i in index if not isinstance(pattern[i], Variable))
    prefix = nstore2["prefix"] + [subspace] + prefix
    prefix = db.pack(prefix)
    db.unpack(prefix)
    for key, _ in db.query(cnx, prefix, db.next_prefix(prefix)):
        items = db.unpack(key)[len(nstore2["prefix"]) + 1:]
        # re-order the items
        items = tuple(items[index.index(i)] for i in range(len(pattern)))
        bindings = seed
        for i, item in enumerate(pattern):
            if isinstance(item, Variable):
                bindings = bindings.set(item.name, items[i])
        yield bindings


def _where(cnx, nstore2, iterator, pattern):
    for bindings in iterator:
        # bind PATTERN against BINDINGS
        bound = []
        for item in pattern:
            # if ITEM is variable try to bind
            if isinstance(item, Variable):
                try:
                    value = bindings[item.name]
                except KeyError:
                    # no bindings
                    bound.append(item)
                else:
                    # pick the value in bindings
                    bound.append(value)
            else:
                # otherwise keep item as is
                bound.append(item)
        # hey!
        yield from _from(cnx, nstore2, *bound, seed=bindings)


def query(cnx, nstore2, *patterns):
    out = _from(cnx, nstore2, *patterns[0])

    for pattern in patterns[1:]:
        out = _where(cnx, nstore2, out, pattern)

    return out

A pyproject.toml => pyproject.toml +15 -0
@@ 0,0 1,15 @@
[tool.poetry]
name = "okdb"
version = "0.1.0"
description = "Simple database with which I am productive"
authors = ["Amirouche <amirouche@hyper.dev>"]
license = "MIT"

[tool.poetry.dependencies]
python = "^3.7"

[tool.poetry.dev-dependencies]

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

A tests.py => tests.py +231 -0
@@ 0,0 1,231 @@
import os
from tempfile import mkstemp
from contextlib import contextmanager

import okdb as db
from okdb import nstore2


@contextmanager
def tempfile():
    fd, filename = mkstemp(".sqlite3", "lm4k-tests-", "/tmp/")
    os.close(fd)
    yield filename
    os.remove(filename)


def test_db_open():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        db.close(cnx)


def test_db_get():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        assert db.get(cnx, b'\x42') is None
        db.close(cnx)

def test_db_setget():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        assert db.set(cnx, b'\x42', b'\x42') is None
        assert db.get(cnx, b'\x42') == b'\x42'
        db.close(cnx)


def test_db_query():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        assert len(list(db.query(cnx, b'\x00', b'\xFF'))) == 0
        db.close(cnx)


def test_db_setquery():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x01', b''),
            (b'\x02', b''),
            (b'\x03', b''),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        assert list(db.query(cnx, b'\x00', b'\xFF')) == expected
        db.close(cnx)

def test_db_setquery_desc():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x01', b''),
            (b'\x02', b''),
            (b'\x03', b''),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        assert list(db.query(cnx, b'\xFF', b'\x00')) == list(reversed(expected))
        db.close(cnx)


def test_db_query_limit_offset_desc():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x01', b''),
            (b'\x02', b''),
            (b'\x03', b''),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        out = db.query(cnx, b'\x03', b'\x00', limit=1, offset=1)
        assert list(out) == [(b'\x01', b'')]
        db.close(cnx)


def test_db_count():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x01', b''),
            (b'\x02', b''),
            (b'\x03', b''),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        assert db.count(cnx, b'\x00', b'\xFF') == 3
        db.close(cnx)


def test_db_size():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x01', b'\x01'),
            (b'\x02', b'\x01'),
            (b'\x03', b'\x01'),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        assert db.size(cnx, b'\x00', b'\xFF') == 6
        db.close(cnx)


def test_db_count2():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x00', b''),
            (b'\x01', b''),
            (b'\x02', b''),
            (b'\x03', b''),
            (b'\xFF', b''),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        assert db.count(cnx, b'\x01', b'\xFF') == 3
        db.close(cnx)


def test_db_size2():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = [
            (b'\x00', b'\x01'),
            (b'\x01', b'\x01'),
            (b'\x02', b'\x01'),
            (b'\x03', b'\x01'),
            (b'\xFF', b'\x01'),
        ]
        for k, v in expected:
            db.set(cnx, k, v)
        assert db.size(cnx, b'\x01', b'\xFF') == 6
        db.close(cnx)


def test_db_delete():
    with tempfile() as filepath:
        cnx = db.open(filepath)
        kvs = [
            (b'\x00', b'\x01'),
            (b'\x01', b'\x01'),
            (b'\x02', b'\x01'),
            (b'\x03', b'\x01'),
        ]
        for k, v in kvs:
            db.set(cnx, k, v)
        assert db.count(cnx, b'\x00', b'\xFF') == 4
        # GIVEN
        db.delete(cnx, b'\x00')
        # THEN
        assert db.count(cnx, b'\x00', b'\xFF') == 3
        # GIVEN
        db.delete(cnx, b'\x00', b'\xFF')
        # THEN
        assert db.count(cnx, b'\x00', b'\xFF') == 0
        db.close(cnx)


def test_db_packunpack():
    expected = (None, True, False, 42, 0, -42, "abcdef", b'\xc0ff33', (1337, None))
    assert db.unpack(db.pack(expected)) == expected


def test_db_sortedpack():
    given = (
        ((-100,), (0,)),
        ((0,), (100,)),
        ((42,), (43,)),
        ((-42,), (42,)),
        ((-52,), (-42,)),
        (("car",), ("care",)),
        (("abc",), ("xyz",)),
        ((b'\x00',), (b'\xFF',)),
        ((None, (b'\x00',)), (None, (b'\xFF',))),
    )
    for item, other in given:
        assert item < other
        assert db.pack(item) < db.pack(other)


def test_nstore2_get():
    ns = nstore2.init("test", [b'\x01'], 3)
    with tempfile() as filepath:
        cnx = db.open(filepath)
        out = nstore2.get(cnx, ns, 1, 2, 3)
        assert out is None


def test_nstore2_add_get():
    ns = nstore2.init("test", [b'\x01'], 3)
    with tempfile() as filepath:
        cnx = db.open(filepath)
        expected = b'\x42'
        nstore2.add(cnx, ns, 1, 2, 3, value=expected)
        out = nstore2.get(cnx, ns, 1, 2, 3)
        assert out == expected

def test_nstore2_add_delete_get():
    ns = nstore2.init("test", [b'\x01'], 3)
    with tempfile() as filepath:
        cnx = db.open(filepath)
        nstore2.add(cnx, ns, 1, 2, 3, b'\x42')
        nstore2.delete(cnx, ns, 1, 2, 3)
        out = nstore2.get(cnx, ns, 1, 2, 3)
        assert out is None


def test_nstore2_query():
    ns = nstore2.init("test", [b'\x01'], 3)
    with tempfile() as filepath:
        cnx = db.open(filepath)
        for i in range(10):
            nstore2.add(cnx, ns, 0, 0, i, b'\x42')
        for i in range(10):
            nstore2.add(cnx, ns, 1, 1, i, b'\x42')
        for i in range(10):
            nstore2.add(cnx, ns, 2, 2, i, b'\x42')
        out = nstore2.query(cnx, ns, (1, 1, nstore2.var('i')))
        out = [x.get('i') for x in out]
        expected = list(range(10))
        assert out == expected