Skip to content
1 change: 1 addition & 0 deletions src/core/prompts/mmlu_single_token_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def single_token_sys_prompt_with_thinking(subject: str | None = None):
return sys_msg



def single_token_sys_prompt_with_fallback_for_unknown_answers(subject: str | None = None):
if subject is not None:
sys_msg = f"The following are multiple choice questions about {subject}."
Expand Down
129 changes: 129 additions & 0 deletions src/core/training/branch_b_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py and this script? Why do we need both?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_cleaned_b_full20_prompt1.py is the entry point

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Could you move the main script to core/... then? And keep the entry point in experiements

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ок

Shared orchestration for Branch B reasoning training and evaluation.

1. Train with LoRATrainer using preprocessed data + MMLUReasoningResponseDataset
2. Evaluate all checkpoints post-training with MultiCheckpointEvaluator
"""

from pathlib import Path

from transformers import AutoTokenizer

from core.datasets.causal_dataset_adapter import CausalDatasetAdapter
from core.datasets.mmlu.mmlu_cot_response_dataset import MMLUCoTResponseDataset
from core.datasets.mmlu.mmlu_reasoning_response_dataset import MMLUReasoningResponseDataset
from core.datasets.qa_dataset import QADatasetConfig
from core.datasets.qa_dataset_adapter import QADatasetAdapter
from core.evaluation.multi_checkpoint_evaluator import (
GenerationConfig,
MultiCheckpointEvaluator,
MultiCheckpointEvaluatorConfig,
)
from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs, LoRASpecificTrainingArgs
from core.utils.logger import logger

MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
PROJECT_ROOT = Path(__file__).resolve().parents[4]


def run_branch_b_training(
*,
prompt_id: int = 1,
eval_split_dir: str = "data/out/splits/single_token_entropy/mmlu/qwen_3b",
eval_groups: int = 6,
per_device_train_batch_size: int = 1,
num_train_epochs: int = 20,
cot_eval_max_new_tokens: int = 8192,
cot_eval_max_batch_size: int = 64,
run_tag: str = "",
):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.thinking_start_token = "<think>"
tokenizer.thinking_end_token = "</think>"

train_data_path = (
PROJECT_ROOT
/ f"data/out/distillation/mmlu_branch_b_cleaned_prompt{prompt_id}_prepared.parquet"
)
if not train_data_path.exists():
raise FileNotFoundError(
f"Preprocessed data not found: {train_data_path}. "
f"Run prepare_cleaned_b_data.py first."
)

run_suffix = f"_{run_tag}" if run_tag else ""
out_path = str(
PROJECT_ROOT / f"artifacts/sft_distill/branch_b_cleaned_prompt{prompt_id}{run_suffix}"
)
save_schedule = sorted(
set(e for e in [1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20] if e <= num_train_epochs)
| {num_train_epochs}
)

# --- Training ---
logger.info(f"Training: prompt_id={prompt_id}, epochs={num_train_epochs}, out={out_path}")

trainer = LoRATrainer(
config=LoRATrainerConfig(
out_path=out_path,
model_id=MODEL_NAME,
train_dataset=CausalDatasetAdapter(
dataset=MMLUReasoningResponseDataset(
tokenizer=tokenizer,
config=QADatasetConfig(
path=str(train_data_path),
dataset_id=f"distill_branch_b_prompt{prompt_id}",
),
)
),
training_args=LoRATrainingArgs(
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
warmup_ratio=0.06,
torch_compile=False,
),
lora_training_args=LoRASpecificTrainingArgs(
r=16,
alpha=32,
lora_dropout=0.05,
use_rslora=True,
),
save_schedule=save_schedule,
),
tokenizer=tokenizer,
)
trainer.train()
trainer.unload()

# --- CoT Evaluation (post-training) ---
logger.info("Starting post-training CoT evaluation...")

eval_split_root = PROJECT_ROOT / eval_split_dir

cot_evaluator = MultiCheckpointEvaluator(
config=MultiCheckpointEvaluatorConfig(
checkpoints_dir=out_path,
eval_dataset=[
QADatasetAdapter(
dataset=MMLUCoTResponseDataset(
tokenizer=tokenizer,
config=QADatasetConfig(
path=str(eval_split_root / f"group{j}_test.parquet"),
dataset_id=f"mmlu_cot_response_group{j}_test",
),
)
)
for j in range(eval_groups)
],
base_model_id=MODEL_NAME,
generation=GenerationConfig(
max_new_tokens=cot_eval_max_new_tokens,
max_batch_size=cot_eval_max_batch_size,
attn_implementation="sdpa",
),
summary_filename="summary_cot.json",
),
tokenizer=tokenizer,
)
cot_evaluator.evaluate_all()
2 changes: 2 additions & 0 deletions src/core/training/lora_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ def model(self):
use_rslora=self.config.lora_training_args.use_rslora,
)
self._model = get_peft_model(model, peft_config)
if self.config.training_args.gradient_checkpointing:
self._model.enable_input_require_grads()

return self._model
90 changes: 90 additions & 0 deletions src/experiments/distill/train_branches/prepare_cleaned_b_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Preprocess distillation Branch B data to flat MMLU format.

Reads raw distillation parquet (nested input/output schema),
filters out eval questions and answer-leaked rows,
converts to flat MMLU schema compatible with MMLUReasoningResponseDataset.

Usage:
uv run python src/experiments/distill/train_branches/prepare_cleaned_b_data.py
"""

import re
from pathlib import Path

import pandas as pd
import pyarrow.parquet as pq

PROJECT_ROOT = Path(__file__).resolve().parents[4]

ANSWER_LEAK_RE = re.compile(
"|".join([
r"\bcorrect answer\b", r"\bthe answer is\b", r"\banswer is\b",
r"\banswer:\b", r"\bcorrect option\b", r"\bcorrect choice\b",
r"\b[a-j]\s+is\s+correct\b", r"\[\[\s*[a-jA-J]\s*\]\]",
]),
flags=re.IGNORECASE,
)


def collect_eval_question_ids(eval_split_dir: str, groups: int) -> set[str]:
split_root = PROJECT_ROOT / eval_split_dir
question_ids: set[str] = set()
for g in range(groups):
path = split_root / f"group{g}_test.parquet"
rows = pq.read_table(path, columns=["question_id"]).to_pylist()
question_ids.update(str(r["question_id"]) for r in rows)
return question_ids


def main():
eval_split_dir = "data/out/splits/single_token_entropy/mmlu/qwen_3b"
eval_groups = 6

eval_ids = collect_eval_question_ids(eval_split_dir, eval_groups)
print(f"Eval question IDs to exclude: {len(eval_ids)}")

for prompt_id in [1, 2, 3]:
raw_path = PROJECT_ROOT / f"data/out/distillation/mmlu_synth_gptoss_b_t0_8_cleaned_32b_prompt{prompt_id}.parquet"
if not raw_path.exists():
print(f"Skipping prompt {prompt_id}: not found")
continue

df = pd.read_parquet(raw_path)
total = len(df)

rows = []
for _, row in df.iterrows():
inp = row["input"]
out = row["output"]
qid = str(inp["question_id"])

if qid in eval_ids:
continue

thinking = str(out.get("thinking") or "").strip()
if not thinking or ANSWER_LEAK_RE.search(thinking):
continue

opts_dict = inp["options"]
opts_list = [opts_dict[k] for k in sorted(opts_dict.keys())]

rows.append({
"question": inp["question"],
"options": str(opts_list),
"answer": inp["gold"],
"thinking": thinking,
"base_cluster": inp.get("subject", ""),
"question_id": inp["question_id"],
})

out_df = pd.DataFrame(rows)
out_path = PROJECT_ROOT / f"data/out/distillation/mmlu_branch_b_cleaned_prompt{prompt_id}_prepared.parquet"
out_df.to_parquet(out_path, index=False)

filtered = total - len(out_df)
print(f"Prompt {prompt_id}: {total} -> {len(out_df)} rows ({filtered} filtered, {filtered/total*100:.1f}%)")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Full 20-epoch SFT run for Branch B prompt 1.

Usage:
CUDA_VISIBLE_DEVICES=0,2 uv run torchrun --nproc_per_node=2 src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py
"""

from core.training.branch_b_training import run_branch_b_training


def main():
run_branch_b_training(
prompt_id=1,
eval_split_dir="data/out/splits/single_token_entropy/mmlu/qwen_3b",
eval_groups=6,
per_device_train_batch_size=2,
num_train_epochs=20,
run_tag="full20_eval6_v2",
)


if __name__ == "__main__":
main()