1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
import contextvars
import functools
import inspect
from contextvars import ContextVar
from typing import Optional, ParamSpec, TypeVar, NoReturn, Union
from collections.abc import Callable
from debputy.exceptions import (
UnhandledOrUnexpectedErrorFromPluginError,
DebputyRuntimeError,
)
from debputy.util import _trace_log, _is_trace_log_enabled
_current_debputy_plugin_cxt_var: ContextVar[str | None] = ContextVar(
"current_debputy_plugin",
default=None,
)
P = ParamSpec("P")
R = TypeVar("R")
def current_debputy_plugin_if_present() -> str | None:
return _current_debputy_plugin_cxt_var.get()
def current_debputy_plugin_required() -> str:
v = current_debputy_plugin_if_present()
if v is None:
raise AssertionError(
"current_debputy_plugin_required() was called, but no plugin was set."
)
return v
def wrap_plugin_code(
plugin_name: str,
func: Callable[P, R],
*,
non_debputy_exception_handling: bool | Callable[[Exception], NoReturn] = True,
) -> Callable[P, R]:
if isinstance(non_debputy_exception_handling, bool):
runner = run_in_context_of_plugin
if non_debputy_exception_handling:
runner = run_in_context_of_plugin_wrap_errors
def _plugin_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return runner(plugin_name, func, *args, **kwargs)
functools.update_wrapper(_plugin_wrapper, func)
return _plugin_wrapper
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return run_in_context_of_plugin(plugin_name, func, *args, **kwargs)
except DebputyRuntimeError:
raise
except Exception as e:
non_debputy_exception_handling(e)
functools.update_wrapper(_wrapper, func)
return _wrapper
def run_in_context_of_plugin(
plugin: str,
func: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
context = contextvars.copy_context()
if _is_trace_log_enabled():
call_stack = inspect.stack()
caller: str = "[N/A]"
for frame in call_stack:
if frame.filename != __file__:
try:
fname = frame.frame.f_code.co_qualname
except AttributeError:
fname = None
if fname is None:
fname = frame.function
caller = f"{frame.filename}:{frame.lineno} ({fname})"
break
# Do not keep the reference longer than necessary
del call_stack
_trace_log(
f"Switching plugin context to {plugin} at {caller} (from context: {current_debputy_plugin_if_present()})"
)
# Wish we could just do a regular set without wrapping it in `context.run`
context.run(_current_debputy_plugin_cxt_var.set, plugin)
return context.run(func, *args, **kwargs)
def run_in_context_of_plugin_wrap_errors(
plugin: str,
func: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
return run_in_context_of_plugin(plugin, func, *args, **kwargs)
except DebputyRuntimeError:
raise
except Exception as e:
if plugin != "debputy":
raise UnhandledOrUnexpectedErrorFromPluginError(
f"{func.__qualname__} from the plugin {plugin} raised exception that was not expected here."
) from e
else:
raise AssertionError(
"Bug in the `debputy` plugin: Unhandled exception."
) from e
|