Skip to content

Add Branch B CoT training run#16

Open
SemyonEpanov wants to merge 10 commits intomainfrom
run-distill-branch-b
Open

Add Branch B CoT training run#16
SemyonEpanov wants to merge 10 commits intomainfrom
run-distill-branch-b

Conversation

@SemyonEpanov
Copy link
Copy Markdown
Collaborator

No description provided.

Semyon Epanov added 6 commits March 10, 2026 19:08
# Conflicts:
#	src/core/training/base_trainer.py
#	src/experiments/sft_by_complexity_splits/mmlu/llama_3b.py
…istill-branch-b

# Conflicts:
#	src/core/datasets/distillation/distillation_branch_b_cot_dataset.py
#	src/core/datasets/mmlu/mmlu_cot_response_dataset.py
#	src/core/training/base_trainer.py
#	src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py
#	src/experiments/distill/train_branches/train_cleaned_b_new.py
use_rslora=self.config.lora_training_args.use_rslora,
)
self._model = get_peft_model(model, peft_config)
if self.config.training_args.gradient_checkpointing:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need it? Isn't it handled automatically by transformers?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will remove.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed use_cache=False (handled by Trainer), but had to keep enable_input_require_grads(), because training crashes with RE.

eval_split_dir="data/out/splits/single_token_entropy/mmlu/qwen_3b",
eval_groups=6,
per_device_train_batch_size=1,
effective_train_batch_size=120,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to change effective batch size?

per_device_train_batch_size=1,
effective_train_batch_size=120,
num_train_epochs=20,
learning_rate=1e-4,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change lr?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lr doesn't change (it is just set explicitly)

class LoRATrainingArgs(BaseTrainingArgs):
    # Sane overrides for LoRA SFT fine-tuning
    effective_train_batch_size: int = 64
    learning_rate: float = 1e-4
    warmup_ratio: float = 0.06
    weight_decay: float = 0.0

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your code it is effective_train_batch_size=120, typo?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set effective_train_batch_size=64, per_device_train_batch_size=2.

raise FileNotFoundError(f"Train parquet not found: {train_data_path}")

eval_question_ids = _collect_eval_question_ids(eval_split_dir, eval_groups)
train_row_filter = _build_train_row_filter(eval_question_ids)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of applying the filter dynamically, shall we preprocess the data and save it to disk? Just like with other MMLU data splits

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we could use just CausalDatasetAdapater

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, agree.

@@ -0,0 +1,193 @@
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between src/experiments/distill/train_branches/train_cleaned_b_full20_prompt1.py and this script? Why do we need both?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_cleaned_b_full20_prompt1.py is the entry point

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Could you move the main script to core/... then? And keep the entry point in experiements

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ок

lambda row: self.process_row(row).model_dump(),
num_proc=4,
remove_columns=ds.column_names,
load_from_cache_file=False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it needed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted. Used in debugging process.

return (
f"Question: {question.strip()}\n\n"
f"Options:\n{opts}\n\n"
f"Answer with the option letter first, then provide reasoning inside {THINKING_START}...{THINKING_END} tags."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to prompt the model to asnwer with reasoning tags, right? Reasoning models should use reasoning by default. Meaning that teh prompt should come without the request to use them

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. The model will learn the answer-first + reasoning format from the training data itself.
Will simplify the prompt to a plain question format without thinking tag instructions.

from core.prompts.thinking_markers import THINKING_START, THINKING_END


class DistillationBranchBCoTDataset(CausalDataset[CausalDatasetConfig]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need it? Could we use MMLUReasoningResponseDataset instead? Just pre-process the data to match the current format of MMLU datasets

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we'll use MMLURasoningResponseDataset directly

Semyon Epanov added 4 commits March 22, 2026 19:29
- Remove redundant gradient checkpointing code from LoRATrainer
- Revert load_from_cache_file=False from abstract base class
- Delete DistillationBranchBCoTDataset, use MMLUReasoningResponseDataset
- Remove single_token_sys_prompt_with_answer_first_thinking
- Add data preprocessing script (prepare_cleaned_b_data.py)
- Rewrite training orchestration (branch_b_training.py)
- Use default effective_batch_size=64, remove explicit lr=1e-4
- Delete FilteredCausalDatasetAdapter and train_cleaned_b_new.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants