diff --git a/sentry_sdk/integrations/anthropic.py b/sentry_sdk/integrations/anthropic.py index 0592e337de..1dcef160de 100644 --- a/sentry_sdk/integrations/anthropic.py +++ b/sentry_sdk/integrations/anthropic.py @@ -43,6 +43,7 @@ MessageStreamManager, MessageStream, AsyncMessageStreamManager, + AsyncMessageStream, ) from anthropic.types import ( @@ -60,7 +61,15 @@ raise DidNotEnable("Anthropic not installed") if TYPE_CHECKING: - from typing import Any, AsyncIterator, Iterator, Optional, Union, Callable + from typing import ( + Any, + AsyncIterator, + Iterator, + Optional, + Union, + Callable, + Awaitable, + ) from sentry_sdk.tracing import Span from sentry_sdk._types import TextPart @@ -71,7 +80,6 @@ TextBlockParam, ToolUnionParam, ) - from anthropic.lib.streaming import AsyncMessageStream class _RecordedUsage: @@ -95,6 +103,7 @@ def setup_once() -> None: """ client.messages.create(stream=True) can return an instance of the Stream class, which implements the iterator protocol. + Analogously, the function can return an AsyncStream, which implements the asynchronous iterator protocol. The private _iterator variable and the close() method are patched. During iteration over the _iterator generator, information from intercepted events is accumulated and used to populate output attributes on the AI Client Span. @@ -108,9 +117,11 @@ def setup_once() -> None: Stream.close = _wrap_close(Stream.close) AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create) + AsyncStream.close = _wrap_async_close(AsyncStream.close) """ client.messages.stream() can return an instance of the MessageStream class, which implements the iterator protocol. + Analogously, the function can return an AsyncMessageStream, which implements the asynchronous iterator protocol. The private _iterator variable and the close() method are patched. During iteration over the _iterator generator, information from intercepted events is accumulated and used to populate output attributes on the AI Client Span. @@ -137,6 +148,11 @@ def setup_once() -> None: ) ) + # Before https://github.com/anthropics/anthropic-sdk-python/commit/b1a1c0354a9aca450a7d512fdbdeb59c0ead688a + # AsyncMessageStream inherits from AsyncStream, so patching Stream is sufficient on these versions. + if not issubclass(AsyncMessageStream, AsyncStream): + AsyncMessageStream.close = _wrap_async_close(AsyncMessageStream.close) + def _capture_exception(exc: "Any") -> None: set_span_errored() @@ -475,20 +491,13 @@ def _wrap_synchronous_message_iterator( async def _wrap_asynchronous_message_iterator( + stream: "Union[Stream, MessageStream]", iterator: "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]", - span: "Span", - integration: "AnthropicIntegration", ) -> "AsyncIterator[Union[RawMessageStreamEvent, MessageStreamEvent]]": """ Sets information received while iterating the response stream on the AI Client Span. - Responsible for closing the AI Client Span. + Responsible for closing the AI Client Span, unless the span has already been closed in the close() patch. """ - model = None - usage = _RecordedUsage() - content_blocks: "list[str]" = [] - response_id = None - finish_reason = None - try: async for event in iterator: # Message and content types are aliases for corresponding Raw* types, introduced in @@ -507,44 +516,21 @@ async def _wrap_asynchronous_message_iterator( yield event continue - ( - model, - usage, - content_blocks, - response_id, - finish_reason, - ) = _collect_ai_data( - event, - model, - usage, - content_blocks, - response_id, - finish_reason, - ) + _accumulate_event_data(stream, event) yield event finally: with capture_internal_exceptions(): - # Anthropic's input_tokens excludes cached/cache_write tokens. - # Normalize to total input tokens for correct cost calculations. - total_input = ( - usage.input_tokens - + (usage.cache_read_input_tokens or 0) - + (usage.cache_write_input_tokens or 0) - ) - - _set_output_data( - span=span, - integration=integration, - model=model, - input_tokens=total_input, - output_tokens=usage.output_tokens, - cache_read_input_tokens=usage.cache_read_input_tokens, - cache_write_input_tokens=usage.cache_write_input_tokens, - content_blocks=[{"text": "".join(content_blocks), "type": "text"}], - finish_span=True, - response_id=response_id, - finish_reason=finish_reason, - ) + if hasattr(stream, "_span"): + _finish_streaming_span( + span=stream._span, + integration=stream._integration, + model=stream._model, + usage=stream._usage, + content_blocks=stream._content_blocks, + response_id=stream._response_id, + finish_reason=stream._finish_reason, + ) + del stream._span def _set_output_data( @@ -643,9 +629,15 @@ def _sentry_patched_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A return result if isinstance(result, AsyncStream): + result._span = span + result._integration = integration + + _initialize_data_accumulation_state(result) result._iterator = _wrap_asynchronous_message_iterator( - result._iterator, span, integration + result, + result._iterator, ) + return result with capture_internal_exceptions(): @@ -864,6 +856,38 @@ async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any": return _sentry_patched_create_async +def _wrap_async_close( + f: "Callable[..., Awaitable[None]]", +) -> "Callable[..., Awaitable[None]]": + """ + Closes the AI Client Span, unless the finally block in `_wrap_asynchronous_message_iterator()` runs first. + """ + + async def close(self: "AsyncStream") -> None: + if not hasattr(self, "_span"): + return await f(self) + + if not hasattr(self, "_model"): + self._span.__exit__(None, None, None) + del self._span + return await f(self) + + _finish_streaming_span( + span=self._span, + integration=self._integration, + model=self._model, + usage=self._usage, + content_blocks=self._content_blocks, + response_id=self._response_id, + finish_reason=self._finish_reason, + ) + del self._span + + return await f(self) + + return close + + def _wrap_message_stream(f: "Any") -> "Any": """ Attaches user-provided arguments to the returned context manager. @@ -1020,10 +1044,13 @@ async def _sentry_patched_aenter( tools=self._tools, ) + stream._span = span + stream._integration = integration + + _initialize_data_accumulation_state(stream) stream._iterator = _wrap_asynchronous_message_iterator( - iterator=stream._iterator, - span=span, - integration=integration, + stream, + stream._iterator, ) return stream diff --git a/tests/integrations/anthropic/test_anthropic.py b/tests/integrations/anthropic/test_anthropic.py index 2139d74a1b..e8555a52f1 100644 --- a/tests/integrations/anthropic/test_anthropic.py +++ b/tests/integrations/anthropic/test_anthropic.py @@ -779,6 +779,110 @@ async def test_streaming_create_message_async( assert span["data"][SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS] == ["max_tokens"] +@pytest.mark.asyncio +async def test_streaming_create_message_async_close( + sentry_init, + capture_events, + get_model_response, + async_iterator, + server_side_event_chunks, +): + client = AsyncAnthropic(api_key="z") + + response = get_model_response( + async_iterator( + server_side_event_chunks( + [ + MessageStartEvent( + message=EXAMPLE_MESSAGE, + type="message_start", + ), + ContentBlockStartEvent( + type="content_block_start", + index=0, + content_block=TextBlock(type="text", text=""), + ), + ContentBlockDeltaEvent( + delta=TextDelta(text="Hi", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=TextDelta(text="!", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=TextDelta(text=" I'm Claude!", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockStopEvent(type="content_block_stop", index=0), + MessageDeltaEvent( + delta=Delta(stop_reason="max_tokens"), + usage=MessageDeltaUsage(output_tokens=10), + type="message_delta", + ), + ] + ) + ) + ) + + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + messages = [ + { + "role": "user", + "content": "Hello, Claude", + } + ] + + with mock.patch.object( + client._client, + "send", + return_value=response, + ) as _: + with start_transaction(name="anthropic"): + messages = await client.messages.create( + max_tokens=1024, messages=messages, model="model", stream=True + ) + + for _ in range(4): + await messages.__anext__() + await messages.close() + + assert len(events) == 1 + (event,) = events + + assert event["type"] == "transaction" + assert event["transaction"] == "anthropic" + + span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT) + + assert span["op"] == OP.GEN_AI_CHAT + assert span["description"] == "chat model" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model" + + assert ( + span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + == '[{"role": "user", "content": "Hello, Claude"}]' + ) + assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!" + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30 + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True + assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL" + + @pytest.mark.asyncio @pytest.mark.parametrize( "send_default_pii, include_prompts", @@ -900,6 +1004,118 @@ async def test_stream_message_async( assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL" +@pytest.mark.asyncio +async def test_stream_messages_async_close( + sentry_init, + capture_events, + get_model_response, + async_iterator, + server_side_event_chunks, +): + client = AsyncAnthropic(api_key="z") + + response = get_model_response( + async_iterator( + server_side_event_chunks( + [ + MessageStartEvent( + message=EXAMPLE_MESSAGE, + type="message_start", + ), + ContentBlockStartEvent( + type="content_block_start", + index=0, + content_block=TextBlock(type="text", text=""), + ), + ContentBlockDeltaEvent( + delta=TextDelta(text="Hi", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=TextDelta(text="!", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockDeltaEvent( + delta=TextDelta(text=" I'm Claude!", type="text_delta"), + index=0, + type="content_block_delta", + ), + ContentBlockStopEvent(type="content_block_stop", index=0), + MessageDeltaEvent( + delta=Delta(stop_reason="max_tokens"), + usage=MessageDeltaUsage(output_tokens=10), + type="message_delta", + ), + ] + ) + ) + ) + + sentry_init( + integrations=[AnthropicIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + messages = [ + { + "role": "user", + "content": "Hello, Claude", + } + ] + + with mock.patch.object( + client._client, + "send", + return_value=response, + ) as _: + with start_transaction(name="anthropic"): + async with client.messages.stream( + max_tokens=1024, + messages=messages, + model="model", + ) as stream: + for _ in range(4): + await stream.__anext__() + + # New versions add TextEvent, so consume one more event. + if TextEvent is not None and isinstance( + await stream.__anext__(), TextEvent + ): + await stream.__anext__() + + await stream.close() + + assert len(events) == 1 + (event,) = events + + assert event["type"] == "transaction" + assert event["transaction"] == "anthropic" + + span = next(span for span in event["spans"] if span["op"] == OP.GEN_AI_CHAT) + + assert span["op"] == OP.GEN_AI_CHAT + assert span["description"] == "chat model" + assert span["data"][SPANDATA.GEN_AI_SYSTEM] == "anthropic" + assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "chat" + assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "model" + + assert ( + span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + == '[{"role": "user", "content": "Hello, Claude"}]' + ) + assert span["data"][SPANDATA.GEN_AI_RESPONSE_TEXT] == "Hi!" + + assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10 + assert span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 20 + assert span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 30 + assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True + assert span["data"][SPANDATA.GEN_AI_RESPONSE_ID] == "msg_01XFDUDYJgAACzvnptvVoYEL" + + @pytest.mark.skipif( ANTHROPIC_VERSION < (0, 27), reason="Versions <0.27.0 do not include InputJSONDelta, which was introduced in >=0.27.0 along with a new message delta type for tool calling.",