diff --git a/AGENTS.md b/AGENTS.md index 61b7638..2d0bc16 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -59,6 +59,7 @@ Write `HyperBench` when referring to the project, repository, organization, or p - `make check` - **Typing:** Add and preserve type annotations. Run `make typecheck` for changes that touch typed code. - **Imports:** Keep imports at module top level unless a delayed import is necessary. Use `TYPE_CHECKING` guards for type-only or heavyweight imports. +- **Validation:** Use explicit validation functions for argument checks. Avoid `assert` and do not validate types at runtime. - **Runtime checks:** Do not use `assert` for library-facing validation. Raise explicit exceptions instead. - **Public APIs:** Avoid changing public signatures without a clear reason and matching docs/tests updates. - **Scope:** Keep changes narrow. Do not mix behavioral edits with unrelated refactors. diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 3d536a5..677609f 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -10,6 +10,8 @@ NodeSpaceFiller, NodeSpaceSetting, is_transductive_setting, + validate_node_space_setting, + validate_split_ratios, ) from hyperbench.data.hif import HIFLoader, HIFProcessor @@ -146,8 +148,9 @@ def enrich_node_features( Args: enricher: An instance of NodeEnricher to generate structural node features from hypergraph topology. enrichment_mode: How to combine generated features with existing ``hdata.x``. - ``concatenate`` appends new features as additional columns. + ``concatenate`` appends new features to the existing ones as additional columns. ``replace`` substitutes ``hdata.x`` entirely. + Defaults to ``replace`` if not provided. """ self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode) @@ -192,13 +195,15 @@ def enrich_hyperedge_attr( enricher: HyperedgeEnricher, enrichment_mode: EnrichmentMode | None = None, ) -> None: - """Enrich hyperedge features using the provided hyperedge feature enricher. + """ + Enrich hyperedge features using the provided hyperedge feature enricher. Args: - enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology. - enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``. - ``concatenate`` appends new features as additional columns. + enricher: An instance of HyperedgeEnricher to generate structural hyperedge attributes from hypergraph topology. + enrichment_mode: How to combine generated attributes with existing ``hdata.hyperedge_attr``. + ``concatenate`` appends new attributes to the existing ones as additional columns. ``replace`` substitutes ``hdata.hyperedge_attr`` entirely. + Defaults to ``replace`` if not provided. """ self.hdata = self.hdata.enrich_hyperedge_attr(enricher, enrichment_mode) @@ -207,13 +212,15 @@ def enrich_hyperedge_weights( enricher: HyperedgeEnricher, enrichment_mode: EnrichmentMode | None = None, ) -> None: - """Enrich hyperedge weights using the provided hyperedge weight enricher. + """ + Enrich hyperedge weights using the provided hyperedge weight enricher. Args: - enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology. - enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_weights``. - ``concatenate`` appends new features as additional columns. + enricher: An instance of HyperedgeEnricher to generate structural hyperedge weights from hypergraph topology. + enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. + ``concatenate`` appends new weights to the existing ones as additional columns. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. + Defaults to ``replace`` if not provided. """ self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode) @@ -325,7 +332,8 @@ def split_with_ratios( seed: int | None = None, node_space_setting: NodeSpaceSetting = "transductive", ) -> tuple[list[Dataset], list[float]]: - """Split the dataset and return the final hyperedge ratios. + """ + Split the dataset and return the final hyperedge ratios. Final ratios are computed from split hyperedge counts after ratio boundaries and any transductive rebalancing have been applied. @@ -350,12 +358,9 @@ def split_with_ratios( hyperedges, or a transductive first split cannot cover the full node space. """ - # Allow small imprecision in sum of ratios, but raise error if it's significant - # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid) - # ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError) - # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) - if abs(sum(ratios) - 1.0) > 1e-6: - raise ValueError(f"Split ratios must sum to 1.0, got {sum(ratios)}.") + validate_node_space_setting(node_space_setting) + validate_split_ratios(ratios) + device = self.hdata.device hyperedge_splitter = HyperedgeIDSplitter(self.hdata) diff --git a/hyperbench/data/enricher.py b/hyperbench/data/enricher.py index 5f88802..c630802 100644 --- a/hyperbench/data/enricher.py +++ b/hyperbench/data/enricher.py @@ -1,6 +1,6 @@ -import warnings import random import torch +import warnings from abc import ABC, abstractmethod from torch import Tensor, optim @@ -8,6 +8,13 @@ from torch_geometric.nn import Node2Vec as PyGNode2Vec from hyperbench.types import EdgeIndex, HyperedgeIndex from hyperbench.models import VilLain +from hyperbench.utils import ( + validate_is_between, + validate_is_finite, + validate_is_finite_when_provided, + validate_is_non_negative, + validate_is_positive, +) EnrichmentMode: TypeAlias = Literal["concatenate", "replace"] @@ -63,6 +70,8 @@ def __init__( self.weight_decay = weight_decay self.verbose = verbose + self.__validate() + def _empty_features(self, hyperedge_index: Tensor) -> Tensor: """ Return an empty feature matrix on the same device as ``hyperedge_index``. @@ -147,6 +156,28 @@ def _train(self, hyperedge_index: Tensor): return model + def __validate(self) -> None: + validate_is_positive("num_features", self.embedding_dim) + validate_is_non_negative("num_nodes", self.num_nodes) + validate_is_non_negative("num_hyperedges", self.num_hyperedges) + + if self.labels_per_subspace < 2: + raise ValueError( + f"'labels_per_subspace' must be at least 2, got {self.labels_per_subspace}." + ) + + validate_is_positive("training_steps", self.training_steps) + validate_is_positive("generation_steps", self.generation_steps) + validate_is_finite("tau", self.tau) + validate_is_positive("tau", self.tau) + validate_is_finite("eps", self.eps) + validate_is_positive("eps", self.eps) + validate_is_positive("num_epochs", self.num_epochs) + validate_is_positive("learning_rate", self.learning_rate) + validate_is_non_negative("weight_decay", self.weight_decay) + validate_is_finite("learning_rate", self.learning_rate) + validate_is_finite("weight_decay", self.weight_decay) + class Enricher(ABC): """ @@ -322,8 +353,9 @@ def __init__( beta: float | None = None, ): super().__init__(cache_dir=cache_dir) - if alpha < 0.0 or alpha > 1.0: - raise ValueError("Alpha must be between 0.0 and 1.0.") + + validate_is_between("alpha", alpha, 0.0, 1.0) + validate_is_finite_when_provided("beta", beta) self.alpha = alpha self.beta = beta @@ -407,12 +439,6 @@ def __init__( verbose: bool = False, ): super().__init__(cache_dir=cache_dir) - if walk_length < context_size: - raise ValueError( - f"Expected walk_length >= context_size, got " - f"walk_length={walk_length}, context_size={context_size}." - ) - self.embedding_dim = num_features self.walk_length = walk_length self.context_size = context_size @@ -428,6 +454,8 @@ def __init__( self.sparse = sparse self.verbose = verbose + self.__validate() + def enrich(self, hyperedge_index: Tensor) -> Tensor: """ Compute Node2Vec embeddings from the clique expansion of the hypergraph. @@ -519,6 +547,28 @@ def enrich(self, hyperedge_index: Tensor) -> Tensor: # Detach node embeddings from computation graph and return them return x.detach().to(device) + def __validate(self) -> None: + validate_is_positive("num_features", self.embedding_dim) + validate_is_positive("walk_length", self.walk_length) + validate_is_positive("context_size", self.context_size) + if self.walk_length < self.context_size: + raise ValueError( + "Expected walk_length >= context_size, got " + f"walk_length={self.walk_length}, context_size={self.context_size}." + ) + + validate_is_positive("num_walks_per_node", self.num_walks_per_node) + validate_is_finite("p", self.p) + validate_is_positive("p", self.p) + validate_is_finite("q", self.q) + validate_is_positive("q", self.q) + validate_is_positive("num_negative_samples", self.num_negative_samples) + validate_is_non_negative("num_nodes", self.num_nodes) + validate_is_positive("num_epochs", self.num_epochs) + validate_is_finite("learning_rate", self.learning_rate) + validate_is_positive("learning_rate", self.learning_rate) + validate_is_positive("batch_size", self.batch_size) + class LaplacianPositionalEncodingEnricher(NodeEnricher): """ @@ -540,6 +590,10 @@ def __init__( cache_dir: str | None = None, ): super().__init__(cache_dir=cache_dir) + + validate_is_positive("num_features", num_features) + validate_is_non_negative("num_nodes", num_nodes) + self.num_features = num_features self.num_nodes = num_nodes diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 60036f4..276842d 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -74,6 +74,9 @@ def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData: # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} + if len(node_id_to_idx) != num_nodes: + raise ValueError("HIF node IDs must be unique.") + # Initialize edge_set only with edges that have incidences, so that # we avoid inflating edge count due to isolated nodes/missing incidences hyperedge_id_to_idx: dict[Any, int] = {} @@ -84,6 +87,11 @@ def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData: for incidence in hypergraph.incidences: node_id = incidence.get("node", 0) hyperedge_id = incidence.get("edge", 0) + if node_id not in node_id_to_idx: + raise ValueError( + f"Incidence references unknown node id {node_id!r}; " + "all incidence nodes must be declared in the HIF nodes list." + ) if hyperedge_id not in hyperedge_id_to_idx: # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences @@ -152,8 +160,7 @@ def __process_hyperedge_attr( hyperedge_id_to_idx: dict[Any, int], num_hyperedges: int, ) -> Tensor | None: - # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] - hyperedge_attr = None + hyperedge_attr = None # shape [num_hyperedges, num_hyperedge_attributes] has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 has_any_hyperedge_attrs = has_hyperedges and any( "attrs" in edge for edge in hypergraph.hyperedges @@ -234,7 +241,7 @@ def __process_hyperedge_weights( edge_attrs = hyperedge_id_to_attrs.get(edge_id, {}) weights.append(float(edge_attrs.get("weight", 1.0))) - return torch.tensor(weights, dtype=torch.float) + return torch.tensor(weights, dtype=torch.float) # shape [num_hyperedges,] class HIFLoader: diff --git a/hyperbench/data/loader.py b/hyperbench/data/loader.py index 9c70553..b99a5a9 100644 --- a/hyperbench/data/loader.py +++ b/hyperbench/data/loader.py @@ -83,10 +83,7 @@ def collate(self, batch: list[HData]) -> HData: collated_x = self.__cached_dataset_hdata.x[node_ids] collated_y = self.__cached_dataset_hdata.y[hyperedge_ids] - - collated_global_node_ids = None - if self.__cached_dataset_hdata.global_node_ids is not None: - collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids] + collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids] collated_hyperedge_attr = None if self.__cached_dataset_hdata.hyperedge_attr is not None: diff --git a/hyperbench/data/negative_sampling_scheduler.py b/hyperbench/data/negative_sampling_scheduler.py index d76d352..68c0fa8 100644 --- a/hyperbench/data/negative_sampling_scheduler.py +++ b/hyperbench/data/negative_sampling_scheduler.py @@ -1,15 +1,13 @@ -from enum import Enum -from typing import Any +from typing import Any, Literal, TypeAlias from hyperbench.types import HData from hyperbench.data import NegativeSampler -class NegativeSamplingSchedule(Enum): - """When to run negative sampling during training.""" - - FIRST_EPOCH = "first_epoch" # Only at epoch 0, cached for all subsequent epochs - EVERY_N_EPOCHS = "every_n_epochs" # Every N epochs (N provided separately) - EVERY_EPOCH = "every_epoch" # Negatives generated every epoch +NegativeSamplingSchedule: TypeAlias = Literal[ + "first_epoch", # Only at epoch 0, cached for all subsequent epochs + "every_n_epochs", # Every N epochs (N provided separately) + "every_epoch", # Negatives generated every epoch +] class NegativeSamplingScheduler: @@ -21,14 +19,15 @@ class NegativeSamplingScheduler: Args: negative_sampler: An instance of a ``NegativeSampler`` that defines how to sample negatives. - negative_sampling_schedule: An instance of ``NegativeSamplingSchedule`` that specifies the schedule for sampling negatives. - negative_sampling_every_n: An integer specifying the interval for sampling negatives when the schedule is set to ``EVERY_N_EPOCHS``. This parameter is ignored for other schedules. + negative_sampling_schedule: Literal string specifying the schedule for sampling negatives. + negative_sampling_every_n: An integer specifying the interval for sampling negatives + when the schedule is set to ``"every_n_epochs"``. This parameter is ignored for other schedules. """ def __init__( self, negative_sampler: NegativeSampler, - negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule: NegativeSamplingSchedule = "every_epoch", negative_sampling_every_n: int = 1, ) -> None: self.negative_sampler = negative_sampler @@ -56,13 +55,24 @@ def should_sample(self, epoch: int) -> bool: Returns: should_sample: True if negatives should be resampled for the current epoch, False otherwise. """ + if epoch < 0: + raise ValueError(f"Epoch must be non-negative, got {epoch}.") + match self.negative_sampling_schedule: - case NegativeSamplingSchedule.EVERY_N_EPOCHS: + case "every_n_epochs": + if self.negative_sampling_every_n <= 0: + raise ValueError( + f"negative_sampling_every_n must be positive, got {self.negative_sampling_every_n}." + ) return epoch % self.negative_sampling_every_n == 0 - case NegativeSamplingSchedule.FIRST_EPOCH: + case "first_epoch": return epoch == 0 - case _: # Defaults to NegativeSamplingSchedule.EVERY_EPOCH + case "every_epoch": return True + case _: + raise ValueError( + f"Unsupported negative sampling schedule: {self.negative_sampling_schedule!r}." + ) def sample(self, batch: HData, epoch: int) -> HData: """ diff --git a/hyperbench/data/sampler.py b/hyperbench/data/sampler.py index be496e9..e855b65 100644 --- a/hyperbench/data/sampler.py +++ b/hyperbench/data/sampler.py @@ -49,6 +49,7 @@ def _normalize_index(self, index: int | list[int], size: int) -> list[int]: Raises: ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of sampleable items). + TypeError: If the index is not an integer or a list of integers. """ if isinstance(index, list): if len(index) < 1: @@ -57,7 +58,15 @@ def _normalize_index(self, index: int | list[int], size: int) -> list[int]: raise ValueError( f"Index list length ({len(index)}) cannot exceed the number of sampleable items ({size})." ) + for id in index: + if not isinstance(id, int) or isinstance(id, bool): + raise TypeError("Index list must contain only integers.") + return list(set(index)) + + if not isinstance(index, int) or isinstance(index, bool): + raise TypeError("Index must be an integer or a list of integers.") + return [index] def _sample_hyperedge_index( @@ -244,10 +253,15 @@ def create_sampler_from_strategy(strategy: SamplingStrategy) -> BaseSampler: strategy: An instance of SamplingStrategy enum indicating which sampling strategy to use. Returns: - sampler: An instance of a subclass of BaseSampler corresponding to the specified strategy. If strategy is not recognized, defaults to ``HyperedgeSampler``. + sampler: An instance of a subclass of BaseSampler corresponding to the specified strategy. + + Raises: + ValueError: If ``strategy`` is not a supported `SamplingStrategy`. """ match strategy: case SamplingStrategy.NODE: return NodeSampler() - case _: + case SamplingStrategy.HYPEREDGE: return HyperedgeSampler() + case _: + raise ValueError(f"Unsupported sampling strategy: {strategy!r}.") diff --git a/hyperbench/data/splitter.py b/hyperbench/data/splitter.py index 0bd26dc..d4386dd 100644 --- a/hyperbench/data/splitter.py +++ b/hyperbench/data/splitter.py @@ -4,6 +4,11 @@ from typing import cast from torch import Tensor from hyperbench.types import HData +from hyperbench.utils import ( + create_seeded_torch_generator, + validate_is_non_empty, + validate_split_ratios, +) class Splitter(ABC): @@ -60,6 +65,13 @@ def ensure_split_covers_all_nodes( Raises: ValueError: If one or more nodes do not appear in any hyperedge of the source hypergraph. """ + validate_is_non_empty("hyperedge_ids_by_split", hyperedge_ids_by_split) + if split_idx < 0 or split_idx >= len(hyperedge_ids_by_split): + raise ValueError( + f"split_idx must reference an existing split, got {split_idx} " + f"for {len(hyperedge_ids_by_split)} splits." + ) + required_node_ids = torch.arange(self.hdata.num_nodes, device=self.hdata.device) available_node_ids = self.hdata.hyperedge_index[0].unique() missing_from_hypergraph_mask = torch.logical_not( @@ -134,9 +146,7 @@ def get_hyperedge_ids_permutation(self, shuffle: bool | None, seed: int | None) # Shuffle hyperedge IDs if shuffle is requested, otherwise keep original order for deterministic splits if shuffle: - generator = torch.Generator(device=device) - if seed is not None: - generator.manual_seed(seed) + generator = create_seeded_torch_generator(device=device, seed=seed) random_hyperedge_ids_permutation = torch.randperm( n=num_hyperedges, @@ -185,6 +195,8 @@ def split(self, to_split: Tensor, ratios: list[float]) -> tuple[list[Tensor], li hyperedge_ids_by_split: The updated hyperedge IDs for each split. ratios: The final ratios of hyperedges in each split after rebalancing. """ + validate_split_ratios(ratios) + # Cumulative floor boundaries keep early splits from over-consuming hyperedges. # The last split absorbs any rounding remainder. num_hyperedges = int(to_split.size(0)) diff --git a/hyperbench/hlp/common.py b/hyperbench/hlp/common.py index 08c7b34..870a222 100644 --- a/hyperbench/hlp/common.py +++ b/hyperbench/hlp/common.py @@ -20,8 +20,8 @@ class HlpModule(L.LightningModule): metrics: Optional ``MetricCollection`` of torchmetrics to compute during evaluation. Cloned per stage (train, val, test) for independent state accumulation. negative_sampler: Optional negative sampler. If ``None``, no negative sampling is performed. - negative_sampling_schedule: When to perform negative sampling during training. Defaults to ``EVERY_EPOCH``. - negative_sampling_every_n: If using ``EVERY_N_EPOCHS`` schedule, how many epochs between negative sampling runs. Defaults to ``1``. + negative_sampling_schedule: When to perform negative sampling during training. Defaults to ``"every_epoch"``. + negative_sampling_every_n: If using ``"every_n_epochs"`` schedule, how many epochs between negative sampling runs. Defaults to ``1``. """ def __init__( @@ -31,7 +31,7 @@ def __init__( encoder: nn.Module | None = None, metrics: MetricCollection | None = None, negative_sampler: NegativeSampler | None = None, - negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule: NegativeSamplingSchedule = "every_epoch", negative_sampling_every_n: int = 1, ): super().__init__() diff --git a/hyperbench/integration_tests/common.py b/hyperbench/integration_tests/common.py index 72cc7f7..6c47016 100644 --- a/hyperbench/integration_tests/common.py +++ b/hyperbench/integration_tests/common.py @@ -1,10 +1,10 @@ -from typing import Literal -from collections.abc import Sequence import lightning as L import torch +from collections.abc import Sequence from functools import cache from pathlib import Path +from typing import Literal from torchmetrics import MetricCollection from torchmetrics.classification import ( BinaryAUROC, @@ -13,7 +13,6 @@ BinaryPrecision, BinaryRecall, ) - from hyperbench.data import ( Dataset, DataLoader, @@ -24,27 +23,12 @@ ) from hyperbench.train import MultiModelTrainer from hyperbench.types import ModelConfig, HData -from torch import Generator +from hyperbench.utils import create_seeded_torch_generator SEED = 42 -def __create_seeded_torch_generator( - device: torch.device, - seed: int | None, -) -> Generator | None: - - if seed is None: - return None - generator = Generator(device=device) - generator.manual_seed(seed) - return generator - - -generator = __create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED) - - @cache def _cached_split_dataset( sampling_strategy: SamplingStrategy, @@ -52,6 +36,7 @@ def _cached_split_dataset( node_space_setting: Literal["transductive", "inductive"] = "transductive", ) -> tuple[Dataset, Dataset, Dataset]: if dataset is None: + generator = create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED) x = torch.randn((100, 4), generator=generator) # 100 nodes with 4 features each hyperedge_index = torch.cat( # 200 hyperedges, each connecting 5 nodes [ @@ -166,27 +151,26 @@ def loaders( batch_size: int = 1, sample_full_hypergraph: bool = False, ) -> tuple[DataLoader, DataLoader, DataLoader]: - train_loader = DataLoader( train_dataset, batch_size=batch_size, sample_full_hypergraph=sample_full_hypergraph, shuffle=False, - generator=generator, + generator=create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED), ) val_loader = DataLoader( val_dataset, batch_size=batch_size, sample_full_hypergraph=sample_full_hypergraph, shuffle=False, - generator=generator, + generator=create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED), ) tests_loader = DataLoader( test_dataset, batch_size=batch_size, sample_full_hypergraph=sample_full_hypergraph, shuffle=False, - generator=generator, + generator=create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED), ) return train_loader, val_loader, tests_loader diff --git a/hyperbench/integration_tests/data/enricher_integration_test.py b/hyperbench/integration_tests/data/enricher_integration_test.py index 6f0f757..095e2b4 100644 --- a/hyperbench/integration_tests/data/enricher_integration_test.py +++ b/hyperbench/integration_tests/data/enricher_integration_test.py @@ -3,15 +3,15 @@ from hyperbench.data import ( list_datasets, get_dataset_by_name, - Node2VecEnricher, FillValueHyperedgeAttrsEnricher, + LaplacianPositionalEncodingEnricher, + Node2VecEnricher, VilLainHyperedgeAttrsEnricher, VilLainEnricher, ) from hyperbench.integration_tests.common import ( split_dataset, ) -from hyperbench.data import LaplacianPositionalEncodingEnricher excluded_dataset = [ @@ -78,7 +78,7 @@ def test_n2v_node_enricher(dataset_name): walk_length=5, num_walks_per_node=2, num_negative_samples=1, - num_nodes=dataset.hdata.num_nodes, + num_nodes=to_enrich_dataset.hdata.num_nodes, num_epochs=3, learning_rate=0.01, batch_size=128, @@ -104,8 +104,8 @@ def test_villain_node_enricher(dataset_name): to_enrich_dataset.enrich_node_features( enricher=VilLainEnricher( num_features=NUM_FEATURES, - num_nodes=dataset.hdata.num_nodes, - num_hyperedges=dataset.hdata.num_hyperedges, + num_nodes=to_enrich_dataset.hdata.num_nodes, + num_hyperedges=to_enrich_dataset.hdata.num_hyperedges, labels_per_subspace=2, training_steps=2, generation_steps=4, @@ -150,8 +150,8 @@ def test_villain_hyperedge_attrs_enricher(dataset_name): to_enrich_dataset.enrich_hyperedge_attr( enricher=VilLainHyperedgeAttrsEnricher( num_features=NUM_FEATURES, - num_nodes=dataset.hdata.num_nodes, - num_hyperedges=dataset.hdata.num_hyperedges, + num_nodes=to_enrich_dataset.hdata.num_nodes, + num_hyperedges=to_enrich_dataset.hdata.num_hyperedges, labels_per_subspace=2, training_steps=2, generation_steps=4, diff --git a/hyperbench/models/nhp.py b/hyperbench/models/nhp.py index 549f41a..89de057 100644 --- a/hyperbench/models/nhp.py +++ b/hyperbench/models/nhp.py @@ -188,8 +188,10 @@ def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: max_embeddings = incidence_aggregator.pool("max") min_embeddings = incidence_aggregator.pool("min") hyperedge_embeddings = max_embeddings - min_embeddings - case _: + case "mean": hyperedge_embeddings = incidence_aggregator.pool("mean") + case _: + raise ValueError(f"Invalid aggregation method: {self.aggregation}") # Decode: linear projection to scalar score per hyperedge # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,) diff --git a/hyperbench/models/node2vec.py b/hyperbench/models/node2vec.py index 52ef52b..b7d40d8 100644 --- a/hyperbench/models/node2vec.py +++ b/hyperbench/models/node2vec.py @@ -56,7 +56,7 @@ def __init__( super().__init__() if walk_length < context_size: raise ValueError( - f"Expected walk_length >= context_size, got " + "Expected walk_length >= context_size, got " f"walk_length={walk_length}, context_size={context_size}." ) diff --git a/hyperbench/nn/aggregator.py b/hyperbench/nn/aggregator.py index 7bfb617..8be3808 100644 --- a/hyperbench/nn/aggregator.py +++ b/hyperbench/nn/aggregator.py @@ -1,7 +1,6 @@ from torch import Tensor from typing import Literal from torch_geometric.utils import scatter - from hyperbench.types import HyperedgeIndex from hyperbench.utils import maxmin_scatter diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index a042691..3a23a76 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -2,7 +2,7 @@ import torch import re -from typing import cast +from typing import Any, cast from unittest.mock import patch, MagicMock from hyperbench.types import HData from hyperbench.data import ( @@ -28,7 +28,7 @@ def mock_hdata() -> HData: def mock_hdata_with_hyperedge_attr() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_attr = torch.ones((3, 1), dtype=torch.float) + hyperedge_attr = torch.ones((2, 1), dtype=torch.float) return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) @@ -36,7 +36,7 @@ def mock_hdata_with_hyperedge_attr() -> HData: def mock_hdata_with_hyperedge_weights() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0], dtype=torch.float) return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) @@ -47,14 +47,6 @@ def mock_hdata_isolated_hyperedges() -> HData: return HData(x=x, hyperedge_index=hyperedge_index) -@pytest.fixture -def mock_hdata_three_nodes_weighted() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([[1.0], [2.0]], dtype=torch.float) - return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - - @pytest.fixture def mock_hdata_four_nodes() -> HData: x = torch.ones((4, 1), dtype=torch.float) @@ -75,19 +67,12 @@ def mock_hdata_transductive_split() -> HData: return HData(x=x, hyperedge_index=hyperedge_index) -@pytest.fixture -def mock_hdata_no_hyperedge_attr() -> HData: - x = torch.ones((2, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) - return HData(x=x, hyperedge_index=hyperedge_index) - - @pytest.fixture def mock_hdata_multiple_hyperedge_attrs() -> HData: x = torch.ones((4, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) - hyperedge_weights = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) - hyperedge_attr = torch.ones((4, 1), dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) + hyperedge_attr = torch.ones((3, 1), dtype=torch.float) return HData( x=x, hyperedge_index=hyperedge_index, @@ -106,7 +91,7 @@ def mock_hdata_transductive_multiple_hyperedges_attrs() -> HData: ], dtype=torch.long, ) - hyperedge_weights = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) hyperedge_attr = torch.ones((3, 1), dtype=torch.float) return HData( x=x, @@ -116,14 +101,6 @@ def mock_hdata_transductive_multiple_hyperedges_attrs() -> HData: ) -@pytest.fixture -def mock_hdata_two_hyperedge_attrs_weighted() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([[1.0, 2.0], [3.0, 0.1]], dtype=torch.float) - return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - - @pytest.fixture def mock_negative_sampler() -> tuple[NegativeSampler, MagicMock]: sampler = new_mock_negative_sampler() @@ -168,31 +145,40 @@ def test_dataset_process_no_incidences(mock_hdata_isolated_hyperedges): assert dataset.hdata.hyperedge_attr is None -def test_dataset_process_with_edge_attributes(mock_hdata_two_hyperedge_attrs_weighted): - with patch.object( - HIFLoader, "load_by_name", return_value=mock_hdata_two_hyperedge_attrs_weighted - ): +def test_dataset_process_with_hyperedge_attributes(mock_hdata_with_hyperedge_attr): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_attr): dataset = AlgebraDataset() assert dataset.hdata is not None - assert dataset.hdata.x.shape[0] == 3 + assert dataset.hdata.hyperedge_attr is not None assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 3 + assert torch.allclose(dataset.hdata.hyperedge_attr, torch.ones((2, 1), dtype=torch.float)) + + +def test_dataset_process_without_hyperedge_attributes(mock_hdata): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): + dataset = AlgebraDataset() + + assert dataset.hdata is not None assert dataset.hdata.hyperedge_attr is None + + +def test_dataset_process_with_hyperedge_weights(mock_hdata_with_hyperedge_weights): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_weights): + dataset = AlgebraDataset() + + assert dataset.hdata is not None assert dataset.hdata.hyperedge_weights is not None - assert dataset.hdata.hyperedge_weights.shape == (2, 2) - assert torch.allclose(dataset.hdata.hyperedge_weights[0], torch.tensor([1.0, 2.0])) - assert torch.allclose(dataset.hdata.hyperedge_weights[1], torch.tensor([3.0, 0.1])) + assert dataset.hdata.hyperedge_weights.shape == (2,) + assert torch.allclose(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 2.0])) -def test_dataset_process_without_edge_attributes(mock_hdata_no_hyperedge_attr): - with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_hyperedge_attr): +def test_dataset_process_without_hyperedge_weights(mock_hdata): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): dataset = AlgebraDataset() assert dataset.hdata is not None - assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 2 - assert dataset.hdata.hyperedge_attr is None + assert dataset.hdata.hyperedge_weights is None def test_dataset_process_hyperedge_index_in_correct_format(mock_hdata_four_nodes): @@ -319,14 +305,17 @@ def test_getitem_when_list_index_provided( pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_with_edge_attr(mock_hdata_three_nodes_weighted, strategy): - with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_three_nodes_weighted): +def test_getitem_with_hyperedge_attr(mock_hdata_with_hyperedge_attr, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_attr): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] assert data.hyperedge_index.shape == (2, 2) assert data.num_hyperedges == 1 + + # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None + # as the hyperedge attributes are handled by the loader's collate function during batching assert data.hyperedge_attr is None @@ -337,8 +326,8 @@ def test_getitem_with_edge_attr(mock_hdata_three_nodes_weighted, strategy): pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_without_edge_attr(mock_hdata_no_hyperedge_attr, strategy): - with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_hyperedge_attr): +def test_getitem_without_hyperedge_attr(mock_hdata, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] @@ -354,7 +343,7 @@ def test_getitem_without_edge_attr(mock_hdata_no_hyperedge_attr, strategy): pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), ], ) -def test_getitem_with_multiple_edges_attr(mock_hdata_multiple_hyperedge_attrs, strategy, index): +def test_getitem_with_multiple_hyperedge_attr(mock_hdata_multiple_hyperedge_attrs, strategy, index): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_multiple_hyperedge_attrs): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -366,6 +355,42 @@ def test_getitem_with_multiple_edges_attr(mock_hdata_multiple_hyperedge_attrs, s assert data.hyperedge_attr is None +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_with_hyperedge_weights(mock_hdata_with_hyperedge_weights, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_weights): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[0] + + assert data.hyperedge_index.shape == (2, 2) + assert data.num_hyperedges == 1 + + # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_weights as None + # as the hyperedge weights are handled by the loader's collate function during batching + assert data.hyperedge_weights is None + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_without_hyperedge_weights(mock_hdata, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[0] + assert data.hyperedge_weights is None + + @pytest.mark.parametrize( "strategy, expected_len", [ @@ -450,7 +475,7 @@ def test_enrich_hyperedge_attr_replace(mock_hdata): dataset = Dataset.from_hdata(mock_hdata) enricher = MagicMock(spec=HyperedgeEnricher) - enriched_x = torch.randn(3, 4) + enriched_x = torch.randn(2, 4) enricher.enrich.return_value = enriched_x dataset.enrich_hyperedge_attr(enricher) @@ -468,7 +493,7 @@ def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): original_hyperedge_attr = original_hyperedge_attr.clone() enricher = MagicMock(spec=HyperedgeEnricher) - enriched_x = torch.randn(3, 4) + enriched_x = torch.randn(2, 4) enricher.enrich.return_value = enriched_x dataset.enrich_hyperedge_attr(enricher, enrichment_mode="concatenate") @@ -478,14 +503,14 @@ def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): hyperedge_attr = dataset.hdata.hyperedge_attr assert hyperedge_attr is not None assert torch.equal(hyperedge_attr, expected_x) - assert hyperedge_attr.shape == (3, 5) # 1 original + 4 enriched + assert hyperedge_attr.shape == (2, 5) # 1 original + 4 enriched def test_enrich_hyperedge_weights_replace(mock_hdata): dataset = Dataset.from_hdata(mock_hdata) enricher = MagicMock(spec=HyperedgeEnricher) - enriched_weights = torch.randn(3) + enriched_weights = torch.randn(2) enricher.enrich.return_value = enriched_weights dataset.enrich_hyperedge_weights(enricher) @@ -496,24 +521,27 @@ def test_enrich_hyperedge_weights_replace(mock_hdata): assert torch.equal(hyperedge_weights, enriched_weights) -def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights): +def test_enrich_hyperedge_weights_concatenate( + mock_hdata_with_hyperedge_weights, +): dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_weights) - original_weights = dataset.hdata.hyperedge_weights - assert original_weights is not None - original_weights = original_weights.clone() + dataset.hdata.hyperedge_index = torch.tensor([[0, 1, 2, 0], [0, 0, 1, 2]]) + dataset.hdata.num_hyperedges = 3 + dataset.hdata.y = torch.ones(3, dtype=torch.float) enricher = MagicMock(spec=HyperedgeEnricher) - enriched_weights = torch.randn(3) + enriched_weights = torch.tensor([3.0]) enricher.enrich.return_value = enriched_weights dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") - enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_weights.hyperedge_index) - expected_weights = torch.cat([original_weights, enriched_weights], dim=0) - hyperedge_weights = dataset.hdata.hyperedge_weights - assert hyperedge_weights is not None - assert torch.equal(hyperedge_weights, expected_weights) - assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched + enricher.enrich.assert_called_once() + enriched_hyperedge_index = enricher.enrich.call_args.args[0] + + assert torch.equal(enriched_hyperedge_index, dataset.hdata.hyperedge_index) + assert dataset.hdata.hyperedge_weights is not None + assert torch.equal(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 2.0, 3.0])) + assert dataset.hdata.hyperedge_weights.shape == (3,) @pytest.mark.parametrize( @@ -610,10 +638,10 @@ def test_split_three_way(mock_hdata_multiple_hyperedge_attrs): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.25, 0.25], node_space_setting="inductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 3 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -632,10 +660,10 @@ def test_split_transductive_three_way( dataset = AlgebraDataset() splits = dataset.split([0.5, 0.25, 0.25], node_space_setting="transductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 3 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -662,10 +690,43 @@ def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() - with pytest.raises(ValueError, match=re.escape("Split ratios must sum to 1.0")): + with pytest.raises(ValueError, match=re.escape("'ratios' must sum to 1.0")): dataset.split([0.8, 0.1, 0.05]) +@pytest.mark.parametrize( + "ratios, expected_exception, expected_message", + [ + pytest.param([], ValueError, "'ratios' cannot be empty.", id="empty"), + pytest.param([0.5, 0.0, 0.5], ValueError, "'ratios' must be positive, got 0.0.", id="zero"), + pytest.param( + [0.5, float("inf")], ValueError, "'ratios' must be finite, got inf.", id="infinite" + ), + ], +) +def test_split_validates_ratio_values( + mock_hdata_four_nodes, ratios, expected_exception, expected_message +): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): + dataset = AlgebraDataset() + + with pytest.raises(expected_exception, match=re.escape(expected_message)): + dataset.split(cast(Any, ratios)) + + +def test_split_raises_on_invalid_node_space_setting(mock_hdata_four_nodes): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): + dataset = AlgebraDataset() + + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', got 'semi'." + ), + ): + dataset.split([0.5, 0.5], node_space_setting=cast(Any, "semi")) + + def test_split_raises_when_a_split_has_zero_hyperedges(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() @@ -737,10 +798,10 @@ def test_split_with_shuffle_when_no_seed_provided( dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5], shuffle=True, node_space_setting="inductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 2 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -759,10 +820,10 @@ def test_split_transductive_with_shuffle_when_no_seed_provided( dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5], shuffle=True, node_space_setting="transductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 2 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -770,7 +831,7 @@ def test_split_transductive_with_shuffle_when_no_seed_provided( assert split.hdata.num_hyperedges > 0 -def test_split_preserves_edge_attr(mock_hdata_multiple_hyperedge_attrs): +def test_split_preserves_hyperedge_attr(mock_hdata_multiple_hyperedge_attrs): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_multiple_hyperedge_attrs): dataset = AlgebraDataset() @@ -781,7 +842,7 @@ def test_split_preserves_edge_attr(mock_hdata_multiple_hyperedge_attrs): assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges -def test_split_transductive_preserves_edge_attr( +def test_split_transductive_preserves_hyperedge_attr( mock_hdata_transductive_multiple_hyperedges_attrs, ): with patch.object( @@ -798,7 +859,7 @@ def test_split_transductive_preserves_edge_attr( assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges -def test_split_without_edge_attr(mock_hdata_four_nodes): +def test_split_without_hyperedge_attr(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() @@ -808,7 +869,7 @@ def test_split_without_edge_attr(mock_hdata_four_nodes): assert split.hdata.hyperedge_attr is None -def test_split_transductive_without_edge_attr(mock_hdata_transductive_split): +def test_split_transductive_without_hyperedge_attr(mock_hdata_transductive_split): with patch.object( HIFLoader, "load_by_name", @@ -822,6 +883,58 @@ def test_split_transductive_without_edge_attr(mock_hdata_transductive_split): assert split.hdata.hyperedge_attr is None +def test_split_preserves_hyperedge_weights(mock_hdata_multiple_hyperedge_attrs): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_multiple_hyperedge_attrs): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="inductive") + + for split in splits: + assert split.hdata.hyperedge_weights is not None + assert split.hdata.hyperedge_weights.shape[0] == split.hdata.num_hyperedges + + +def test_split_transductive_preserves_hyperedge_weights( + mock_hdata_transductive_multiple_hyperedges_attrs, +): + with patch.object( + HIFLoader, + "load_by_name", + return_value=mock_hdata_transductive_multiple_hyperedges_attrs, + ): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="transductive") + + for split in splits: + assert split.hdata.hyperedge_weights is not None + assert split.hdata.hyperedge_weights.shape[0] == split.hdata.num_hyperedges + + +def test_split_without_hyperedge_weights(mock_hdata_four_nodes): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="inductive") + + for split in splits: + assert split.hdata.hyperedge_weights is None + + +def test_split_transductive_without_hyperedge_weights(mock_hdata_transductive_split): + with patch.object( + HIFLoader, + "load_by_name", + return_value=mock_hdata_transductive_split, + ): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="transductive") + + for split in splits: + assert split.hdata.hyperedge_weights is None + + def test_to_device(mock_hdata): device = torch.device("cpu") @@ -1078,30 +1191,6 @@ def test_enrich_node_features_from_dataset(): assert torch.equal(target_dataset.hdata.x, torch.tensor([[3.0, 30.0], [1.0, 10.0]])) -def test_enrich_node_features_from_propagates_hdata_validation_errors(): - source_dataset = Dataset.from_hdata( - HData( - x=torch.tensor([[1.0], [2.0]]), - hyperedge_index=torch.tensor([[0, 1], [0, 0]]), - global_node_ids=torch.tensor([10, 20]), - ) - ) - target_dataset = Dataset.from_hdata( - HData( - x=torch.tensor([[0.0]]), - hyperedge_index=torch.tensor([[0], [0]]), - global_node_ids=torch.tensor([10]), - ) - ) - target_dataset.hdata.global_node_ids = None - - with pytest.raises( - ValueError, - match=re.escape("Both HData instances must define global_node_ids to align node features."), - ): - target_dataset.enrich_node_features_from(source_dataset) - - def test_enrich_node_features_from_dataset_with_fill_value(): source_dataset = Dataset.from_hdata( HData( diff --git a/hyperbench/tests/data/enricher_test.py b/hyperbench/tests/data/enricher_test.py index 40800e8..8ee7f5e 100644 --- a/hyperbench/tests/data/enricher_test.py +++ b/hyperbench/tests/data/enricher_test.py @@ -3,6 +3,7 @@ import re from pathlib import Path +from collections.abc import Callable from unittest.mock import patch from torch import Tensor from hyperbench.data import ( @@ -17,7 +18,6 @@ VilLainEnricher, VilLainHyperedgeAttrsEnricher, ) - from hyperbench.data.enricher import Enricher, _VilLainTrainer from hyperbench.tests.mock.mock import new_mock_pyg_node2vec, new_mock_villain @@ -92,10 +92,15 @@ def test_fill_value_hyperedge_attrs_enricher_returns_empty_attrs_for_empty_input ], ) def test_ab_hyperedge_weights_enricher_rejects_invalid_alpha(alpha: float) -> None: - with pytest.raises(ValueError, match=re.escape("Alpha must be between 0.0 and 1.0.")): + with pytest.raises(ValueError, match=re.escape("'alpha' must be between 0.0 and 1.0")): ABHyperedgeWeightsEnricher(alpha=alpha) +def test_ab_hyperedge_weights_enricher_rejects_non_finite_beta() -> None: + with pytest.raises(ValueError, match=re.escape("'beta' must be finite when provided")): + ABHyperedgeWeightsEnricher(beta=float("inf")) + + def test_ab_hyperedge_weights_enricher_counts_nodes_per_hyperedge( mock_two_hyperedge_index: Tensor, ) -> None: @@ -137,6 +142,74 @@ def test_node2vec_enricher_rejects_context_larger_than_walk_length() -> None: Node2VecEnricher(num_features=4, walk_length=2, context_size=3) +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: Node2VecEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, walk_length=0), + "'walk_length' must be positive", + id="walk_length", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, context_size=0), + "'context_size' must be positive", + id="context_size", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_walks_per_node=0), + "'num_walks_per_node' must be positive", + id="walks_per_node", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, p=0.0), + "'p' must be positive", + id="p", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, q=0.0), + "'q' must be positive", + id="q", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_negative_samples=0), + "'num_negative_samples' must be positive", + id="negative_samples", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_epochs=0), + "'num_epochs' must be positive", + id="epochs", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, learning_rate=float("nan")), + "'learning_rate' must be finite", + id="learning_rate_finite", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, batch_size=0), + "'batch_size' must be positive", + id="batch_size", + ), + ], +) +def test_node2vec_enricher_rejects_invalid_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_node2vec_enricher_returns_empty_features_when_no_nodes() -> None: hyperedge_index = torch.zeros((2, 0), dtype=torch.long) enricher = Node2VecEnricher(num_features=3) @@ -283,6 +356,29 @@ def test_laplacian_positional_encoding_enricher_uses_explicit_num_nodes() -> Non assert result.shape == (4, 2) +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: LaplacianPositionalEncodingEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: LaplacianPositionalEncodingEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + ], +) +def test_laplacian_positional_encoding_enricher_rejects_invalid_semantic_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_villain_trainer_resolves_explicit_and_inferred_counts( mock_two_hyperedge_index: Tensor, ) -> None: @@ -302,6 +398,94 @@ def test_villain_trainer_falls_back_to_inferred_counts( assert trainer._num_hyperedges(mock_two_hyperedge_index) == 2 +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: VilLainEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, num_hyperedges=-1), + "'num_hyperedges' must be non-negative", + id="num_hyperedges", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, labels_per_subspace=1), + "'labels_per_subspace' must be at least 2", + id="labels_per_subspace", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, training_steps=0), + "'training_steps' must be positive", + id="training_steps", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, generation_steps=0), + "'generation_steps' must be positive", + id="generation_steps", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, tau=float("nan")), + "'tau' must be finite", + id="tau_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, tau=0.0), + "'tau' must be positive", + id="tau_positive", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, eps=float("nan")), + "'eps' must be finite", + id="eps_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, eps=0.0), + "'eps' must be positive", + id="eps_positive", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, num_epochs=0), + "'num_epochs' must be positive", + id="num_epochs", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, learning_rate=float("nan")), + "'learning_rate' must be finite", + id="learning_rate_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, learning_rate=0.0), + "'learning_rate' must be positive", + id="learning_rate_positive", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, weight_decay=float("nan")), + "'weight_decay' must be finite", + id="weight_decay_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, weight_decay=-0.1), + "'weight_decay' must be non-negative", + id="weight_decay_non_negative", + ), + ], +) +def test_villain_node_enricher_rejects_invalid_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_villain_node_enricher_returns_empty_features_when_no_nodes() -> None: hyperedge_index = torch.zeros((2, 0), dtype=torch.long) enricher = VilLainEnricher(num_features=3) @@ -313,6 +497,94 @@ def test_villain_node_enricher_returns_empty_features_when_no_nodes() -> None: assert result.device == hyperedge_index.device +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, num_hyperedges=-1), + "'num_hyperedges' must be non-negative", + id="num_hyperedges", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, labels_per_subspace=1), + "'labels_per_subspace' must be at least 2", + id="labels_per_subspace", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, training_steps=0), + "'training_steps' must be positive", + id="training_steps", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, generation_steps=0), + "'generation_steps' must be positive", + id="generation_steps", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, tau=float("nan")), + "'tau' must be finite", + id="tau_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, tau=0.0), + "'tau' must be positive", + id="tau_positive", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, eps=float("nan")), + "'eps' must be finite", + id="eps_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, eps=0.0), + "'eps' must be positive", + id="eps_positive", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, num_epochs=0), + "'num_epochs' must be positive", + id="num_epochs", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, learning_rate=float("nan")), + "'learning_rate' must be finite", + id="learning_rate_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, learning_rate=0.0), + "'learning_rate' must be positive", + id="learning_rate_positive", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, weight_decay=float("nan")), + "'weight_decay' must be finite", + id="weight_decay_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, weight_decay=-0.1), + "'weight_decay' must be non-negative", + id="weight_decay_non_negative", + ), + ], +) +def test_villain_hyperedge_attrs_enricher_rejects_invalid_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_villain_hyperedge_attrs_enricher_returns_empty_attrs_when_no_hyperedges() -> None: hyperedge_index = torch.zeros((2, 0), dtype=torch.long) enricher = VilLainHyperedgeAttrsEnricher(num_features=3) diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index 07f7769..6658d88 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -202,6 +202,30 @@ def test_transform_node_attrs_adds_padding_zero_when_attr_keys_padding(): assert torch.allclose(result, torch.tensor([0.0, 2.5])) +def test_process_hypergraph_rejects_duplicate_node_ids(): + hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "0"}], + hyperedges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + with pytest.raises(ValueError, match="HIF node IDs must be unique"): + HIFProcessor.process_hypergraph(hypergraph) + + +def test_process_hypergraph_rejects_incidences_with_unknown_node_ids(): + hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}], + hyperedges=[{"edge": "0"}], + incidences=[{"node": "missing", "edge": "0"}], + ) + + with pytest.raises(ValueError, match="Incidence references unknown node id"): + HIFProcessor.process_hypergraph(hypergraph) + + def test_load_from_url_rejects_invalid_url(): with pytest.raises(ValueError, match="Invalid URL"): HIFLoader.load_from_url("not-a-url") diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index 9057f62..5e56c11 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -13,7 +13,7 @@ def mock_dataset_single_sample(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) hyperedge_attr = torch.tensor([[0.5], [0.7]]) - hyperedge_weights = torch.tensor([[0.8], [0.9]]) + hyperedge_weights = torch.tensor([0.8, 0.9]) hdata = HData( x=x, hyperedge_index=hyperedge_index, @@ -36,7 +36,7 @@ def mock_dataset_single_sample_with_weights(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) hyperedge_attr = torch.tensor([[0.5], [0.7]]) - hyperedge_weights = torch.tensor([[0.8], [0.9]]) + hyperedge_weights = torch.tensor([0.8, 0.9]) hdata = HData( x=x, hyperedge_index=hyperedge_index, @@ -122,7 +122,7 @@ def test_collate_single_sample_with_weights(mock_dataset_single_sample_with_weig expected_x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) expected_hyperedge_attr = torch.tensor([[0.5], [0.7]]) expected_hyeredge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) - expected_hyperedge_weights = torch.tensor([[0.8], [0.9]]) + expected_hyperedge_weights = torch.tensor([0.8, 0.9]) assert torch.equal(batched.x, expected_x) assert batched.hyperedge_index.shape == (2, 4) @@ -308,8 +308,7 @@ def test_collate_when_dataset_no_hyperedge_attr_presence(): def test_collate_when_dataset_has_no_global_node_ids(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - hdata.global_node_ids = None + hdata = HData(x=x, hyperedge_index=hyperedge_index, global_node_ids=None) sample0 = HData.from_hyperedge_index(torch.tensor([[0, 1], [0, 0]])) sample1 = HData.from_hyperedge_index(torch.tensor([[2], [1]])) diff --git a/hyperbench/tests/data/negative_sampler_test.py b/hyperbench/tests/data/negative_sampler_test.py index ee777c2..06e68ba 100644 --- a/hyperbench/tests/data/negative_sampler_test.py +++ b/hyperbench/tests/data/negative_sampler_test.py @@ -33,10 +33,32 @@ def mock_hdata_no_attr() -> HData: hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), hyperedge_attr=None, num_nodes=3, + num_hyperedges=2, + ) + + +@pytest.fixture +def mock_hdata_with_weights() -> HData: + return HData( + x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + hyperedge_index=torch.tensor([[0, 1, 2], [0, 1, 2]]), + hyperedge_weights=torch.tensor([0.5, 0.7, 0.9]), + num_nodes=3, num_hyperedges=3, ) +@pytest.fixture +def mock_hdata_no_weights() -> HData: + return HData( + x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), + hyperedge_weights=None, + num_nodes=3, + num_hyperedges=2, + ) + + @pytest.fixture def mock_clique_hdata() -> HData: return HData( @@ -101,7 +123,7 @@ def test_random_negative_sampler_sample_too_many_nodes(mock_hdata_with_attr): sampler.sample(mock_hdata_with_attr) -def test_random_negative_sampler_with_edge_attr(mock_hdata_with_attr): +def test_random_negative_sampler_with_hyperedge_attr(mock_hdata_with_attr): sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=2) result = sampler.sample(mock_hdata_with_attr) @@ -118,7 +140,7 @@ def test_random_negative_sampler_with_edge_attr(mock_hdata_with_attr): assert result.hyperedge_attr.shape[0] == 2 -def test_random_negative_sampler_sample_no_edge_attr(mock_hdata_no_attr): +def test_random_negative_sampler_sample_no_hyperedge_attr(mock_hdata_no_attr): sampler = RandomNegativeSampler(num_negative_samples=1, num_nodes_per_sample=2) result = sampler.sample(mock_hdata_no_attr) @@ -128,10 +150,43 @@ def test_random_negative_sampler_sample_no_edge_attr(mock_hdata_no_attr): assert ( result.hyperedge_index.shape[1] == 2 ) # 1 negative hyperedge * 2 nodes per negative hyperedge - assert 3 in result.hyperedge_index[1] # New hyperedge ID (3) should be present + assert 2 in result.hyperedge_index[1] # New hyperedge ID (2) should be present assert result.hyperedge_attr is None +def test_random_negative_sampler_with_hyperedge_weights(mock_hdata_with_weights): + sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_with_weights) + + assert result.num_hyperedges == 2 + assert result.x.shape[0] <= mock_hdata_with_weights.x.shape[0] + assert result.hyperedge_index.shape[0] == 2 + assert ( + result.hyperedge_index.shape[1] == 4 + ) # 2 negative hyperedges * 2 nodes per negative hyperedge + assert ( + 3 in result.hyperedge_index[1] and 4 in result.hyperedge_index[1] + ) # New hyperedge IDs (3, 4) should be present + + # Weights for new hyperedges are added only if the sampler is configured + # with a hyperedge_weights_enricher, otherwise they default to None + assert result.hyperedge_weights is None + + +def test_random_negative_sampler_sample_no_hyperedge_weights(mock_hdata_no_weights): + sampler = RandomNegativeSampler(num_negative_samples=1, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_no_weights) + + assert result.num_hyperedges == 1 + assert result.x.shape[0] <= mock_hdata_no_weights.x.shape[0] + assert result.hyperedge_index.shape[0] == 2 + assert ( + result.hyperedge_index.shape[1] == 2 + ) # 1 negative hyperedge * 2 nodes per negative hyperedge + assert 2 in result.hyperedge_index[1] # New hyperedge ID (2) should be present + assert result.hyperedge_weights is None + + def test_random_negative_sampler_sample_with_seed_is_reproducible(mock_hdata_with_attr): sampler = RandomNegativeSampler(num_negative_samples=3, num_nodes_per_sample=2) diff --git a/hyperbench/tests/data/negative_sampling_scheduler_test.py b/hyperbench/tests/data/negative_sampling_scheduler_test.py index ffab739..8c99f54 100644 --- a/hyperbench/tests/data/negative_sampling_scheduler_test.py +++ b/hyperbench/tests/data/negative_sampling_scheduler_test.py @@ -1,6 +1,8 @@ import pytest import torch +import re +from typing import Any, cast from unittest.mock import MagicMock from hyperbench.data import NegativeSampler, NegativeSamplingSchedule, NegativeSamplingScheduler from hyperbench.types import HData @@ -34,7 +36,7 @@ def mock_sampler(mock_negative_hdata): def test_config_returns_scheduler_parameters(mock_sampler): - schedule = NegativeSamplingSchedule.EVERY_N_EPOCHS + schedule: NegativeSamplingSchedule = "every_n_epochs" scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, negative_sampling_schedule=schedule, @@ -50,7 +52,7 @@ def test_config_returns_scheduler_parameters(mock_sampler): def test_sample_caches_result_across_non_sampling_epochs(mock_sampler, mock_batch): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.FIRST_EPOCH, + negative_sampling_schedule="first_epoch", ) # Epoch 0: should sample @@ -65,7 +67,7 @@ def test_sample_caches_result_across_non_sampling_epochs(mock_sampler, mock_batc def test_sample_delegates_to_negative_sampler(mock_sampler, mock_batch, mock_negative_hdata): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule="every_epoch", ) result = scheduler.sample(mock_batch, epoch=0) @@ -77,7 +79,7 @@ def test_sample_delegates_to_negative_sampler(mock_sampler, mock_batch, mock_neg def test_sample_raises_when_cache_is_empty(mock_sampler, mock_batch): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_N_EPOCHS, + negative_sampling_schedule="every_n_epochs", negative_sampling_every_n=5, ) @@ -89,7 +91,7 @@ def test_sample_raises_when_cache_is_empty(mock_sampler, mock_batch): def test_sample_resamples_on_every_n_epoch(mock_sampler, mock_batch): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_N_EPOCHS, + negative_sampling_schedule="every_n_epochs", negative_sampling_every_n=3, ) @@ -114,7 +116,7 @@ def test_sample_resamples_on_every_n_epoch(mock_sampler, mock_batch): def test_should_sample_every_epoch(mock_sampler, epoch, expected_should_sample): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule="every_epoch", ) assert scheduler.should_sample(epoch) == expected_should_sample @@ -137,7 +139,7 @@ def test_should_sample_every_epoch(mock_sampler, epoch, expected_should_sample): def test_should_sample_every_n_epochs(mock_sampler, epoch, every_n, expected_should_sample): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_N_EPOCHS, + negative_sampling_schedule="every_n_epochs", negative_sampling_every_n=every_n, ) @@ -156,7 +158,51 @@ def test_should_sample_every_n_epochs(mock_sampler, epoch, every_n, expected_sho def test_should_sample_first_epoch(mock_sampler, epoch, expected_should_sample): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.FIRST_EPOCH, + negative_sampling_schedule="first_epoch", ) assert scheduler.should_sample(epoch) == expected_should_sample + + +def test_should_sample_rejects_invalid_epoch(mock_sampler): + scheduler = NegativeSamplingScheduler(negative_sampler=mock_sampler) + + with pytest.raises(ValueError, match=re.escape("Epoch must be non-negative, got -1.")): + scheduler.should_sample(epoch=-1) + + +def test_should_sample_rejects_unsupported_schedule(mock_sampler): + scheduler = NegativeSamplingScheduler( + negative_sampler=mock_sampler, + negative_sampling_schedule=cast(Any, "sometimes"), + ) + + with pytest.raises( + ValueError, + match=re.escape("Unsupported negative sampling schedule: 'sometimes'."), + ): + scheduler.should_sample(epoch=0) + + +@pytest.mark.parametrize( + "every_n, expected_exception, expected_message", + [ + pytest.param( + 0, ValueError, "negative_sampling_every_n must be positive, got 0.", id="zero" + ), + pytest.param( + -1, ValueError, "negative_sampling_every_n must be positive, got -1.", id="negative" + ), + ], +) +def test_should_sample_rejects_invalid_every_n( + mock_sampler, every_n, expected_exception, expected_message +): + scheduler = NegativeSamplingScheduler( + negative_sampler=mock_sampler, + negative_sampling_schedule="every_n_epochs", + negative_sampling_every_n=cast(Any, every_n), + ) + + with pytest.raises(expected_exception, match=re.escape(expected_message)): + scheduler.should_sample(epoch=0) diff --git a/hyperbench/tests/data/sampler_test.py b/hyperbench/tests/data/sampler_test.py index 3ddc82f..7afd1a0 100644 --- a/hyperbench/tests/data/sampler_test.py +++ b/hyperbench/tests/data/sampler_test.py @@ -2,6 +2,7 @@ import pytest import torch +from typing import Any, cast from hyperbench.data import ( BaseSampler, HyperedgeSampler, @@ -48,6 +49,14 @@ def test_create_sampler_from_strategy_node(): assert isinstance(sampler, NodeSampler) +def test_create_sampler_from_strategy_rejects_unsupported_strategy(): + with pytest.raises( + ValueError, + match=re.escape("Unsupported sampling strategy: 'edge'."), + ): + create_sampler_from_strategy(cast(Any, "edge")) + + def test_hyperedge_sampling_single_index(mock_four_node_two_hyperedge_hdata): sampler = HyperedgeSampler() result = sampler.sample(0, mock_four_node_two_hyperedge_hdata) @@ -123,6 +132,42 @@ def test_sample_empty_index_raises(mock_four_node_two_hyperedge_hdata, sampler): sampler.sample([], mock_four_node_two_hyperedge_hdata) +@pytest.mark.parametrize( + "index", + [ + pytest.param("0", id="string_index"), + pytest.param(True, id="bool_index"), + ], +) +def test_sample_rejects_non_integer_index(mock_four_node_two_hyperedge_hdata, index): + sampler = HyperedgeSampler() + + with pytest.raises( + TypeError, + match=re.escape("Index must be an integer or a list of integers."), + ): + sampler.sample(cast(Any, index), mock_four_node_two_hyperedge_hdata) + + +@pytest.mark.parametrize( + "index", + [ + pytest.param([0, "1"], id="list_with_string"), + pytest.param([0, False], id="list_with_bool"), + ], +) +def test_sample_rejects_index_list_with_non_integer_items( + mock_four_node_two_hyperedge_hdata, index +): + sampler = HyperedgeSampler() + + with pytest.raises( + TypeError, + match=re.escape("Index list must contain only integers."), + ): + sampler.sample(cast(Any, index), mock_four_node_two_hyperedge_hdata) + + @pytest.mark.parametrize( "sampler, label", [ diff --git a/hyperbench/tests/data/splitter_test.py b/hyperbench/tests/data/splitter_test.py index 1310c42..d863365 100644 --- a/hyperbench/tests/data/splitter_test.py +++ b/hyperbench/tests/data/splitter_test.py @@ -2,6 +2,7 @@ import pytest import re +from typing import Any, cast from hyperbench.data import HyperedgeIDSplitter from hyperbench.types import HData @@ -63,6 +64,26 @@ def test_split_uses_cumulative_floor_boundaries_and_last_split_absorbs_remainder assert final_ratios == [0.4, 0.2, 0.4] +@pytest.mark.parametrize( + "ratios, expected_exception, expected_message", + [ + pytest.param([], ValueError, "'ratios' cannot be empty.", id="empty"), + pytest.param([0.5, 0.0, 0.5], ValueError, "'ratios' must be positive, got 0.0.", id="zero"), + pytest.param( + [0.5, float("inf")], ValueError, "'ratios' must be finite, got inf.", id="infinite" + ), + ], +) +def test_split_validates_ratio_values( + mock_hdata_five_hyperedges, ratios, expected_exception, expected_message +): + hyperedge_ids = torch.arange(5) + splitter = HyperedgeIDSplitter(mock_hdata_five_hyperedges) + + with pytest.raises(expected_exception, match=re.escape(expected_message)): + splitter.split(hyperedge_ids, cast(Any, ratios)) + + def test_ensure_split_covers_all_nodes_moves_best_covering_hyperedge_into_first_split(): x = torch.ones((4, 1), dtype=torch.float) hyperedge_index = torch.tensor( @@ -89,6 +110,25 @@ def test_ensure_split_covers_all_nodes_moves_best_covering_hyperedge_into_first_ assert final_ratios == [0.6, 0.4] +def test_ensure_split_covers_all_nodes_rejects_empty_splits(mock_hdata_five_hyperedges): + splitter = HyperedgeIDSplitter(mock_hdata_five_hyperedges) + + with pytest.raises(ValueError, match="'hyperedge_ids_by_split' cannot be empty"): + splitter.ensure_split_covers_all_nodes([]) + + +def test_ensure_split_covers_all_nodes_rejects_invalid_split_idx( + mock_hdata_five_hyperedges, +): + splitter = HyperedgeIDSplitter(mock_hdata_five_hyperedges) + + with pytest.raises(ValueError, match="split_idx must reference an existing split"): + splitter.ensure_split_covers_all_nodes( + [torch.tensor([0], dtype=torch.long)], + split_idx=1, + ) + + def test_ensure_split_covers_all_nodes_raises_when_a_node_is_missing_from_hypergraph(): x = torch.ones((4, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) diff --git a/hyperbench/tests/train/latex_logger_test.py b/hyperbench/tests/train/latex_logger_test.py index 06b9efa..367f3fb 100644 --- a/hyperbench/tests/train/latex_logger_test.py +++ b/hyperbench/tests/train/latex_logger_test.py @@ -1,8 +1,7 @@ import pytest -from textwrap import dedent - -from hyperbench.train.latex_logger import ( +from textwrap import dedent +from hyperbench.train import ( LaTexTableConfig, LaTexTableLogger, colorize_metric_value, @@ -35,6 +34,17 @@ def test_latex_logger_basics(tmp_path, mock_option_configs): assert logger.experiment_name == "exp1" +def test_latex_logger_rejects_negative_precision(tmp_path, mock_option_configs): + with pytest.raises(ValueError, match="'precision' must be non-negative"): + LaTexTableLogger( + save_dir=str(tmp_path), + model_name="model_a", + experiment_name="negative_precision", + precision=-1, + options=mock_option_configs, + ) + + def test_latex_logger_log_hyperparams_is_noop(tmp_path, mock_option_configs): logger = LaTexTableLogger( @@ -122,6 +132,17 @@ def test_save_comparison_tables_no_val_results(tmp_path, mock_option_configs): assert (tmp_path / "comparison" / "test.tex").exists() +def test_colorize_metric_value_rejects_invalid_sort_order(): + with pytest.raises(ValueError, match="'sort_order' must be 'asc' or 'des'"): + colorize_metric_value( + metric="test_auc", + value=0.8, + text="0.8000", + metric_bounds={"test_auc": (0.1, 0.9)}, + sort_order="invalid", + ) + + def test_save_comparison_tables_only_val_results(tmp_path, mock_option_configs): logger = LaTexTableLogger( save_dir=str(tmp_path), diff --git a/hyperbench/tests/train/markdown_logger_test.py b/hyperbench/tests/train/markdown_logger_test.py index d619b9f..f73c71e 100644 --- a/hyperbench/tests/train/markdown_logger_test.py +++ b/hyperbench/tests/train/markdown_logger_test.py @@ -1,7 +1,7 @@ -from textwrap import dedent import pytest -from hyperbench.train.markdown_logger import MarkdownTableLogger +from textwrap import dedent +from hyperbench.train import MarkdownTableLogger def test_markdown_table_logger_basic_functions(tmp_path): @@ -21,6 +21,16 @@ def test_markdown_table_logger_basic_functions(tmp_path): assert logger.experiment_name == experiment_name +def test_markdown_table_logger_rejects_negative_precision(tmp_path): + with pytest.raises(ValueError, match="'precision' must be non-negative"): + MarkdownTableLogger( + save_dir=str(tmp_path), + model_name="model_a", + experiment_name="negative_precision", + precision=-1, + ) + + def test_log_hyperparams_is_noop(tmp_path): experiment_name = "exp_hparams" diff --git a/hyperbench/tests/train/trainer_test.py b/hyperbench/tests/train/trainer_test.py index 6d79627..f7d2bb7 100644 --- a/hyperbench/tests/train/trainer_test.py +++ b/hyperbench/tests/train/trainer_test.py @@ -1,5 +1,5 @@ -import re import pytest +import re from unittest.mock import MagicMock, patch from hyperbench.train import MultiModelTrainer @@ -50,6 +50,11 @@ def test_trainer_initialization( assert config.trainer is not None +def test_trainer_initialization_rejects_invalid_tensorboard_port(mock_model_configs): + with pytest.raises(ValueError, match="'tensorboard_port' must be non-negative"): + MultiModelTrainer(mock_model_configs, tensorboard_port=-1) + + @patch("hyperbench.train.trainer.L.Trainer") @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -70,6 +75,20 @@ def test_trainer_initialization_with_initialized_trainer( assert config.trainer is not None +@patch("hyperbench.train.trainer.L.Trainer") +@patch("hyperbench.train.trainer.CSVLogger") +@patch("hyperbench.train.trainer.MarkdownTableLogger") +@patch("hyperbench.train.trainer.LaTexTableLogger") +def test_trainer_initialization_with_no_models( + mock_latex_logger_cls, + mock_md_logger_cls, + mock_csv_logger_cls, + mock_trainer_cls, +): + with pytest.raises(ValueError, match=re.escape("'model_configs' cannot be empty.")): + MultiModelTrainer(model_configs=[]) + + @patch("hyperbench.train.trainer.L.Trainer") @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -152,14 +171,6 @@ def test_models_property_returns_models( assert len(models) == len(mock_model_configs) -@patch("hyperbench.train.trainer.L.Trainer") -def test_models_property_returns_empty_when_no_models(mock_trainer_cls): - multi_model_trainer = MultiModelTrainer([]) - models = multi_model_trainer.models - - assert len(models) == 0 - - @patch("hyperbench.train.trainer.L.Trainer") @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -273,22 +284,6 @@ def test_fit_all_calls_fit( config.trainer.fit.assert_called_once() -@patch("hyperbench.train.trainer.L.Trainer") -@patch("hyperbench.train.trainer.CSVLogger") -@patch("hyperbench.train.trainer.MarkdownTableLogger") -@patch("hyperbench.train.trainer.LaTexTableLogger") -def test_fit_all_with_no_models( - mock_latex_logger_cls, - mock_md_logger_cls, - mock_csv_logger_cls, - mock_trainer_cls, -): - multi_model_trainer = MultiModelTrainer([]) - - with pytest.raises(ValueError, match=re.escape("No models to fit.")): - multi_model_trainer.fit_all(verbose=False) - - @patch("hyperbench.train.trainer.L.Trainer", return_value=None) @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -424,22 +419,6 @@ def test_test_all_calls_test_and_returns_results( config.trainer.test.assert_called_once() -@patch("hyperbench.train.trainer.L.Trainer") -@patch("hyperbench.train.trainer.CSVLogger") -@patch("hyperbench.train.trainer.MarkdownTableLogger") -@patch("hyperbench.train.trainer.LaTexTableLogger") -def test_test_all_with_no_models( - mock_latex_logger_cls, - mock_md_logger_cls, - mock_csv_logger_cls, - mock_trainer_cls, -): - multi_model_trainer = MultiModelTrainer([]) - - with pytest.raises(ValueError, match=re.escape("No models to test.")): - multi_model_trainer.test_all(verbose=False) - - @patch("hyperbench.train.trainer.L.Trainer", return_value=None) @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") diff --git a/hyperbench/tests/types/graph_test.py b/hyperbench/tests/types/graph_test.py index 7d9e9b7..326e38a 100644 --- a/hyperbench/tests/types/graph_test.py +++ b/hyperbench/tests/types/graph_test.py @@ -1,8 +1,8 @@ -import re - import pytest import torch +import re +from collections.abc import Callable from hyperbench.types import EdgeIndex, Graph @@ -606,13 +606,6 @@ def test_add_selfloops_adds_unit_edge_weights(): assert weighted_edges == expected_weighted_edges -def test_add_selfloops_raises_on_empty_edge_index(): - edge_index = EdgeIndex(torch.tensor([[], []], dtype=torch.long)) - - with pytest.raises(ValueError, match="Edge index must have at least one edge"): - edge_index.add_selfloops() - - def test_add_selfloops_does_not_duplicate_selfloops(): edge_index = EdgeIndex(torch.tensor([[0, 1, 1], [1, 2, 1]])) edge_index.add_selfloops() @@ -696,6 +689,71 @@ def test_get_sparse_adjacency_matrix_empty_edge_index(): assert torch.all(dense_adj_matrix == 0) +@pytest.mark.parametrize( + "call_with_num_nodes", + [ + pytest.param( + lambda edge_index, num_nodes: edge_index.add_selfloops(num_nodes=num_nodes), + id="add_selfloops", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_adjacency_matrix( + num_nodes=num_nodes + ), + id="get_sparse_adjacency_matrix", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_identity_matrix( + num_nodes=num_nodes + ), + id="get_sparse_identity_matrix", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_normalized_degree_matrix( + num_nodes=num_nodes + ), + id="get_sparse_normalized_degree_matrix", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_normalized_laplacian( + num_nodes=num_nodes + ), + id="get_sparse_normalized_laplacian", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_normalized_gcn_laplacian( + num_nodes=num_nodes + ), + id="get_sparse_normalized_gcn_laplacian", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.remove_duplicate_edges(num_nodes=num_nodes), + id="remove_duplicate_edges", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.to_undirected(num_nodes=num_nodes), + id="to_undirected", + ), + ], +) +@pytest.mark.parametrize( + ("num_nodes", "expected_message"), + [ + pytest.param(-1, "'num_nodes' must be non-negative", id="negative"), + pytest.param(1, "'num_nodes' is too small for the edge index", id="too_small"), + ], +) +def test_edge_index_methods_reject_invalid_num_nodes( + call_with_num_nodes: Callable[[EdgeIndex, int], object], + num_nodes: int, + expected_message: str, +): + edge_index = EdgeIndex(torch.tensor([[0, 1], [1, 2]], dtype=torch.long)) + + with pytest.raises(ValueError, match=expected_message): + call_with_num_nodes(edge_index, num_nodes) + + @pytest.mark.parametrize( "edge_index, num_nodes, expected_entries", [ diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 57e2be4..dd8d5ca 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -1,10 +1,9 @@ -import re import pytest import torch +import re from unittest.mock import MagicMock -from typing import cast -from torch import Tensor +from typing import Any, cast from hyperbench import utils from hyperbench.data import HyperedgeEnricher, NegativeSampler, NodeEnricher, RandomNegativeSampler from hyperbench.types import HData @@ -91,20 +90,32 @@ def sample(data: HData, seed: int | None = None) -> HData: "explicit_num_nodes, expected_num_nodes", [ pytest.param(None, 7, id="inferred_from_x"), - pytest.param(10, 10, id="explicit_overrides_x"), - pytest.param(0, 0, id="explicit_zero"), + pytest.param(10, 10, id="explicit_allows_isolated_nodes"), ], ) def test_init_num_nodes(explicit_num_nodes, expected_num_nodes): hyperedge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [0, 0, 0, 0, 0, 0, 0]]) - num_nodes = hyperedge_index[0].size(0) - x = torch.randn(num_nodes, 3) + x = torch.randn(expected_num_nodes, 3) data = HData(x=x, hyperedge_index=hyperedge_index, num_nodes=explicit_num_nodes) assert data.num_nodes == expected_num_nodes +def test_init_raises_when_num_nodes_is_too_small_for_hyperedge_index(): + x = torch.randn(2, 2) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 0]]) + + with pytest.raises( + ValueError, + match=re.escape( + "num_nodes is too small for hyperedge_index. " + "Got num_nodes=2, but hyperedge_index contains 3 unique node IDs." + ), + ): + HData(x=x, hyperedge_index=hyperedge_index, num_nodes=2) + + @pytest.mark.parametrize( "hyperedge_index, explicit_num_hyperedges, expected_num_hyperedges", [ @@ -135,6 +146,19 @@ def test_init_num_hyperedges(hyperedge_index, explicit_num_hyperedges, expected_ assert data.num_hyperedges == expected_num_hyperedges +def test_init_raises_when_num_hyperedges_is_too_small_for_hyperedge_index(): + x = torch.randn(2, 2) + + with pytest.raises( + ValueError, + match=re.escape( + "num_hyperedges is too small for hyperedge_index. " + "Got num_hyperedges=2, but hyperedge_index contains 3 unique hyperedge IDs." + ), + ): + HData(x=x, hyperedge_index=torch.tensor([[0, 1, 0], [0, 1, 2]]), num_hyperedges=2) + + def test_init_default_y_is_ones(): x = torch.randn(3, 2) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) @@ -170,6 +194,167 @@ def test_init_hyperedge_attr_defaults_to_none(): assert data.hyperedge_attr is None +@pytest.mark.parametrize( + "kwargs, expected_message", + [ + pytest.param( + {"x": torch.randn(3), "hyperedge_index": torch.tensor([[0, 1], [0, 0]])}, + "x must be a 2D tensor, got shape (3,).", + id="x_not_2d", + ), + pytest.param( + {"x": torch.randn(3, 2), "hyperedge_index": torch.tensor([0, 1])}, + "hyperedge_index must have shape (2, num_incidences), got (2,).", + id="hyperedge_index_not_2d", + ), + pytest.param( + {"x": torch.randn(3, 2), "hyperedge_index": torch.tensor([[0, 1, 2]])}, + "hyperedge_index must have shape (2, num_incidences), got (1, 3).", + id="hyperedge_index_wrong_rows", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0.0, 1.0], [0.0, 0.0]]), + }, + "hyperedge_index must have dtype torch.long, got torch.float32.", + id="hyperedge_index_not_long", + ), + pytest.param( + {"x": torch.randn(3, 2), "hyperedge_index": torch.tensor([[-1, 1], [0, 0]])}, + "hyperedge_index cannot contain negative node or hyperedge IDs.", + id="hyperedge_index_negative_id", + ), + pytest.param( + { + "x": torch.randn(2, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + }, + ( + "x must have one feature row per node, or be 'torch.empty((0, 0))' if there are no " + "nodes. Got x.shape=(2, 2) but num_nodes=3." + ), + id="x_rows_do_not_match_num_nodes", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + "global_node_ids": torch.tensor([[0, 1, 2]]), + }, + "global_node_ids must be a 1D tensor, got shape (1, 3).", + id="global_node_ids_not_1d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + "global_node_ids": torch.tensor([0.0, 1.0, 2.0]), + }, + "global_node_ids must have dtype torch.long, got torch.float32.", + id="global_node_ids_not_long", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + "global_node_ids": torch.tensor([0, 1]), + }, + "global_node_ids must have one entry per node. Got size=2 but num_nodes=3.", + id="global_node_ids_wrong_length", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "y": torch.tensor([[1.0, 0.0]]), + }, + "y must be a 1D tensor, got shape (1, 2).", + id="y_not_1d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "y": torch.tensor([1.0]), + }, + "y must have one entry per hyperedge. Got 1 entries but num_hyperedges=2.", + id="y_wrong_length", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_weights": torch.tensor([[0.25, 0.75]]), + }, + "hyperedge_weights must be a 1D tensor, got shape (1, 2).", + id="hyperedge_weights_not_1d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_weights": torch.tensor([0.25]), + }, + ( + "hyperedge_weights must have one entry per hyperedge. " + "Got size=1 but num_hyperedges=2." + ), + id="hyperedge_weights_wrong_length", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_attr": torch.tensor([1.0, 2.0]), + }, + "hyperedge_attr must be a 2D tensor, got shape (2,).", + id="hyperedge_attr_not_2d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_attr": torch.randn(1, 4), + }, + "hyperedge_attr must have one row per hyperedge. Got size=1 but num_hyperedges=2.", + id="hyperedge_attr_wrong_rows", + ), + ], +) +def test_init_validates_input_values(kwargs, expected_message): + with pytest.raises(ValueError, match=re.escape(expected_message)): + HData(**kwargs) + + +@pytest.mark.parametrize( + "kwargs, expected_message", + [ + pytest.param( + { + "x": torch.empty((0, 1)), + "hyperedge_index": torch.empty((2, 0), dtype=torch.long), + "num_nodes": -1, + }, + "'num_nodes' must be non-negative, got -1.", + id="negative_num_nodes", + ), + pytest.param( + { + "x": torch.empty((0, 1)), + "hyperedge_index": torch.empty((2, 0), dtype=torch.long), + "num_hyperedges": -1, + }, + "'num_hyperedges' must be non-negative, got -1.", + id="negative_num_hyperedges", + ), + ], +) +def test_init_validates_non_negative_number_of_nodes_and_hyperedges(kwargs, expected_message): + with pytest.raises(ValueError, match=re.escape(expected_message)): + HData(**kwargs) + + def test_repr_contains_class_name_and_fields(mock_hdata): r = repr(mock_hdata) @@ -192,17 +377,23 @@ def test_empty_returns_empty_hdata(): data = HData.empty() assert data.x is not None - assert isinstance(data.x, Tensor) assert data.x.shape == (0, 0) assert data.hyperedge_index is not None - assert isinstance(data.hyperedge_index, Tensor) assert data.hyperedge_index.shape == (2, 0) assert data.hyperedge_attr is None + assert data.hyperedge_weights is None + assert data.num_nodes == 0 assert data.num_hyperedges == 0 + assert data.global_node_ids is not None + assert data.global_node_ids.shape == (0,) + + assert data.y is not None + assert data.y.shape == (0,) + @pytest.mark.parametrize( "hyperedge_index, expected_num_nodes, expected_num_hyperedges", @@ -376,7 +567,7 @@ def test_hdata_to_mps_handles_none_hyperedge_attr(mock_hdata): def test_cat_same_node_space_raises_on_empty_list(): - with pytest.raises(ValueError, match=re.escape("At least one instance is required.")): + with pytest.raises(ValueError, match=re.escape("'hdatas' cannot be empty.")): HData.cat_same_node_space([]) @@ -441,25 +632,87 @@ def test_cat_same_node_space_concatenates_labels(): def test_cat_same_node_space_uses_largest_x_when_not_provided(): x_large = torch.randn(3, 1) x_small = torch.randn(2, 1) - hdata1 = HData(x=x_large, hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 0]])) + global_node_ids = torch.tensor([10, 20, 30]) + hdata1 = HData( + x=x_large, + hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 0]]), + global_node_ids=global_node_ids, + ) hdata2 = HData(x=x_small, hyperedge_index=torch.tensor([[0, 2], [1, 1]])) + expected_hyperedge_index = torch.tensor([[0, 1, 2, 0, 2], [0, 0, 0, 1, 1]]) result = HData.cat_same_node_space([hdata1, hdata2]) assert torch.equal(result.x, x_large) + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, global_node_ids) assert torch.equal(result.hyperedge_index, expected_hyperedge_index) -def test_cat_same_node_space_uses_provided_x(): +def test_cat_same_node_space_uses_provided_x_and_global_node_ids(): x = torch.randn(2, 4) hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) custom_x = torch.randn(4, 4) - result = HData.cat_same_node_space([hdata1, hdata2], x=custom_x) + custom_global_node_ids = torch.tensor([10, 20, 30, 40]) + + result = HData.cat_same_node_space( + hdatas=[hdata1, hdata2], + x=custom_x, + global_node_ids=custom_global_node_ids, + ) assert torch.equal(result.x, custom_x) + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, custom_global_node_ids) + + +def test_cat_same_node_space_raises_when_only_x_is_provided(): + x = torch.randn(2, 4) + hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) + hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) + + with pytest.raises( + ValueError, + match=re.escape( + "If x is provided, global_node_ids must also be provided to ensure consistency." + ), + ): + HData.cat_same_node_space([hdata1, hdata2], x=torch.randn(4, 4)) + + +def test_cat_same_node_space_raises_when_only_global_node_ids_are_provided(): + x = torch.randn(2, 4) + hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) + hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) + + with pytest.raises( + ValueError, + match=re.escape( + "If global_node_ids is provided, x must also be provided to ensure consistency." + ), + ): + HData.cat_same_node_space([hdata1, hdata2], global_node_ids=torch.arange(4)) + + +def test_cat_same_node_space_validates_global_node_ids_alignment(): + x = torch.randn(2, 4) + hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) + hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) + + with pytest.raises( + ValueError, + match=re.escape( + "global_node_ids must have one entry per node. Got size=3 but num_nodes=4." + ), + ): + HData.cat_same_node_space( + [hdata1, hdata2], + x=torch.randn(4, 4), + global_node_ids=torch.arange(3), + ) def test_cat_same_node_space_concatenates_hyperedge_attr(): @@ -531,15 +784,20 @@ def test_cat_same_node_space_does_not_share_mutable_storage_with_inputs( __assert_mutating_result_keeps_source_tensors_unchanged(other_hdata, result) -def test_cat_same_node_space_clones_provided_x(hdata_with_all_mutable_tensors): - hdata = hdata_with_all_mutable_tensors - custom_x = torch.full_like(hdata.x, 9.0) +def test_cat_same_node_space_clones_provided_x_and_global_node_ids(hdata_with_all_mutable_tensors): + custom_x = torch.full_like(hdata_with_all_mutable_tensors.x, 9.0) + custom_global_node_ids = torch.arange(hdata_with_all_mutable_tensors.num_nodes) + 100 + original_custom_x = custom_x.clone() + original_custom_global_node_ids = custom_global_node_ids.clone() - result = HData.cat_same_node_space([hdata], x=custom_x) + result = HData.cat_same_node_space( + hdatas=[hdata_with_all_mutable_tensors], x=custom_x, global_node_ids=custom_global_node_ids + ) result.x.flatten()[0].add_(1) assert torch.equal(custom_x, original_custom_x) + assert torch.equal(custom_global_node_ids, original_custom_global_node_ids) def test_add_negative_samples_combines_positive_and_negative_hyperedges(mock_negative_sampler): @@ -682,6 +940,25 @@ def test_split_transductive_counts( assert torch.equal(result.hyperedge_index, expected_hyperedge_index) +def test_split_raises_on_invalid_node_space_setting(): + hdata = HData( + x=torch.randn(2, 1), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + ) + + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', got 'semi'." + ), + ): + HData.split( + hdata, + split_hyperedge_ids=torch.tensor([0]), + node_space_setting=cast(Any, "semi"), + ) + + def test_split_inductive_subsets_node_features(): x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) hyperedge_index = torch.tensor([[0, 1, 3, 4], [0, 0, 1, 1]]) @@ -723,7 +1000,7 @@ def test_split_subsets_labels(): pytest.param( "inductive", torch.tensor([1]), - torch.arange(2), + torch.tensor([2, 3]), id="inductive", ), ], @@ -733,8 +1010,7 @@ def test_split_handles_none_global_node_ids( ): x = torch.tensor([[10.0], [20.0], [30.0], [40.0]]) hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - hdata.global_node_ids = None + hdata = HData(x=x, hyperedge_index=hyperedge_index, global_node_ids=None) result = HData.split( hdata, @@ -774,8 +1050,12 @@ def test_split_transductive_keeps_full_x_and_global_node_ids(): def test_split_transductive_handles_none_global_node_ids(): x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) hyperedge_index = torch.tensor([[0, 2, 3, 4], [0, 0, 1, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index, y=torch.tensor([1.0, 0.0])) - hdata.global_node_ids = None + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + y=torch.tensor([1.0, 0.0]), + global_node_ids=None, + ) result = HData.split( hdata, @@ -911,12 +1191,34 @@ def test_enrich_node_features_concatenate(mock_hdata): assert result.x.shape == (5, 7) # 4 original + 3 enriched +@pytest.mark.parametrize( + "enrich_method", + [ + pytest.param("enrich_node_features", id="node_features"), + pytest.param("enrich_hyperedge_weights", id="hyperedge_weights"), + pytest.param("enrich_hyperedge_attr", id="hyperedge_attr"), + ], +) +def test_enrich_rejects_invalid_enrichment_mode(mock_hdata, enrich_method): + enricher_spec = NodeEnricher if enrich_method == "enrich_node_features" else HyperedgeEnricher + enricher = MagicMock(spec=enricher_spec) + + with pytest.raises( + ValueError, + match=re.escape( + "enrichment_mode must be one of 'replace', 'concatenate', or None, got 'append'." + ), + ): + getattr(mock_hdata, enrich_method)(enricher, enrichment_mode=cast(Any, "append")) + + enricher.enrich.assert_not_called() + + @pytest.mark.parametrize( "enrichment_mode", [ pytest.param("replace", id="replace"), pytest.param("concatenate", id="concatenate"), - pytest.param(None, id="none_enrichment_mode_defaults_to_replace"), ], ) def test_enrich_node_features_replace_preserves_global_node_ids(mock_hdata, enrichment_mode): @@ -958,37 +1260,6 @@ def test_enrich_node_features_from_aligns_by_global_node_ids(): assert torch.equal(result.y, target_hdata.y) -@pytest.mark.parametrize( - "missing_side", - [ - pytest.param("source", id="source_missing_global_node_ids"), - pytest.param("target", id="target_missing_global_node_ids"), - ], -) -def test_enrich_node_features_from_raises_without_global_node_ids(missing_side): - source_hdata = HData( - x=torch.tensor([[1.0], [2.0]]), - hyperedge_index=torch.tensor([[0, 1], [0, 0]]), - global_node_ids=torch.tensor([10, 20]), - ) - target_hdata = HData( - x=torch.tensor([[0.0]]), - hyperedge_index=torch.tensor([[0], [0]]), - global_node_ids=torch.tensor([10]), - ) - - if missing_side == "source": - source_hdata.global_node_ids = None - else: - target_hdata.global_node_ids = None - - with pytest.raises( - ValueError, - match=re.escape("Both HData instances must define global_node_ids to align node features."), - ): - target_hdata.enrich_node_features_from(source_hdata) - - def test_enrich_node_features_from_raises_when_source_rows_do_not_match_global_node_ids(): source_hdata = HData( x=torch.empty((0, 0)), @@ -1177,6 +1448,30 @@ def test_enrich_methods_do_not_share_mutable_storage_with_source(hdata_with_all_ ) +def test_enrich_node_features_from_raises_on_invalid_node_space_setting(): + source_hdata = HData( + x=torch.tensor([[1.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + target_hdata = HData( + x=torch.tensor([[0.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', got 'semi'." + ), + ): + target_hdata.enrich_node_features_from( + source_hdata, + node_space_setting=cast(Any, "semi"), + ) + + def test_enrich_hyperedge_weights_replace(): x = torch.tensor([[1.0], [2.0], [3.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) @@ -1223,6 +1518,29 @@ def test_enrich_hyperedge_weights_concatenate(): assert torch.equal(utils.to_non_empty_edgeattr(result.hyperedge_weights), enriched_weights) +def test_enrich_hyperedge_weights_concatenate_after_hyperedge_index_expansion(): + x = torch.tensor([[1.0], [2.0], [3.0]]) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_weights=torch.tensor([0.1, 0.2]), + ) + hdata.hyperedge_index = torch.tensor([[0, 1, 2, 0], [0, 0, 1, 2]]) + hdata.num_hyperedges = 3 + hdata.y = torch.ones(3, dtype=torch.float) + + enricher = MagicMock(spec=HyperedgeEnricher) + enricher.enrich.return_value = torch.tensor([0.7]) + + result = hdata.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") + + enricher.enrich.assert_called_once_with(hdata.hyperedge_index) + assert torch.equal( + utils.to_non_empty_edgeattr(result.hyperedge_weights), torch.tensor([0.1, 0.2, 0.7]) + ) + + def test_enrich_hyperedge_attr_replace(): x = torch.tensor([[1.0], [2.0], [3.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) @@ -1306,8 +1624,12 @@ def test_get_device_if_all_consistent_handles_none_global_node_ids(): x = torch.randn(3, 2) hyperedge_index = torch.tensor([[0, 1], [0, 0]]) hyperedge_attr = torch.randn(1, 4) - hdata = HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) - hdata.global_node_ids = None + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_attr=hyperedge_attr, + global_node_ids=None, + ) assert hdata.get_device_if_all_consistent() == torch.device("cpu") @@ -1643,16 +1965,24 @@ def test_remove_hyperedges_with_fewer_than_k_nodes_keeps_none_hyperedge_attr(): assert result.hyperedge_attr is None +def test_remove_hyperedges_with_fewer_than_k_nodes_rejects_invalid_k(): + x = torch.randn(2, 1) + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + hdata = HData(x=x, hyperedge_index=hyperedge_index) + + with pytest.raises(ValueError, match="'k' must be positive"): + hdata.remove_hyperedges_with_fewer_than_k_nodes(k=0) + + def test_remove_hyperedges_with_fewer_than_k_nodes_handles_none_global_node_ids(): x = torch.randn(5, 2) hyperedge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 1, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - hdata.global_node_ids = None + hdata = HData(x=x, hyperedge_index=hyperedge_index, global_node_ids=None) result = hdata.remove_hyperedges_with_fewer_than_k_nodes(k=3) assert result.global_node_ids is not None - assert torch.equal(result.global_node_ids, torch.arange(result.num_nodes)) + assert torch.equal(result.global_node_ids, torch.tensor([2, 3, 4])) def test_remove_hyperedges_with_fewer_than_k_nodes_subsets_hyperedge_weights(): diff --git a/hyperbench/tests/types/hypergraph_test.py b/hyperbench/tests/types/hypergraph_test.py index 77a5b18..63cc25e 100644 --- a/hyperbench/tests/types/hypergraph_test.py +++ b/hyperbench/tests/types/hypergraph_test.py @@ -1,9 +1,9 @@ -import json -import re - import pytest import torch +import json +import re +from typing import Any, cast from unittest.mock import patch from hyperbench.types import HIFHypergraph, Hypergraph, HyperedgeIndex from hyperbench.tests import MOCK_BASE_PATH @@ -503,6 +503,12 @@ def test_hifhypergraph_stats_returns_correct_statistics(): [2, 3], id="second_of_two_hyperedges", ), + pytest.param( + torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]), + 10, + [], + id="non_existent_hyperedge_id", + ), ], ) def test_hyperedge_index_nodes_in(hyperedge_index_tensor, hyperedge_id, expected_nodes): @@ -510,8 +516,10 @@ def test_hyperedge_index_nodes_in(hyperedge_index_tensor, hyperedge_id, expected assert hyperedge_index.nodes_in(hyperedge_id) == expected_nodes -def _edge_index_to_edge_set(edge_index: torch.Tensor) -> set[tuple[int, int]]: - return set(zip(edge_index[0].tolist(), edge_index[1].tolist(), strict=True)) +def test_hyperedge_index_nodes_in_raises_when_hyperedge_id_is_negative(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2], [0, 0, 0]], dtype=torch.long)) + with pytest.raises(ValueError, match=re.escape("'hyperedge_id' must be non-negative, got -1")): + hyperedge_index.nodes_in(-1) @pytest.mark.parametrize( @@ -559,7 +567,9 @@ def test_reduce_with_clique_expansion_returns_expected_edges( assert result.dtype == torch.long assert result.shape[0] == 2 - assert _edge_index_to_edge_set(result) == expected_edges + + edges = set(zip(result[0].tolist(), result[1].tolist(), strict=True)) + assert edges == expected_edges def test_reduce_with_clique_expansion_matches_specialized_reducer(): @@ -572,6 +582,16 @@ def test_reduce_with_clique_expansion_matches_specialized_reducer(): assert torch.equal(result, expected) +def test_reduce_rejects_unsupported_strategy(): + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + + with pytest.raises( + ValueError, + match=re.escape("Unsupported reduction strategy: fancy_expansion. "), + ): + HyperedgeIndex(hyperedge_index).reduce(cast(Any, "fancy_expansion")) + + @pytest.mark.parametrize( "hyperedge_index_tensor, expected_num_edges", [ @@ -658,6 +678,13 @@ def test_get_clique_expansion_adjacency(hyperedge_index_tensor, num_nodes, expec assert result == expected_adjacency +def test_get_clique_expansion_adjacency_rejects_num_nodes_too_small(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 2], [0, 0]], dtype=torch.long)) + + with pytest.raises(ValueError, match="'num_nodes' is too small for the hyperedge index"): + hyperedge_index.get_clique_expansion_adjacency_list(num_nodes=2) + + @pytest.mark.parametrize( "x, hyperedge_index, with_mediators, expected_num_edges", [ @@ -1026,6 +1053,13 @@ def test_remove_hyperedges_with_fewer_than_k_nodes_returns_self(): assert result is hyperedge_index +def test_remove_hyperedges_with_fewer_than_k_nodes_rejects_invalid_k(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + + with pytest.raises(ValueError, match="'k' must be positive"): + hyperedge_index.remove_hyperedges_with_fewer_than_k_nodes(0) + + @pytest.mark.parametrize( "dropout", [ @@ -1283,18 +1317,42 @@ def test_get_sparse_incidence_matrix_sums_duplicate_incidences_when_coalesced(): assert torch.allclose(incidence_matrix.to_dense(), expected_incidence_matrix, atol=1e-6) -def test_get_sparse_incidence_matrix_rejects_explicit_num_nodes_too_small(): - hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2], [0, 0, 0]], dtype=torch.long)) +@pytest.mark.parametrize( + ("kwargs", "expected_message"), + [ + pytest.param({"num_nodes": 3}, "'num_nodes' is too small", id="nodes"), + pytest.param( + {"num_hyperedges": 1}, + "'num_hyperedges' is too small", + id="hyperedges", + ), + ], +) +def test_get_sparse_incidence_matrix_rejects_explicit_dimensions_too_small( + kwargs, expected_message +): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long)) - with pytest.raises(ValueError, match="num_nodes is too small"): - hyperedge_index.get_sparse_incidence_matrix(num_nodes=2) + with pytest.raises(ValueError, match=expected_message): + hyperedge_index.get_sparse_incidence_matrix(**kwargs) -def test_get_sparse_incidence_matrix_rejects_explicit_num_hyperedges_too_small(): - hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long)) +@pytest.mark.parametrize( + ("kwargs", "expected_message"), + [ + pytest.param({"num_nodes": -1}, "'num_nodes' must be non-negative", id="nodes"), + pytest.param( + {"num_hyperedges": -1}, + "'num_hyperedges' must be non-negative", + id="hyperedges", + ), + ], +) +def test_get_sparse_incidence_matrix_rejects_negative_dimensions(kwargs, expected_message): + hyperedge_index = HyperedgeIndex(torch.zeros((2, 0), dtype=torch.long)) - with pytest.raises(ValueError, match="num_hyperedges is too small"): - hyperedge_index.get_sparse_incidence_matrix(num_hyperedges=1) + with pytest.raises(ValueError, match=expected_message): + hyperedge_index.get_sparse_incidence_matrix(**kwargs) def test_get_sparse_symnormalized_node_degree_matrix_is_expected_diagonal(): @@ -1373,6 +1431,30 @@ def test_get_sparse_normalized_node_degree_matrix_zeroes_isolated_nodes_for_nega assert torch.allclose(node_degree_matrix.to_dense(), torch.diag(expected_diagonal), atol=1e-6) +def test_get_sparse_normalized_node_degree_matrix_rejects_mismatched_num_nodes(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix(num_nodes=3) + + with pytest.raises(ValueError, match="'num_nodes' must match the incidence matrix dimension"): + hyperedge_index.get_sparse_normalized_node_degree_matrix( + incidence_matrix, + power=-1, + num_nodes=2, + ) + + +def test_get_sparse_normalized_node_degree_matrix_rejects_negative_num_nodes(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix() + + with pytest.raises(ValueError, match="'num_nodes' must be non-negative"): + hyperedge_index.get_sparse_normalized_node_degree_matrix( + incidence_matrix, + power=-1, + num_nodes=-1, + ) + + def test_get_sparse_rownormalized_node_degree_matrix_is_expected_diagonal(): hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 0, 2], [0, 0, 1, 1]])) incidence_matrix = hyperedge_index.get_sparse_incidence_matrix() # shape (3,2) @@ -1427,3 +1509,28 @@ def test_get_sparse_normalized_hyperedge_degree_matrix_infers_num_hyperedges(): ) assert hyperedge_degree_matrix.shape == (2, 2) + + +def test_get_sparse_normalized_hyperedge_degree_matrix_rejects_mismatched_num_hyperedges(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix(num_hyperedges=2) + + with pytest.raises( + ValueError, + match="'num_hyperedges' must match the incidence matrix dimension", + ): + hyperedge_index.get_sparse_normalized_hyperedge_degree_matrix( + incidence_matrix, + num_hyperedges=1, + ) + + +def test_get_sparse_normalized_hyperedge_degree_matrix_rejects_negative_num_hyperedges(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix() + + with pytest.raises(ValueError, match="'num_hyperedges' must be non-negative"): + hyperedge_index.get_sparse_normalized_hyperedge_degree_matrix( + incidence_matrix, + num_hyperedges=-1, + ) diff --git a/hyperbench/tests/utils/data_utils_test.py b/hyperbench/tests/utils/data_utils_test.py index 9422614..7e5c508 100644 --- a/hyperbench/tests/utils/data_utils_test.py +++ b/hyperbench/tests/utils/data_utils_test.py @@ -37,35 +37,39 @@ def test_clone_optional_tensor_with_tensor_does_not_share_storage(): assert tensor[0, 0] == 1.0 -def test_empty_edgeindex(): +def test_empty_hyperedgeindex(): result = empty_hyperedgeindex() assert result.shape == (2, 0) - assert result.dtype == torch.float32 + assert result.dtype == torch.long def test_empty_edgeattr_zero_edges(): result = empty_edgeattr(num_edges=0) assert result.shape == (0, 0) + assert result.dtype == torch.float def test_empty_edgeattr_single_edge(): result = empty_edgeattr(num_edges=1) assert result.shape == (1, 0) + assert result.dtype == torch.float def test_empty_edgeattr_with_edges(): result = empty_edgeattr(num_edges=5) assert result.shape == (5, 0) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_none(): result = to_non_empty_edgeattr(edge_attr=None) assert result.shape == (0, 0) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_tensor(): @@ -74,6 +78,7 @@ def test_to_non_empty_edgeattr_with_tensor(): assert torch.equal(result, edge_attr) assert result.shape == (3, 1) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_empty_tensor(): @@ -82,6 +87,7 @@ def test_to_non_empty_edgeattr_with_empty_tensor(): assert torch.equal(result, edge_attr) assert result.shape == (0, 3) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_multi_dimensional(): @@ -90,12 +96,14 @@ def test_to_non_empty_edgeattr_with_multi_dimensional(): assert torch.equal(result, edge_attr) assert result.shape == (2, 3) + assert result.dtype == torch.float def test_empty_nodefeatures(): result = empty_nodefeatures() assert result.shape == (0, 0) + assert result.dtype == torch.float @pytest.mark.parametrize( diff --git a/hyperbench/tests/utils/node_utils_test.py b/hyperbench/tests/utils/node_utils_test.py index 76b03ad..b3cee37 100644 --- a/hyperbench/tests/utils/node_utils_test.py +++ b/hyperbench/tests/utils/node_utils_test.py @@ -1,10 +1,13 @@ import pytest import torch +import re +from typing import Any, cast from hyperbench.utils import ( assign_hyperedge_label_to_nodes, is_inductive_setting, is_transductive_setting, + validate_node_space_setting, ) @@ -89,3 +92,32 @@ def test_is_inductive_setting(node_space_setting, expected): ) def test_is_transductive_setting(node_space_setting, expected): assert is_transductive_setting(node_space_setting) == expected + + +@pytest.mark.parametrize( + "node_space_setting", + [ + pytest.param("inductive", id="inductive"), + pytest.param("transductive", id="transductive"), + ], +) +def test_validate_node_space_setting_accepts_supported_values(node_space_setting): + validate_node_space_setting(node_space_setting) + + +@pytest.mark.parametrize( + "node_space_setting", + [ + pytest.param("semi", id="unsupported_string"), + pytest.param(None, id="none"), + ], +) +def test_validate_node_space_setting_rejects_unsupported_values(node_space_setting): + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', " + f"got {node_space_setting!r}." + ), + ): + validate_node_space_setting(cast(Any, node_space_setting)) diff --git a/hyperbench/train/__init__.py b/hyperbench/train/__init__.py index 390b13d..30f3a97 100644 --- a/hyperbench/train/__init__.py +++ b/hyperbench/train/__init__.py @@ -1,6 +1,6 @@ import logging -from .latex_logger import LaTexTableLogger +from .latex_logger import LaTexTableConfig, LaTexTableLogger, colorize_metric_value from .markdown_logger import MarkdownTableLogger @@ -9,7 +9,9 @@ logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) __all__ = [ + "LaTexTableConfig", "LaTexTableLogger", "MarkdownTableLogger", "MultiModelTrainer", + "colorize_metric_value", ] diff --git a/hyperbench/train/latex_logger.py b/hyperbench/train/latex_logger.py index b2999aa..13582d5 100644 --- a/hyperbench/train/latex_logger.py +++ b/hyperbench/train/latex_logger.py @@ -2,6 +2,7 @@ from typing import Any, ClassVar, TypedDict from collections.abc import Mapping from typing_extensions import NotRequired +from hyperbench.utils import validate_is_non_negative from lightning.pytorch.loggers import Logger @@ -35,6 +36,10 @@ def colorize_metric_value( if bounds is None: return text + normalized_sort_order = sort_order.lower() + if normalized_sort_order not in ("asc", "des"): + raise ValueError(f"'sort_order' must be 'asc' or 'des', got {sort_order!r}.") + min_metric_value, max_metric_value = bounds if max_metric_value == min_metric_value: quality = 1.0 @@ -44,7 +49,7 @@ def colorize_metric_value( ) # 0..1, low->high quality = ( (1.0 - normalized_metric_value) - if sort_order.lower() == "asc" + if normalized_sort_order == "asc" else normalized_metric_value ) @@ -107,6 +112,8 @@ def __init__( options: LaTexTableConfig | None = None, ) -> None: super().__init__() + validate_is_non_negative("precision", precision) + self.__save_dir = save_dir self.__model_name = model_name self.__experiment_name = experiment_name diff --git a/hyperbench/train/markdown_logger.py b/hyperbench/train/markdown_logger.py index ed69946..3e5e056 100644 --- a/hyperbench/train/markdown_logger.py +++ b/hyperbench/train/markdown_logger.py @@ -4,6 +4,7 @@ from typing import Any, ClassVar from lightning.pytorch.loggers import Logger from collections.abc import Mapping +from hyperbench.utils import validate_is_non_negative class MarkdownTableLogger(Logger): @@ -35,6 +36,8 @@ def __init__( precision: int = 4, ) -> None: super().__init__() + validate_is_non_negative("precision", precision) + self.__save_dir = save_dir self.__model_name = model_name self.__experiment_name = experiment_name @@ -161,7 +164,8 @@ def __build_comparison_table( results: Mapping[str, Mapping[str, float]], precision: int = 4, ) -> str: - """Build a markdown comparison table from model results. + """ + Build a markdown comparison table from model results. Examples: Input: diff --git a/hyperbench/train/trainer.py b/hyperbench/train/trainer.py index 95987cb..5a5031d 100644 --- a/hyperbench/train/trainer.py +++ b/hyperbench/train/trainer.py @@ -9,8 +9,7 @@ from pathlib import Path from typing import Any -from collections.abc import Mapping -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import CSVLogger, Logger @@ -18,6 +17,7 @@ from lightning.pytorch.strategies import Strategy from hyperbench.data import DataLoader from hyperbench.types import CkptStrategy, ModelConfig, TestResult +from hyperbench.utils import validate_is_non_empty, validate_is_non_negative from hyperbench.train.markdown_logger import MarkdownTableLogger from hyperbench.train.latex_logger import LaTexTableLogger @@ -165,13 +165,17 @@ def __init__( auto_wait: bool = False, **kwargs, ) -> None: + self.auto_wait = auto_wait + self.__tensorboard_process: subprocess.Popen | None = None + validate_is_non_negative("tensorboard_port", tensorboard_port) + self.model_configs = model_configs + validate_is_non_empty("model_configs", self.model_configs) + self.log_dir = self.__setup_logdir(default_root_dir, experiment_name) self.auto_start_tensorboard = auto_start_tensorboard - self.auto_wait = auto_wait self.tensorboard_port = tensorboard_port - self.__tensorboard_process: subprocess.Popen | None = None for model_config in model_configs: if model_config.trainer is None: @@ -238,9 +242,6 @@ def fit_all( ckpt_path: CkptStrategy | None = None, verbose: bool = True, ) -> None: - if len(self.model_configs) < 1: - raise ValueError("No models to fit.") - for i, config in enumerate(self.model_configs): if not config.is_trainable: if verbose: @@ -281,11 +282,7 @@ def test_all( verbose: bool = True, verbose_loop: bool = True, ) -> Mapping[str, TestResult]: - if len(self.model_configs) < 1: - raise ValueError("No models to test.") - test_results: dict[str, TestResult] = {} - for i, config in enumerate(self.model_configs): if config.trainer is None: raise ValueError(f"Trainer not defined for model {config.full_model_name()}.") diff --git a/hyperbench/types/graph.py b/hyperbench/types/graph.py index def2f27..4e2ba40 100644 --- a/hyperbench/types/graph.py +++ b/hyperbench/types/graph.py @@ -3,7 +3,8 @@ import torch from torch import Tensor -from hyperbench import utils +from typing import cast +from hyperbench.utils import sparse_dropout, validate_is_non_negative class Graph: @@ -69,7 +70,8 @@ def remove_selfloops(self) -> Graph: # -> edges without self-loops = [[0, 1], # [2, 3]] no_selfloop_mask = edges_tensor[:, 0] != edges_tensor[:, 1] - self.edges = edges_tensor[no_selfloop_mask].tolist() + edges_without_selfloops: list[list[int]] = edges_tensor[no_selfloop_mask].tolist() + self.edges = edges_without_selfloops # Example: edge_weights = [0.5, 1.0, 0.8], no_selfloop_mask = [True, False, True] # -> edge_weights without self-loops = [0.5, 0.8] @@ -124,7 +126,7 @@ def smoothing_with_laplacian_matrix( x: The smoothed feature matrix. Size ``(num_nodes, C)``. """ if drop_rate > 0.0: - laplacian_matrix = utils.sparse_dropout(laplacian_matrix, drop_rate) + laplacian_matrix = sparse_dropout(laplacian_matrix, drop_rate) return laplacian_matrix.matmul(x) @@ -225,8 +227,8 @@ def add_selfloops( Raises: ValueError: If the input edge index has no edges (i.e., ``shape (2, 0)``). """ - if self.__edge_index.size(1) < 1: - raise ValueError("Edge index must have at least one edge to add self-loops.") + num_selfloop_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_selfloop_nodes) device = self.__edge_index.device src, dest = self.__edge_index[0], self.__edge_index[1] @@ -248,7 +250,6 @@ def add_selfloops( # -> dest = [1, 0, 3, 0, 1, 2, 3, 4, 5] # -> edge_index_with_selfloops = [[0, 1, 2, 0, 1, 2, 3, 4, 5], # [1, 0, 3, 0, 1, 2, 3, 4, 5]] - num_selfloop_nodes = self.num_nodes if num_nodes is None else num_nodes selfloop_indices = torch.arange(num_selfloop_nodes, device=device) src = torch.cat([src, selfloop_indices]) dest = torch.cat([dest, selfloop_indices]) @@ -302,9 +303,11 @@ def get_sparse_adjacency_matrix( Returns: adjacency: The sparse adjacency matrix of shape ``(num_nodes, num_nodes)``. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + device = self.__edge_index.device src, dest = self.__edge_index - num_nodes = self.num_nodes if num_nodes is None else num_nodes # Example: edge_index = [[0, 1, 2, 3], # [1, 0, 3, 2]] @@ -350,8 +353,10 @@ def get_sparse_identity_matrix(self, num_nodes: int | None = None) -> Tensor: Returns: identity: The sparse identity matrix I of shape ``(num_nodes, num_nodes)``. """ - device = self.__edge_index.device num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + + device = self.__edge_index.device # Example: num_nodes = 3 # -> identity_indices = [[0, 1, 2], @@ -387,9 +392,10 @@ def get_sparse_normalized_degree_matrix( Returns: degree_matrix: The sparse normalized degree matrix D^-1/2 of shape ``(num_nodes, num_nodes)``. """ - device = self.__edge_index.device - num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + + device = self.__edge_index.device adj_matrix = self.get_sparse_adjacency_matrix( num_nodes=num_nodes, use_edge_weights=use_edge_weights @@ -440,9 +446,10 @@ def get_sparse_normalized_laplacian( Returns: laplacian: The sparse symmetric normalized Laplacian matrix of shape ``(num_nodes, num_nodes)``. """ - self.to_undirected(with_selfloops=False) - num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + + self.to_undirected(with_selfloops=False, num_nodes=num_nodes) degree_matrix = self.get_sparse_normalized_degree_matrix(num_nodes) adj_matrix = self.get_sparse_adjacency_matrix(num_nodes) @@ -479,9 +486,10 @@ def get_sparse_normalized_gcn_laplacian( Returns: laplacian: The sparse symmetrically normalized Laplacian matrix of shape ``(num_nodes, num_nodes)``. """ - self.to_undirected(with_selfloops=True, num_nodes=num_nodes) - num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + + self.to_undirected(with_selfloops=True, num_nodes=num_nodes) degree_matrix = self.get_sparse_normalized_degree_matrix( num_nodes=num_nodes, use_edge_weights=use_edge_weights @@ -505,7 +513,8 @@ def remove_selfloops(self) -> EdgeIndex: # -> edge_index = [[0, 2, 3], # [1, 3, 2]], shape (2, |E'| = 3) keep_mask = self.__edge_index[0] != self.__edge_index[1] - self.__edge_index = self.__edge_index[:, keep_mask] + edge_index_without_selfloops: Tensor = self.__edge_index[:, keep_mask] + self.__edge_index = edge_index_without_selfloops if self.__edge_weights is not None: self.__edge_weights = self.__edge_weights[keep_mask] return self @@ -522,6 +531,9 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: Returns: edge_index: This `EdgeIndex` instance with duplicate edges removed. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + # Example: edge_index = [[0, 1, 2, 2, 0, 3, 2], # [1, 0, 3, 2, 1, 2, 2]], shape (2, |E| = 7) # -> after torch.unique(..., dim=1): @@ -530,7 +542,10 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: # Note: we call contiguous() to ensure that the resulting tensor is contiguous in memory, # which can improve performance for subsequent operations that require contiguous tensors. if self.__edge_weights is None: - self.__edge_index = torch.unique(self.__edge_index, dim=1).contiguous() + edge_index_without_duplicate_edges = cast( + Tensor, torch.unique(self.__edge_index, dim=1) + ) + self.__edge_index = edge_index_without_duplicate_edges.contiguous() return self # No edges to process, just ensure tensors are contiguous @@ -549,7 +564,6 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: # -> edge_index = [[0, 1], # [1, 2]] # -> edge_weights = [3.0, 3.0] (weights of duplicate edges are summed) - num_nodes = self.num_nodes if num_nodes is None else num_nodes coalesced = torch.sparse_coo_tensor( self.__edge_index, self.__edge_weights, @@ -576,9 +590,10 @@ def to_undirected( Returns: edge_index: This `EdgeIndex` instance converted to undirected. """ - device = self.__edge_index.device num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + device = self.__edge_index.device orig_src, orig_dest = self.__edge_index[0], self.__edge_index[1] # Encode each directed edge (u, v) as a unique scalar key u * num_nodes + v. @@ -641,7 +656,7 @@ def to_undirected( # In this way, we don't do the duplicate edge removal twice, which would be redundant and inefficient self.add_selfloops(num_nodes=num_nodes, with_duplicate_removal=False) - self.remove_duplicate_edges() + self.remove_duplicate_edges(num_nodes=num_nodes) return self @@ -659,3 +674,15 @@ def __validate_edge_weights(self, edge_weights: Tensor | None) -> None: "edge_weights must have the same number of entries as edge_index columns. " f"Got {edge_weights.size(0)} edge weights but {self.__edge_index.size(1)} edge columns." ) + + def __validate_num_nodes(self, num_nodes: int) -> None: + validate_is_non_negative("num_nodes", num_nodes) + + if self.num_edges < 1: + return + + if self.max_node_id >= num_nodes: + raise ValueError( + "'num_nodes' is too small for the edge index. " + f"Got num_nodes={num_nodes}, but max node id is {self.max_node_id}." + ) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 1713748..f09373c 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -9,11 +9,16 @@ NodeSpaceFiller, NodeSpaceSetting, clone_optional_tensor, + create_seeded_torch_generator, empty_hyperedgeindex, empty_nodefeatures, is_inductive_setting, is_transductive_setting, to_0based_ids, + validate_is_non_empty, + validate_is_non_negative, + validate_is_positive, + validate_node_space_setting, ) from hyperbench.types.hypergraph import HyperedgeIndex @@ -63,28 +68,28 @@ def __init__( y: Tensor | None = None, ): self.x: Tensor = x - self.hyperedge_index: Tensor = hyperedge_index + self.__validate_x_and_hyperedge_index_type_and_dim() self.hyperedge_weights: Tensor | None = hyperedge_weights - self.hyperedge_attr: Tensor | None = hyperedge_attr hyperedge_index_wrapper = HyperedgeIndex(hyperedge_index) - self.num_nodes: int = ( num_nodes if num_nodes is not None # There should never be isolated nodes when HData is created by Dataset - # as each isolted node gets its own self-loop hyperedge + # as each isolated node gets its own self-loop hyperedge else hyperedge_index_wrapper.num_nodes_if_isolated_exist(num_nodes=x.size(0)) ) + validate_is_non_negative("num_nodes", self.num_nodes) self.num_hyperedges: int = ( num_hyperedges if num_hyperedges is not None else hyperedge_index_wrapper.num_hyperedges ) + validate_is_non_negative("num_hyperedges", self.num_hyperedges) - self.global_node_ids: Tensor | None = ( + self.global_node_ids = ( # torch.arange is to handle isolated nodes, as they are already considered # when computing self.num_nodes via num_nodes_if_isolated_exist global_node_ids if global_node_ids is not None else torch.arange(self.num_nodes) @@ -96,6 +101,8 @@ def __init__( else torch.ones((self.num_hyperedges,), dtype=torch.float, device=self.x.device) ) + self.__validate() + self.device = self.get_device_if_all_consistent() def __repr__(self) -> str: @@ -114,16 +121,28 @@ def __repr__(self) -> str: ) @classmethod - def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) -> HData: + def cat_same_node_space( + cls, + hdatas: Sequence[HData], + x: Tensor | None = None, + global_node_ids: Tensor | None = None, + ) -> HData: """ Concatenate `HData` instances that share the same node space, meaning nodes with the same ID in different instances are the same node. This is useful when combining positive and negative hyperedges that reference the same set of nodes. Notes: - - ``x`` is derived from the instance with the largest number of nodes, if not provided explicitly. If there are conflicting features for the same node ID across instances, the features from the instance with the largest number of nodes will be used. + - ``x`` is derived from the instance with the largest number of nodes, if not provided explicitly. + If there are conflicting features for the same node ID across instances, + the features from the instance with the largest number of nodes will be used. + If ``global_node_ids`` is provided explicitly, ``x`` must also be provided to ensure consistency. - ``hyperedge_index`` is the concatenation of all input hyperedge indices. - - ``hyperedge_weights`` is the concatenation of all input hyperedge weights, if present. If some instances have hyperedge weights and others do not, the resulting ``hyperedge_weights`` will be set to ``None``. - - ``hyperedge_attr`` is the concatenation of all input hyperedge attributes, if present. If some instances have hyperedge attributes and others do not, the resulting ``hyperedge_attr`` will be set to ``None``. + - ``hyperedge_weights`` is the concatenation of all input hyperedge weights, if present. + If some instances have hyperedge weights and others do not, the resulting ``hyperedge_weights`` will be set to ``None``. + - ``hyperedge_attr`` is the concatenation of all input hyperedge attributes, if present. + If some instances have hyperedge attributes and others do not, the resulting ``hyperedge_attr`` will be set to ``None``. + - ``global_node_ids`` is derived from the instance with the largest number of nodes, if not provided explicitly. + If ``x`` is provided explicitly, ``global_node_ids`` must be provided explicitly as well to ensure consistency. - ``y`` is the concatenation of all input labels. Examples: @@ -138,25 +157,30 @@ def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) - hdatas: One or more `HData` instances sharing the same node space. x: Optional node feature matrix to use for the resulting `HData`. If ``None``, the node features from the instance with the largest number of nodes will be used. + If ``global_node_ids`` is provided explicitly, ``x`` must also be provided to ensure consistency. + global_node_ids: Optional global node IDs for the resulting `HData`. + If ``None``, the global node IDs from the instance with the largest number of nodes will be used. + If ``x`` is provided explicitly, ``global_node_ids`` must also be provided to ensure consistency. + If ``x`` is provided and there is no need for ``global_node_ids`` to preserve access to the canonical node space, + it is recommended to use arbitrary global node IDs that are consistent with the feature rows of ``x``. + For example, ``global_node_ids=torch.arange(x.size(0))``). Returns: hdata: A new `HData` with shared nodes and concatenated hyperedges. Raises: - ValueError: If the node counts do not match across inputs. + ValueError: If no HData instances are provided, if there are overlapping hyperedge IDs across instances, + or if ``x`` and ``global_node_ids`` are not both provided when one of them is provided. """ - if len(hdatas) < 1: - raise ValueError("At least one instance is required.") - - joint_hyperedge_ids = torch.cat([hdata.hyperedge_index[1].unique() for hdata in hdatas]) - unique_joint_hyperedge_ids = joint_hyperedge_ids.unique() - if unique_joint_hyperedge_ids.size(0) != joint_hyperedge_ids.size(0): - raise ValueError( - "Overlapping hyperedge IDs found across instances. Ensure each instance uses distinct hyperedge IDs." - ) + cls.__validate_can_perform_cat_same_node_space(hdatas, x, global_node_ids) hdata_with_largest_node_space = max(hdatas, key=lambda hdata: hdata.num_nodes) - new_x = (x if x is not None else hdata_with_largest_node_space.x).clone() + new_x = (x.clone() if x is not None else hdata_with_largest_node_space.x).clone() + new_global_node_ids = ( + global_node_ids.clone() + if global_node_ids is not None + else hdata_with_largest_node_space.global_node_ids.clone() + ) new_y = torch.cat([hdata.y for hdata in hdatas], dim=0) new_hyperedge_index = torch.cat([hdata.hyperedge_index for hdata in hdatas], dim=1) @@ -181,7 +205,7 @@ def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) - hyperedge_attr=new_hyperedge_attr, num_nodes=new_x.size(0), num_hyperedges=new_y.size(0), - global_node_ids=clone_optional_tensor(hdata_with_largest_node_space.global_node_ids), + global_node_ids=new_global_node_ids, y=new_y, ) @@ -213,7 +237,7 @@ def empty(cls) -> HData: hyperedge_attr=None, num_nodes=0, num_hyperedges=0, - global_node_ids=torch.empty(size=(0, 0), dtype=torch.long), + global_node_ids=None, y=None, ) @@ -282,7 +306,12 @@ def split( Returns: hdata: The splitted instance with remapped node and hyperedge IDs. + + Raises: + ValueError: If ``node_space_setting`` is not ``"transductive"`` or ``"inductive"``. """ + cls.__validate_node_space_setting_value(node_space_setting) + # Mask to keep only incidences belonging to selected hyperedges # Example: hyperedge_index = [[0, 0, 1, 2, 3, 4], # [0, 0, 0, 1, 2, 2]] @@ -343,12 +372,10 @@ def split( .item ) - split_global_node_ids = None - if hdata.global_node_ids is not None: - split_global_node_ids = hdata.global_node_ids[split_unique_node_ids] - + split_x = hdata.x[split_unique_node_ids] + split_global_node_ids = hdata.global_node_ids[split_unique_node_ids] return cls( - x=hdata.x[split_unique_node_ids], + x=split_x, hyperedge_index=split_hyperedge_index.clone(), hyperedge_weights=split_hyperedge_weights, hyperedge_attr=split_hyperedge_attr, @@ -361,7 +388,7 @@ def split( def enrich_node_features( self, enricher: NodeEnricher, - enrichment_mode: EnrichmentMode | None = None, + enrichment_mode: EnrichmentMode | None = "replace", ) -> HData: """ Enrich node features using the provided node feature enricher. @@ -371,7 +398,9 @@ def enrich_node_features( enrichment_mode: How to combine generated features with existing ``hdata.x``. ``concatenate`` appends new features as additional columns. ``replace`` substitutes ``hdata.x`` entirely. + Defaults to ``replace`` if not provided. """ + self.__validate_enrichment_mode(enrichment_mode) enriched_features = enricher.enrich(self.hyperedge_index) match enrichment_mode: @@ -435,10 +464,6 @@ def enrich_node_features_from( """ source_global_node_ids = hdata_with_features.global_node_ids source_x = hdata_with_features.x - if self.global_node_ids is None or source_global_node_ids is None: - raise ValueError( - "Both HData instances must define global_node_ids to align node features." - ) if source_x.size(0) != source_global_node_ids.size(0): raise ValueError( "Expected hdata_with_features.x rows to align with hdata_with_features.global_node_ids." @@ -509,16 +534,22 @@ def enrich_node_features_from( def enrich_hyperedge_weights( self, enricher: HyperedgeEnricher, - enrichment_mode: EnrichmentMode | None = None, + enrichment_mode: EnrichmentMode | None = "replace", ) -> HData: - """Enrich hyperedge weights using the provided hyperedge weight enricher. + """ + Enrich hyperedge weights using the provided hyperedge weight enricher. Args: enricher: An instance of HyperedgeEnricher to generate hyperedge weights from hypergraph topology. enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. ``concatenate`` appends new weights to the existing 1D tensor. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. + Defaults to ``replace`` if not provided. + + Returns: + hdata: A new `HData` with enriched hyperedge weights. """ + self.__validate_enrichment_mode(enrichment_mode) enriched_weights = enricher.enrich(self.hyperedge_index) match enrichment_mode: @@ -545,7 +576,7 @@ def enrich_hyperedge_weights( def enrich_hyperedge_attr( self, enricher: HyperedgeEnricher, - enrichment_mode: EnrichmentMode | None = None, + enrichment_mode: EnrichmentMode | None = "replace", ) -> HData: """ Enrich hyperedge features using the provided hyperedge feature enricher. @@ -555,7 +586,9 @@ def enrich_hyperedge_attr( enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``. ``concatenate`` appends new features as additional columns. ``replace`` substitutes ``hdata.hyperedge_attr`` entirely. + Defaults to ``replace`` if not provided. """ + self.__validate_enrichment_mode(enrichment_mode) enriched_features = enricher.enrich(self.hyperedge_index) match enrichment_mode: @@ -590,27 +623,32 @@ def get_device_if_all_consistent(self) -> torch.device: Raises: ValueError: If tensors are on different devices. """ - devices = {self.x.device, self.hyperedge_index.device, self.y.device} - if self.global_node_ids is not None: - devices.add(self.global_node_ids.device) + devices = { + self.x.device, + self.hyperedge_index.device, + self.global_node_ids.device, + self.y.device, + } + if self.hyperedge_attr is not None: devices.add(self.hyperedge_attr.device) if self.hyperedge_weights is not None: devices.add(self.hyperedge_weights.device) + if len(devices) > 1: raise ValueError(f"Inconsistent device placement: {devices}") return devices.pop() if len(devices) == 1 else torch.device("cpu") def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HData: + validate_is_positive("k", k) + hyperedge_index_wrapper = HyperedgeIndex( self.hyperedge_index.clone() ).remove_hyperedges_with_fewer_than_k_nodes(k) x = self.x[hyperedge_index_wrapper.node_ids] - global_node_ids = None - if self.global_node_ids is not None: - global_node_ids = self.global_node_ids[hyperedge_index_wrapper.node_ids] + global_node_ids = self.global_node_ids[hyperedge_index_wrapper.node_ids] y = self.y[hyperedge_index_wrapper.hyperedge_ids] hyperedge_attr = None @@ -657,10 +695,7 @@ def shuffle(self, seed: int | None = None) -> HData: Returns: hdata: A new `HData` instance with hyperedge IDs, ``y``, and ``hyperedge_attr`` permuted. """ - generator = torch.Generator(device=self.device) - if seed is not None: - generator.manual_seed(seed) - + generator = create_seeded_torch_generator(device=self.device, seed=seed) permutation = torch.randperm(self.num_hyperedges, generator=generator, device=self.device) # permutation[new_id] = old_id, so y[permutation] puts old labels into new slots @@ -775,11 +810,21 @@ def with_y_to(self, value: float) -> HData: ) def with_y_ones(self) -> HData: - """Return a copy of this instance with a y attribute of all ones.""" + """ + Return a copy of this instance with a y attribute of all ones. + + Returns: + hdata: A new `HData` instance with the same attributes except for y, which is set to a tensor of ones. + """ return self.with_y_to(1.0) def with_y_zeros(self) -> HData: - """Return a copy of this instance with a y attribute of all zeros.""" + """ + Return a copy of this instance with a y attribute of all zeros. + + Returns: + hdata: A new `HData` instance with the same attributes except for y, which is set to a tensor of zeros. + """ return self.with_y_to(0.0) def stats(self) -> dict[str, Any]: @@ -883,6 +928,31 @@ def stats(self) -> dict[str, Any]: "distribution_hyperedge_size_hist": distribution_hyperedge_size_hist, } + @classmethod + def __validate_can_perform_cat_same_node_space( + cls, + hdatas: Sequence[HData], + x: Tensor | None, + global_node_ids: Tensor | None, + ) -> None: + validate_is_non_empty("hdatas", hdatas) + + if x is not None and global_node_ids is None: + raise ValueError( + "If x is provided, global_node_ids must also be provided to ensure consistency." + ) + if x is None and global_node_ids is not None: + raise ValueError( + "If global_node_ids is provided, x must also be provided to ensure consistency." + ) + + joint_hyperedge_ids = torch.cat([hdata.hyperedge_index[1].unique() for hdata in hdatas]) + unique_joint_hyperedge_ids = joint_hyperedge_ids.unique() + if unique_joint_hyperedge_ids.size(0) != joint_hyperedge_ids.size(0): + raise ValueError( + "Overlapping hyperedge IDs found across instances. Ensure each instance uses distinct hyperedge IDs." + ) + def __to_fill_features( self, fill_value: NodeSpaceFiller | None, @@ -915,14 +985,137 @@ def __to_fill_features( ) return fill_features + def __validate(self) -> None: + self.__validate_x() + self.__validate_hyperedge_index() + self.__validate_hyperedge_attr() + self.__validate_hyperedge_weights() + self.__validate_global_node_ids() + self.__validate_labels() + + def __validate_hyperedge_attr(self) -> None: + if self.hyperedge_attr is None: + return + + if self.hyperedge_attr.dim() != 2: + raise ValueError( + f"hyperedge_attr must be a 2D tensor, got shape {tuple(self.hyperedge_attr.shape)}." + ) + if self.hyperedge_attr.size(0) != self.num_hyperedges: + raise ValueError( + "hyperedge_attr must have one row per hyperedge. " + f"Got size={self.hyperedge_attr.size(0)} but num_hyperedges={self.num_hyperedges}." + ) + + def __validate_hyperedge_index(self) -> None: + if self.hyperedge_index.dtype != torch.long: + raise ValueError( + f"hyperedge_index must have dtype torch.long, got {self.hyperedge_index.dtype}." + ) + if self.hyperedge_index.numel() > 0 and bool((self.hyperedge_index < 0).any()): + raise ValueError("hyperedge_index cannot contain negative node or hyperedge IDs.") + + unique_node_count = self.hyperedge_index[0].unique().size(0) + if unique_node_count > self.num_nodes: + raise ValueError( + "num_nodes is too small for hyperedge_index. " + f"Got num_nodes={self.num_nodes}, but hyperedge_index contains " + f"{unique_node_count} unique node IDs." + ) + + unique_hyperedge_count = self.hyperedge_index[1].unique().size(0) + if unique_hyperedge_count > self.num_hyperedges: + raise ValueError( + "num_hyperedges is too small for hyperedge_index. " + f"Got num_hyperedges={self.num_hyperedges}, but hyperedge_index contains " + f"{unique_hyperedge_count} unique hyperedge IDs." + ) + + def __validate_hyperedge_weights(self) -> None: + if self.hyperedge_weights is None: + return + + if self.hyperedge_weights.dim() != 1: + raise ValueError( + f"hyperedge_weights must be a 1D tensor, got shape {tuple(self.hyperedge_weights.shape)}." + ) + if self.hyperedge_weights.size(0) != self.num_hyperedges: + raise ValueError( + "hyperedge_weights must have one entry per hyperedge. " + f"Got size={self.hyperedge_weights.size(0)} but num_hyperedges={self.num_hyperedges}." + ) + + def __validate_global_node_ids(self) -> None: + if self.global_node_ids.dim() != 1: + raise ValueError( + f"global_node_ids must be a 1D tensor, got shape {tuple(self.global_node_ids.shape)}." + ) + if self.global_node_ids.size(0) != self.num_nodes: + raise ValueError( + "global_node_ids must have one entry per node. " + f"Got size={self.global_node_ids.size(0)} but num_nodes={self.num_nodes}." + ) + + if self.global_node_ids.dtype != torch.long: + raise ValueError( + f"global_node_ids must have dtype torch.long, got {self.global_node_ids.dtype}." + ) + + def __validate_labels(self) -> None: + if self.y.dim() != 1: + raise ValueError(f"y must be a 1D tensor, got shape {tuple(self.y.shape)}.") + if self.y.size(0) != self.num_hyperedges: + raise ValueError( + "y must have one entry per hyperedge. " + f"Got {self.y.size(0)} entries but num_hyperedges={self.num_hyperedges}." + ) + + def __validate_x(self) -> None: + if self.x.size(0) not in (0, self.num_nodes): + raise ValueError( + "x must have one feature row per node, or be 'torch.empty((0, 0))' if there are no nodes. " + f"Got x.shape={tuple(self.x.shape)} but num_nodes={self.num_nodes}." + ) + def __validate_node_space_setting( self, node_space_setting: NodeSpaceSetting, fill_value: NodeSpaceFiller | None, ) -> None: + validate_node_space_setting(node_space_setting) + if is_transductive_setting(node_space_setting) and fill_value is not None: raise ValueError( "fill_value cannot be provided when node_space_setting='transductive'." ) if is_inductive_setting(node_space_setting) and fill_value is None: raise ValueError("fill_value must be provided when node_space_setting='inductive'.") + + @staticmethod + def __validate_enrichment_mode(enrichment_mode: EnrichmentMode | None) -> None: + if enrichment_mode is None or enrichment_mode in ("replace", "concatenate"): + return + + raise ValueError( + f"enrichment_mode must be one of 'replace', 'concatenate', or None, got {enrichment_mode!r}." + ) + + @staticmethod + def __validate_node_space_setting_value(node_space_setting: NodeSpaceSetting) -> None: + if is_transductive_setting(node_space_setting) or is_inductive_setting(node_space_setting): + return + + raise ValueError( + "node_space_setting must be one of 'transductive' or 'inductive', " + f"got {node_space_setting!r}." + ) + + def __validate_x_and_hyperedge_index_type_and_dim(self) -> None: + if self.x.dim() != 2: + raise ValueError(f"x must be a 2D tensor, got shape {tuple(self.x.shape)}.") + + if self.hyperedge_index.dim() != 2 or self.hyperedge_index.size(0) != 2: + raise ValueError( + "hyperedge_index must have shape (2, num_incidences), got " + f"{tuple(self.hyperedge_index.shape)}." + ) diff --git a/hyperbench/types/hypergraph.py b/hyperbench/types/hypergraph.py index 39abd2e..64c938e 100644 --- a/hyperbench/types/hypergraph.py +++ b/hyperbench/types/hypergraph.py @@ -4,8 +4,14 @@ from itertools import combinations from torch import Tensor -from typing import Any, Literal, TypeAlias -from hyperbench.utils import sparse_dropout, to_0based_ids, create_seeded_torch_generator +from typing import Any, Literal, TypeAlias, cast +from hyperbench.utils import ( + create_seeded_torch_generator, + sparse_dropout, + to_0based_ids, + validate_is_non_negative, + validate_is_positive, +) from hyperbench.types.graph import EdgeIndex, Graph @@ -224,6 +230,8 @@ def neighbors_of(self, node: int) -> Neighborhood: Returns: neighbors: A set of neighbor node IDs (excluding the node itself). """ + validate_is_non_negative("node", node) + neighbors: Neighborhood = set() for hyperedge in self.hyperedges: if node in hyperedge: @@ -444,7 +452,16 @@ def num_incidences(self) -> int: return self.__hyperedge_index.size(1) def nodes_in(self, hyperedge_id: int) -> list[int]: - """Return the list of node IDs that belong to the given hyperedge.""" + """ + Return the list of node IDs that belong to the given hyperedge. + + Args: + hyperedge_id: The ID of the hyperedge to query. + + Returns: + node_ids: A list of node IDs that belong to the specified hyperedge. + """ + validate_is_non_negative("hyperedge_id", hyperedge_id) return self.__hyperedge_index[0, self.__hyperedge_index[1] == hyperedge_id].tolist() def num_nodes_if_isolated_exist(self, num_nodes: int) -> int: @@ -474,6 +491,8 @@ def get_clique_expansion_adjacency_list(self, num_nodes: int | None = None) -> l adjacency: A list where ``adjacency[node_id]`` is the set of nodes adjacent to ``node_id``. """ num_nodes = num_nodes if num_nodes is not None else self.num_nodes + self.__validate_num_nodes(num_nodes) + adjacency_list: list[set[int]] = [set() for _ in range(num_nodes)] for hyperedge_id in self.hyperedge_ids.tolist(): @@ -513,24 +532,12 @@ def get_sparse_incidence_matrix( Raises: ValueError: If the provided dimensions cannot contain the raw node or hyperedge IDs. """ - device = self.__hyperedge_index.device num_nodes = num_nodes if num_nodes is not None else self.num_nodes num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges - if self.num_incidences > 0: - max_node_id = int(self.all_node_ids.max().item()) - if max_node_id >= num_nodes: - raise ValueError( - "num_nodes is too small for the hyperedge index. " - f"Got num_nodes={num_nodes}, but max node id is {max_node_id}." - ) - max_hyperedge_id = int(self.all_hyperedge_ids.max().item()) - if max_hyperedge_id >= num_hyperedges: - raise ValueError( - "num_hyperedges is too small for the hyperedge index. " - f"Got num_hyperedges={num_hyperedges}, " - f"but max hyperedge id is {max_hyperedge_id}." - ) + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) + device = self.__hyperedge_index.device incidence_values = torch.ones(self.num_incidences, dtype=torch.float, device=device) incidence_indices = torch.stack([self.all_node_ids, self.all_hyperedge_ids], dim=0) incidence_matrix = torch.sparse_coo_tensor( @@ -557,8 +564,15 @@ def get_sparse_normalized_node_degree_matrix( Returns: degree_matrix: The sparse diagonal matrix of shape ``(num_nodes, num_nodes)``. """ + num_nodes = num_nodes if num_nodes is not None else int(incidence_matrix.size(0)) + self.__validate_num_nodes(num_nodes) + self.__validate_degree_matrix_dimension( + name="num_nodes", + value=num_nodes, + expected=int(incidence_matrix.size(0)), + ) + device = self.__hyperedge_index.device - num_nodes = num_nodes if num_nodes is not None else self.num_nodes degrees = torch.sparse.sum(incidence_matrix, dim=1).to_dense() normalized_degrees = degrees.pow(power) @@ -654,8 +668,17 @@ def get_sparse_normalized_hyperedge_degree_matrix( Returns: degree_matrix: The sparse diagonal matrix D_e^-1 of shape ``(num_hyperedges, num_hyperedges)``. """ + num_hyperedges = ( + num_hyperedges if num_hyperedges is not None else int(incidence_matrix.size(1)) + ) + self.__validate_num_hyperedges(num_hyperedges) + self.__validate_degree_matrix_dimension( + name="num_hyperedges", + value=num_hyperedges, + expected=int(incidence_matrix.size(1)), + ) + device = self.__hyperedge_index.device - num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges # Example: hyperedge_index = [[0, 1, 2, 0], # [0, 0, 0, 1]] @@ -710,6 +733,8 @@ def get_sparse_hgnn_smoothing_matrix( """ num_nodes = num_nodes if num_nodes is not None else self.num_nodes num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) incidence_matrix = self.get_sparse_incidence_matrix(num_nodes, num_hyperedges) node_degree_matrix = self.get_sparse_symnormalized_node_degree_matrix( @@ -756,6 +781,8 @@ def get_sparse_hgnnp_smoothing_matrix( """ num_nodes = num_nodes if num_nodes is not None else self.num_nodes num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) incidence_matrix = self.get_sparse_incidence_matrix(num_nodes, num_hyperedges) node_degree_matrix = self.get_sparse_rownormalized_node_degree_matrix( @@ -788,8 +815,13 @@ def reduce(self, strategy: Literal["clique_expansion"], **kwargs) -> Tensor: edge_index: The edge index of the reduced graph. Size ``(2, num_edges)``. """ match strategy: - case _: + case "clique_expansion": return self.reduce_to_edge_index_on_clique_expansion(**kwargs) + case _: + raise ValueError( + f"Unsupported reduction strategy: {strategy}. " + "Supported strategies: ['clique_expansion']" + ) def reduce_to_edge_index_on_clique_expansion( self, @@ -811,6 +843,9 @@ def reduce_to_edge_index_on_clique_expansion( Returns: edge_index: The edge index of the clique-expanded graph. Size ``(2, |E'|)``. """ + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) + incidence_matrix = self.get_sparse_incidence_matrix( num_nodes=num_nodes, num_hyperedges=num_hyperedges, @@ -926,7 +961,10 @@ def remove_duplicate_edges(self) -> HyperedgeIndex: # Note: we need to call contiguous() after torch.unique() to ensure # the resulting tensor is contiguous in memory, which is important for efficient indexing # and further operations (e.g., searchsorted) - self.__hyperedge_index = torch.unique(self.__hyperedge_index, dim=1).contiguous() + hyperedge_index_without_duplicates = cast( + Tensor, torch.unique(self.__hyperedge_index, dim=1) + ) + self.__hyperedge_index = hyperedge_index_without_duplicates.contiguous() return self def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HyperedgeIndex: @@ -957,6 +995,8 @@ def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HyperedgeIndex: Returns: hyperedge_index: A new `HyperedgeIndex` instance with hyperedges containing fewer than k nodes. """ + validate_is_positive("k", k) + _, idx_to_hyperedge_id, num_nodes_per_hyperedge = torch.unique( self.all_hyperedge_ids, return_inverse=True, @@ -998,3 +1038,41 @@ def to_0based( self.__hyperedge_index[1] = to_0based_ids(self.all_hyperedge_ids, hyperedge_ids_to_rebase) return self + + def __validate_num_hyperedges(self, num_hyperedges: int | None) -> None: + if num_hyperedges is None: + return + validate_is_non_negative("num_hyperedges", num_hyperedges) + + if self.all_hyperedge_ids.numel() < 1: + return + + max_hyperedge_id = int(self.all_hyperedge_ids.max().item()) + if max_hyperedge_id >= num_hyperedges: + raise ValueError( + f"'num_hyperedges' is too small for the hyperedge index. " + f"Got num_hyperedges={num_hyperedges}, but max hyperedge id is {max_hyperedge_id}." + ) + + def __validate_num_nodes(self, num_nodes: int | None) -> None: + if num_nodes is None: + return + validate_is_non_negative("num_nodes", num_nodes) + + if self.all_node_ids.numel() < 1: + return + + max_node_id = int(self.all_node_ids.max().item()) + if max_node_id >= num_nodes: + raise ValueError( + f"'num_nodes' is too small for the hyperedge index. " + f"Got num_nodes={num_nodes}, but max node id is {max_node_id}." + ) + + def __validate_degree_matrix_dimension(self, name: str, value: int, expected: int) -> None: + validate_is_non_negative(name, value) + if value != expected: + raise ValueError( + f"'{name}' must match the incidence matrix dimension. " + f"Got {name}={value}, but expected {expected}." + ) diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 9e17e99..ba40dea 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -5,6 +5,13 @@ empty_nodefeatures, to_non_empty_edgeattr, to_0based_ids, + validate_is_between, + validate_is_finite, + validate_is_finite_when_provided, + validate_is_non_empty, + validate_is_non_negative, + validate_is_positive, + validate_split_ratios, ) from .hif_utils import ( @@ -32,6 +39,7 @@ assign_hyperedge_label_to_nodes, is_inductive_setting, is_transductive_setting, + validate_node_space_setting, ) from .random_utils import create_seeded_torch_generator @@ -87,6 +95,14 @@ "validate_hif_data", "validate_hif_json", "validate_http_url", + "validate_is_between", + "validate_is_finite", + "validate_is_finite_when_provided", + "validate_is_non_empty", + "validate_is_non_negative", + "validate_is_positive", + "validate_node_space_setting", + "validate_split_ratios", "write_dataset_to_disk_as_zst", "write_zst_file_to_disk", ] diff --git a/hyperbench/utils/data_utils.py b/hyperbench/utils/data_utils.py index 66c49ef..a929453 100644 --- a/hyperbench/utils/data_utils.py +++ b/hyperbench/utils/data_utils.py @@ -1,5 +1,7 @@ +import math import torch +from collections.abc import Sequence from torch import Tensor @@ -8,15 +10,15 @@ def clone_optional_tensor(tensor: Tensor | None) -> Tensor | None: def empty_nodefeatures() -> Tensor: - return torch.empty((0, 0)) + return torch.empty((0, 0), dtype=torch.float) def empty_hyperedgeindex() -> Tensor: - return torch.empty((2, 0)) + return torch.empty((2, 0), dtype=torch.long) def empty_edgeattr(num_edges: int) -> Tensor: - return torch.empty((num_edges, 0)) + return torch.empty((num_edges, 0), dtype=torch.float) def to_non_empty_edgeattr(edge_attr: Tensor | None) -> Tensor: @@ -53,3 +55,54 @@ def to_0based_ids(original_ids: Tensor, ids_to_rebase: Tensor | None = None) -> ids_to_keep = original_ids[keep_mask] sorted_unique_ids_to_rebase = ids_to_rebase.unique(sorted=True) return torch.searchsorted(sorted_unique_ids_to_rebase, ids_to_keep) + + +def validate_is_between( + name: str, + value: int | float, + min_value: int | float, + max_value: int | float, +) -> None: + if not math.isfinite(value) or value < min_value or value > max_value: + raise ValueError(f"{name!r} must be between {min_value} and {max_value}, got {value}.") + + +def validate_is_finite(name: str, value: int | float) -> None: + if not math.isfinite(value): + raise ValueError(f"{name!r} must be finite, got {value}.") + + +def validate_is_finite_when_provided(name: str, value: int | float | None) -> None: + if value is not None and not math.isfinite(value): + raise ValueError(f"{name!r} must be finite when provided, got {value}.") + + +def validate_is_non_negative(name: str, value: int | float) -> None: + if value < 0: + raise ValueError(f"{name!r} must be non-negative, got {value}.") + + +def validate_is_positive(name: str, value: int | float) -> None: + if value <= 0: + raise ValueError(f"{name!r} must be positive, got {value}.") + + +def validate_is_non_empty(name: str, value: Sequence) -> None: + if len(value) < 1: + raise ValueError(f"{name!r} cannot be empty.") + + +def validate_split_ratios(ratios: list[int | float]) -> None: + validate_is_non_empty("ratios", ratios) + + for ratio in ratios: + validate_is_finite("ratios", ratio) + validate_is_positive("ratios", ratio) + + # Allow small imprecision in sum of ratios, but raise error if it's significant + # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid) + # ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError) + # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) + ratio_sum = float(sum(ratios)) + if abs(ratio_sum - 1.0) > 1e-6: + raise ValueError(f"'ratios' must sum to 1.0, got {ratio_sum}.") diff --git a/hyperbench/utils/node_utils.py b/hyperbench/utils/node_utils.py index ac3b50a..9e1203f 100644 --- a/hyperbench/utils/node_utils.py +++ b/hyperbench/utils/node_utils.py @@ -26,3 +26,21 @@ def is_inductive_setting(node_space_setting: NodeSpaceSetting | None) -> bool: def is_transductive_setting(node_space_setting: NodeSpaceSetting | None) -> bool: return node_space_setting == "transductive" + + +def validate_node_space_setting(node_space_setting: NodeSpaceSetting) -> None: + """ + Validate that the node space setting is one of the supported values. + + Args: + node_space_setting: The node space setting to validate, which should be either "inductive" or "transductive". + + Raises: + ValueError: If the node space setting is not one of the supported values. + """ + if is_transductive_setting(node_space_setting) or is_inductive_setting(node_space_setting): + return + + raise ValueError( + f"node_space_setting must be one of 'transductive' or 'inductive', got {node_space_setting!r}." + )