~sirn/fanboi2

ref: 834edf0edc5dd633c0ecea16231b6ed2d728476d fanboi2/fanboi2/tests/__init__.py -rw-r--r-- 3.1 KiB
834edf0eKridsada Thanabulpong Massive cleanup in preparation for 0.30 (#25) 3 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os

from sqlalchemy.engine import create_engine
from sqlalchemy.orm import sessionmaker

from ..models import make_history_event


DATABASE_URL = os.environ.get(
    'POSTGRESQL_TEST_DATABASE',
    'postgresql://fanboi2:@localhost:5432/fanboi2_test')

engine = create_engine(DATABASE_URL)
dbmaker = sessionmaker()
make_history_event(dbmaker)


def make_cache_region(store=None):
    from dogpile.cache import make_region
    if store is None:
        store = {}
    return make_region().configure(
        'dogpile.cache.memory',
        arguments={'cache_dict': store})


def mock_service(request, mappings={}):
    def _find_service(iface=None, name=None):
        for l in (iface, name):
            if l in mappings:
                return mappings[l]

    request.find_service = _find_service
    return request


class DummyRedis(object):

    def __init__(self):
        self._store = {}
        self._expire = {}

    def get(self, key):
        return self._store.get(key)

    def set(self, key, value):
        try:
            value = bytes(value.encode('utf-8'))
        except AttributeError:
            pass
        self._store[key] = value

    def setnx(self, key, value):
        if not self.get(key):
            self.set(key, value)

    def exists(self, key):
        return key in self._store

    def expire(self, key, time):
        self._expire[key] = time

    def ttl(self, key):
        return self._expire.get(key, 0)

    def _reset(self):
        self._store = {}
        self._expire = {}


class DummyAsyncResult(object):
    def __init__(self, id_, status, result=None):
        self._id = id_
        self._status = status
        self._result = result

    @property
    def id(self):
        return self._id

    @property
    def status(self):
        return self._status.upper()

    @property
    def state(self):
        from celery import states
        return getattr(states, self.status)

    def get(self):
        return self._result


class ModelTransactionEngineMixin(object):

    def setUp(self):
        super(ModelTransactionEngineMixin, self).setUp()
        self.connection = engine.connect()
        self.tx = self.connection.begin()

    def tearDown(self):
        super(ModelTransactionEngineMixin, self).tearDown()
        self.tx.rollback()
        self.connection.close()


class ModelSessionMixin(ModelTransactionEngineMixin, object):

    def setUp(self):
        super(ModelSessionMixin, self).setUp()
        from sqlalchemy import event
        from ..models import Base
        self.dbsession = dbmaker(bind=self.connection)
        Base.metadata.bind = self.connection
        Base.metadata.create_all()
        self.dbsession.begin_nested()

        @event.listens_for(self.dbsession, "after_transaction_end")
        def restart_savepoint(dbsession, transaction):
            if transaction.nested and not transaction._parent.nested:
                dbsession.expire_all()
                dbsession.begin_nested()

    def tearDown(self):
        super(ModelSessionMixin, self).tearDown()
        self.dbsession.close()

    def _make(self, model_obj):
        self.dbsession.add(model_obj)
        return model_obj