M matteridge/api.py => matteridge/api.py +27 -10
@@ 290,17 290,23 @@ class MattermostClient:
async def get_user(self, user_id: str) -> User:
user = await self._get_user(user_id)
+ assert user.id
+ assert user.username
self._cache.add_user(self.base_url, user.id, user.username)
return user
async def get_user_by_username(self, username: str) -> User:
user = await self._get_user_by_username(username)
+ assert user.id
+ assert user.username
self._cache.add_user(self.base_url, user.id, user.username)
return user
async def get_users_by_ids(self, user_ids: list[str]) -> list[User]:
users = await self._get_users_by_ids(user_ids)
for u in users:
+ assert u.id
+ assert u.username
self._cache.add_user(self.base_url, u.id, u.username)
return users
@@ 321,7 327,9 @@ class MattermostClient:
async def get_other_username_from_direct_channel_id(
self, channel_id: str
) -> Optional[str]:
- cached = self._cache.get_by_direct_channel_id(self.base_url, channel_id)
+ cached = self._cache.get_user_by_direct_channel_id(
+ self.base_url, await self.mm_id, channel_id
+ )
if not cached:
return None
if not cached.username:
@@ 332,8 340,13 @@ class MattermostClient:
assert channel.name
for user_id in channel.name.split("__"):
if user_id != await self.mm_id:
- self._cache.add_user(
- self.base_url, user_id, direct_channel_id=channel.id
+ cached_user = self._cache.get_by_user_id(self.base_url, user_id)
+ if cached_user is None:
+ username = await self.get_username_by_user_id(user_id)
+ self._cache.add_user(self.base_url, user_id, username)
+ assert channel.id
+ self._cache.add_direct_channel(
+ self.base_url, await self.mm_id, user_id, channel.id
)
return user_id
raise ValueError("This is not a direct channel", channel)
@@ 434,14 447,18 @@ class MattermostClient:
return r.id
async def get_direct_channel_id(self, user_id: str) -> str:
- cached = self._cache.get_by_user_id(self.base_url, user_id)
- if cached and cached.direct_channel_id:
- return cached.direct_channel_id
+ cached = self._cache.get_direct_channel_id(
+ self.base_url, await self.mm_id, user_id
+ )
+ if cached:
+ return cached
direct_channel = await self.create_direct_channel([await self.mm_id, user_id])
if not direct_channel or not direct_channel.id:
raise RuntimeError("Could not create direct channel")
- self._cache.add_user(
- self.base_url, user_id, direct_channel_id=direct_channel.id
+ username = await self.get_username_by_user_id(user_id)
+ self._cache.add_user(self.base_url, user_id, username)
+ self._cache.add_direct_channel(
+ self.base_url, await self.mm_id, user_id, direct_channel.id
)
return direct_channel.id
@@ 454,7 471,7 @@ class MattermostClient:
async def get_latest_post_id_for_channel(
self, channel_id: str
) -> Optional[Union[str, Unset]]:
- cache = self._cache.msg_id_get(channel_id)
+ cache = self._cache.msg_id_get(await self.mm_id, channel_id)
if cache is not None:
return cache
@@ 464,7 481,7 @@ class MattermostClient:
else:
return None
if post.id:
- self._cache.msg_id_store(channel_id, post.id)
+ self._cache.msg_id_store(await self.mm_id, channel_id, post.id)
return last.id
async def get_posts_for_channel(
M matteridge/cache.py => matteridge/cache.py +83 -58
@@ 2,9 2,7 @@ import logging
import sqlite3
from os import PathLike
from pathlib import Path
-from typing import NamedTuple, Optional, Union
-
-from mattermost_api_reference_client.types import Unset
+from typing import NamedTuple, Optional
SCHEMA = """
CREATE TABLE server(
@@ 15,36 13,43 @@ CREATE TABLE server(
CREATE TABLE user(
id INTEGER PRIMARY KEY,
server_id INTEGER NON NULL,
- user_id TEXT,
+ user_id TEXT NON NULL,
+ username TEXT NON NULL,
- direct_channel_id TEXT,
- username TEXT,
FOREIGN KEY(server_id) REFERENCES server(id),
UNIQUE (server_id, user_id),
- UNIQUE (server_id, direct_channel_id),
UNIQUE (server_id, username)
);
+CREATE TABLE direct_channel(
+ id INTEGER PRIMARY KEY,
+ server_id INTEGER NON NULL,
+ me INTEGER NON NULL,
+ them INTEGER NON NULL,
+ direct_channel_id TEXT NON NULL,
+
+ FOREIGN KEY(server_id) REFERENCES server(id),
+ FOREIGN KEY(me) REFERENCES user(id),
+ FOREIGN KEY(them) REFERENCES user(id),
+ UNIQUE(me, them, direct_channel_id)
+);
+
CREATE INDEX user_server_id ON user(server_id);
CREATE INDEX user_user_id ON user(user_id);
-CREATE INDEX user_direct_channel_id ON user(direct_channel_id);
CREATE INDEX user_username ON user(username);
"""
class MattermostUser(NamedTuple):
user_id: str
- username: Optional[str]
- direct_channel_id: Optional[str]
+ username: str
-def factory(
- _cursor: sqlite3.Cursor, row: tuple[str, Optional[str], Optional[str]]
-) -> MattermostUser:
+def user_factory(_cursor: sqlite3.Cursor, row: tuple[str, str]) -> MattermostUser:
return MattermostUser(*row)
-ORDER = "SELECT user_id, username, direct_channel_id FROM user WHERE "
+ORDER_USER = "SELECT user_id, username FROM user WHERE "
SERVER = "(SELECT id FROM server WHERE server = ?)"
@@ 54,9 59,11 @@ class Cache:
self.con = con = sqlite3.connect(filename)
- self.cur = cur = self.con.cursor()
- self.__last_msg_id = dict[str, str]()
- cur.row_factory = factory # type:ignore
+ self.user_cur = self.con.cursor()
+ self.user_cur.row_factory = user_factory # type:ignore
+
+ # (slidge_user_id, channel_id) → message_id
+ self.__last_msg_id = dict[tuple[str, str], str]()
if exists:
log.debug("File exists")
@@ 67,64 74,82 @@ class Cache:
def add_server(self, server: str):
with self.con:
- try:
- self.con.execute("INSERT INTO server(server) VALUES(?)", (server,))
- except sqlite3.IntegrityError:
- pass
+ self.con.execute(
+ "INSERT OR IGNORE INTO server(server) VALUES(?)", (server,)
+ )
- def add_user(
- self,
- server: str,
- user_id: Optional[Union[str, Unset]] = None,
- username: Optional[Union[str, Unset]] = None,
- direct_channel_id: Optional[Union[str, Unset]] = None,
- ):
- keys = []
- values = []
- if user_id:
- keys.append("user_id")
- values.append(user_id)
- if username:
- keys.append("username")
- values.append(username)
- if direct_channel_id:
- values.append(direct_channel_id)
- keys.append("direct_channel_id")
- if not keys:
- raise TypeError("No info")
- question_marks = ",".join("?" * (len(keys)))
+ def add_user(self, server: str, user_id: str, username: str):
with self.con:
- conflict_str = ",".join(f"{key}=?" for key in keys)
- keys_str = ",".join(keys)
query = (
- f"INSERT INTO user(server_id, {keys_str}) "
- f"VALUES({SERVER},{question_marks}) "
- f"ON CONFLICT DO UPDATE SET {conflict_str}"
+ f"INSERT OR IGNORE INTO user(server_id, user_id, username) "
+ f"VALUES({SERVER}, ?, ?)"
)
- values = [server, *values, *values]
+ values = [server, user_id, username]
log.debug("Query: %s -> %s", query, values)
self.con.execute(query, values)
def __get(self, server: str, key: str, value: str) -> MattermostUser:
- query = ORDER + f"{key} = ? AND server_id = {SERVER}"
+ query = ORDER_USER + f"{key} = ? AND server_id = {SERVER}"
with self.con:
- res = self.cur.execute(query, (value, server))
+ res = self.user_cur.execute(query, (value, server))
return res.fetchone()
- def get_by_user_id(self, server: str, user_id: str):
+ def get_by_user_id(self, server: str, user_id: str) -> MattermostUser:
return self.__get(server, "user_id", user_id)
- def get_by_direct_channel_id(self, server: str, direct_channel_id: str):
- return self.__get(server, "direct_channel_id", direct_channel_id)
+ def get_user_by_direct_channel_id(
+ self, server: str, slidge_user_id: str, direct_channel_id: str
+ ) -> MattermostUser:
+ with self.con:
+ res = self.user_cur.execute(
+ "SELECT user_id, username FROM user WHERE "
+ "id = (SELECT them FROM direct_channel WHERE direct_channel_id = ? "
+ "AND server_id = (SELECT id FROM server WHERE server = ?) "
+ "AND me = (SELECT id FROM user WHERE user_id = ?))",
+ (direct_channel_id, server, slidge_user_id),
+ )
+ return res.fetchone()
- def get_by_username(self, server: str, username: str):
+ def add_direct_channel(
+ self,
+ server: str,
+ slidge_user_id: str,
+ other_user_id: str,
+ direct_channel_id: str,
+ ):
+ with self.con:
+ self.con.execute(
+ "INSERT OR IGNORE INTO direct_channel(server_id, me, them, direct_channel_id) "
+ "VALUES ((SELECT id FROM server WHERE server = ?),"
+ "(SELECT id FROM user WHERE user_id = ?),"
+ "(SELECT id FROM user WHERE user_id = ?),"
+ "?)",
+ (server, slidge_user_id, other_user_id, direct_channel_id),
+ )
+
+ def get_direct_channel_id(
+ self, server: str, slidge_user_id: str, other_user_id: str
+ ) -> Optional[str]:
+ with self.con:
+ row = self.con.execute(
+ "SELECT direct_channel_id FROM direct_channel WHERE "
+ "server_id = (SELECT id FROM server where server = ?) "
+ "AND me = (SELECT id FROM user where user_id = ?) "
+ "AND them = (SELECT id FROM user where user_id = ?)",
+ (server, slidge_user_id, other_user_id),
+ ).fetchone()
+ if row is None:
+ return None
+ return row[0]
+
+ def get_by_username(self, server: str, username: str) -> MattermostUser:
return self.__get(server, "username", username)
- def msg_id_get(self, channel_id: str) -> Optional[str]:
- return self.__last_msg_id.get(channel_id)
+ def msg_id_get(self, slidge_user_id: str, channel_id: str) -> Optional[str]:
+ return self.__last_msg_id.get((slidge_user_id, channel_id))
- def msg_id_store(self, channel_id: str, post_id: str):
- self.__last_msg_id[channel_id] = post_id
+ def msg_id_store(self, slidge_user_id: str, channel_id: str, post_id: str):
+ self.__last_msg_id[(slidge_user_id, channel_id)] = post_id
log = logging.getLogger(__name__)
M matteridge/gateway.py => matteridge/gateway.py +1 -1
@@ 46,7 46,7 @@ class Gateway(BaseGateway):
def __init__(self):
super().__init__()
- self.cache = Cache(global_config.HOME_DIR / "mm_client_cache.sqlite")
+ self.cache = Cache(global_config.HOME_DIR / "mm_client_cache_v2.sqlite")
if not logging.getLogger().isEnabledFor(logging.DEBUG):
logging.getLogger("httpx").setLevel(logging.WARNING)
M matteridge/group.py => matteridge/group.py +6 -2
@@ 126,8 126,12 @@ class MUC(LegacyMUC[str, str, Participant, str]):
async for post in self.session.mm_client.get_posts_for_channel(
self.legacy_id, before=oldest_message_id
):
- if i == 0 and not self.xmpp.cache.msg_id_get(self.legacy_id):
- self.xmpp.cache.msg_id_store(self.legacy_id, post.id)
+ if i == 0 and not self.xmpp.cache.msg_id_get(
+ await self.session.mm_client.mm_id, self.legacy_id
+ ):
+ self.xmpp.cache.msg_id_store(
+ await self.session.mm_client.mm_id, self.legacy_id, post.id
+ )
if now - datetime.fromtimestamp(post.create_at / 1000) > timedelta(
days=global_config.MAM_MAX_DAYS
):
M matteridge/session.py => matteridge/session.py +3 -1
@@ 183,7 183,9 @@ class Session(BaseSession[str, Recipient]):
self.log.debug("Post: %s", post)
user_id = post.user_id
- self.xmpp.cache.msg_id_store(post.channel_id, post.id)
+ self.xmpp.cache.msg_id_store(
+ await self.mm_client.mm_id, post.channel_id, post.id
+ )
if await self.is_waiting_for_echo(post.id):
return
M tests/test_cache.py => tests/test_cache.py +23 -10
@@ 1,19 1,32 @@
from matteridge.cache import Cache
-def test_cache(tmp_path):
+def test_user(tmp_path):
+ c = Cache(tmp_path / "test.sql")
+
+ c.add_server("example.com")
+ c.add_server("example.com")
+
+ c.add_user("example.com", "id", "name")
+ c.add_user("example.com", "id", "name")
+
+ assert c.get_by_user_id("example.com", "id").username == "name"
+ assert c.get_by_username("example.com", "name").user_id == "id"
+
+
+def test_direct_channel(tmp_path):
c = Cache(tmp_path / "test.sql")
c.add_server("example.com")
- c.add_user("example.com", "id", direct_channel_id="channel")
- assert c.get_by_user_id("example.com", "id").direct_channel_id == "channel"
- assert c.get_by_direct_channel_id("example.com", "channel").user_id == "id"
- assert c.get_by_username("example.com", "username") is None
- c.add_user("example.com", "id", "username")
+ c.add_user("example.com", "me", "myname")
+ c.add_user("example.com", "them", "theirname")
+ c.add_direct_channel("example.com", "me", "them", "channel")
+ c.add_direct_channel("example.com", "me", "them", "channel")
- assert c.get_by_username("example.com", "username").user_id == "id"
- assert c.get_by_user_id("example.com", "id").direct_channel_id == "channel"
- assert c.get_by_direct_channel_id("example.com", "channel").user_id == "id"
+ assert (
+ c.get_user_by_direct_channel_id("example.com", "me", "channel").username
+ == "theirname"
+ )
- assert c.get_by_user_id("example.com", "XXX") is None
+ assert c.get_direct_channel_id("example.com", "me", "them") == "channel"