Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/core/dataset_samplers/abstract_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC

from datasets import Dataset
from pydraconf import PydraConfig

from core.datasets.abstract_dataset_adapter import abstractmethod


class AbstractDatasetSamplerConfig(PydraConfig):
top_k: int


class AbstractDatasetSampler(ABC):
def __init__(self, config: AbstractDatasetSamplerConfig):
self.config = config

@abstractmethod
def _score_row(self, row: dict) -> float: ...

def create_sample(self, ds: Dataset) -> Dataset:
df = ds.to_pandas()
df["score"] = df.apply(self._score_row, axis=1)
df = df.sort_values("score", ascending=False)

sampled_df = df.head(self.config.top_k)

sampled_ds = Dataset.from_pandas(sampled_df)
return sampled_ds
13 changes: 13 additions & 0 deletions src/core/dataset_samplers/entropy_gain_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import override

from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig


class TeacherEntropySamplerConfig(AbstractDatasetSamplerConfig):
pass


class EntropyGainSampler(AbstractDatasetSampler):
@override
def _score_row(self, row: dict) -> float:
return max(row["student_entropy"] - row["teacher_entropy"], 0)
15 changes: 15 additions & 0 deletions src/core/dataset_samplers/entropy_ratio_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import override

from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig


class TeacherEntropySamplerConfig(AbstractDatasetSamplerConfig):
pass


class EntropyGainSampler(AbstractDatasetSampler):
_EPS = 1e-8

@override
def _score_row(self, row: dict) -> float:
return row["student_entropy"] / (row["teacher_entropy"] + self._EPS)
13 changes: 13 additions & 0 deletions src/core/dataset_samplers/student_entropy_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import override

from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig


class StudentEntropySamplerConfig(AbstractDatasetSamplerConfig):
pass


class StudentEntropySampler(AbstractDatasetSampler):
@override
def _score_row(self, row: dict) -> float:
return row["student_entropy"]
13 changes: 13 additions & 0 deletions src/core/dataset_samplers/teacher_entropy_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import override

from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler, AbstractDatasetSamplerConfig


class TeacherEntropySamplerConfig(AbstractDatasetSamplerConfig):
pass


class TeacherEntropySampler(AbstractDatasetSampler):
@override
def _score_row(self, row: dict) -> float:
return row["teacher_entropy"]
7 changes: 6 additions & 1 deletion src/core/datasets/abstract_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datasets import Dataset, load_dataset, load_from_disk
from pydantic import BaseModel

from core.dataset_samplers.abstract_sampler import AbstractDatasetSampler
from core.datasets.base_dataset import BaseDataset


Expand All @@ -16,8 +17,9 @@ class TokenizedRow(BaseModel):


class AbstractDatasetAdapter[D: BaseDataset](ABC):
def __init__(self, dataset: D):
def __init__(self, dataset: D, dataset_sampler: AbstractDatasetSampler | None = None):
self.dataset = dataset
self.dataset_sampler = dataset_sampler

@abstractmethod
def process_row(self, row: dict) -> TokenizedRow: ...
Expand All @@ -36,6 +38,9 @@ def _load_ds(self, path_override: str | None = None) -> Dataset:
def process_dataset(self, path_override: str | None = None) -> Dataset:
ds = self._load_ds(path_override)

if self.dataset_sampler is not None:
ds = self.dataset_sampler.create_sample(ds)

processed_ds = ds.map(
lambda row: self.process_row(row).model_dump(),
num_proc=4,
Expand Down