M src/groceries/sqldb.py => src/groceries/sqldb.py +15 -4
@@ 18,34 18,45 @@ class SqliteDB:
def _init(self):
with self._connect() as conn:
conn.execute(
+ "CREATE TABLE lists ("
+ "id INTEGER PRIMARY KEY, "
+ "name TEXT NOT NULL UNIQUE)"
+ )
+ conn.execute(
"CREATE TABLE items ("
+ "list_id INTEGER NOT NULL REFERENCES lists, "
"name TEXT, "
"section TEXT, "
"priority INTEGER, "
"shop TEXT)"
)
+ conn.execute("CREATE INDEX list_index ON items(list_id)")
+ conn.execute("INSERT INTO lists VALUES (1, 'Default')")
conn.commit()
def insert(self, item):
with self._connect() as conn:
conn.execute(
"INSERT INTO items "
- "(name, section, priority, shop) "
- "VALUES (?, ?, ?, ?)",
+ "(list_id, name, section, priority, shop) "
+ "VALUES (1, ?, ?, ?, ?)",
tuple(item),
)
conn.commit()
def select_all(self):
with self._connect() as conn:
- c = conn.execute("SELECT * FROM items")
+ c = conn.execute("SELECT name, section, priority, shop FROM items")
rows = c.fetchall()
return [Item(*row) for row in rows]
def select(self, item_id):
with self._connect() as conn:
- c = conn.execute("SELECT * FROM items WHERE name=?", (item_id,))
+ c = conn.execute(
+ "SELECT name, section, priority, shop FROM items WHERE name=?",
+ (item_id,),
+ )
row = c.fetchone()
return Item(*row) if row else None
M tests/conftest.py => tests/conftest.py +3 -2
@@ 32,6 32,7 @@ def groceries_file(tmpdir, items):
def create(conn):
conn.execute(
"CREATE TABLE items ("
+ "list_id INTEGER, "
"name TEXT, "
"section TEXT, "
"priority INTEGER, "
@@ 43,8 44,8 @@ def insert(conn, items):
for item in items:
conn.execute(
"INSERT INTO items "
- "(name, section, priority, shop) "
- "VALUES (?, ?, ?, ?)",
+ "(list_id, name, section, priority, shop) "
+ "VALUES (1, ?, ?, ?, ?)",
tuple(item),
)
M tests/test_sqldb.py => tests/test_sqldb.py +3 -3
@@ 24,7 24,7 @@ def test_insert_into_new_db(new_db):
item = Item("gooseberries", "produce", 2, "farmstall")
new_db.insert(item)
- query = "SELECT * FROM items WHERE name=?"
+ query = "SELECT name, section, priority, shop FROM items WHERE name=?"
conn = sqlite3.connect(new_db._path)
rows = conn.execute(query, (item.name,)).fetchall()
conn.close()
@@ 41,7 41,7 @@ def test_insert_into_existing_db(sql_db, items):
sql_db.insert(item)
conn = sqlite3.connect(sql_db._path)
- rows = conn.execute("SELECT * FROM items").fetchall()
+ rows = conn.execute("SELECT name, section, priority, shop FROM items").fetchall()
conn.close()
assert items + new_items == [Item(*r) for r in rows]
@@ 68,7 68,7 @@ def test_update(sql_db, items):
sql_db.update(name, item)
conn = sqlite3.connect(sql_db._path)
- rows = conn.execute("SELECT * FROM items").fetchall()
+ rows = conn.execute("SELECT name, section, priority, shop FROM items").fetchall()
conn.close()
assert items[:-2] + list(reversed(updates.values())) == [Item(*r) for r in rows]