Skip to content
Open
5 changes: 4 additions & 1 deletion src/maxtext/input_pipeline/data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from maxtext.input_pipeline import tokenizer


TOOLS_COLUMN = "tools"


def parse_and_keep_features(dataset, config, data_columns, tokenize):
"""Parse arrayrecord features or keep specified columns for other formats."""
if config.grain_file_type in ("arrayrecord", "tfrecord"):
Expand Down Expand Up @@ -57,7 +60,7 @@ def validate_and_configure_sft_columns(data_columns, tokenizer_model, chat_templ
if chat_template and hasattr(tokenizer_model, "chat_template"):
tokenizer_model.chat_template = chat_template

supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]]
supported_columns = [["prompt", "completion"], ["messages"], ["messages", TOOLS_COLUMN], ["question", "answer"]]
assert any(
set(data_columns) == set(supported) for supported in supported_columns
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}"
Expand Down
25 changes: 17 additions & 8 deletions src/maxtext/input_pipeline/grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,26 +268,33 @@ def dpo_preprocessing_pipeline(

def _format_chat_template_grain(element, data_columns, tokenizer_model):
"""Grain-compatible mapping function to format raw columns into conversational messages."""
tools_column_name = data_processing_utils.TOOLS_COLUMN if data_processing_utils.TOOLS_COLUMN in data_columns else None
primary_columns = [c for c in data_columns if c != data_processing_utils.TOOLS_COLUMN]

# Convert raw columns to conversational messages
if "messages" in data_columns:
if "messages" in primary_columns:
messages = element["messages"]
elif set(data_columns) == {"prompt", "completion"}:
if isinstance(messages, (str, bytes)):
messages = json.loads(messages)
elif set(primary_columns) == {"prompt", "completion"}:
messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}]
elif set(data_columns) == {"question", "answer"}:
elif set(primary_columns) == {"question", "answer"}:
messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}]
else:
# Fallback if it's already a single string
messages = element[data_columns[0]]
messages = element[primary_columns[0]]

assert all(
hasattr(m, "__contains__") and "role" in m and "content" in m for m in messages
), f"SFT requires a conversational format. Expected dicts with 'role' and 'content', but got: {messages}"

# Assign the standardized messages back to the primary column
element[data_columns[0]] = messages
element[primary_columns[0]] = messages

return input_pipeline_utils.apply_chat_template(
element, tokenizer_model=tokenizer_model, data_column_name=data_columns[0]
element, tokenizer_model=tokenizer_model,
data_column_name=primary_columns[0],
tools_column_name=tools_column_name,
)


Expand Down Expand Up @@ -318,6 +325,8 @@ def sft_preprocessing_pipeline(
data_columns, tokenizer_model, getattr(config, "chat_template", None)
)

primary_columns = [c for c in data_columns if c != data_processing_utils.TOOLS_COLUMN]

dataset = dataset.map(
functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model)
)
Expand All @@ -326,14 +335,14 @@ def sft_preprocessing_pipeline(
dataset = dataset.map(
functools.partial(
_tokenize_sft_chunks,
text_column_name=data_columns[0],
text_column_name=primary_columns[0],
tokenizer_model=tokenizer_model,
)
)

dataset = dataset.map(
input_pipeline_utils.SFTPromptMasking(
text_column_name=data_columns[0],
text_column_name=primary_columns[0],
completion_only=config.sft_train_on_completion_only,
max_target_length=config.max_target_length,
unk_id=pad_id,
Expand Down
30 changes: 25 additions & 5 deletions src/maxtext/input_pipeline/hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Input pipeline using Huggingface datasets."""

import json
from typing import Optional

import ml_collections
Expand Down Expand Up @@ -251,14 +252,27 @@ def preprocessing_pipeline(
if use_sft:
data_processing_utils.validate_and_configure_sft_columns(data_column_names, tokenizer, chat_template)

# Separate auxiliary "tools" column from primary data columns
tools_column_name = data_processing_utils.TOOLS_COLUMN if data_processing_utils.TOOLS_COLUMN in data_column_names else None
if tools_column_name:
data_column_names = [c for c in data_column_names if c != tools_column_name]

# convert instruction dataset to conversational format
# currently only works for Q&A datasets
dataset, data_column_names = instruction_data_processing.convert_to_conversational_format(
dataset=dataset, data_columns=data_column_names, chat_template_path=chat_template_path
)
assert input_pipeline_utils.is_conversational(
dataset.features, data_column_names
), "Dataset is not in conversational format."
# Deserialize JSON string columns if needed
_json_deserialized = False
for col in data_column_names:
if isinstance(dataset.features.get(col), datasets.Value) and dataset.features[col].dtype == "string":
dataset = dataset.map(lambda x, c=col: {c: json.loads(x[c]) if isinstance(x[c], str) else x[c]})
_json_deserialized = True

if not _json_deserialized:
assert input_pipeline_utils.is_conversational(
dataset.features, data_column_names
), "Dataset is not in conversational format."

if len(data_column_names) > 1:
combined_column_name = "messages"
Expand All @@ -271,12 +285,18 @@ def preprocessing_pipeline(
remove_columns=data_column_names,
features=dataset_features,
)
data_column_names = [combined_column_name]

data_column_names = list(dataset.features.keys())
dataset = dataset.map(
input_pipeline_utils.apply_chat_template,
fn_kwargs={"tokenizer_model": tokenizer, "data_column_name": data_column_names[0]},
fn_kwargs={
"tokenizer_model": tokenizer,
"data_column_name": data_column_names[0],
"tools_column_name": tools_column_name,
},
)
if tools_column_name:
dataset = dataset.remove_columns([tools_column_name])

pad_id = _get_pad_id(tokenizer)

Expand Down
42 changes: 29 additions & 13 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Operations used by Grain"""

import dataclasses
import json
import warnings
from threading import current_thread
from typing import Any, Iterable, TYPE_CHECKING
Expand Down Expand Up @@ -240,17 +241,17 @@ def verify_chat_template_generation_prompt_logic(tokenizer_model):
dummy_msgs = [{"role": "system", "content": "System message"}, {"role": "user", "content": "Test message"}]

try:
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True, enable_thinking=True)
except TemplateError:
max_logging.info(
"Tokenizer failed to apply chat template with 'system' role. "
"Falling back to 'user' role only for chat template verification."
)
dummy_msgs.pop(0)
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True, enable_thinking=True)
prompt_wo_gen_ids = _extract_token_ids(prompt_wo_gen_tokens)

prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True)
prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True, enable_thinking=True)
prompt_w_gen_ids = _extract_token_ids(prompt_w_gen_tokens)

if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids:
Expand All @@ -259,7 +260,7 @@ def verify_chat_template_generation_prompt_logic(tokenizer_model):
assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :]
full_turn_tokens = _extract_token_ids(
tokenizer_model.apply_chat_template(
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True, enable_thinking=True
)
)
full_turn_ids = _extract_token_ids(full_turn_tokens)
Expand All @@ -277,7 +278,7 @@ def verify_chat_template_generation_prompt_logic(tokenizer_model):
)


def _get_completion_in_chat_template(tokenizer_model, round_msgs):
def _get_completion_in_chat_template(tokenizer_model, round_msgs, tools=None):
"""
Calculates the completion part of a conversation turn when formatted with a chat template.

Expand All @@ -291,9 +292,10 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
Returns:
A string representing the completion formatted by the chat template.
"""
prompt_completion_tokens = tokenizer_model.apply_chat_template(round_msgs, add_generation_prompt=False, tokenize=True)
tools_kwargs = {"tools": tools} if tools is not None else {}
prompt_completion_tokens = tokenizer_model.apply_chat_template(round_msgs, add_generation_prompt=False, tokenize=True, enable_thinking=True, **tools_kwargs)
# include generation_prompt as part of the prompt tokens
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True, enable_thinking=True, **tools_kwargs)

prompt_completion_ids = _extract_token_ids(prompt_completion_tokens)
prompt_ids = _extract_token_ids(prompt_tokens)
Expand All @@ -303,7 +305,7 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
return completion_in_chat_template


def apply_chat_template(example, tokenizer_model, data_column_name):
def apply_chat_template(example, tokenizer_model, data_column_name, tools_column_name=None):
"""Formats conversational data by applying the tokenizer's chat template
and identifying prompt/completion segments for SFT masking.

Expand All @@ -327,25 +329,39 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
messages = []
is_prompt = []
round_msgs = []
conversation = example[data_column_name]
if isinstance(conversation, str):
conversation = json.loads(conversation)
tools = example.get(tools_column_name) if tools_column_name else None
if isinstance(tools, str):
tools = json.loads(tools)
tools_kwargs = {"tools": tools} if tools is not None else {}
try:
for idx, message in enumerate(example[data_column_name]):
for idx, message in enumerate(conversation):
if message["role"] == "system":
if idx != 0:
raise ValueError(f"System message found at index {idx}. System messages must be at index 0.")
round_msgs.append(message)
elif message["role"] == "user":
round_msgs.append(message)
prompt_in_chat_template = tokenizer_model.apply_chat_template(
round_msgs, add_generation_prompt=True, tokenize=False
round_msgs, add_generation_prompt=True, tokenize=False, enable_thinking=True, **tools_kwargs
)
messages.append(prompt_in_chat_template)
is_prompt.append(True)
elif message["role"] == "tool":
round_msgs.append(message)
elif message["role"] == "assistant":
if not round_msgs:
raise ValueError(f"Assistant message at index {idx} with no preceding context.")
round_msgs.append(message)
messages.append(_get_completion_in_chat_template(tokenizer_model, round_msgs))
messages.append(_get_completion_in_chat_template(tokenizer_model, round_msgs, tools=tools))
is_prompt.append(False)
# Round ended, clearing the buffer.
round_msgs.clear()
# Clear round only when the next message starts a new user turn or conversation ends
# This preserves context for consecutive assistant/tool messages
next_idx = idx + 1
if next_idx >= len(conversation) or conversation[next_idx]["role"] == "user":
round_msgs.clear()
except ValueError as e:
max_logging.log(f"Unable to apply chat template: {e}")
raise e
Expand Down