~naglis/aio_mpv_ipc

b021c37da812a501872253ef4bdc7bff85af1345 — Naglis Jonaitis 9 months ago 1b6c840
Add a lock around writer
2 files changed, 35 insertions(+), 15 deletions(-)

M aio_mpv_ipc/_client.py
M tests/test_client.py
M aio_mpv_ipc/_client.py => aio_mpv_ipc/_client.py +21 -15
@@ 5,6 5,7 @@ import json
import logging
import pathlib
import typing
import weakref

from aio_mpv_ipc.exceptions import (
    AioMpvIPCException,


@@ 83,6 84,7 @@ class MpvClient:
    ) -> None:
        self._socket_path = socket_path
        self._reader, self._writer = None, None
        self._writer_lock = asyncio.Lock()
        self._max_connect_attempts = max_connect_attempts
        self._connect_sleep_timeout = connect_sleep_timeout
        self._event_queue_timeout = event_queue_timeout


@@ 91,10 93,12 @@ class MpvClient:

        self._request_counter = 0

        self._futures: typing.Dict[str, asyncio.Future] = {}
        self._futures: "weakref.WeakValueDictionary[str, asyncio.Future]" = (
            weakref.WeakValueDictionary()
        )
        self._event_subscribers: typing.Dict[
            str, typing.Set[asyncio.Queue]
        ] = collections.defaultdict(set)
            str, "weakref.WeakSet[asyncio.Queue]"
        ] = collections.defaultdict(weakref.WeakSet)

    @property
    def socket_path(self) -> pathlib.Path:


@@ 132,10 136,6 @@ class MpvClient:
        return self

    async def __aexit__(self, _, __, ___):
        if self._writer is not None:
            logger.debug("Closing mpv JSON IPC socket writer")
            self._writer.close()

        if self._poll_task is not None:
            logger.debug("Cancelling mpv polling task")
            self._poll_task.cancel()


@@ 147,8 147,12 @@ class MpvClient:

        if self._futures:
            logger.debug("Cancelling pending futures")
            for fut in self._futures.values():
                fut.cancel()
            for request_id in list(self._futures):
                self._futures.pop(request_id).cancel()

        if self._writer is not None:
            logger.debug("Closing mpv JSON IPC socket writer")
            self._writer.close()

    async def _poll(self):
        while not self._reader.at_eof():


@@ 215,13 219,15 @@ class MpvClient:
        logger.debug("mpv IPC payload (nowait=%r): %s", nowait, payload)

        data = self._json_dumps(payload)
        self._writer.write(data.encode(encoding=IPC_ENCODING))
        self._writer.write(NEWLINE)

        try:
            await self._writer.drain()
        except ConnectionError as exc:
            raise IPCConnectionException("IPC connection to mpv failed") from exc
        async with self._writer_lock:
            self._writer.write(data.encode(encoding=IPC_ENCODING))
            self._writer.write(NEWLINE)

            try:
                await self._writer.drain()
            except ConnectionError as exc:
                raise IPCConnectionException("IPC connection to mpv failed") from exc

        if not nowait:
            return await fut

M tests/test_client.py => tests/test_client.py +14 -0
@@ 193,3 193,17 @@ async def test_ipc_after_mpv_exit():
                match=r"IPC connection to mpv failed",
            ):
                await client.ipc("client_name")


@pytest.mark.mpv
@pytest.mark.asyncio
async def test_client_context_manager_cleanup():
    with tmp_socket() as socket_file:
        mpv = aio_mpv_ipc.start_mpv(
            socket_file.name, *MPV_ARGS, terminate=True, mpv_path=MPV_PATH
        )
        client = aio_mpv_ipc.MpvClient(socket_file.name)
        async with mpv as mpv, client:
            pass

        assert client._poll_task is None