diff --git a/src/core/dataset_samplers/abstract_sampler.py b/src/core/dataset_samplers/abstract_sampler.py new file mode 100644 index 0000000..0368b45 --- /dev/null +++ b/src/core/dataset_samplers/abstract_sampler.py @@ -0,0 +1,28 @@ +from abc import ABC + +from datasets import Dataset +from pydraconf import PydraConfig + +from core.datasets.abstract_dataset_adapter import abstractmethod + + +class AbstractDatasetSamplerConfig(PydraConfig): + top_k: int + + +class AbstractDatasetSampler(ABC): + def __init__(self, config: AbstractDatasetSamplerConfig): + self.config = config + + @abstractmethod + def _score_row(self, row: dict) -> float: ... + + def create_sample(self, ds: Dataset) -> Dataset: + df = ds.to_pandas() + df["score"] = df.apply(self._score_row, axis=1) + df = df.sort_values("score", ascending=False) + + sampled_df = df.head(self.config.top_k) + + sampled_ds = Dataset.from_pandas(sampled_df) + return sampled_ds diff --git a/src/core/dataset_samplers/entropy_gain_sampler.py b/src/core/dataset_samplers/entropy_gain_sampler.py new file mode 100644 index 0000000..204f0cc --- /dev/null +++ b/src/core/dataset_samplers/entropy_gain_sampler.py @@ -0,0 +1,13 @@ +from typing import override + +from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig + + +class TeacherEntropySamplerConfig(AbstractDatasetSamplerConfig): + pass + + +class EntropyGainSampler(AbstractDatasetSampler): + @override + def _score_row(self, row: dict) -> float: + return max(row["student_entropy"] - row["teacher_entropy"], 0) diff --git a/src/core/dataset_samplers/entropy_ratio_sampler.py b/src/core/dataset_samplers/entropy_ratio_sampler.py new file mode 100644 index 0000000..e4034a7 --- /dev/null +++ b/src/core/dataset_samplers/entropy_ratio_sampler.py @@ -0,0 +1,15 @@ +from typing import override + +from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig + + +class TeacherEntropySamplerConfig(AbstractDatasetSamplerConfig): + pass + + +class EntropyGainSampler(AbstractDatasetSampler): + _EPS = 1e-8 + + @override + def _score_row(self, row: dict) -> float: + return row["student_entropy"] / (row["teacher_entropy"] + self._EPS) diff --git a/src/core/dataset_samplers/student_entropy_sampler.py b/src/core/dataset_samplers/student_entropy_sampler.py new file mode 100644 index 0000000..22c2d40 --- /dev/null +++ b/src/core/dataset_samplers/student_entropy_sampler.py @@ -0,0 +1,13 @@ +from typing import override + +from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig + + +class StudentEntropySamplerConfig(AbstractDatasetSamplerConfig): + pass + + +class StudentEntropySampler(AbstractDatasetSampler): + @override + def _score_row(self, row: dict) -> float: + return row["student_entropy"] diff --git a/src/core/dataset_samplers/teacher_entropy_sampler.py b/src/core/dataset_samplers/teacher_entropy_sampler.py new file mode 100644 index 0000000..f0646fe --- /dev/null +++ b/src/core/dataset_samplers/teacher_entropy_sampler.py @@ -0,0 +1,13 @@ +from typing import override + +from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig + + +class TeacherEntropySamplerConfig(AbstractDatasetSamplerConfig): + pass + + +class TeacherEntropySampler(AbstractDatasetSampler): + @override + def _score_row(self, row: dict) -> float: + return row["teacher_entropy"] diff --git a/src/core/datasets/abstract_dataset_adapter.py b/src/core/datasets/abstract_dataset_adapter.py index 3bf81fb..de64298 100644 --- a/src/core/datasets/abstract_dataset_adapter.py +++ b/src/core/datasets/abstract_dataset_adapter.py @@ -3,6 +3,7 @@ from datasets import Dataset, load_dataset, load_from_disk from pydantic import BaseModel +from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler from core.datasets.base_dataset import BaseDataset @@ -16,8 +17,9 @@ class TokenizedRow(BaseModel): class AbstractDatasetAdapter[D: BaseDataset](ABC): - def __init__(self, dataset: D): + def __init__(self, dataset: D, dataset_sampler: AbstractDatasetSampler | None = None): self.dataset = dataset + self.dataset_sampler = dataset_sampler @abstractmethod def process_row(self, row: dict) -> TokenizedRow: ... @@ -36,6 +38,9 @@ def _load_ds(self, path_override: str | None = None) -> Dataset: def process_dataset(self, path_override: str | None = None) -> Dataset: ds = self._load_ds(path_override) + if self.dataset_sampler is not None: + ds = self.dataset_sampler.create_sample(ds) + processed_ds = ds.map( lambda row: self.process_row(row).model_dump(), num_proc=4,