~homeworkprod/byceps

ref: bcef32cd2dc329bcf81057794883008f5c040ac8 byceps/byceps/database.py -rw-r--r-- 3.2 KiB
bcef32cd — Jochen Kupperschmidt Work around Jinja 3.0.0 bug with `for` inside of `set` block 5 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
byceps.database
~~~~~~~~~~~~~~~

Database utilities.

:Copyright: 2006-2021 Jochen Kupperschmidt
:License: Revised BSD (see `LICENSE` file for details)
"""

from __future__ import annotations
from typing import Any, Callable, Iterable, Optional, TypeVar
import uuid

from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.schema import Table

from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Query


F = TypeVar('F')
T = TypeVar('T')

Mapper = Callable[[F], T]


db = SQLAlchemy(session_options={'autoflush': False})


db.JSONB = JSONB


class Uuid(UUID):

    def __init__(self):
        super().__init__(as_uuid=True)


db.Uuid = Uuid


def generate_uuid() -> uuid.UUID:
    """Generate a random UUID (Universally Unique IDentifier)."""
    return uuid.uuid4()


def paginate(
    query: Query,
    page: int,
    per_page: int,
    *,
    item_mapper: Optional[Mapper] = None,
) -> Pagination:
    """Return `per_page` items from page `page`."""
    if page < 1:
        page = 1

    if per_page < 1:
        raise ValueError('The number of items per page must be positive.')

    offset = (page - 1) * per_page

    items = query \
        .limit(per_page) \
        .offset(offset) \
        .all()

    item_count = len(items)
    if page == 1 and item_count < per_page:
        total = item_count
    else:
        total = query.order_by(None).count()

    if item_mapper is not None:
        items = [item_mapper(item) for item in items]

    # Intentionally pass no query object.
    return Pagination(None, page, per_page, total, items)


def insert_ignore_on_conflict(table: Table, values: dict[str, Any]) -> None:
    """Insert the record identified by the primary key (specified as
    part of the values), or do nothing on conflict.
    """
    query = insert(table) \
        .values(**values) \
        .on_conflict_do_nothing(constraint=table.primary_key)

    db.session.execute(query)
    db.session.commit()


def upsert(
    table: Table, identifier: dict[str, Any], replacement: dict[str, Any]
) -> None:
    """Insert or update the record identified by `identifier` with value
    `replacement`.
    """
    execute_upsert(table, identifier, replacement)
    db.session.commit()


def upsert_many(
    table: Table,
    identifiers: Iterable[dict[str, Any]],
    replacement: dict[str, Any],
) -> None:
    """Insert or update the record identified by `identifier` with value
    `replacement`.
    """
    for identifier in identifiers:
        execute_upsert(table, identifier, replacement)

    db.session.commit()


def execute_upsert(
    table: Table, identifier: dict[str, Any], replacement: dict[str, Any]
) -> None:
    """Execute, but do not commit, an UPSERT."""
    query = _build_upsert_query(table, identifier, replacement)
    db.session.execute(query)


def _build_upsert_query(
    table: Table, identifier: dict[str, Any], replacement: dict[str, Any]
) -> Insert:
    values = identifier.copy()
    values.update(replacement)

    return insert(table) \
        .values(**values) \
        .on_conflict_do_update(
            constraint=table.primary_key,
            set_=replacement)