From c681cb7713ab0dc27aa8317d4bfd9898c29fcd88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Tue, 2 Jun 2026 13:16:17 -0300 Subject: [PATCH 1/2] fix(core): 4-tier transition resolution + wildcard ('*') source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces wildcard transition sources so a graph can route from ANY action to a handler, e.g. ("*", "handler", expr("error is not None")). Resolves the transition-precedence bug that blocked issue #30: an action's own default transition would shadow a guarded wildcard route, so the route never fired. get_next_node now resolves in 4 tiers, first match wins: 1. source-specific non-default transitions (insertion order) 2. wildcard non-default transitions (insertion order) 3. source-specific default transition 4. wildcard default transition This is a core-routing semantic change: in default-FIRST graphs a previously dead non-default transition becomes live. For the common default-LAST graphs (and graphs without wildcards) resolution is identical to before; covered by a regression test. The wildcard source is carried by a sentinel Result().with_name("*") used only as a transition from_; it is never added to the action map, so introspection (visualize, ApplicationModel serialization) is unaffected. Part of #30. Signed-off-by: André Ahlert --- burr/core/graph.py | 52 ++++++++++++++++++++++----- tests/core/test_graph.py | 77 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 9 deletions(-) diff --git a/burr/core/graph.py b/burr/core/graph.py index fcbc9c5cb..1b2b324cb 100644 --- a/burr/core/graph.py +++ b/burr/core/graph.py @@ -22,7 +22,7 @@ import pathlib from typing import Any, Callable, List, Literal, Optional, Set, Tuple, Union -from burr.core.action import Action, Condition, create_action, default +from burr.core.action import Action, Condition, Result, create_action, default from burr.core.state import State from burr.core.validation import BASE_ERROR_MESSAGE, assert_set @@ -59,7 +59,9 @@ def _validate_transitions( ): exhausted = {} # items for which we have seen a default transition for from_, to, condition in transitions: - if from_ not in actions: + # "*" is a wildcard source -- it routes from ANY action, so it is not + # required to be a declared action. The target, however, must be real. + if from_ != "*" and from_ not in actions: raise ValueError( f"Transition source: `{from_}` not found in actions! " f"Please add to actions using with_actions({from_}=...)" @@ -69,7 +71,9 @@ def _validate_transitions( f"Transition target: `{to}` not found in actions! " f"Please add to actions using with_actions({to}=...)" ) - if condition.name == "default": # we have seen a default transition + # Skip default-exhaustion bookkeeping for the wildcard source -- it does not + # shadow per-source defaults (see get_next_node's 4-tier resolution). + if from_ != "*" and condition.name == "default": # we have seen a default transition if from_ in exhausted: raise ValueError( f"Transition `{from_}` -> `{to}` is redundant -- " @@ -150,13 +154,36 @@ def _create_action_tag_map(actions: List[Action]) -> dict[str, List[Action]]: def get_next_node( self, prior_step: Optional[str], state: State, entrypoint: str ) -> Optional[Action]: - """Gives the next node to execute given state + prior step.""" + """Gives the next node to execute given state + prior step. + + Resolution uses a 4-tier precedence so that guarded wildcard ("*") transitions + can route errors from any action while leaving common default-last graphs + unchanged. A condition is "default" iff ``condition.name == "default"``. The + tiers, evaluated in order with first match winning: + + 1. source-specific non-default transitions (insertion order) + 2. wildcard non-default transitions (insertion order) + 3. source-specific default transition + 4. wildcard default transition + """ if prior_step is None: return self._action_map[entrypoint] - possibilities = self._adjacency_map[prior_step] - for next_action, condition in possibilities: - if condition.run(state)[Condition.KEY]: - return self._action_map[next_action] + src = self._adjacency_map[prior_step] + wild = self._adjacency_map.get("*", []) + + def _is_default(condition: Condition) -> bool: + return condition.name == "default" + + tiers = ( + [(t, c) for t, c in src if not _is_default(c)], + [(t, c) for t, c in wild if not _is_default(c)], + [(t, c) for t, c in src if _is_default(c)], + [(t, c) for t, c in wild if _is_default(c)], + ) + for tier in tiers: + for next_action, condition in tier: + if condition.run(state)[Condition.KEY]: + return self._action_map[next_action] return None def get_action(self, action_name: str) -> Optional[Action]: @@ -372,11 +399,18 @@ def build(self) -> Graph: actions_by_name = {action.name: action for action in self.actions} all_actions = set(actions_by_name.keys()) _validate_transitions(self.transitions, all_actions) + + # Sentinel action used solely as the `from_` of wildcard ("*") transitions. + # It is never added to `actions`, so `_action_map` is unaffected; it exists + # only so the Transition dataclass and adjacency map (which key off + # `from_.name`) can carry the "*" source. We reuse Result as a concrete, + # no-op Action so it behaves like a real Action if anything inspects it. + wildcard_source = Result().with_name("*") return Graph( actions=self.actions, transitions=[ Transition( - from_=actions_by_name[from_], + from_=wildcard_source if from_ == "*" else actions_by_name[from_], to=actions_by_name[to], condition=condition, ) diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py index 22cdda7b6..ead73896e 100644 --- a/tests/core/test_graph.py +++ b/tests/core/test_graph.py @@ -204,3 +204,80 @@ def test_get_actions_by_tag(): assert len(graph.get_actions_by_tag("tag3")) == 1 with pytest.raises(ValueError, match="not found"): graph.get_actions_by_tag("tag4") + + +def _flaky_handler_graph(): + """Source 'flaky' has BOTH a default ('flaky'->'next') AND a wildcard guarded + error route ('*'->'handler', error is not None). Plus 'next' and 'handler' + actions so targets are real.""" + return ( + GraphBuilder() + .with_actions( + flaky=Result("count"), + next=Result("count"), + handler=Result("count"), + ) + .with_transitions( + ("flaky", "next"), # default + ("*", "handler", Condition.expr("error is not None")), + ) + .build() + ) + + +def test_wildcard_guarded_route_beats_source_default_when_error_set(): + """THE crux test: with error set, a guarded wildcard wins over the source's own default.""" + graph = _flaky_handler_graph() + state = State({"count": 0, "error": {"type": "ValueError"}}) + assert graph.get_next_node("flaky", state, entrypoint="flaky").name == "handler" + + +def test_wildcard_guarded_route_falls_through_to_source_default_when_error_unset(): + graph = _flaky_handler_graph() + state = State({"count": 0, "error": None}) + assert graph.get_next_node("flaky", state, entrypoint="flaky").name == "next" + + +def test_specific_guarded_beats_wildcard_guarded(): + """A source-specific guarded transition wins over a wildcard guarded transition.""" + graph = ( + GraphBuilder() + .with_actions( + a=Result("count"), + specific=Result("count"), + wild=Result("count"), + ) + .with_transitions( + ("a", "specific", Condition.expr("error is not None")), + ("*", "wild", Condition.expr("error is not None")), + ) + .build() + ) + state = State({"count": 0, "error": {"type": "ValueError"}}) + assert graph.get_next_node("a", state, entrypoint="a").name == "specific" + + +def test_default_last_regression_resolves_identically(): + """Existing guarded + default-last graph (no wildcards) resolves as before.""" + graph = ( + GraphBuilder() + .with_actions(counter=base_counter_action, result=Result("count")) + .with_transitions( + ("counter", "counter", Condition.expr("count < 10")), + ("counter", "result"), + ) + .build() + ) + assert ( + graph.get_next_node("counter", State({"count": 0}), entrypoint="counter").name == "counter" + ) + assert ( + graph.get_next_node("counter", State({"count": 10}), entrypoint="counter").name == "result" + ) + + +def test__validate_transitions_accepts_wildcard(): + assert _validate_transitions( + [("*", "handler", Condition.expr("error is not None"))], + {"flaky", "handler"}, + ) From f0f1a6613a4873a190defe8b563035d24b565ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Ahlert?= Date: Tue, 2 Jun 2026 13:16:27 -0300 Subject: [PATCH 2/2] feat(core): exception transitions via on_error capture (#30) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the ability to suppress an action exception and route to an error-handling action instead of breaking control flow. * capture_as(field, include_traceback=True): an on_error handler that writes a JSON-serializable record {type, message, traceback} into a state field. Exported from burr.core. * @action(..., on_error=handler): per-action error handler, any callable (State, Exception) -> State. * ApplicationBuilder.with_error_handling(handler): builder-level global handler applied to actions without their own on_error. Per-action wins. When an action raises, the effective handler suppresses the exception, writes captured state, and execution continues; a wildcard transition ("*", handler, expr("error is not None")) then routes to the handler. The captured field bypasses reducer write-validation and need not be in the action's declared writes. A failing handler never masks the original exception (re-raised with the handler error as cause). The capture field must be seeded (e.g. .with_state(error=None)) since expr() raises on a missing key; documented in the capture_as example. Also restructures _astep so a SYNC action driven through the async path reports the real result/state to the async post_run_step hook instead of stale values (previously delegated to _step and discarded). Behavior is locked by a regression test. Streaming error handling (@streaming_action on_error during incremental consumption) is intentionally out of scope and left as follow-up. Closes #30. Signed-off-by: André Ahlert --- burr/core/__init__.py | 13 ++- burr/core/action.py | 106 ++++++++++++++++++- burr/core/application.py | 111 ++++++++++++++++---- tests/core/test_action.py | 59 +++++++++++ tests/core/test_application.py | 186 +++++++++++++++++++++++++++++++++ 5 files changed, 452 insertions(+), 23 deletions(-) diff --git a/burr/core/__init__.py b/burr/core/__init__.py index aa2f75a46..3c127f5e3 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. -from burr.core.action import Action, Condition, Result, action, default, expr, type_eraser, when +from burr.core.action import ( + Action, + Condition, + Result, + action, + capture_as, + default, + expr, + type_eraser, + when, +) from burr.core.application import ( Application, ApplicationBuilder, @@ -28,6 +38,7 @@ __all__ = [ "action", "Action", + "capture_as", "Application", "ApplicationBuilder", "ApplicationGraph", diff --git a/burr/core/action.py b/burr/core/action.py index 08771d411..10b79b42c 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -22,6 +22,7 @@ import inspect import sys import textwrap +import traceback import types import typing from collections.abc import AsyncIterator @@ -379,6 +380,82 @@ def tags(self) -> list[str]: """ return [] + @property + def on_error(self) -> Optional[Callable[[State, Exception], State]]: + """Returns the error handler associated with this action, if any. + + An error handler is any callable ``(State, Exception) -> State``. If it is set + and the action raises, the application will call it to produce a new state + (suppressing the exception) instead of re-raising. See + :py:func:`capture_as ` for a built-in handler. + + :return: The error handler callable, or None if not set + """ + return getattr(self, "_on_error", None) + + +class capture_as: + """Error handler that suppresses an exception and records a JSON-serializable + summary of it into the given state field. + + Use this with the ``on_error`` parameter of the :py:func:`@action ` + decorator, or with :py:meth:`ApplicationBuilder.with_error_handling + `: + + .. code-block:: python + + @action(reads=[], writes=["output"], on_error=capture_as("error")) + def flaky(state: State) -> tuple[dict, State]: + result = {"output": call_some_api(...)} + return result, state.update(**result) + + @action(reads=["error"], writes=[]) + def handler(state: State) -> tuple[dict, State]: + ... # inspect/reset state["error"], then route onward + + app = ( + ApplicationBuilder() + .with_actions(flaky=flaky, handler=handler) + # the capture field MUST be initialized -- expr() raises on a missing key, + # so seed ``error`` to None so the wildcard condition is safe to evaluate + # on every step before any exception has occurred. + .with_state(error=None) + .with_transitions(("*", "handler", expr("error is not None"))) + .with_entrypoint("flaky") + .build() + ) + + Note the captured field need NOT appear in the action's ``writes`` -- the handler + writes it directly, bypassing reducer write-validation. When the action raises, + ``state["error"]`` is set to a JSON-serializable dict like:: + + {"type": "ValueError", "message": "...", "traceback": "..."} + + which the wildcard transition above then routes on. + """ + + def __init__(self, field: str, include_traceback: bool = True): + """:param field: State field to write the error record to. + :param include_traceback: Whether to include the formatted traceback string. + """ + self.field = field + self.include_traceback = include_traceback + + def __call__(self, state: State, exception: Exception) -> State: + record = {"type": type(exception).__name__, "message": str(exception)} + if self.include_traceback: + # Formatting must never fail -- a handler that raises would mask the + # original exception in the calling context. Fall back to a safe + # placeholder if traceback formatting blows up (e.g. corrupted + # traceback objects or pathological custom exceptions). + try: + record["traceback"] = "".join( + traceback.format_exception(type(exception), exception, exception.__traceback__) + ) + except Exception as e: # noqa: BLE001 + record["traceback"] = f"" + return state.update(**{self.field: record}) + class Condition(Function): KEY = "PROCEED" @@ -806,6 +883,7 @@ def __init__( originating_fn: Optional[Callable] = None, schema: ActionSchema = DEFAULT_SCHEMA, tags: Optional[List[str]] = None, + on_error: Optional[Callable[[State, Exception], State]] = None, ): """Instantiates a function-based action with the given function, reads, and writes. The function must take in a state and return a tuple of (result, new_state). @@ -815,6 +893,8 @@ def __init__( :param writes: Keys that the function writes to the state :param bound_params: Prior bound parameters :param input_spec: Specification for inputs. Will derive from function if not provided. + :param on_error: Optional error handler ``(State, Exception) -> State``. If set and the + action raises, the exception is suppressed and the returned state is used instead. """ super(FunctionBasedAction, self).__init__() self._originating_fn = originating_fn if originating_fn is not None else fn @@ -834,6 +914,7 @@ def __init__( ) self._schema = schema self._tags = tags if tags is not None else [] + self._on_error = on_error @property def fn(self) -> Callable: @@ -877,6 +958,7 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": originating_fn=self._originating_fn, schema=self._schema, tags=self._tags, + on_error=self._on_error, ) def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: @@ -1473,7 +1555,13 @@ def pydantic( tags=tags, ) - def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str]] = None): + def __init__( + self, + reads: List[str], + writes: List[str], + tags: Optional[List[str]] = None, + on_error: Optional[Callable[[State, Exception], State]] = None, + ): """Decorator to create a function-based action. This is user-facing. Note that, in the future, with typed state, we may not need this for all cases. @@ -1484,17 +1572,24 @@ def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str] :param reads: Items to read from the state :param writes: Items to write to the state + :param tags: Optional list of tags to associate with this action + :param on_error: Optional error handler ``(State, Exception) -> State``. If set and the + action raises, the exception is suppressed and the returned state is used instead. + See :py:func:`capture_as ` for a built-in handler. :return: The decorator to assign the function as an action """ self.reads = reads self.writes = writes self.tags = tags + self.on_error = on_error def __call__(self, fn) -> FunctionRepresentingAction: setattr( fn, FunctionBasedAction.ACTION_FUNCTION, - FunctionBasedAction(fn, self.reads, self.writes, tags=self.tags), + FunctionBasedAction( + fn, self.reads, self.writes, tags=self.tags, on_error=self.on_error + ), ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn @@ -1537,7 +1632,12 @@ def pydantic( tags=tags, ) - def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str]] = None): + def __init__( + self, + reads: List[str], + writes: List[str], + tags: Optional[List[str]] = None, + ): """Decorator to create a streaming function-based action. This is user-facing. If parameters are not bound, they will be interpreted as inputs and must be passed in at runtime. diff --git a/burr/core/application.py b/burr/core/application.py index 25bce4a10..f82e30f74 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -849,6 +849,7 @@ def __init__( parallel_executor_factory: Optional[Executor] = None, state_persister: Union[BaseStateSaver, LifecycleAdapter, None] = None, state_initializer: Union[BaseStateLoader, LifecycleAdapter, None] = None, + global_error_handler: Optional[Callable[[State, Exception], State]] = None, ): """Instantiates an Application. This is an internal API -- use the builder! @@ -900,6 +901,7 @@ def __init__( self._spawning_parent_pointer = spawning_parent_pointer self._state_initializer = state_initializer self._state_persister = state_persister + self._global_error_handler = global_error_handler self._adapter_set.call_all_lifecycle_hooks_sync( "post_application_create", state=self._state, @@ -985,9 +987,29 @@ def _step( new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: - exc = e - logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) - raise e + # Resolve the effective error handler: per-action takes precedence + # over the builder-level global handler. + handler = next_action.on_error or self._global_error_handler + if handler is not None: + # The handler writes the captured record via state.update, bypassing + # the normal reducer -- so it is NOT subject to writes-validation. + # The handler itself must never mask the original exception: if it + # raises, surface the original exception with the handler failure as + # its cause rather than silently swapping which error propagates. + try: + new_state = handler(self._state, e) + except Exception as handler_exc: + logger.exception(f"Error handler for {next_action.name} failed") + raise e from handler_exc + new_state = self._update_internal_state_value(new_state, next_action) + self._set_state(new_state) + result = None + exc = e # keep for post_run_step hook observability + # DO NOT re-raise; flow continues. get_next_node routes via the captured state field. + else: + exc = e + logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) + raise e finally: if _run_hooks: self._adapter_set.call_all_lifecycle_hooks_sync( @@ -1094,12 +1116,19 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True return None if inputs is None: inputs = {} + # Process inputs before the pre_run_step hook so that hooks in the async path + # observe the same processed inputs (dependency factory injection + validation + # applied) as the sync path does. We process once here and run the action inline + # (rather than delegating to _step) so the async lifecycle dispatcher still fires + # and the dependency factory -- which constructs fresh tracer/context objects per + # call -- is not invoked twice for a single step. + action_inputs = self._process_inputs(inputs, next_action) if _run_hooks: await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( "pre_run_step", action=next_action, state=self._state, - inputs=inputs, + inputs=action_inputs, sequence_id=self.sequence_id, app_id=self._uid, partition_key=self._partition_key, @@ -1109,18 +1138,20 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True new_state = self._state try: if not next_action.is_async(): - # we can just delegate to the synchronous version, it will block the event loop, - # but that's safer than assuming its OK to launch a thread - # TODO -- add an option/configuration to launch a thread (yikes, not super safe, but for a pure function - # which this is supposed to be its OK). - # this delegates hooks to the synchronous version, so we'll call all of them as well - # In this case we allow the self._step to do input processing - return self._step( - inputs=inputs, _run_hooks=False - ) # Skip hooks as we already ran all of them/will run all of them in this function's finally - # In this case we want to process inputs because we run the function directly - action_inputs = self._process_inputs(inputs, next_action) - if next_action.single_step: + # The action is synchronous -- running it here will block the event loop, + # but that's safer than assuming it's OK to launch a thread for what is + # supposed to be a pure function. We still drive it through this async + # method so the async lifecycle hooks fire and inputs are processed once. + if next_action.single_step: + result, new_state = _run_single_step_action( + next_action, self._state, action_inputs + ) + else: + result = _run_function( + next_action, self._state, action_inputs, name=next_action.name + ) + new_state = _run_reducer(next_action, self._state, result, next_action.name) + elif next_action.single_step: result, new_state = await _arun_single_step_action( next_action, self._state, inputs=action_inputs ) @@ -1135,9 +1166,29 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: - exc = e - logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) - raise e + # Resolve the effective error handler: per-action takes precedence + # over the builder-level global handler. + handler = next_action.on_error or self._global_error_handler + if handler is not None: + # The handler writes the captured record via state.update, bypassing + # the normal reducer -- so it is NOT subject to writes-validation. + # The handler itself must never mask the original exception: if it + # raises, surface the original exception with the handler failure as + # its cause rather than silently swapping which error propagates. + try: + new_state = handler(self._state, e) + except Exception as handler_exc: + logger.exception(f"Error handler for {next_action.name} failed") + raise e from handler_exc + new_state = self._update_internal_state_value(new_state, next_action) + self._set_state(new_state) + result = None + exc = e # keep for post_run_step hook observability + # DO NOT re-raise; flow continues. get_next_node routes via the captured state field. + else: + exc = e + logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) + raise e finally: if _run_hooks: await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( @@ -2222,6 +2273,7 @@ def __init__(self): self.parallel_executor_factory = None self.state_persister = None self._is_async: bool = False + self._error_handler: Optional[Callable[[State, Exception], State]] = None def with_identifiers( self, app_id: str = None, partition_key: str = None, sequence_id: int = None @@ -2419,6 +2471,26 @@ def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder[StateTy self.lifecycle_adapters.extend(adapters) return self + def with_error_handling( + self, handler: Callable[[State, Exception], State] + ) -> "ApplicationBuilder[StateType]": + """Sets a global error handler for the application. This is any callable + ``(State, Exception) -> State``. If an action raises and does not have its own + per-action ``on_error`` handler, this global handler is invoked: the exception + is suppressed and the returned state is used instead. Per-action ``on_error`` + takes precedence over this global handler. + + See :py:func:`capture_as ` for a built-in handler + that records a JSON-serializable summary of the exception into a state field, + which you can then route on via a wildcard transition, e.g. + ``("*", "handler", expr("error is not None"))``. + + :param handler: Error handler callable ``(State, Exception) -> State`` + :return: The application builder for future chaining. + """ + self._error_handler = handler + return self + def with_tracker( self, tracker: Union[Literal["local"], "TrackingClient"] = "local", @@ -2782,6 +2854,7 @@ def _build_common(self) -> Application: parallel_executor_factory=self.parallel_executor_factory, state_persister=self.state_persister, state_initializer=self.state_initializer, + global_error_handler=self._error_handler, ) def build(self) -> Application[StateType]: diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 367ee58e8..4fb39c638 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -1184,3 +1184,62 @@ def test_exported_from_burr_core(self): from burr.core import type_eraser as te assert te is type_eraser + + +def test_capture_as_returns_serializable_dict(): + import json + + from burr.core.action import capture_as + + handler = capture_as("error") + state = State() + try: + raise ValueError("boom") + except ValueError as e: + new_state = handler(state, e) + record = new_state["error"] + assert record["type"] == "ValueError" + assert record["message"] == "boom" + assert "traceback" in record + # must be JSON-serializable -- no raw Exception object + json.dumps(record) + + +def test_capture_as_include_traceback_toggle(): + from burr.core.action import capture_as + + state = State() + try: + raise RuntimeError("nope") + except RuntimeError as e: + with_tb = capture_as("error", include_traceback=True)(state, e) + without_tb = capture_as("error", include_traceback=False)(state, e) + assert "traceback" in with_tb["error"] + assert "traceback" not in without_tb["error"] + + +def test_action_decorator_accepts_on_error_and_attaches_it(): + from burr.core.action import FunctionBasedAction, capture_as + + handler = capture_as("error") + + @action(reads=[], writes=["error"], on_error=handler) + def my_action(state: State) -> Tuple[dict, State]: + return {}, state + + fba = getattr(my_action, FunctionBasedAction.ACTION_FUNCTION) + assert fba.on_error is handler + + +def test_bind_preserves_on_error(): + from burr.core.action import FunctionBasedAction, capture_as + + handler = capture_as("error") + + @action(reads=[], writes=["error"], on_error=handler) + def my_action(state: State, z: int) -> Tuple[dict, State]: + return {}, state + + bound = my_action.bind(z=2) + fba = getattr(bound, FunctionBasedAction.ACTION_FUNCTION) + assert fba.on_error is handler diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 9313cefc9..fec15da93 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -4299,3 +4299,189 @@ def noop(state: State) -> State: app = builder.build() assert app.state["x"] == 100 + + +@action(reads=["error"], writes=[]) +def _error_route_handler(state: State) -> Tuple[dict, State]: + return {}, state + + +def _build_flaky_app(builder_handler=None, action_handler=None): + """Builds an app where 'flaky' raises, routes via wildcard to 'handler' + when error is set. flaky declares writes=['count'] but never writes it + (it raises); the error field is written by the on_error handler and must + NOT trip writes-validation.""" + from burr.core.action import capture_as # noqa: F401 + + @action(reads=[], writes=["count"], on_error=action_handler) + def flaky(state: State) -> Tuple[dict, State]: + raise ValueError("boom") + + b = ( + ApplicationBuilder() + .with_actions(flaky=flaky, handler=_error_route_handler) + .with_transitions(("*", "handler", expr("error is not None"))) + .with_entrypoint("flaky") + .with_state(error=None) + ) + if builder_handler is not None: + b = b.with_error_handling(builder_handler) + return b.build() + + +def test_on_error_capture_as_suppresses_and_routes_to_wildcard_handler_sync(): + import json + + from burr.core.action import capture_as + + app = _build_flaky_app(action_handler=capture_as("error")) + last_action, result, state = app.run(halt_after=["handler"]) + assert last_action.name == "handler" + assert isinstance(state["error"], dict) + assert state["error"]["type"] == "ValueError" + assert state["error"]["message"] == "boom" + # JSON-serializable -- no raw Exception object + json.dumps(state["error"]) + + +async def test_on_error_capture_as_suppresses_and_routes_to_wildcard_handler_async(): + import json + + from burr.core.action import capture_as + + @action(reads=[], writes=["count"], on_error=capture_as("error")) + async def flaky(state: State) -> Tuple[dict, State]: + await asyncio.sleep(0) + raise ValueError("boom") + + app = ( + ApplicationBuilder() + .with_actions(flaky=flaky, handler=_error_route_handler) + .with_transitions(("*", "handler", expr("error is not None"))) + .with_entrypoint("flaky") + .with_state(error=None) + .build() + ) + last_action, result, state = await app.arun(halt_after=["handler"]) + assert last_action.name == "handler" + assert isinstance(state["error"], dict) + assert state["error"]["type"] == "ValueError" + json.dumps(state["error"]) + + +def test_builder_with_error_handling_applies_when_no_per_action_handler(): + from burr.core.action import capture_as + + app = _build_flaky_app(builder_handler=capture_as("error")) + last_action, result, state = app.run(halt_after=["handler"]) + assert last_action.name == "handler" + assert state["error"]["type"] == "ValueError" + + +def test_captured_field_not_in_declared_writes_still_works(): + """flaky declares writes=['count'] (not 'error'); capturing 'error' must + bypass writes-validation and not raise.""" + from burr.core.action import capture_as + + app = _build_flaky_app(action_handler=capture_as("error")) + last_action, result, state = app.run(halt_after=["handler"]) + assert "error" in state + assert last_action.name == "handler" + + +def test_action_without_error_handler_still_raises(): + @action(reads=[], writes=["count"]) + def flaky(state: State) -> Tuple[dict, State]: + raise ValueError("boom") + + app = ( + ApplicationBuilder() + .with_actions(flaky=flaky, handler=_error_route_handler) + .with_transitions(("*", "handler", expr("error is not None"))) + .with_entrypoint("flaky") + .with_state(error=None) + .build() + ) + with pytest.raises(ValueError, match="boom"): + app.run(halt_after=["handler"]) + + +def test_per_action_on_error_takes_precedence_over_global_handler(): + """When both a per-action on_error and a builder-level global handler are set, + the per-action handler wins.""" + from burr.core.action import capture_as + + @action(reads=[], writes=["count"], on_error=capture_as("error")) + def flaky(state: State) -> Tuple[dict, State]: + raise ValueError("boom") + + app = ( + ApplicationBuilder() + .with_actions(flaky=flaky, handler=_error_route_handler) + .with_transitions(("*", "handler", expr("error is not None"))) + .with_entrypoint("flaky") + .with_state(error=None, global_error=None) + .with_error_handling(capture_as("global_error")) + .build() + ) + last_action, result, state = app.run(halt_after=["handler"]) + assert last_action.name == "handler" + assert isinstance(state["error"], dict) # per-action handler wrote this + assert state["global_error"] is None # global handler did NOT run + + +def test_wildcard_error_condition_safe_on_success_path(): + """The documented wildcard pattern ('*', handler, expr('error is not None')) must + NOT crash on a normal (non-raising) step and must NOT hijack the success path. + ``error`` is seeded to None so the wildcard condition evaluates to False on every + step and flow follows the action's own default transition.""" + + @action(reads=[], writes=["output"]) + def ok(state: State) -> Tuple[dict, State]: + return {"output": 1}, state.update(output=1) + + @action(reads=["output"], writes=[]) + def done(state: State) -> Tuple[dict, State]: + return {}, state + + app = ( + ApplicationBuilder() + .with_actions(ok=ok, done=done, handler=_error_route_handler) + .with_transitions( + ("ok", "done"), # source default + ("*", "handler", expr("error is not None")), # guarded wildcard + ) + .with_entrypoint("ok") + .with_state(error=None) + .build() + ) + last_action, result, state = app.run(halt_after=["done", "handler"]) + assert last_action.name == "done" # wildcard did not hijack the success path + assert state["error"] is None + + +async def test_astep_sync_action_post_run_step_receives_correct_result_and_state(): + """Regression for the _astep restructure: running a SYNC action through + arun/_astep must report the real result and updated state to the async + post_run_step hook. The prior `return self._step(...)` delegation left + result/state stale on this path.""" + tracker = ActionTrackerAsync() + + @action(reads=[], writes=["count"]) + def sync_inc(state: State) -> Tuple[dict, State]: + return {"count": 1}, state.update(count=1) + + app = ( + ApplicationBuilder() + .with_actions(sync_inc=sync_inc) + .with_entrypoint("sync_inc") + .with_state(count=0) + .with_hooks(tracker) + .build() + ) + await app.astep() + assert len(tracker.post_called) == 1 + name, kwargs = tracker.post_called[0] + assert name == "sync_inc" + assert kwargs["result"] == {"count": 1} + assert kwargs["state"]["count"] == 1