Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 77 additions & 50 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MessageStreamManager,
MessageStream,
AsyncMessageStreamManager,
AsyncMessageStream,
)

from anthropic.types import (
Expand All @@ -60,7 +61,15 @@
raise DidNotEnable("Anthropic not installed")

if TYPE_CHECKING:
from typing import Any, AsyncIterator, Iterator, Optional, Union, Callable
from typing import (
Any,
AsyncIterator,
Iterator,
Optional,
Union,
Callable,
Awaitable,
)
from sentry_sdk.tracing import Span
from sentry_sdk._types import TextPart

Expand All @@ -71,7 +80,6 @@
TextBlockParam,
ToolUnionParam,
)
from anthropic.lib.streaming import AsyncMessageStream


class _RecordedUsage:
Expand All @@ -95,6 +103,7 @@ def setup_once() -> None:

"""
client.messages.create(stream=True) can return an instance of the Stream class, which implements the iterator protocol.
Analogously, the function can return an AsyncStream, which implements the asynchronous iterator protocol.
The private _iterator variable and the close() method are patched. During iteration over the _iterator generator,
information from intercepted events is accumulated and used to populate output attributes on the AI Client Span.

Expand All @@ -108,9 +117,11 @@ def setup_once() -> None:
Stream.close = _wrap_close(Stream.close)

AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
AsyncStream.close = _wrap_async_close(AsyncStream.close)

"""
client.messages.stream() can return an instance of the MessageStream class, which implements the iterator protocol.
Analogously, the function can return an AsyncMessageStream, which implements the asynchronous iterator protocol.
The private _iterator variable and the close() method are patched. During iteration over the _iterator generator,
information from intercepted events is accumulated and used to populate output attributes on the AI Client Span.

Expand All @@ -137,6 +148,11 @@ def setup_once() -> None:
)
)

# Before https://github.com/anthropics/anthropic-sdk-python/commit/b1a1c0354a9aca450a7d512fdbdeb59c0ead688a
# AsyncMessageStream inherits from AsyncStream, so patching Stream is sufficient on these versions.
if not issubclass(AsyncMessageStream, AsyncStream):
AsyncMessageStream.close = _wrap_async_close(AsyncMessageStream.close)


def _capture_exception(exc: "Any") -> None:
set_span_errored()
Expand Down Expand Up @@ -475,20 +491,13 @@ def _wrap_synchronous_message_iterator(


async def _wrap_asynchronous_message_iterator(
stream: "Union[Stream, MessageStream]",
iterator: "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]",
span: "Span",
integration: "AnthropicIntegration",
) -> "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]":
"""
Sets information received while iterating the response stream on the AI Client Span.
Responsible for closing the AI Client Span.
Responsible for closing the AI Client Span, unless the span has already been closed in the close() patch.
"""
model = None
usage = _RecordedUsage()
content_blocks: "list[str]" = []
response_id = None
finish_reason = None

try:
async for event in iterator:
# Message and content types are aliases for corresponding Raw* types, introduced in
Expand All @@ -507,44 +516,21 @@ async def _wrap_asynchronous_message_iterator(
yield event
continue

(
model,
usage,
content_blocks,
response_id,
finish_reason,
) = _collect_ai_data(
event,
model,
usage,
content_blocks,
response_id,
finish_reason,
)
_accumulate_event_data(stream, event)
yield event
finally:
with capture_internal_exceptions():
# Anthropic's input_tokens excludes cached/cache_write tokens.
# Normalize to total input tokens for correct cost calculations.
total_input = (
usage.input_tokens
+ (usage.cache_read_input_tokens or 0)
+ (usage.cache_write_input_tokens or 0)
)

_set_output_data(
span=span,
integration=integration,
model=model,
input_tokens=total_input,
output_tokens=usage.output_tokens,
cache_read_input_tokens=usage.cache_read_input_tokens,
cache_write_input_tokens=usage.cache_write_input_tokens,
content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
finish_span=True,
response_id=response_id,
finish_reason=finish_reason,
)
if hasattr(stream, "_span"):
_finish_streaming_span(
span=stream._span,
integration=stream._integration,
model=stream._model,
usage=stream._usage,
content_blocks=stream._content_blocks,
response_id=stream._response_id,
finish_reason=stream._finish_reason,
)
del stream._span


def _set_output_data(
Expand Down Expand Up @@ -643,9 +629,15 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
return result

if isinstance(result, AsyncStream):
result._span = span
result._integration = integration

_initialize_data_accumulation_state(result)
result._iterator = _wrap_asynchronous_message_iterator(
result._iterator, span, integration
result,
result._iterator,
)

return result

with capture_internal_exceptions():
Expand Down Expand Up @@ -864,6 +856,38 @@ async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
return _sentry_patched_create_async


def _wrap_async_close(
f: "Callable[..., Awaitable[None]]",
) -> "Callable[..., Awaitable[None]]":
"""
Closes the AI Client Span, unless the finally block in `_wrap_asynchronous_message_iterator()` runs first.
"""

async def close(self: "AsyncStream") -> None:
if not hasattr(self, "_span"):
return await f(self)

if not hasattr(self, "_model"):
self._span.__exit__(None, None, None)
del self._span
return await f(self)

_finish_streaming_span(
span=self._span,
integration=self._integration,
model=self._model,
usage=self._usage,
content_blocks=self._content_blocks,
response_id=self._response_id,
finish_reason=self._finish_reason,
)
del self._span

return await f(self)

return close


def _wrap_message_stream(f: "Any") -> "Any":
"""
Attaches user-provided arguments to the returned context manager.
Expand Down Expand Up @@ -1020,10 +1044,13 @@ async def _sentry_patched_aenter(
tools=self._tools,
)

stream._span = span
stream._integration = integration

_initialize_data_accumulation_state(stream)
stream._iterator = _wrap_asynchronous_message_iterator(
iterator=stream._iterator,
span=span,
integration=integration,
stream,
stream._iterator,
)

return stream
Expand Down
Loading
Loading