From 183134918922894021305af98b0b3b1cc7d21f17 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Thu, 26 Mar 2026 22:09:07 +0100 Subject: [PATCH 1/3] [WIP] Resampling trainer --- src/core/training/resampling_trainer.py | 38 +++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/core/training/resampling_trainer.py diff --git a/src/core/training/resampling_trainer.py b/src/core/training/resampling_trainer.py new file mode 100644 index 0000000..cd36897 --- /dev/null +++ b/src/core/training/resampling_trainer.py @@ -0,0 +1,38 @@ +from pydraconf import PydraConfig + +from core.complexity_estimation.complexity_estimation_runner import ComplexityEstimationRunner +from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs +from core.utils.logger import logger + + +class ResamplingTrainerConfig(PydraConfig): + training_args: LoRATrainingArgs + + +class ResamplingTrainer: + def __init__(self, config: ResamplingTrainerConfig, tokenizer): + self.config = config + self.tokenizer = tokenizer + + def train(self): + for epoch in range(self.config.training_args.num_train_epochs): + logger.info(f"Epoch {epoch + 1}/{self.config.training_args.num_train_epochs}...") + + # TODO: build tmp dataset with teacher entropy + + logger.info("Estimating complexity...") + ComplexityEstimationRunner().estimate() + + logger.info("Training...") + trainer = LoRATrainer( + config=LoRATrainerConfig( + out_path=f"{self.config.out_path}/epoch_{epoch + 1}", + model_id=self.config.model_id, + train_dataset=self.config.train_dataset, + training_args=self.config.training_args, + lora_training_args=self.config.lora_training_args, + ), + tokenizer=self.tokenizer, + ) + trainer.train() + trainer.unload() From 8ad38e329b219eb8cfe1c6b766ec5bdc36a0bb88 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Fri, 27 Mar 2026 18:59:26 +0100 Subject: [PATCH 2/3] Add dataset aggregator --- src/core/datasets/abstract_dataset_adapter.py | 2 +- src/core/datasets/base_dataset.py | 14 ++++- src/core/datasets/base_dataset_aggregator.py | 25 ++++++++ src/core/training/base_trainer.py | 14 +++-- src/core/training/lora_trainer.py | 6 +- src/core/training/resampling_trainer.py | 59 +++++++++---------- .../mmlu/qwen_3b/student_entropy.py | 38 ++++++++++++ 7 files changed, 119 insertions(+), 39 deletions(-) create mode 100644 src/core/datasets/base_dataset_aggregator.py create mode 100644 src/experiments/train_pipeline/mmlu/qwen_3b/student_entropy.py diff --git a/src/core/datasets/abstract_dataset_adapter.py b/src/core/datasets/abstract_dataset_adapter.py index de64298..f69670d 100644 --- a/src/core/datasets/abstract_dataset_adapter.py +++ b/src/core/datasets/abstract_dataset_adapter.py @@ -31,7 +31,7 @@ def _load_ds(self, path_override: str | None = None) -> Dataset: ds = load_dataset( "parquet", - data_files={"default": self.dataset.config.path}, + data_files={"default": self.dataset.processed_path}, ) return ds["default"] diff --git a/src/core/datasets/base_dataset.py b/src/core/datasets/base_dataset.py index 1c3b4d0..e990569 100644 --- a/src/core/datasets/base_dataset.py +++ b/src/core/datasets/base_dataset.py @@ -3,9 +3,11 @@ from pydraconf import PydraConfig from transformers import PreTrainedTokenizer +from core.datasets.base_dataset_aggregator import BaseDatasetAggregator + class BaseDatasetConfig(PydraConfig): - path: str + path: str | BaseDatasetAggregator dataset_id: str @@ -13,6 +15,7 @@ class BaseDataset[C: BaseDatasetConfig](ABC): def __init__(self, tokenizer: PreTrainedTokenizer, config: C): self.tokenizer = tokenizer self.config = config + self._path = config.path if isinstance(config.path, str) else None @abstractmethod def system_prompt(self, row: dict) -> str: ... @@ -26,3 +29,12 @@ def row_id(self, row: dict) -> str: ... @property def dataset_id(self) -> str: return self.config.dataset_id + + @property + def processed_path(self) -> str: + if self._path is None: + assert isinstance(self.config.path, BaseDatasetAggregator) + self._path = self.config.path.aggregate() + + assert self._path is not None + return self._path diff --git a/src/core/datasets/base_dataset_aggregator.py b/src/core/datasets/base_dataset_aggregator.py new file mode 100644 index 0000000..187ec16 --- /dev/null +++ b/src/core/datasets/base_dataset_aggregator.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +import pandas as pd +from pydraconf import PydraConfig + +from core.utils.logger import logger + + +class BaseDatasetAggregatorConfig(PydraConfig): + in_paths: str + out_path: str + + +class BaseDatasetAggregator(ABC): + def __init__(self, config: BaseDatasetAggregatorConfig): + self.config = config + + @abstractmethod + def _merge(self, dfs: list[pd.DataFrame]) -> pd.DataFrame: ... + + def aggregate(self) -> None: + logger.info(f"Aggregating datasets from {self.config.in_paths} into {self.config.out_path}...") + dfs = [pd.read_parquet(path) for path in self.config.in_paths] + merged_df = self._merge(dfs) + merged_df.to_parquet(self.config.out_path) diff --git a/src/core/training/base_trainer.py b/src/core/training/base_trainer.py index 79e6033..0cdd122 100644 --- a/src/core/training/base_trainer.py +++ b/src/core/training/base_trainer.py @@ -5,7 +5,6 @@ from typing import Any import torch - from pydantic import BaseModel from pydraconf import PydraConfig from transformers import ( @@ -15,6 +14,7 @@ PreTrainedTokenizer, Seq2SeqTrainingArguments, ) +from transformers.modeling_utils import PreTrainedModel from transformers.trainer_seq2seq import Seq2SeqTrainer from core.datasets.abstract_dataset_adapter import AbstractDatasetAdapter @@ -59,7 +59,7 @@ class BaseTrainer[TConfig: BaseTrainerConfig[Any] = BaseTrainerConfig]: def __init__(self, config: TConfig, tokenizer: PreTrainedTokenizer | None = None): self.config = config self._tokenizer: PreTrainedTokenizer | None = tokenizer - self._model: AutoModelForCausalLM | None = None + self._model: PreTrainedModel | None = None def train(self): if not self._directory_is_empty(self.config.out_path, self.config.training_args.num_train_epochs): @@ -71,7 +71,8 @@ def train(self): logger.info(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) train_ds = self._prepare_data() - self._run_training(train_ds) + trainer = self._build_trainer(train_ds) + self._run_training(trainer) return get_last_checkpoint_dir(self.config.out_path) @@ -89,7 +90,7 @@ def tokenizer(self): return self._tokenizer @property - def model(self): + def model(self) -> PreTrainedModel: if not self._model: self._model = AutoModelForCausalLM.from_pretrained(self.config.model_id) @@ -125,7 +126,7 @@ def _prepare_data(self): return train_ds - def _run_training(self, train_ds): + def _build_trainer(self, train_ds): trainer = Seq2SeqTrainer( model=self.model, args=self.training_args, @@ -137,6 +138,9 @@ def _run_training(self, train_ds): if self.config.save_schedule is not None: trainer.add_callback(SaveByScheduleCallback(schedule=self.config.save_schedule)) + return trainer + + def _run_training(self, trainer): has_checkpoint = get_last_checkpoint_dir(self.config.out_path) is not None logger.info(f"Has checkpoint: {has_checkpoint}") trainer.train(resume_from_checkpoint=has_checkpoint) diff --git a/src/core/training/lora_trainer.py b/src/core/training/lora_trainer.py index 8169ff6..25d7846 100644 --- a/src/core/training/lora_trainer.py +++ b/src/core/training/lora_trainer.py @@ -3,6 +3,7 @@ from peft import LoraConfig, TaskType, get_peft_model from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM +from transformers.modeling_utils import PreTrainedModel from core.training.base_trainer import BaseTrainer, BaseTrainerConfig, BaseTrainingArgs @@ -42,10 +43,10 @@ class LoRATrainerConfig(BaseTrainerConfig[LoRATrainingArgs]): lora_training_args: LoRASpecificTrainingArgs = Field(default_factory=LoRASpecificTrainingArgs) -class LoRATrainer(BaseTrainer[LoRATrainerConfig]): +class LoRATrainer[TConfig: LoRATrainerConfig = LoRATrainerConfig](BaseTrainer[TConfig]): @property @override - def model(self): + def model(self) -> PreTrainedModel: if not self._model: model = AutoModelForCausalLM.from_pretrained(self.config.model_id) peft_config = LoraConfig( @@ -60,4 +61,5 @@ def model(self): ) self._model = get_peft_model(model, peft_config) + assert self._model is not None return self._model diff --git a/src/core/training/resampling_trainer.py b/src/core/training/resampling_trainer.py index cd36897..4861432 100644 --- a/src/core/training/resampling_trainer.py +++ b/src/core/training/resampling_trainer.py @@ -1,38 +1,37 @@ -from pydraconf import PydraConfig - -from core.complexity_estimation.complexity_estimation_runner import ComplexityEstimationRunner -from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig, LoRATrainingArgs +from typing import override + +from core.complexity_estimation.complexity_estimation_runner import ( + BaseComplexityEstimator, + ComplexityEstimationRunner, + ComplexityEstimationRunnerConfig, + ModelGenerateConfig, + QADatasetAdapter, +) +from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig from core.utils.logger import logger -class ResamplingTrainerConfig(PydraConfig): - training_args: LoRATrainingArgs - +class ResamplingTrainerConfig(LoRATrainerConfig): + complexity_evaluation_dataset: QADatasetAdapter + complexity_estimator: BaseComplexityEstimator -class ResamplingTrainer: - def __init__(self, config: ResamplingTrainerConfig, tokenizer): - self.config = config - self.tokenizer = tokenizer - def train(self): - for epoch in range(self.config.training_args.num_train_epochs): - logger.info(f"Epoch {epoch + 1}/{self.config.training_args.num_train_epochs}...") +class ResamplingTrainer(LoRATrainer[ResamplingTrainerConfig]): + def _estimate_complexity_for_epoch(self, epoch: int): + logger.info(f"Estimating complexity for epoch {epoch + 1}...") - # TODO: build tmp dataset with teacher entropy + ComplexityEstimationRunner( + config=ComplexityEstimationRunnerConfig( + out_path=self._path_for_epoch(epoch).as_posix(), + answer_field_name="estimation_phase_answer", + answer_correctness_field_name="estimation_phase_answer_correctness", + generate_config=ModelGenerateConfig(max_new_tokens=1), + ), + complexity_estimator=self.config.complexity_estimator, + ).estimate(dataset_adapter=self.config.complexity_evaluation_dataset, model=self._trainer.model) - logger.info("Estimating complexity...") - ComplexityEstimationRunner().estimate() + @override + def _prepare_data(self): ... - logger.info("Training...") - trainer = LoRATrainer( - config=LoRATrainerConfig( - out_path=f"{self.config.out_path}/epoch_{epoch + 1}", - model_id=self.config.model_id, - train_dataset=self.config.train_dataset, - training_args=self.config.training_args, - lora_training_args=self.config.lora_training_args, - ), - tokenizer=self.tokenizer, - ) - trainer.train() - trainer.unload() + @override + def _build_trainer(self, train_ds): ... diff --git a/src/experiments/train_pipeline/mmlu/qwen_3b/student_entropy.py b/src/experiments/train_pipeline/mmlu/qwen_3b/student_entropy.py new file mode 100644 index 0000000..2366423 --- /dev/null +++ b/src/experiments/train_pipeline/mmlu/qwen_3b/student_entropy.py @@ -0,0 +1,38 @@ +from pathlib import Path + +from transformers import AutoTokenizer + +from core.datasets.causal_dataset_adapter import CausalDatasetAdapter +from core.datasets.mmlu.mmlu_reasoning_response_dataset import MMLUReasoningResponseDataset +from core.datasets.qa_dataset import QADatasetConfig +from core.training.lora_trainer import LoRATrainingArgs +from core.training.resampling_trainer import ResamplingTrainer, ResamplingTrainerConfig + +MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" + +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +trainer = ResamplingTrainer( + config=ResamplingTrainerConfig( + training_args=LoRATrainingArgs( + num_train_epochs=20, + per_device_train_batch_size=32, + ), + out_path=Path(__file__) + .parent.joinpath("../../../../artifacts/train_pipeline/mmlu/qwen_3b/student_entropy") + .as_posix(), + model_id=MODEL_NAME, + train_dataset=CausalDatasetAdapter( + dataset=MMLUReasoningResponseDataset( + config=QADatasetConfig( + path=Path(__file__).parent.joinpath("../../../../../data/source/mmlu_pro_stem.parquet").as_posix(), + dataset_id="mmlu_qwen_3b_student_entropy", + ), + tokenizer=tokenizer, + ) + ), + save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20], + ), + tokenizer=tokenizer, +) From 35b5bedd73b8d1ec42356c40c6f6044363c17d43 Mon Sep 17 00:00:00 2001 From: Andrey Goncharov Date: Sun, 29 Mar 2026 14:25:58 +0100 Subject: [PATCH 3/3] Draft up Resampling trainer --- src/core/training/resampling_trainer.py | 58 ++++++++++++++++++++----- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/src/core/training/resampling_trainer.py b/src/core/training/resampling_trainer.py index 4861432..c4e0c44 100644 --- a/src/core/training/resampling_trainer.py +++ b/src/core/training/resampling_trainer.py @@ -1,5 +1,9 @@ +from pathlib import Path from typing import override +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +from transformers.modeling_utils import PreTrainedModel + from core.complexity_estimation.complexity_estimation_runner import ( BaseComplexityEstimator, ComplexityEstimationRunner, @@ -11,27 +15,59 @@ from core.utils.logger import logger -class ResamplingTrainerConfig(LoRATrainerConfig): - complexity_evaluation_dataset: QADatasetAdapter - complexity_estimator: BaseComplexityEstimator +class EstimateComplexityCallback(TrainerCallback): + def __init__( + self, + complexity_evaluation_dataset: QADatasetAdapter, + complexity_estimator: BaseComplexityEstimator, + complexity_estimation_runner_generation_config: ModelGenerateConfig, + out_path: Path, + ) -> None: + super().__init__() + self._complexity_evaluation_dataset = complexity_evaluation_dataset + self._complexity_estimator = complexity_estimator + self._complexity_estimation_runner_generation_config = complexity_estimation_runner_generation_config + self._out_path = out_path -class ResamplingTrainer(LoRATrainer[ResamplingTrainerConfig]): - def _estimate_complexity_for_epoch(self, epoch: int): - logger.info(f"Estimating complexity for epoch {epoch + 1}...") + @override + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: + logger.info(f"Estimating complexity for epoch {state.epoch}...") + + model: PreTrainedModel = kwargs["model"] ComplexityEstimationRunner( config=ComplexityEstimationRunnerConfig( - out_path=self._path_for_epoch(epoch).as_posix(), + out_path=self._out_path.as_posix(), answer_field_name="estimation_phase_answer", answer_correctness_field_name="estimation_phase_answer_correctness", - generate_config=ModelGenerateConfig(max_new_tokens=1), + generate_config=self._complexity_estimation_runner_generation_config, ), - complexity_estimator=self.config.complexity_estimator, - ).estimate(dataset_adapter=self.config.complexity_evaluation_dataset, model=self._trainer.model) + complexity_estimator=self._complexity_estimator, + ).estimate(dataset_adapter=self._complexity_evaluation_dataset, model=model) + +class ResamplingTrainerConfig(LoRATrainerConfig): + complexity_evaluation_dataset: QADatasetAdapter + complexity_estimator: BaseComplexityEstimator + complexity_estimation_runner_generation_config: ModelGenerateConfig + + +class ResamplingTrainer(LoRATrainer[ResamplingTrainerConfig]): @override def _prepare_data(self): ... @override - def _build_trainer(self, train_ds): ... + def _build_trainer(self, train_ds): + trainer = super()._build_trainer(train_ds) + + trainer.add_callback( + EstimateComplexityCallback( + complexity_evaluation_dataset=self.config.complexity_evaluation_dataset, + complexity_estimator=self.config.complexity_estimator, + complexity_estimation_runner_generation_config=self.config.complexity_estimation_runner_generation_config, + out_path=Path(__file__).parent.joinpath("TODO"), + ) + ) + + return trainer