@@ 1,25 1,39 @@
import asyncio
import re
-from typing import TYPE_CHECKING, Any, Optional, Union
-
-from slidge import BaseSession, LegacyMUC
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Concatenate,
+ Optional,
+ ParamSpec,
+ TypeVar,
+ Union,
+)
+
+from slidge import BaseSession
from slidge.util.types import PseudoPresenceShow, ResourceDict
from . import events
-from .api import get_client_from_registration_form
+from .api import MattermostException, get_client_from_registration_form
from .util import emojize_single
from .websocket import Websocket
if TYPE_CHECKING:
from .contact import Contact, Roster
from .gateway import Gateway
- from .group import Bookmarks
+ from .group import MUC, Bookmarks
-Recipient = Union["Contact", "LegacyMUC"]
+Recipient = Union["Contact", "MUC"]
+P = ParamSpec("P")
+T = TypeVar("T", bound=Awaitable)
-def lock(meth):
+def lock(
+ meth: Callable[Concatenate["Session", P], T]
+) -> Callable[Concatenate["Session", P], T]:
async def wrapped(self, *a, **k):
async with self.send_lock:
return await meth(self, *a, **k)
@@ 27,6 41,24 @@ def lock(meth):
return wrapped
+def catch_expired_session(
+ meth: Callable[Concatenate["Session", P], T]
+) -> Callable[Concatenate["Session", P], T]:
+ async def wrapped(self, *a, **k):
+ try:
+ return await meth(self, *a, **k)
+ except MattermostException as e:
+ if e.is_expired_session:
+ await self.logout()
+ await self.renew_token()
+ await self.login()
+ return await meth(self, *a, **k)
+ else:
+ raise
+
+ return wrapped
+
+
class Session(BaseSession[str, Recipient]):
contacts: "Roster"
bookmarks: "Bookmarks"
@@ 73,8 105,28 @@ class Session(BaseSession[str, Recipient]):
else:
return True
+ async def renew_token(self):
+ self.update_token(
+ await self.input(
+ "Your mattermost token has expired, please provide a new one."
+ )
+ )
+
+ def update_token(self, token: str):
+ self.user.registration_form["token"] = token
+ self.user.commit()
+ self.mm_client = get_client_from_registration_form(
+ self.user.registration_form, self.xmpp.cache
+ )
+
async def login(self):
- await self.mm_client.login()
+ try:
+ await self.mm_client.login()
+ except MattermostException as e:
+ if not e.is_expired_session:
+ raise
+ await self.renew_token()
+ return await self.mm_client.login()
self.contacts.user_legacy_id = (await self.mm_client.me).username
t1 = self._ws_task = asyncio.create_task(self.ws.connect(self.on_mm_event))
t2 = self._update_status_task = asyncio.create_task(
@@ 255,7 307,10 @@ class Session(BaseSession[str, Recipient]):
await muc.get_participant_by_mm_user_id(event.user_id)
async def logout(self):
- pass
+ if self._ws_task is not None:
+ self._ws_task.cancel()
+ if self._update_status_task is not None:
+ self._update_status_task.cancel()
@staticmethod
async def __get_channel_id(chat: Recipient):
@@ 264,6 319,7 @@ class Session(BaseSession[str, Recipient]):
else:
return await chat.direct_channel_id() # type:ignore
+ @catch_expired_session
@lock
async def send_text(self, chat: Recipient, text: str, thread=None, **k):
channel_id = await self.__get_channel_id(chat)
@@ 271,6 327,7 @@ class Session(BaseSession[str, Recipient]):
self.messages_waiting_for_echo.add(msg_id)
return msg_id
+ @catch_expired_session
@lock
async def send_file(
self, chat: Recipient, url: str, http_response, thread=None, **k
@@ 283,37 340,46 @@ class Session(BaseSession[str, Recipient]):
self.messages_waiting_for_echo.add(msg_id)
return msg_id
+ @catch_expired_session
async def active(self, c: Recipient, thread=None):
pass
+ @catch_expired_session
async def inactive(self, c: Recipient, thread=None):
pass
+ @catch_expired_session
async def composing(self, c: Recipient, thread=None):
channel_id = await self.__get_channel_id(c)
await self.ws.user_typing(channel_id) # type:ignore
+ @catch_expired_session
async def paused(self, c: Recipient, thread=None):
# no equivalent in MM, seems to have an automatic timeout in clients
pass
+ @catch_expired_session
async def displayed(self, c: Recipient, legacy_msg_id: Any, thread=None):
channel_id = await self.__get_channel_id(c)
f = self.view_events[channel_id] = asyncio.Event()
await self.mm_client.view_channel(channel_id)
await f.wait()
+ @catch_expired_session
@lock
async def correct(self, c: Recipient, text: str, legacy_msg_id: Any, thread=None):
await self.mm_client.update_post(legacy_msg_id, text)
self.messages_waiting_for_echo.add(legacy_msg_id)
+ @catch_expired_session
async def search(self, form_values: dict[str, str]):
pass
+ @catch_expired_session
async def retract(self, c: Recipient, legacy_msg_id: Any, thread=None):
await self.mm_client.delete_post(legacy_msg_id)
+ @catch_expired_session
async def react(
self, c: Recipient, legacy_msg_id: Any, emojis: list[str], thread=None
):
@@ 338,6 404,7 @@ class Session(BaseSession[str, Recipient]):
if i == user_id
}
+ @catch_expired_session
async def presence(
self,
resource: str,