debputy.plugin.plugin_state

src/debputy/plugin/plugin_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import contextvars
import functools
import inspect
from contextvars import ContextVar
from typing import Optional, ParamSpec, TypeVar, NoReturn, Union
from collections.abc import Callable

from debputy.exceptions import (
    UnhandledOrUnexpectedErrorFromPluginError,
    DebputyRuntimeError,
)
from debputy.util import _trace_log, _is_trace_log_enabled

_current_debputy_plugin_cxt_var: ContextVar[str | None] = ContextVar(
    "current_debputy_plugin",
    default=None,
)

P = ParamSpec("P")
R = TypeVar("R")


def current_debputy_plugin_if_present() -> str | None:
    return _current_debputy_plugin_cxt_var.get()


def current_debputy_plugin_required() -> str:
    v = current_debputy_plugin_if_present()
    if v is None:
        raise AssertionError(
            "current_debputy_plugin_required() was called, but no plugin was set."
        )
    return v


def wrap_plugin_code(
    plugin_name: str,
    func: Callable[P, R],
    *,
    non_debputy_exception_handling: bool | Callable[[Exception], NoReturn] = True,
) -> Callable[P, R]:
    if isinstance(non_debputy_exception_handling, bool):

        runner = run_in_context_of_plugin
        if non_debputy_exception_handling:
            runner = run_in_context_of_plugin_wrap_errors

        def _plugin_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
            return runner(plugin_name, func, *args, **kwargs)

        functools.update_wrapper(_plugin_wrapper, func)
        return _plugin_wrapper

    def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        try:
            return run_in_context_of_plugin(plugin_name, func, *args, **kwargs)
        except DebputyRuntimeError:
            raise
        except Exception as e:
            non_debputy_exception_handling(e)

    functools.update_wrapper(_wrapper, func)
    return _wrapper


def run_in_context_of_plugin(
    plugin: str,
    func: Callable[P, R],
    *args: P.args,
    **kwargs: P.kwargs,
) -> R:
    context = contextvars.copy_context()
    if _is_trace_log_enabled():
        call_stack = inspect.stack()
        caller: str = "[N/A]"
        for frame in call_stack:
            if frame.filename != __file__:
                try:
                    fname = frame.frame.f_code.co_qualname
                except AttributeError:
                    fname = None
                if fname is None:
                    fname = frame.function
                caller = f"{frame.filename}:{frame.lineno} ({fname})"
                break
        # Do not keep the reference longer than necessary
        del call_stack
        _trace_log(
            f"Switching plugin context to {plugin} at {caller} (from context: {current_debputy_plugin_if_present()})"
        )
    # Wish we could just do a regular set without wrapping it in `context.run`
    context.run(_current_debputy_plugin_cxt_var.set, plugin)
    return context.run(func, *args, **kwargs)


def run_in_context_of_plugin_wrap_errors(
    plugin: str,
    func: Callable[P, R],
    *args: P.args,
    **kwargs: P.kwargs,
) -> R:
    try:
        return run_in_context_of_plugin(plugin, func, *args, **kwargs)
    except DebputyRuntimeError:
        raise
    except Exception as e:
        if plugin != "debputy":
            raise UnhandledOrUnexpectedErrorFromPluginError(
                f"{func.__qualname__} from the plugin {plugin} raised exception that was not expected here."
            ) from e
        else:
            raise AssertionError(
                "Bug in the `debputy` plugin: Unhandled exception."
            ) from e