~nicoco/matteridge

f4abac279381c24b28fd28ce9ec4997dd359d48e — nicoco 7 months ago a20283c
fix: multi-user usage

Our cache was poorly designed, and broke
in a multi-user context…
M matteridge/api.py => matteridge/api.py +27 -10
@@ 290,17 290,23 @@ class MattermostClient:

    async def get_user(self, user_id: str) -> User:
        user = await self._get_user(user_id)
        assert user.id
        assert user.username
        self._cache.add_user(self.base_url, user.id, user.username)
        return user

    async def get_user_by_username(self, username: str) -> User:
        user = await self._get_user_by_username(username)
        assert user.id
        assert user.username
        self._cache.add_user(self.base_url, user.id, user.username)
        return user

    async def get_users_by_ids(self, user_ids: list[str]) -> list[User]:
        users = await self._get_users_by_ids(user_ids)
        for u in users:
            assert u.id
            assert u.username
            self._cache.add_user(self.base_url, u.id, u.username)
        return users



@@ 321,7 327,9 @@ class MattermostClient:
    async def get_other_username_from_direct_channel_id(
        self, channel_id: str
    ) -> Optional[str]:
        cached = self._cache.get_by_direct_channel_id(self.base_url, channel_id)
        cached = self._cache.get_user_by_direct_channel_id(
            self.base_url, await self.mm_id, channel_id
        )
        if not cached:
            return None
        if not cached.username:


@@ 332,8 340,13 @@ class MattermostClient:
        assert channel.name
        for user_id in channel.name.split("__"):
            if user_id != await self.mm_id:
                self._cache.add_user(
                    self.base_url, user_id, direct_channel_id=channel.id
                cached_user = self._cache.get_by_user_id(self.base_url, user_id)
                if cached_user is None:
                    username = await self.get_username_by_user_id(user_id)
                    self._cache.add_user(self.base_url, user_id, username)
                assert channel.id
                self._cache.add_direct_channel(
                    self.base_url, await self.mm_id, user_id, channel.id
                )
                return user_id
        raise ValueError("This is not a direct channel", channel)


@@ 434,14 447,18 @@ class MattermostClient:
        return r.id

    async def get_direct_channel_id(self, user_id: str) -> str:
        cached = self._cache.get_by_user_id(self.base_url, user_id)
        if cached and cached.direct_channel_id:
            return cached.direct_channel_id
        cached = self._cache.get_direct_channel_id(
            self.base_url, await self.mm_id, user_id
        )
        if cached:
            return cached
        direct_channel = await self.create_direct_channel([await self.mm_id, user_id])
        if not direct_channel or not direct_channel.id:
            raise RuntimeError("Could not create direct channel")
        self._cache.add_user(
            self.base_url, user_id, direct_channel_id=direct_channel.id
        username = await self.get_username_by_user_id(user_id)
        self._cache.add_user(self.base_url, user_id, username)
        self._cache.add_direct_channel(
            self.base_url, await self.mm_id, user_id, direct_channel.id
        )
        return direct_channel.id



@@ 454,7 471,7 @@ class MattermostClient:
    async def get_latest_post_id_for_channel(
        self, channel_id: str
    ) -> Optional[Union[str, Unset]]:
        cache = self._cache.msg_id_get(channel_id)
        cache = self._cache.msg_id_get(await self.mm_id, channel_id)
        if cache is not None:
            return cache



@@ 464,7 481,7 @@ class MattermostClient:
        else:
            return None
        if post.id:
            self._cache.msg_id_store(channel_id, post.id)
            self._cache.msg_id_store(await self.mm_id, channel_id, post.id)
        return last.id

    async def get_posts_for_channel(

M matteridge/cache.py => matteridge/cache.py +83 -58
@@ 2,9 2,7 @@ import logging
import sqlite3
from os import PathLike
from pathlib import Path
from typing import NamedTuple, Optional, Union

from mattermost_api_reference_client.types import Unset
from typing import NamedTuple, Optional

SCHEMA = """
CREATE TABLE server(


@@ 15,36 13,43 @@ CREATE TABLE server(
CREATE TABLE user(
  id INTEGER PRIMARY KEY,
  server_id INTEGER NON NULL,
  user_id TEXT,
  user_id TEXT NON NULL,
  username TEXT NON NULL,

  direct_channel_id TEXT,
  username TEXT,
  FOREIGN KEY(server_id) REFERENCES server(id),
  UNIQUE (server_id, user_id),
  UNIQUE (server_id, direct_channel_id),
  UNIQUE (server_id, username)
);

CREATE TABLE direct_channel(
  id INTEGER PRIMARY KEY,
  server_id INTEGER NON NULL,
  me INTEGER NON NULL,
  them INTEGER NON NULL,
  direct_channel_id TEXT NON NULL,

  FOREIGN KEY(server_id) REFERENCES server(id),
  FOREIGN KEY(me) REFERENCES user(id),
  FOREIGN KEY(them) REFERENCES user(id),
  UNIQUE(me, them, direct_channel_id)
);

CREATE INDEX user_server_id ON user(server_id);
CREATE INDEX user_user_id ON user(user_id);
CREATE INDEX user_direct_channel_id ON user(direct_channel_id);
CREATE INDEX user_username ON user(username);
"""


class MattermostUser(NamedTuple):
    user_id: str
    username: Optional[str]
    direct_channel_id: Optional[str]
    username: str


def factory(
    _cursor: sqlite3.Cursor, row: tuple[str, Optional[str], Optional[str]]
) -> MattermostUser:
def user_factory(_cursor: sqlite3.Cursor, row: tuple[str, str]) -> MattermostUser:
    return MattermostUser(*row)


ORDER = "SELECT user_id, username, direct_channel_id FROM user WHERE "
ORDER_USER = "SELECT user_id, username FROM user WHERE "
SERVER = "(SELECT id FROM server WHERE server = ?)"




@@ 54,9 59,11 @@ class Cache:

        self.con = con = sqlite3.connect(filename)

        self.cur = cur = self.con.cursor()
        self.__last_msg_id = dict[str, str]()
        cur.row_factory = factory  # type:ignore
        self.user_cur = self.con.cursor()
        self.user_cur.row_factory = user_factory  # type:ignore

        # (slidge_user_id, channel_id) → message_id
        self.__last_msg_id = dict[tuple[str, str], str]()

        if exists:
            log.debug("File exists")


@@ 67,64 74,82 @@ class Cache:

    def add_server(self, server: str):
        with self.con:
            try:
                self.con.execute("INSERT INTO server(server) VALUES(?)", (server,))
            except sqlite3.IntegrityError:
                pass
            self.con.execute(
                "INSERT OR IGNORE INTO server(server) VALUES(?)", (server,)
            )

    def add_user(
        self,
        server: str,
        user_id: Optional[Union[str, Unset]] = None,
        username: Optional[Union[str, Unset]] = None,
        direct_channel_id: Optional[Union[str, Unset]] = None,
    ):
        keys = []
        values = []
        if user_id:
            keys.append("user_id")
            values.append(user_id)
        if username:
            keys.append("username")
            values.append(username)
        if direct_channel_id:
            values.append(direct_channel_id)
            keys.append("direct_channel_id")
        if not keys:
            raise TypeError("No info")
        question_marks = ",".join("?" * (len(keys)))
    def add_user(self, server: str, user_id: str, username: str):
        with self.con:
            conflict_str = ",".join(f"{key}=?" for key in keys)
            keys_str = ",".join(keys)
            query = (
                f"INSERT INTO user(server_id, {keys_str}) "
                f"VALUES({SERVER},{question_marks}) "
                f"ON CONFLICT DO UPDATE SET {conflict_str}"
                f"INSERT OR IGNORE INTO user(server_id, user_id, username) "
                f"VALUES({SERVER}, ?, ?)"
            )
            values = [server, *values, *values]
            values = [server, user_id, username]
            log.debug("Query: %s -> %s", query, values)
            self.con.execute(query, values)

    def __get(self, server: str, key: str, value: str) -> MattermostUser:
        query = ORDER + f"{key} = ? AND server_id = {SERVER}"
        query = ORDER_USER + f"{key} = ? AND server_id = {SERVER}"
        with self.con:
            res = self.cur.execute(query, (value, server))
            res = self.user_cur.execute(query, (value, server))
            return res.fetchone()

    def get_by_user_id(self, server: str, user_id: str):
    def get_by_user_id(self, server: str, user_id: str) -> MattermostUser:
        return self.__get(server, "user_id", user_id)

    def get_by_direct_channel_id(self, server: str, direct_channel_id: str):
        return self.__get(server, "direct_channel_id", direct_channel_id)
    def get_user_by_direct_channel_id(
        self, server: str, slidge_user_id: str, direct_channel_id: str
    ) -> MattermostUser:
        with self.con:
            res = self.user_cur.execute(
                "SELECT user_id, username FROM user WHERE "
                "id = (SELECT them FROM direct_channel WHERE direct_channel_id = ? "
                "AND server_id = (SELECT id FROM server WHERE server = ?) "
                "AND me = (SELECT id FROM user WHERE user_id = ?))",
                (direct_channel_id, server, slidge_user_id),
            )
            return res.fetchone()

    def get_by_username(self, server: str, username: str):
    def add_direct_channel(
        self,
        server: str,
        slidge_user_id: str,
        other_user_id: str,
        direct_channel_id: str,
    ):
        with self.con:
            self.con.execute(
                "INSERT OR IGNORE INTO direct_channel(server_id, me, them, direct_channel_id) "
                "VALUES ((SELECT id FROM server WHERE server = ?),"
                "(SELECT id FROM user WHERE user_id = ?),"
                "(SELECT id FROM user WHERE user_id = ?),"
                "?)",
                (server, slidge_user_id, other_user_id, direct_channel_id),
            )

    def get_direct_channel_id(
        self, server: str, slidge_user_id: str, other_user_id: str
    ) -> Optional[str]:
        with self.con:
            row = self.con.execute(
                "SELECT direct_channel_id FROM direct_channel WHERE "
                "server_id = (SELECT id FROM server where server = ?) "
                "AND me = (SELECT id FROM user where user_id = ?) "
                "AND them = (SELECT id FROM user where user_id = ?)",
                (server, slidge_user_id, other_user_id),
            ).fetchone()
            if row is None:
                return None
            return row[0]

    def get_by_username(self, server: str, username: str) -> MattermostUser:
        return self.__get(server, "username", username)

    def msg_id_get(self, channel_id: str) -> Optional[str]:
        return self.__last_msg_id.get(channel_id)
    def msg_id_get(self, slidge_user_id: str, channel_id: str) -> Optional[str]:
        return self.__last_msg_id.get((slidge_user_id, channel_id))

    def msg_id_store(self, channel_id: str, post_id: str):
        self.__last_msg_id[channel_id] = post_id
    def msg_id_store(self, slidge_user_id: str, channel_id: str, post_id: str):
        self.__last_msg_id[(slidge_user_id, channel_id)] = post_id


log = logging.getLogger(__name__)

M matteridge/gateway.py => matteridge/gateway.py +1 -1
@@ 46,7 46,7 @@ class Gateway(BaseGateway):

    def __init__(self):
        super().__init__()
        self.cache = Cache(global_config.HOME_DIR / "mm_client_cache.sqlite")
        self.cache = Cache(global_config.HOME_DIR / "mm_client_cache_v2.sqlite")
        if not logging.getLogger().isEnabledFor(logging.DEBUG):
            logging.getLogger("httpx").setLevel(logging.WARNING)


M matteridge/group.py => matteridge/group.py +6 -2
@@ 126,8 126,12 @@ class MUC(LegacyMUC[str, str, Participant, str]):
        async for post in self.session.mm_client.get_posts_for_channel(
            self.legacy_id, before=oldest_message_id
        ):
            if i == 0 and not self.xmpp.cache.msg_id_get(self.legacy_id):
                self.xmpp.cache.msg_id_store(self.legacy_id, post.id)
            if i == 0 and not self.xmpp.cache.msg_id_get(
                await self.session.mm_client.mm_id, self.legacy_id
            ):
                self.xmpp.cache.msg_id_store(
                    await self.session.mm_client.mm_id, self.legacy_id, post.id
                )
            if now - datetime.fromtimestamp(post.create_at / 1000) > timedelta(
                days=global_config.MAM_MAX_DAYS
            ):

M matteridge/session.py => matteridge/session.py +3 -1
@@ 183,7 183,9 @@ class Session(BaseSession[str, Recipient]):
        self.log.debug("Post: %s", post)
        user_id = post.user_id

        self.xmpp.cache.msg_id_store(post.channel_id, post.id)
        self.xmpp.cache.msg_id_store(
            await self.mm_client.mm_id, post.channel_id, post.id
        )

        if await self.is_waiting_for_echo(post.id):
            return

M tests/test_cache.py => tests/test_cache.py +23 -10
@@ 1,19 1,32 @@
from matteridge.cache import Cache


def test_cache(tmp_path):
def test_user(tmp_path):
    c = Cache(tmp_path / "test.sql")

    c.add_server("example.com")
    c.add_server("example.com")

    c.add_user("example.com", "id", "name")
    c.add_user("example.com", "id", "name")

    assert c.get_by_user_id("example.com", "id").username == "name"
    assert c.get_by_username("example.com", "name").user_id == "id"


def test_direct_channel(tmp_path):
    c = Cache(tmp_path / "test.sql")

    c.add_server("example.com")
    c.add_user("example.com", "id", direct_channel_id="channel")
    assert c.get_by_user_id("example.com", "id").direct_channel_id == "channel"
    assert c.get_by_direct_channel_id("example.com", "channel").user_id == "id"
    assert c.get_by_username("example.com", "username") is None

    c.add_user("example.com", "id", "username")
    c.add_user("example.com", "me", "myname")
    c.add_user("example.com", "them", "theirname")
    c.add_direct_channel("example.com", "me", "them", "channel")
    c.add_direct_channel("example.com", "me", "them", "channel")

    assert c.get_by_username("example.com", "username").user_id == "id"
    assert c.get_by_user_id("example.com", "id").direct_channel_id == "channel"
    assert c.get_by_direct_channel_id("example.com", "channel").user_id == "id"
    assert (
        c.get_user_by_direct_channel_id("example.com", "me", "channel").username
        == "theirname"
    )

    assert c.get_by_user_id("example.com", "XXX") is None
    assert c.get_direct_channel_id("example.com", "me", "them") == "channel"