Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from functools import cached_property
from itertools import takewhile
import json
import math
import random
from typing import Any, Generator, cast
Expand Down Expand Up @@ -31,6 +32,40 @@ def _normalize_tools_for_chat_template(tools: Any) -> list[ChatTemplateTool] | N
return normalized_tools


def _normalize_tool_call_arguments_for_chat_template(
tokenizer: PreTrainedTokenizerBase,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
chat_template = tokenizer.chat_template
assert isinstance(chat_template, str)
if "tool_call.arguments|items" not in chat_template:
return messages

normalized_messages: list[dict[str, Any]] = []
for message in messages:
tool_calls = message.get("tool_calls")
if tool_calls is None:
normalized_messages.append(message)
continue

assert isinstance(tool_calls, list)
normalized_tool_calls = []
for tool_call in tool_calls:
assert isinstance(tool_call, dict)
function = tool_call["function"]
assert isinstance(function, dict)
arguments_json = function["arguments"]
assert isinstance(arguments_json, str)
arguments = json.loads(arguments_json)
assert isinstance(arguments, dict)
normalized_tool_calls.append(
{**tool_call, "function": {**function, "arguments": arguments}}
)
normalized_messages.append({**message, "tool_calls": normalized_tool_calls})

return normalized_messages


@dataclass
class TokenizedResult:
advantage: float
Expand Down Expand Up @@ -223,20 +258,23 @@ def tokenize_trajectory(
if last_assistant_index == -1:
return None
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
messages = get_messages(messages_and_choices)
messages = cast(list[dict[str, Any]], get_messages(messages_and_choices))
# Qwen3.5's chat template uses `tool_call.arguments|items`, so it needs a
# mapping here instead of the OpenAI JSON string.
messages = _normalize_tool_call_arguments_for_chat_template(tokenizer, messages)
tools = _normalize_tools_for_chat_template(history.tools)
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
messages,
tools=tools,
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = _apply_chat_template_token_ids(
tokenizer,
cast(list[dict[str, Any]], messages),
messages,
tools=tools,
continue_final_message=True,
)
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/test_preprocessing_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import types
from typing import cast

from openai.types.chat.chat_completion import Choice
import pytest
from transformers.tokenization_utils_base import BatchEncoding

Expand All @@ -15,6 +16,7 @@


class _FakeTokenizer:
chat_template = ""
vocab_size = 256
eos_token = "\x00"
eos_token_id = 0
Expand Down Expand Up @@ -60,6 +62,38 @@ def convert_tokens_to_ids(self, tokens):
return self.eos_token_id


class _Qwen3_5FakeTokenizer(_FakeTokenizer):
chat_template = (
"{% for args_name, args_value in tool_call.arguments|items %}{% endfor %}"
)

def apply_chat_template(
self,
messages,
tools=None,
tokenize=True,
return_dict=None,
**kwargs,
):
del kwargs
for message in messages:
tool_calls = message.get("tool_calls")
if tool_calls is None:
continue
assert isinstance(tool_calls, list)
for tool_call in tool_calls:
assert isinstance(tool_call, dict)
function = tool_call["function"]
assert isinstance(function, dict)
assert isinstance(function["arguments"], dict)
return super().apply_chat_template(
messages,
tools=tools,
tokenize=tokenize,
return_dict=return_dict,
)


def test_tokenize_trajectory_accepts_batchencoding_chat_template_output() -> None:
tokenizer = _FakeTokenizer()
messages = cast(
Expand Down Expand Up @@ -143,3 +177,64 @@ def _labels_fn(batch):
[1] * len(expected_ids)
]
assert batch.num_trainable_tokens == len(expected_ids)


def test_tokenize_trajectory_normalizes_mapping_tool_arguments_for_chat_template() -> (
None
):
tokenizer = _Qwen3_5FakeTokenizer()
choice = Choice.model_validate(
{
"finish_reason": "stop",
"index": 0,
"logprobs": {
"content": [
{
"token": "token_id:65",
"bytes": [65],
"logprob": -0.1,
"top_logprobs": [],
}
],
"refusal": None,
},
"message": {
"content": "",
"refusal": None,
"role": "assistant",
"annotations": None,
"audio": None,
"function_call": None,
"tool_calls": [
{
"id": "call_1",
"function": {
"arguments": '{"city": "San Francisco", "days": 3}',
"name": "lookup_weather",
},
"type": "function",
}
],
},
}
)
messages = cast(
MessagesAndChoices,
[
{"role": "user", "content": "Weather?"},
choice,
],
)
history = History(messages_and_choices=messages)
trajectory = Trajectory(messages_and_choices=messages, reward=1.0)

result = tokenize_trajectory(
tokenizer=tokenizer, # type: ignore[arg-type]
image_processor=None,
history=history,
advantage=1.0,
allow_training_without_logprobs=False,
trajectory=trajectory,
)

assert result is not None
Loading