~nicoco/matteridge

83b8cf96d6009cb16e0064ac15479788318d3d1a — nicoco 1 year, 13 days ago 592064e
feat: prompt for new token when session expires
1 files changed, 76 insertions(+), 9 deletions(-)

M matteridge/session.py
M matteridge/session.py => matteridge/session.py +76 -9
@@ 1,25 1,39 @@
import asyncio
import re
from typing import TYPE_CHECKING, Any, Optional, Union

from slidge import BaseSession, LegacyMUC
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Concatenate,
    Optional,
    ParamSpec,
    TypeVar,
    Union,
)

from slidge import BaseSession
from slidge.util.types import PseudoPresenceShow, ResourceDict

from . import events
from .api import get_client_from_registration_form
from .api import MattermostException, get_client_from_registration_form
from .util import emojize_single
from .websocket import Websocket

if TYPE_CHECKING:
    from .contact import Contact, Roster
    from .gateway import Gateway
    from .group import Bookmarks
    from .group import MUC, Bookmarks


Recipient = Union["Contact", "LegacyMUC"]
Recipient = Union["Contact", "MUC"]
P = ParamSpec("P")
T = TypeVar("T", bound=Awaitable)


def lock(meth):
def lock(
    meth: Callable[Concatenate["Session", P], T]
) -> Callable[Concatenate["Session", P], T]:
    async def wrapped(self, *a, **k):
        async with self.send_lock:
            return await meth(self, *a, **k)


@@ 27,6 41,24 @@ def lock(meth):
    return wrapped


def catch_expired_session(
    meth: Callable[Concatenate["Session", P], T]
) -> Callable[Concatenate["Session", P], T]:
    async def wrapped(self, *a, **k):
        try:
            return await meth(self, *a, **k)
        except MattermostException as e:
            if e.is_expired_session:
                await self.logout()
                await self.renew_token()
                await self.login()
                return await meth(self, *a, **k)
            else:
                raise

    return wrapped


class Session(BaseSession[str, Recipient]):
    contacts: "Roster"
    bookmarks: "Bookmarks"


@@ 73,8 105,28 @@ class Session(BaseSession[str, Recipient]):
        else:
            return True

    async def renew_token(self):
        self.update_token(
            await self.input(
                "Your mattermost token has expired, please provide a new one."
            )
        )

    def update_token(self, token: str):
        self.user.registration_form["token"] = token
        self.user.commit()
        self.mm_client = get_client_from_registration_form(
            self.user.registration_form, self.xmpp.cache
        )

    async def login(self):
        await self.mm_client.login()
        try:
            await self.mm_client.login()
        except MattermostException as e:
            if not e.is_expired_session:
                raise
            await self.renew_token()
            return await self.mm_client.login()
        self.contacts.user_legacy_id = (await self.mm_client.me).username
        t1 = self._ws_task = asyncio.create_task(self.ws.connect(self.on_mm_event))
        t2 = self._update_status_task = asyncio.create_task(


@@ 255,7 307,10 @@ class Session(BaseSession[str, Recipient]):
        await muc.get_participant_by_mm_user_id(event.user_id)

    async def logout(self):
        pass
        if self._ws_task is not None:
            self._ws_task.cancel()
        if self._update_status_task is not None:
            self._update_status_task.cancel()

    @staticmethod
    async def __get_channel_id(chat: Recipient):


@@ 264,6 319,7 @@ class Session(BaseSession[str, Recipient]):
        else:
            return await chat.direct_channel_id()  # type:ignore

    @catch_expired_session
    @lock
    async def send_text(self, chat: Recipient, text: str, thread=None, **k):
        channel_id = await self.__get_channel_id(chat)


@@ 271,6 327,7 @@ class Session(BaseSession[str, Recipient]):
        self.messages_waiting_for_echo.add(msg_id)
        return msg_id

    @catch_expired_session
    @lock
    async def send_file(
        self, chat: Recipient, url: str, http_response, thread=None, **k


@@ 283,37 340,46 @@ class Session(BaseSession[str, Recipient]):
        self.messages_waiting_for_echo.add(msg_id)
        return msg_id

    @catch_expired_session
    async def active(self, c: Recipient, thread=None):
        pass

    @catch_expired_session
    async def inactive(self, c: Recipient, thread=None):
        pass

    @catch_expired_session
    async def composing(self, c: Recipient, thread=None):
        channel_id = await self.__get_channel_id(c)
        await self.ws.user_typing(channel_id)  # type:ignore

    @catch_expired_session
    async def paused(self, c: Recipient, thread=None):
        # no equivalent in MM, seems to have an automatic timeout in clients
        pass

    @catch_expired_session
    async def displayed(self, c: Recipient, legacy_msg_id: Any, thread=None):
        channel_id = await self.__get_channel_id(c)
        f = self.view_events[channel_id] = asyncio.Event()
        await self.mm_client.view_channel(channel_id)
        await f.wait()

    @catch_expired_session
    @lock
    async def correct(self, c: Recipient, text: str, legacy_msg_id: Any, thread=None):
        await self.mm_client.update_post(legacy_msg_id, text)
        self.messages_waiting_for_echo.add(legacy_msg_id)

    @catch_expired_session
    async def search(self, form_values: dict[str, str]):
        pass

    @catch_expired_session
    async def retract(self, c: Recipient, legacy_msg_id: Any, thread=None):
        await self.mm_client.delete_post(legacy_msg_id)

    @catch_expired_session
    async def react(
        self, c: Recipient, legacy_msg_id: Any, emojis: list[str], thread=None
    ):


@@ 338,6 404,7 @@ class Session(BaseSession[str, Recipient]):
            if i == user_id
        }

    @catch_expired_session
    async def presence(
        self,
        resource: str,