Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion burr/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
# specific language governing permissions and limitations
# under the License.

from burr.core.action import Action, Condition, Result, action, default, expr, type_eraser, when
from burr.core.action import (
Action,
Condition,
Result,
action,
capture_as,
default,
expr,
type_eraser,
when,
)
from burr.core.application import (
Application,
ApplicationBuilder,
Expand All @@ -28,6 +38,7 @@
__all__ = [
"action",
"Action",
"capture_as",
"Application",
"ApplicationBuilder",
"ApplicationGraph",
Expand Down
106 changes: 103 additions & 3 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import inspect
import sys
import textwrap
import traceback
import types
import typing
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -379,6 +380,82 @@ def tags(self) -> list[str]:
"""
return []

@property
def on_error(self) -> Optional[Callable[[State, Exception], State]]:
"""Returns the error handler associated with this action, if any.

An error handler is any callable ``(State, Exception) -> State``. If it is set
and the action raises, the application will call it to produce a new state
(suppressing the exception) instead of re-raising. See
:py:func:`capture_as <burr.core.action.capture_as>` for a built-in handler.

:return: The error handler callable, or None if not set
"""
return getattr(self, "_on_error", None)


class capture_as:
"""Error handler that suppresses an exception and records a JSON-serializable
summary of it into the given state field.

Use this with the ``on_error`` parameter of the :py:func:`@action <burr.core.action.action>`
decorator, or with :py:meth:`ApplicationBuilder.with_error_handling
<burr.core.application.ApplicationBuilder.with_error_handling>`:

.. code-block:: python

@action(reads=[], writes=["output"], on_error=capture_as("error"))
def flaky(state: State) -> tuple[dict, State]:
result = {"output": call_some_api(...)}
return result, state.update(**result)

@action(reads=["error"], writes=[])
def handler(state: State) -> tuple[dict, State]:
... # inspect/reset state["error"], then route onward

app = (
ApplicationBuilder()
.with_actions(flaky=flaky, handler=handler)
# the capture field MUST be initialized -- expr() raises on a missing key,
# so seed ``error`` to None so the wildcard condition is safe to evaluate
# on every step before any exception has occurred.
.with_state(error=None)
.with_transitions(("*", "handler", expr("error is not None")))
.with_entrypoint("flaky")
.build()
)

Note the captured field need NOT appear in the action's ``writes`` -- the handler
writes it directly, bypassing reducer write-validation. When the action raises,
``state["error"]`` is set to a JSON-serializable dict like::

{"type": "ValueError", "message": "...", "traceback": "..."}

which the wildcard transition above then routes on.
"""

def __init__(self, field: str, include_traceback: bool = True):
""":param field: State field to write the error record to.
:param include_traceback: Whether to include the formatted traceback string.
"""
self.field = field
self.include_traceback = include_traceback

def __call__(self, state: State, exception: Exception) -> State:
record = {"type": type(exception).__name__, "message": str(exception)}
if self.include_traceback:
# Formatting must never fail -- a handler that raises would mask the
# original exception in the calling context. Fall back to a safe
# placeholder if traceback formatting blows up (e.g. corrupted
# traceback objects or pathological custom exceptions).
try:
record["traceback"] = "".join(
traceback.format_exception(type(exception), exception, exception.__traceback__)
)
except Exception as e: # noqa: BLE001
record["traceback"] = f"<Traceback formatting failed: {e}>"
return state.update(**{self.field: record})


class Condition(Function):
KEY = "PROCEED"
Expand Down Expand Up @@ -806,6 +883,7 @@ def __init__(
originating_fn: Optional[Callable] = None,
schema: ActionSchema = DEFAULT_SCHEMA,
tags: Optional[List[str]] = None,
on_error: Optional[Callable[[State, Exception], State]] = None,
):
"""Instantiates a function-based action with the given function, reads, and writes.
The function must take in a state and return a tuple of (result, new_state).
Expand All @@ -815,6 +893,8 @@ def __init__(
:param writes: Keys that the function writes to the state
:param bound_params: Prior bound parameters
:param input_spec: Specification for inputs. Will derive from function if not provided.
:param on_error: Optional error handler ``(State, Exception) -> State``. If set and the
action raises, the exception is suppressed and the returned state is used instead.
"""
super(FunctionBasedAction, self).__init__()
self._originating_fn = originating_fn if originating_fn is not None else fn
Expand All @@ -834,6 +914,7 @@ def __init__(
)
self._schema = schema
self._tags = tags if tags is not None else []
self._on_error = on_error

@property
def fn(self) -> Callable:
Expand Down Expand Up @@ -877,6 +958,7 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
originating_fn=self._originating_fn,
schema=self._schema,
tags=self._tags,
on_error=self._on_error,
)

def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
Expand Down Expand Up @@ -1473,7 +1555,13 @@ def pydantic(
tags=tags,
)

def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str]] = None):
def __init__(
self,
reads: List[str],
writes: List[str],
tags: Optional[List[str]] = None,
on_error: Optional[Callable[[State, Exception], State]] = None,
):
"""Decorator to create a function-based action. This is user-facing.
Note that, in the future, with typed state, we may not need this for
all cases.
Expand All @@ -1484,17 +1572,24 @@ def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str]

:param reads: Items to read from the state
:param writes: Items to write to the state
:param tags: Optional list of tags to associate with this action
:param on_error: Optional error handler ``(State, Exception) -> State``. If set and the
action raises, the exception is suppressed and the returned state is used instead.
See :py:func:`capture_as <burr.core.action.capture_as>` for a built-in handler.
:return: The decorator to assign the function as an action
"""
self.reads = reads
self.writes = writes
self.tags = tags
self.on_error = on_error

def __call__(self, fn) -> FunctionRepresentingAction:
setattr(
fn,
FunctionBasedAction.ACTION_FUNCTION,
FunctionBasedAction(fn, self.reads, self.writes, tags=self.tags),
FunctionBasedAction(
fn, self.reads, self.writes, tags=self.tags, on_error=self.on_error
),
)
setattr(fn, "bind", types.MethodType(bind, fn))
return fn
Expand Down Expand Up @@ -1537,7 +1632,12 @@ def pydantic(
tags=tags,
)

def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str]] = None):
def __init__(
self,
reads: List[str],
writes: List[str],
tags: Optional[List[str]] = None,
):
"""Decorator to create a streaming function-based action. This is user-facing.

If parameters are not bound, they will be interpreted as inputs and must be passed in at runtime.
Expand Down
111 changes: 92 additions & 19 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ def __init__(
parallel_executor_factory: Optional[Executor] = None,
state_persister: Union[BaseStateSaver, LifecycleAdapter, None] = None,
state_initializer: Union[BaseStateLoader, LifecycleAdapter, None] = None,
global_error_handler: Optional[Callable[[State, Exception], State]] = None,
):
"""Instantiates an Application. This is an internal API -- use the builder!

Expand Down Expand Up @@ -900,6 +901,7 @@ def __init__(
self._spawning_parent_pointer = spawning_parent_pointer
self._state_initializer = state_initializer
self._state_persister = state_persister
self._global_error_handler = global_error_handler
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_application_create",
state=self._state,
Expand Down Expand Up @@ -985,9 +987,29 @@ def _step(
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
except Exception as e:
exc = e
logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs))
raise e
# Resolve the effective error handler: per-action takes precedence
# over the builder-level global handler.
handler = next_action.on_error or self._global_error_handler
if handler is not None:
# The handler writes the captured record via state.update, bypassing
# the normal reducer -- so it is NOT subject to writes-validation.
# The handler itself must never mask the original exception: if it
# raises, surface the original exception with the handler failure as
# its cause rather than silently swapping which error propagates.
try:
new_state = handler(self._state, e)
except Exception as handler_exc:
logger.exception(f"Error handler for {next_action.name} failed")
raise e from handler_exc
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
result = None
exc = e # keep for post_run_step hook observability
# DO NOT re-raise; flow continues. get_next_node routes via the captured state field.
else:
exc = e
logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs))
raise e
finally:
if _run_hooks:
self._adapter_set.call_all_lifecycle_hooks_sync(
Expand Down Expand Up @@ -1094,12 +1116,19 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
return None
if inputs is None:
inputs = {}
# Process inputs before the pre_run_step hook so that hooks in the async path
# observe the same processed inputs (dependency factory injection + validation
# applied) as the sync path does. We process once here and run the action inline
# (rather than delegating to _step) so the async lifecycle dispatcher still fires
# and the dependency factory -- which constructs fresh tracer/context objects per
# call -- is not invoked twice for a single step.
action_inputs = self._process_inputs(inputs, next_action)
if _run_hooks:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"pre_run_step",
action=next_action,
state=self._state,
inputs=inputs,
inputs=action_inputs,
sequence_id=self.sequence_id,
app_id=self._uid,
partition_key=self._partition_key,
Expand All @@ -1109,18 +1138,20 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
new_state = self._state
try:
if not next_action.is_async():
# we can just delegate to the synchronous version, it will block the event loop,
# but that's safer than assuming its OK to launch a thread
# TODO -- add an option/configuration to launch a thread (yikes, not super safe, but for a pure function
# which this is supposed to be its OK).
# this delegates hooks to the synchronous version, so we'll call all of them as well
# In this case we allow the self._step to do input processing
return self._step(
inputs=inputs, _run_hooks=False
) # Skip hooks as we already ran all of them/will run all of them in this function's finally
# In this case we want to process inputs because we run the function directly
action_inputs = self._process_inputs(inputs, next_action)
if next_action.single_step:
# The action is synchronous -- running it here will block the event loop,
# but that's safer than assuming it's OK to launch a thread for what is
# supposed to be a pure function. We still drive it through this async
# method so the async lifecycle hooks fire and inputs are processed once.
if next_action.single_step:
result, new_state = _run_single_step_action(
next_action, self._state, action_inputs
)
else:
result = _run_function(
next_action, self._state, action_inputs, name=next_action.name
)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
elif next_action.single_step:
result, new_state = await _arun_single_step_action(
next_action, self._state, inputs=action_inputs
)
Expand All @@ -1135,9 +1166,29 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
except Exception as e:
exc = e
logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs))
raise e
# Resolve the effective error handler: per-action takes precedence
# over the builder-level global handler.
handler = next_action.on_error or self._global_error_handler
if handler is not None:
# The handler writes the captured record via state.update, bypassing
# the normal reducer -- so it is NOT subject to writes-validation.
# The handler itself must never mask the original exception: if it
# raises, surface the original exception with the handler failure as
# its cause rather than silently swapping which error propagates.
try:
new_state = handler(self._state, e)
except Exception as handler_exc:
logger.exception(f"Error handler for {next_action.name} failed")
raise e from handler_exc
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
result = None
exc = e # keep for post_run_step hook observability
# DO NOT re-raise; flow continues. get_next_node routes via the captured state field.
else:
exc = e
logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs))
raise e
finally:
if _run_hooks:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
Expand Down Expand Up @@ -2222,6 +2273,7 @@ def __init__(self):
self.parallel_executor_factory = None
self.state_persister = None
self._is_async: bool = False
self._error_handler: Optional[Callable[[State, Exception], State]] = None

def with_identifiers(
self, app_id: str = None, partition_key: str = None, sequence_id: int = None
Expand Down Expand Up @@ -2419,6 +2471,26 @@ def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder[StateTy
self.lifecycle_adapters.extend(adapters)
return self

def with_error_handling(
self, handler: Callable[[State, Exception], State]
) -> "ApplicationBuilder[StateType]":
"""Sets a global error handler for the application. This is any callable
``(State, Exception) -> State``. If an action raises and does not have its own
per-action ``on_error`` handler, this global handler is invoked: the exception
is suppressed and the returned state is used instead. Per-action ``on_error``
takes precedence over this global handler.

See :py:func:`capture_as <burr.core.action.capture_as>` for a built-in handler
that records a JSON-serializable summary of the exception into a state field,
which you can then route on via a wildcard transition, e.g.
``("*", "handler", expr("error is not None"))``.

:param handler: Error handler callable ``(State, Exception) -> State``
:return: The application builder for future chaining.
"""
self._error_handler = handler
return self

def with_tracker(
self,
tracker: Union[Literal["local"], "TrackingClient"] = "local",
Expand Down Expand Up @@ -2782,6 +2854,7 @@ def _build_common(self) -> Application:
parallel_executor_factory=self.parallel_executor_factory,
state_persister=self.state_persister,
state_initializer=self.state_initializer,
global_error_handler=self._error_handler,
)

def build(self) -> Application[StateType]:
Expand Down
Loading
Loading