diff --git a/src/core/prompts/mmlu_single_token_answer.py b/src/core/prompts/mmlu_single_token_answer.py index 7ad0608..a0f1c5d 100644 --- a/src/core/prompts/mmlu_single_token_answer.py +++ b/src/core/prompts/mmlu_single_token_answer.py @@ -26,6 +26,7 @@ def single_token_sys_prompt_with_thinking(subject: str | None = None): return sys_msg + def single_token_sys_prompt_with_fallback_for_unknown_answers(subject: str | None = None): if subject is not None: sys_msg = f"The following are multiple choice questions about {subject}." diff --git a/src/core/training/branch_b_training.py b/src/core/training/branch_b_training.py new file mode 100644 index 0000000..5d1ba05 --- /dev/null +++ b/src/core/training/branch_b_training.py @@ -0,0 +1,129 @@ +""" +Shared orchestration for Branch B reasoning training and evaluation. + +1. Train with LoRATrainer using preprocessed data + MMLUReasoningResponseDataset +2. Evaluate all checkpoints post-training with MultiCheckpointEvaluator +""" + +from pathlib import Path + +from transformers import AutoTokenizer + +from core.datasets.causal_dataset_adapter import CausalDatasetAdapter +from core.datasets.mmlu.mmlu_cot_response_dataset import MMLUCoTResponseDataset +from core.datasets.mmlu.mmlu_reasoning_response_dataset import MMLUReasoningResponseDataset +from core.datasets.qa_dataset import QADatasetConfig +from core.datasets.qa_dataset_adapter import QADatasetAdapter +from core.evaluation.multi_checkpoint_evaluator import ( + GenerationConfig, + MultiCheckpointEvaluator, + MultiCheckpointEvaluatorConfig, +) +from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs, LoRASpecificTrainingArgs +from core.utils.logger import logger + +MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" +PROJECT_ROOT = Path(__file__).resolve().parents[4] + + +def run_branch_b_training( + *, + prompt_id: int = 1, + eval_split_dir: str = "data/out/splits/single_token_entropy/mmlu/qwen_3b", + eval_groups: int = 6, + per_device_train_batch_size: int = 1, + num_train_epochs: int = 20, + cot_eval_max_new_tokens: int = 8192, + cot_eval_max_batch_size: int = 64, + run_tag: str = "", +): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.thinking_start_token = "" + tokenizer.thinking_end_token = "" + + train_data_path = ( + PROJECT_ROOT + / f"data/out/distillation/mmlu_branch_b_cleaned_prompt{prompt_id}_prepared.parquet" + ) + if not train_data_path.exists(): + raise FileNotFoundError( + f"Preprocessed data not found: {train_data_path}. " + f"Run prepare_cleaned_b_data.py first." + ) + + run_suffix = f"_{run_tag}" if run_tag else "" + out_path = str( + PROJECT_ROOT / f"artifacts/sft_distill/branch_b_cleaned_prompt{prompt_id}{run_suffix}" + ) + save_schedule = sorted( + set(e for e in [1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20] if e <= num_train_epochs) + | {num_train_epochs} + ) + + # --- Training --- + logger.info(f"Training: prompt_id={prompt_id}, epochs={num_train_epochs}, out={out_path}") + + trainer = LoRATrainer( + config=LoRATrainerConfig( + out_path=out_path, + model_id=MODEL_NAME, + train_dataset=CausalDatasetAdapter( + dataset=MMLUReasoningResponseDataset( + tokenizer=tokenizer, + config=QADatasetConfig( + path=str(train_data_path), + dataset_id=f"distill_branch_b_prompt{prompt_id}", + ), + ) + ), + training_args=LoRATrainingArgs( + num_train_epochs=num_train_epochs, + per_device_train_batch_size=per_device_train_batch_size, + warmup_ratio=0.06, + torch_compile=False, + ), + lora_training_args=LoRASpecificTrainingArgs( + r=16, + alpha=32, + lora_dropout=0.05, + use_rslora=True, + ), + save_schedule=save_schedule, + ), + tokenizer=tokenizer, + ) + trainer.train() + trainer.unload() + + # --- CoT Evaluation (post-training) --- + logger.info("Starting post-training CoT evaluation...") + + eval_split_root = PROJECT_ROOT / eval_split_dir + + cot_evaluator = MultiCheckpointEvaluator( + config=MultiCheckpointEvaluatorConfig( + checkpoints_dir=out_path, + eval_dataset=[ + QADatasetAdapter( + dataset=MMLUCoTResponseDataset( + tokenizer=tokenizer, + config=QADatasetConfig( + path=str(eval_split_root / f"group{j}_test.parquet"), + dataset_id=f"mmlu_cot_response_group{j}_test", + ), + ) + ) + for j in range(eval_groups) + ], + base_model_id=MODEL_NAME, + generation=GenerationConfig( + max_new_tokens=cot_eval_max_new_tokens, + max_batch_size=cot_eval_max_batch_size, + attn_implementation="sdpa", + ), + summary_filename="summary_cot.json", + ), + tokenizer=tokenizer, + ) + cot_evaluator.evaluate_all() diff --git a/src/core/training/lora_trainer.py b/src/core/training/lora_trainer.py index 8169ff6..541c1ef 100644 --- a/src/core/training/lora_trainer.py +++ b/src/core/training/lora_trainer.py @@ -59,5 +59,7 @@ def model(self): use_rslora=self.config.lora_training_args.use_rslora, ) self._model = get_peft_model(model, peft_config) + if self.config.training_args.gradient_checkpointing: + self._model.enable_input_require_grads() return self._model diff --git a/src/experiments/distill/train_branches/prepare_cleaned_b_data.py b/src/experiments/distill/train_branches/prepare_cleaned_b_data.py new file mode 100644 index 0000000..cce3d2e --- /dev/null +++ b/src/experiments/distill/train_branches/prepare_cleaned_b_data.py @@ -0,0 +1,90 @@ +""" +Preprocess distillation Branch B data to flat MMLU format. + +Reads raw distillation parquet (nested input/output schema), +filters out eval questions and answer-leaked rows, +converts to flat MMLU schema compatible with MMLUReasoningResponseDataset. + +Usage: + uv run python src/experiments/distill/train_branches/prepare_cleaned_b_data.py +""" + +import re +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq + +PROJECT_ROOT = Path(__file__).resolve().parents[4] + +ANSWER_LEAK_RE = re.compile( + "|".join([ + r"\bcorrect answer\b", r"\bthe answer is\b", r"\banswer is\b", + r"\banswer:\b", r"\bcorrect option\b", r"\bcorrect choice\b", + r"\b[a-j]\s+is\s+correct\b", r"\[\[\s*[a-jA-J]\s*\]\]", + ]), + flags=re.IGNORECASE, +) + + +def collect_eval_question_ids(eval_split_dir: str, groups: int) -> set[str]: + split_root = PROJECT_ROOT / eval_split_dir + question_ids: set[str] = set() + for g in range(groups): + path = split_root / f"group{g}_test.parquet" + rows = pq.read_table(path, columns=["question_id"]).to_pylist() + question_ids.update(str(r["question_id"]) for r in rows) + return question_ids + + +def main(): + eval_split_dir = "data/out/splits/single_token_entropy/mmlu/qwen_3b" + eval_groups = 6 + + eval_ids = collect_eval_question_ids(eval_split_dir, eval_groups) + print(f"Eval question IDs to exclude: {len(eval_ids)}") + + for prompt_id in [1, 2, 3]: + raw_path = PROJECT_ROOT / f"data/out/distillation/mmlu_synth_gptoss_b_t0_8_cleaned_32b_prompt{prompt_id}.parquet" + if not raw_path.exists(): + print(f"Skipping prompt {prompt_id}: not found") + continue + + df = pd.read_parquet(raw_path) + total = len(df) + + rows = [] + for _, row in df.iterrows(): + inp = row["input"] + out = row["output"] + qid = str(inp["question_id"]) + + if qid in eval_ids: + continue + + thinking = str(out.get("thinking") or "").strip() + if not thinking or ANSWER_LEAK_RE.search(thinking): + continue + + opts_dict = inp["options"] + opts_list = [opts_dict[k] for k in sorted(opts_dict.keys())] + + rows.append({ + "question": inp["question"], + "options": str(opts_list), + "answer": inp["gold"], + "thinking": thinking, + "base_cluster": inp.get("subject", ""), + "question_id": inp["question_id"], + }) + + out_df = pd.DataFrame(rows) + out_path = PROJECT_ROOT / f"data/out/distillation/mmlu_branch_b_cleaned_prompt{prompt_id}_prepared.parquet" + out_df.to_parquet(out_path, index=False) + + filtered = total - len(out_df) + print(f"Prompt {prompt_id}: {total} -> {len(out_df)} rows ({filtered} filtered, {filtered/total*100:.1f}%)") + + +if __name__ == "__main__": + main() diff --git a/src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py b/src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py new file mode 100644 index 0000000..5c0774b --- /dev/null +++ b/src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py @@ -0,0 +1,23 @@ +""" +Full 20-epoch SFT run for Branch B prompt 1. + +Usage: + CUDA_VISIBLE_DEVICES=0,2 uv run torchrun --nproc_per_node=2 src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py +""" + +from core.training.branch_b_training import run_branch_b_training + + +def main(): + run_branch_b_training( + prompt_id=1, + eval_split_dir="data/out/splits/single_token_entropy/mmlu/qwen_3b", + eval_groups=6, + per_device_train_batch_size=2, + num_train_epochs=20, + run_tag="full20_eval6_v2", + ) + + +if __name__ == "__main__": + main()