Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion agentlightning/adapter/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Sequence, TypedDict, Union, cast

Expand All @@ -12,6 +13,8 @@

from .base import TraceAdapter

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from openai.types.chat import (
ChatCompletionFunctionToolParam,
Expand Down Expand Up @@ -103,6 +106,30 @@ def group_genai_dict(data: Dict[str, Any], prefix: str) -> Union[Dict[str, Any],
return result


def _infer_missing_role(msg: Dict[str, Any]) -> Optional[str]:
"""Infer an OpenAI chat role for a prompt entry whose ``role`` field is missing.

Some tracers (notably AgentOps when re-emitting prior turns inside ``gen_ai.prompt.N.*``)
serialize the nested ``tool_calls`` / ``tool_call_id`` subtree without the sibling ``role``
key. The role is still unambiguous in those cases:

- A message carrying ``tool_calls`` can only have come from the assistant.
- A message carrying ``tool_call_id`` can only have come from a tool response.

Args:
msg: A prompt entry parsed from ``gen_ai.prompt.*`` attributes.

Returns:
``"assistant"`` or ``"tool"`` when the role can be inferred unambiguously,
otherwise ``None``.
"""
if "tool_calls" in msg:
return "assistant"
if "tool_call_id" in msg:
return "tool"
return None


def convert_to_openai_messages(prompt_completion_list: List[_RawSpanInfo]) -> Generator[OpenAIMessages, None, None]:
"""Convert raw trace payloads into OpenAI-style chat messages.

Expand Down Expand Up @@ -131,7 +158,19 @@ def convert_to_openai_messages(prompt_completion_list: List[_RawSpanInfo]) -> Ge

# Extract messages
for msg in pc_entry["prompt"]:
role = msg["role"]
role = msg.get("role")
if role is None:
# Some tracers omit ``role`` on re-emitted assistant/tool turns; recover when the
# role is unambiguous, otherwise drop just this message instead of crashing the
# whole rollout's adapter pass.
role = _infer_missing_role(msg)
if role is None:
logger.warning(
"Skipping prompt message with no 'role' and no inferable role hint: %r",
msg,
)
continue
logger.debug("Inferred missing role %r for prompt message: %r", role, msg)

if role == "assistant" and "tool_calls" in msg:
# Use the tool_calls directly
Expand Down
112 changes: 112 additions & 0 deletions tests/adapter/test_messages_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import json
import logging
from importlib.metadata import version
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -391,3 +392,114 @@ def test_trace_messages_adapter_handles_multiple_tool_calls():
]

assert adapter.adapt(spans) == expected


# Regression coverage for https://github.com/microsoft/agent-lightning/issues/425 and
# https://github.com/microsoft/agent-lightning/issues/311. Some tracers (notably AgentOps when
# re-emitting prior turns inside ``gen_ai.prompt.N.*``) drop the ``role`` key on assistant
# tool-call entries and tool-response entries. The adapter must recover instead of crashing
# the entire rollout adapter pass.
@pytest.mark.skipif(
_skip_for_openai_lt_1_100_0,
reason="Requires openai>=1.100.0",
)
def test_trace_messages_adapter_recovers_assistant_role_from_tool_calls() -> None:
tool_call_id = "call_AnIgZ6EdncSTDeJMn3rcKSDM"
tool_name = "get_rmc_percentage_of_sales"
tool_arguments = json.dumps({"category": "Scooter"})
tool_payload = json.dumps({"value": 0.42})

spans = [
make_span(
"openai.chat.completion",
{
"gen_ai.prompt.0.role": "system",
"gen_ai.prompt.0.content": "You are a financial data analyst.",
"gen_ai.prompt.1.role": "user",
"gen_ai.prompt.1.content": "How does Scooter compare?",
# Re-emitted assistant tool-call entry: ``tool_calls.*`` present, ``role`` missing.
"gen_ai.prompt.2.tool_calls.0.id": tool_call_id,
"gen_ai.prompt.2.tool_calls.0.name": tool_name,
"gen_ai.prompt.2.tool_calls.0.arguments": tool_arguments,
# Re-emitted tool response: ``tool_call_id`` + ``content`` present, ``role`` missing.
"gen_ai.prompt.3.tool_call_id": tool_call_id,
"gen_ai.prompt.3.content": tool_payload,
"gen_ai.completion.0.role": "assistant",
"gen_ai.completion.0.content": "Scooter sales are up 4%.",
"gen_ai.completion.0.finish_reason": "stop",
},
0,
),
]

result = TraceToMessages().adapt(spans)

assert result == [
{
"messages": [
{"content": "You are a financial data analyst.", "role": "system"},
{"content": "How does Scooter compare?", "role": "user"},
{
"content": None,
"role": "assistant",
"tool_calls": [
{
"id": tool_call_id,
"type": "function",
"function": {"name": tool_name, "arguments": tool_arguments},
}
],
},
{
"content": tool_payload,
"role": "tool",
"tool_call_id": tool_call_id,
},
{"content": "Scooter sales are up 4%.", "role": "assistant"},
],
"tools": None,
}
]


@pytest.mark.skipif(
_skip_for_openai_lt_1_100_0,
reason="Requires openai>=1.100.0",
)
def test_trace_messages_adapter_skips_unidentifiable_prompt_entry(
caplog: pytest.LogCaptureFixture,
) -> None:
spans = [
make_span(
"openai.chat.completion",
{
"gen_ai.prompt.0.role": "system",
"gen_ai.prompt.0.content": "You are helpful.",
"gen_ai.prompt.1.role": "user",
"gen_ai.prompt.1.content": "Hi.",
# Garbage entry with neither ``role`` nor any role hint; must not crash.
"gen_ai.prompt.2.unexpected": "noise",
"gen_ai.completion.0.role": "assistant",
"gen_ai.completion.0.content": "Hello.",
"gen_ai.completion.0.finish_reason": "stop",
},
0,
),
]

with caplog.at_level(logging.WARNING, logger="agentlightning.adapter.messages"):
result = TraceToMessages().adapt(spans)

assert result == [
{
"messages": [
{"content": "You are helpful.", "role": "system"},
{"content": "Hi.", "role": "user"},
{"content": "Hello.", "role": "assistant"},
],
"tools": None,
}
]
assert any(
"no inferable role hint" in record.getMessage() for record in caplog.records
), "Expected a warning naming the malformed prompt entry"