Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/core/datasets/abstract_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _load_ds(self, path_override: str | None = None) -> Dataset:

ds = load_dataset(
"parquet",
data_files={"default": self.dataset.config.path},
data_files={"default": self.dataset.processed_path},
)
return ds["default"]

Expand Down
14 changes: 13 additions & 1 deletion src/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
from pydraconf import PydraConfig
from transformers import PreTrainedTokenizer

from core.datasets.base_dataset_aggregator import BaseDatasetAggregator


class BaseDatasetConfig(PydraConfig):
path: str
path: str | BaseDatasetAggregator
dataset_id: str


class BaseDataset[C: BaseDatasetConfig](ABC):
def __init__(self, tokenizer: PreTrainedTokenizer, config: C):
self.tokenizer = tokenizer
self.config = config
self._path = config.path if isinstance(config.path, str) else None

@abstractmethod
def system_prompt(self, row: dict) -> str: ...
Expand All @@ -26,3 +29,12 @@ def row_id(self, row: dict) -> str: ...
@property
def dataset_id(self) -> str:
return self.config.dataset_id

@property
def processed_path(self) -> str:
if self._path is None:
assert isinstance(self.config.path, BaseDatasetAggregator)
self._path = self.config.path.aggregate()

assert self._path is not None
return self._path
25 changes: 25 additions & 0 deletions src/core/datasets/base_dataset_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod

import pandas as pd
from pydraconf import PydraConfig

from core.utils.logger import logger


class BaseDatasetAggregatorConfig(PydraConfig):
in_paths: str
out_path: str


class BaseDatasetAggregator(ABC):
def __init__(self, config: BaseDatasetAggregatorConfig):
self.config = config

@abstractmethod
def _merge(self, dfs: list[pd.DataFrame]) -> pd.DataFrame: ...

def aggregate(self) -> None:
logger.info(f"Aggregating datasets from {self.config.in_paths} into {self.config.out_path}...")
dfs = [pd.read_parquet(path) for path in self.config.in_paths]
merged_df = self._merge(dfs)
merged_df.to_parquet(self.config.out_path)
14 changes: 9 additions & 5 deletions src/core/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any

import torch

from pydantic import BaseModel
from pydraconf import PydraConfig
from transformers import (
Expand All @@ -15,6 +14,7 @@
PreTrainedTokenizer,
Seq2SeqTrainingArguments,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.trainer_seq2seq import Seq2SeqTrainer

from core.datasets.abstract_dataset_adapter import AbstractDatasetAdapter
Expand Down Expand Up @@ -59,7 +59,7 @@ class BaseTrainer[TConfig: BaseTrainerConfig[Any] = BaseTrainerConfig]:
def __init__(self, config: TConfig, tokenizer: PreTrainedTokenizer | None = None):
self.config = config
self._tokenizer: PreTrainedTokenizer | None = tokenizer
self._model: AutoModelForCausalLM | None = None
self._model: PreTrainedModel | None = None

def train(self):
if not self._directory_is_empty(self.config.out_path, self.config.training_args.num_train_epochs):
Expand All @@ -71,7 +71,8 @@ def train(self):
logger.info(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)

train_ds = self._prepare_data()
self._run_training(train_ds)
trainer = self._build_trainer(train_ds)
self._run_training(trainer)

return get_last_checkpoint_dir(self.config.out_path)

Expand All @@ -89,7 +90,7 @@ def tokenizer(self):
return self._tokenizer

@property
def model(self):
def model(self) -> PreTrainedModel:
if not self._model:
self._model = AutoModelForCausalLM.from_pretrained(self.config.model_id)

Expand Down Expand Up @@ -125,7 +126,7 @@ def _prepare_data(self):

return train_ds

def _run_training(self, train_ds):
def _build_trainer(self, train_ds):
trainer = Seq2SeqTrainer(
model=self.model,
args=self.training_args,
Expand All @@ -137,6 +138,9 @@ def _run_training(self, train_ds):
if self.config.save_schedule is not None:
trainer.add_callback(SaveByScheduleCallback(schedule=self.config.save_schedule))

return trainer

def _run_training(self, trainer):
has_checkpoint = get_last_checkpoint_dir(self.config.out_path) is not None
logger.info(f"Has checkpoint: {has_checkpoint}")
trainer.train(resume_from_checkpoint=has_checkpoint)
Expand Down
6 changes: 4 additions & 2 deletions src/core/training/lora_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from peft import LoraConfig, TaskType, get_peft_model
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import PreTrainedModel

from core.training.base_trainer import BaseTrainer, BaseTrainerConfig, BaseTrainingArgs

Expand Down Expand Up @@ -42,10 +43,10 @@ class LoRATrainerConfig(BaseTrainerConfig[LoRATrainingArgs]):
lora_training_args: LoRASpecificTrainingArgs = Field(default_factory=LoRASpecificTrainingArgs)


class LoRATrainer(BaseTrainer[LoRATrainerConfig]):
class LoRATrainer[TConfig: LoRATrainerConfig = LoRATrainerConfig](BaseTrainer[TConfig]):
@property
@override
def model(self):
def model(self) -> PreTrainedModel:
if not self._model:
model = AutoModelForCausalLM.from_pretrained(self.config.model_id)
peft_config = LoraConfig(
Expand All @@ -60,4 +61,5 @@ def model(self):
)
self._model = get_peft_model(model, peft_config)

assert self._model is not None
return self._model
73 changes: 73 additions & 0 deletions src/core/training/resampling_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pathlib import Path
from typing import override

from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.modeling_utils import PreTrainedModel

from core.complexity_estimation.complexity_estimation_runner import (
BaseComplexityEstimator,
ComplexityEstimationRunner,
ComplexityEstimationRunnerConfig,
ModelGenerateConfig,
QADatasetAdapter,
)
from core.training.lora_trainer import LoRATrainer, LoRATrainerConfig
from core.utils.logger import logger


class EstimateComplexityCallback(TrainerCallback):
def __init__(
self,
complexity_evaluation_dataset: QADatasetAdapter,
complexity_estimator: BaseComplexityEstimator,
complexity_estimation_runner_generation_config: ModelGenerateConfig,
out_path: Path,
) -> None:
super().__init__()

self._complexity_evaluation_dataset = complexity_evaluation_dataset
self._complexity_estimator = complexity_estimator
self._complexity_estimation_runner_generation_config = complexity_estimation_runner_generation_config
self._out_path = out_path

@override
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
logger.info(f"Estimating complexity for epoch {state.epoch}...")

model: PreTrainedModel = kwargs["model"]

ComplexityEstimationRunner(
config=ComplexityEstimationRunnerConfig(
out_path=self._out_path.as_posix(),
answer_field_name="estimation_phase_answer",
answer_correctness_field_name="estimation_phase_answer_correctness",
generate_config=self._complexity_estimation_runner_generation_config,
),
complexity_estimator=self._complexity_estimator,
).estimate(dataset_adapter=self._complexity_evaluation_dataset, model=model)


class ResamplingTrainerConfig(LoRATrainerConfig):
complexity_evaluation_dataset: QADatasetAdapter
complexity_estimator: BaseComplexityEstimator
complexity_estimation_runner_generation_config: ModelGenerateConfig


class ResamplingTrainer(LoRATrainer[ResamplingTrainerConfig]):
@override
def _prepare_data(self): ...

@override
def _build_trainer(self, train_ds):
trainer = super()._build_trainer(train_ds)

trainer.add_callback(
EstimateComplexityCallback(
complexity_evaluation_dataset=self.config.complexity_evaluation_dataset,
complexity_estimator=self.config.complexity_estimator,
complexity_estimation_runner_generation_config=self.config.complexity_estimation_runner_generation_config,
out_path=Path(__file__).parent.joinpath("TODO"),
)
)

return trainer
38 changes: 38 additions & 0 deletions src/experiments/train_pipeline/mmlu/qwen_3b/student_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

from transformers import AutoTokenizer

from core.datasets.causal_dataset_adapter import CausalDatasetAdapter
from core.datasets.mmlu.mmlu_reasoning_response_dataset import MMLUReasoningResponseDataset
from core.datasets.qa_dataset import QADatasetConfig
from core.training.lora_trainer import LoRATrainingArgs
from core.training.resampling_trainer import ResamplingTrainer, ResamplingTrainerConfig

MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

trainer = ResamplingTrainer(
config=ResamplingTrainerConfig(
training_args=LoRATrainingArgs(
num_train_epochs=20,
per_device_train_batch_size=32,
),
out_path=Path(__file__)
.parent.joinpath("../../../../artifacts/train_pipeline/mmlu/qwen_3b/student_entropy")
.as_posix(),
model_id=MODEL_NAME,
train_dataset=CausalDatasetAdapter(
dataset=MMLUReasoningResponseDataset(
config=QADatasetConfig(
path=Path(__file__).parent.joinpath("../../../../../data/source/mmlu_pro_stem.parquet").as_posix(),
dataset_id="mmlu_qwen_3b_student_entropy",
),
tokenizer=tokenizer,
)
),
save_schedule=[1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20],
),
tokenizer=tokenizer,
)