~mser/pkg.mser.at

3a46287d114c9601184b7fd69888709acf22d923 — cryzed 2 years ago f385786
Add logging to unshare-net
2 files changed, 58 insertions(+), 25 deletions(-)

M packages/unshare-net/PKGBUILD
M packages/unshare-net/unshare-net
M packages/unshare-net/PKGBUILD => packages/unshare-net/PKGBUILD +8 -8
@@ 1,14 1,14 @@
# Maintainer: Michael Serajnik <m at mser dot at>
pkgname=unshare-net
pkgver=11
pkgver=12
pkgrel=1
pkgdesc="Selectively whitelist traffic to specified IPs and domains for target applications"
arch=("any")
url="https://git.sr.ht/~mser/pkg.mser.at/tree/master/item/packages/unshare-net"
license=("AGPL3")
depends=("python")
source=("unshare-net")
sha512sums=('482312f83c3c71bf83893a608a4dd54092853c6fbe7758b74ead0ca67645922b8e1b6829c7effbf58633e887f23d0d230ee809b58f9f1e96d8a1ad10d6118922')
pkgdesc='Selectively whitelist traffic to specified IPs and domains for target applications'
arch=('any')
url='https://git.sr.ht/~mser/pkg.mser.at/tree/master/item/packages/unshare-net'
license=('AGPL3')
depends=('python' 'python-appdirs')
source=('unshare-net')
sha512sums=('69b1af86905ada0a17f3a9fd0ab03d7fa75f9415b56ecaa3a144f2741b0533216d7bd08c8b9428e94f59eff3d718186f212f7c4a8702eb33d90a6e8715fddb17')

package() {
  install -D --mode 755 "${srcdir}/unshare-net" --target-directory "${pkgdir}/usr/bin"

M packages/unshare-net/unshare-net => packages/unshare-net/unshare-net +50 -17
@@ 18,7 18,10 @@
import argparse
import collections.abc as abc
import enum
import gzip
import hashlib
import logging
import logging.handlers
import os
import pathlib
import shlex


@@ 30,36 33,37 @@ import tempfile
import time
import typing as T

import appdirs

APPLICATION_NAME = "unshare-net"
LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
MOUNTS_PATH = pathlib.Path("/proc/self/mounts")
IPTABLES_MAX_CHAIN_NAME_LENGTH = 28

logger = logging.getLogger(APPLICATION_NAME)


class ExitCode(enum.IntEnum):
    SUCCESS = 0
    FAILURE = 1


def stderr(*args: T.Any, **kwargs: T.Any) -> None:
    kwargs["file"] = sys.stderr
    print(*args, **kwargs)


def cgroup_create(name: str, mount_path: pathlib.Path) -> pathlib.Path:
    path = mount_path / name
    stderr(f"creating: {str(path)!r}")
    logger.debug("creating: %r", str(path))
    path.mkdir()
    return path


def cgroup_remove(name: str, mount_path: pathlib.Path) -> None:
    path = mount_path / name
    stderr(f"removing {str(path)!r}")
    logger.debug("removing %r", str(path))
    path.rmdir()


def cgroup_add_process(name: str, pid: int, mount_path: pathlib.Path) -> None:
    path = mount_path / name / "cgroup.procs"
    stderr(f"{pid} -> {str(path)!r}")
    logger.debug("%d -> %r", pid, str(path))
    path.write_text(str(pid), encoding="ascii")




@@ 72,7 76,7 @@ def shell_escape(command: abc.Iterable[str]) -> str:


def run_command(command: abc.Sequence[str], **run_kwargs: T.Any) -> subprocess.CompletedProcess:
    stderr(shell_escape(command))
    logger.debug(shell_escape(command))
    return subprocess.run(command, **run_kwargs)




@@ 124,7 128,7 @@ def find_cgroup_mount_paths() -> list[pathlib.Path]:

def get_identifier() -> str:
    time_hash = hashlib.md5(str(time.time()).encode("ascii")).hexdigest()
    return f"unshare-net-{time_hash}"[:IPTABLES_MAX_CHAIN_NAME_LENGTH]
    return f"{APPLICATION_NAME}-{time_hash}"[:IPTABLES_MAX_CHAIN_NAME_LENGTH]


def get_argument_parser() -> argparse.ArgumentParser:


@@ 132,15 136,32 @@ def get_argument_parser() -> argparse.ArgumentParser:
    parser.add_argument("--allow", "-a", action="append", default=[])
    parser.add_argument("--allow-lan", "-A", action="store_true")
    parser.add_argument("--user", "-u", default=os.getenv("SUDO_USER"))
    parser.add_argument(
        "--logging-level", "-l", choices={"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"}, default="INFO"
    )
    return parser


def get_log_path() -> pathlib.Path:
    path = pathlib.Path(appdirs.user_log_dir(APPLICATION_NAME))
    path.mkdir(parents=True, exist_ok=True)
    return path / (APPLICATION_NAME + ".log")


def rotate_log(source: str, destination: str) -> None:
    source_path = pathlib.Path(source)
    with source_path.open("rb") as source_file, gzip.open(destination + ".gz", "wb") as destination_file:
        destination_file.writelines(source_file)

    source_path.unlink()


def run(arguments: argparse.Namespace, command: list[str]) -> ExitCode:
    if not arguments.user:
        stderr("Please specify --user/-u manually! (missing $SUDO_USER)")
        logger.error("Please specify --user/-u manually! (missing $SUDO_USER)")
        return ExitCode.FAILURE
    elif arguments.user == "root":
        stderr("Not allowed to run target as root!")
        logger.error("Not allowed to run target as root!")
        return ExitCode.FAILURE

    # Try to find an existing cgroup2 mount path


@@ 150,8 171,8 @@ def run(arguments: argparse.Namespace, command: list[str]) -> ExitCode:
    if cgroup_mount_paths:
        cgroup_mount_path = cgroup_mount_paths[0]
    else:
        cgroup_mount_path = pathlib.Path(tempfile.mkdtemp(prefix="unshare-net-"))
        stderr(f"created {str(cgroup_mount_path)!r}")
        cgroup_mount_path = pathlib.Path(tempfile.mkdtemp(prefix=APPLICATION_NAME + "-"))
        logger.debug("created %r", str(cgroup_mount_path))
        # These are the mount options used by Arch Linux (systemd), so I assume they are fine
        mount("--types", "cgroup2", identifier, str(cgroup_mount_path), "--options", "defaults,nosuid,nodev,noexec")
        cgroup_mount_created = True


@@ 201,7 222,7 @@ def run(arguments: argparse.Namespace, command: list[str]) -> ExitCode:

    # We use --session-command so we can spawn a shell with job control too
    su_command = "su", arguments.user, "--session-command", *command
    stderr(shell_escape(su_command))
    logger.debug(shell_escape(su_command))
    process = subprocess.Popen(su_command)
    cgroup_add_process(identifier, process.pid, cgroup_mount_path)



@@ 228,7 249,7 @@ def run(arguments: argparse.Namespace, command: list[str]) -> ExitCode:
    cgroup_remove(identifier, cgroup_mount_path)

    if cgroup_mount_created:
        stderr(f"removing {str(cgroup_mount_path)!r}")
        logger.debug("removing %r", str(cgroup_mount_path))
        umount(cgroup_mount_path)
        cgroup_mount_path.rmdir()



@@ 245,7 266,19 @@ def get_command() -> T.Optional[list[str]]:
def main() -> None:
    parser = get_argument_parser()
    arguments, rest = parser.parse_known_args()
    parser.exit(run(arguments, get_command() or rest))

    logging.basicConfig(format=LOGGING_FORMAT)
    logger.setLevel(arguments.logging_level)
    handler = logging.handlers.TimedRotatingFileHandler(get_log_path(), when="midnight")
    handler.rotator = rotate_log
    handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
    logger.addHandler(handler)

    try:
        parser.exit(run(arguments, get_command() or rest))
    except Exception:
        logger.critical("Error", exc_info=True)
        parser.exit(ExitCode.FAILURE)


if __name__ == "__main__":