From eed414aea82e98f41b37dea863f658975387541a Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 15:38:37 +0200 Subject: [PATCH 01/15] fix: add validation --- hyperbench/data/dataset.py | 6 +- hyperbench/data/hif.py | 5 +- hyperbench/tests/data/dataset_test.py | 49 ++- hyperbench/tests/data/loader_test.py | 6 +- hyperbench/tests/types/hdata_test.py | 359 +++++++++++++++++++++- hyperbench/tests/utils/data_utils_test.py | 12 +- hyperbench/types/hdata.py | 229 +++++++++++++- hyperbench/utils/data_utils.py | 6 +- 8 files changed, 609 insertions(+), 63 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 3d536a5a..9203a38f 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -210,9 +210,9 @@ def enrich_hyperedge_weights( """Enrich hyperedge weights using the provided hyperedge weight enricher. Args: - enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology. - enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_weights``. - ``concatenate`` appends new features as additional columns. + enricher: An instance of HyperedgeEnricher to generate structural hyperedge weights from hypergraph topology. + enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. + ``concatenate`` appends new weights to the existing 1D tensor. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. """ self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode) diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 60036f46..3b5312f8 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -152,8 +152,7 @@ def __process_hyperedge_attr( hyperedge_id_to_idx: dict[Any, int], num_hyperedges: int, ) -> Tensor | None: - # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] - hyperedge_attr = None + hyperedge_attr = None # shape [num_hyperedges, num_hyperedge_attributes] has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 has_any_hyperedge_attrs = has_hyperedges and any( "attrs" in edge for edge in hypergraph.hyperedges @@ -234,7 +233,7 @@ def __process_hyperedge_weights( edge_attrs = hyperedge_id_to_attrs.get(edge_id, {}) weights.append(float(edge_attrs.get("weight", 1.0))) - return torch.tensor(weights, dtype=torch.float) + return torch.tensor(weights, dtype=torch.float) # shape [num_hyperedges,] class HIFLoader: diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index a0426911..547c0527 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -28,7 +28,7 @@ def mock_hdata() -> HData: def mock_hdata_with_hyperedge_attr() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_attr = torch.ones((3, 1), dtype=torch.float) + hyperedge_attr = torch.ones((2, 1), dtype=torch.float) return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) @@ -36,7 +36,7 @@ def mock_hdata_with_hyperedge_attr() -> HData: def mock_hdata_with_hyperedge_weights() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0], dtype=torch.float) return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) @@ -51,7 +51,7 @@ def mock_hdata_isolated_hyperedges() -> HData: def mock_hdata_three_nodes_weighted() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([[1.0], [2.0]], dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0], dtype=torch.float) return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) @@ -86,8 +86,8 @@ def mock_hdata_no_hyperedge_attr() -> HData: def mock_hdata_multiple_hyperedge_attrs() -> HData: x = torch.ones((4, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) - hyperedge_weights = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) - hyperedge_attr = torch.ones((4, 1), dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) + hyperedge_attr = torch.ones((3, 1), dtype=torch.float) return HData( x=x, hyperedge_index=hyperedge_index, @@ -106,7 +106,7 @@ def mock_hdata_transductive_multiple_hyperedges_attrs() -> HData: ], dtype=torch.long, ) - hyperedge_weights = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) hyperedge_attr = torch.ones((3, 1), dtype=torch.float) return HData( x=x, @@ -120,7 +120,7 @@ def mock_hdata_transductive_multiple_hyperedges_attrs() -> HData: def mock_hdata_two_hyperedge_attrs_weighted() -> HData: x = torch.ones((3, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([[1.0, 2.0], [3.0, 0.1]], dtype=torch.float) + hyperedge_weights = torch.tensor([1.0, 3.0], dtype=torch.float) return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) @@ -180,9 +180,8 @@ def test_dataset_process_with_edge_attributes(mock_hdata_two_hyperedge_attrs_wei assert dataset.hdata.hyperedge_index.shape[1] == 3 assert dataset.hdata.hyperedge_attr is None assert dataset.hdata.hyperedge_weights is not None - assert dataset.hdata.hyperedge_weights.shape == (2, 2) - assert torch.allclose(dataset.hdata.hyperedge_weights[0], torch.tensor([1.0, 2.0])) - assert torch.allclose(dataset.hdata.hyperedge_weights[1], torch.tensor([3.0, 0.1])) + assert dataset.hdata.hyperedge_weights.shape == (2,) + assert torch.allclose(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 3.0])) def test_dataset_process_without_edge_attributes(mock_hdata_no_hyperedge_attr): @@ -450,7 +449,7 @@ def test_enrich_hyperedge_attr_replace(mock_hdata): dataset = Dataset.from_hdata(mock_hdata) enricher = MagicMock(spec=HyperedgeEnricher) - enriched_x = torch.randn(3, 4) + enriched_x = torch.randn(2, 4) enricher.enrich.return_value = enriched_x dataset.enrich_hyperedge_attr(enricher) @@ -468,7 +467,7 @@ def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): original_hyperedge_attr = original_hyperedge_attr.clone() enricher = MagicMock(spec=HyperedgeEnricher) - enriched_x = torch.randn(3, 4) + enriched_x = torch.randn(2, 4) enricher.enrich.return_value = enriched_x dataset.enrich_hyperedge_attr(enricher, enrichment_mode="concatenate") @@ -478,14 +477,14 @@ def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): hyperedge_attr = dataset.hdata.hyperedge_attr assert hyperedge_attr is not None assert torch.equal(hyperedge_attr, expected_x) - assert hyperedge_attr.shape == (3, 5) # 1 original + 4 enriched + assert hyperedge_attr.shape == (2, 5) # 1 original + 4 enriched def test_enrich_hyperedge_weights_replace(mock_hdata): dataset = Dataset.from_hdata(mock_hdata) enricher = MagicMock(spec=HyperedgeEnricher) - enriched_weights = torch.randn(3) + enriched_weights = torch.randn(2) enricher.enrich.return_value = enriched_weights dataset.enrich_hyperedge_weights(enricher) @@ -496,24 +495,24 @@ def test_enrich_hyperedge_weights_replace(mock_hdata): assert torch.equal(hyperedge_weights, enriched_weights) -def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights): +def test_enrich_hyperedge_weights_concatenate( + mock_hdata_with_hyperedge_weights, +): dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_weights) - original_weights = dataset.hdata.hyperedge_weights - assert original_weights is not None - original_weights = original_weights.clone() + dataset.hdata.hyperedge_index = torch.tensor([[0, 1, 2, 0], [0, 0, 1, 2]]) + dataset.hdata.num_hyperedges = 3 + dataset.hdata.y = torch.ones(3, dtype=torch.float) enricher = MagicMock(spec=HyperedgeEnricher) - enriched_weights = torch.randn(3) + enriched_weights = torch.tensor([3.0]) enricher.enrich.return_value = enriched_weights dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") - enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_weights.hyperedge_index) - expected_weights = torch.cat([original_weights, enriched_weights], dim=0) - hyperedge_weights = dataset.hdata.hyperedge_weights - assert hyperedge_weights is not None - assert torch.equal(hyperedge_weights, expected_weights) - assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched + enricher.enrich.assert_called_once_with(dataset.hdata.hyperedge_index) + assert dataset.hdata.hyperedge_weights is not None + assert torch.equal(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 2.0, 3.0])) + assert dataset.hdata.hyperedge_weights.shape == (3,) @pytest.mark.parametrize( diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index 9057f621..6aa9c79d 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -13,7 +13,7 @@ def mock_dataset_single_sample(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) hyperedge_attr = torch.tensor([[0.5], [0.7]]) - hyperedge_weights = torch.tensor([[0.8], [0.9]]) + hyperedge_weights = torch.tensor([0.8, 0.9]) hdata = HData( x=x, hyperedge_index=hyperedge_index, @@ -36,7 +36,7 @@ def mock_dataset_single_sample_with_weights(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) hyperedge_attr = torch.tensor([[0.5], [0.7]]) - hyperedge_weights = torch.tensor([[0.8], [0.9]]) + hyperedge_weights = torch.tensor([0.8, 0.9]) hdata = HData( x=x, hyperedge_index=hyperedge_index, @@ -122,7 +122,7 @@ def test_collate_single_sample_with_weights(mock_dataset_single_sample_with_weig expected_x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) expected_hyperedge_attr = torch.tensor([[0.5], [0.7]]) expected_hyeredge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]]) - expected_hyperedge_weights = torch.tensor([[0.8], [0.9]]) + expected_hyperedge_weights = torch.tensor([0.8, 0.9]) assert torch.equal(batched.x, expected_x) assert batched.hyperedge_index.shape == (2, 4) diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 57e2be40..715d5d35 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -3,7 +3,7 @@ import torch from unittest.mock import MagicMock -from typing import cast +from typing import Any, cast from torch import Tensor from hyperbench import utils from hyperbench.data import HyperedgeEnricher, NegativeSampler, NodeEnricher, RandomNegativeSampler @@ -91,20 +91,32 @@ def sample(data: HData, seed: int | None = None) -> HData: "explicit_num_nodes, expected_num_nodes", [ pytest.param(None, 7, id="inferred_from_x"), - pytest.param(10, 10, id="explicit_overrides_x"), - pytest.param(0, 0, id="explicit_zero"), + pytest.param(10, 10, id="explicit_allows_isolated_nodes"), ], ) def test_init_num_nodes(explicit_num_nodes, expected_num_nodes): hyperedge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [0, 0, 0, 0, 0, 0, 0]]) - num_nodes = hyperedge_index[0].size(0) - x = torch.randn(num_nodes, 3) + x = torch.randn(expected_num_nodes, 3) data = HData(x=x, hyperedge_index=hyperedge_index, num_nodes=explicit_num_nodes) assert data.num_nodes == expected_num_nodes +def test_init_raises_when_num_nodes_is_too_small_for_hyperedge_index(): + x = torch.randn(2, 2) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 0]]) + + with pytest.raises( + ValueError, + match=re.escape( + "num_nodes is too small for hyperedge_index. " + "Got num_nodes=2, but hyperedge_index contains 3 unique node IDs." + ), + ): + HData(x=x, hyperedge_index=hyperedge_index, num_nodes=2) + + @pytest.mark.parametrize( "hyperedge_index, explicit_num_hyperedges, expected_num_hyperedges", [ @@ -135,6 +147,19 @@ def test_init_num_hyperedges(hyperedge_index, explicit_num_hyperedges, expected_ assert data.num_hyperedges == expected_num_hyperedges +def test_init_raises_when_num_hyperedges_is_too_small_for_hyperedge_index(): + x = torch.randn(2, 2) + + with pytest.raises( + ValueError, + match=re.escape( + "num_hyperedges is too small for hyperedge_index. " + "Got num_hyperedges=2, but hyperedge_index contains 3 unique hyperedge IDs." + ), + ): + HData(x=x, hyperedge_index=torch.tensor([[0, 1, 0], [0, 1, 2]]), num_hyperedges=2) + + def test_init_default_y_is_ones(): x = torch.randn(3, 2) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) @@ -170,6 +195,241 @@ def test_init_hyperedge_attr_defaults_to_none(): assert data.hyperedge_attr is None +@pytest.mark.parametrize( + "kwargs, expected_message", + [ + pytest.param( + {"x": torch.randn(3), "hyperedge_index": torch.tensor([[0, 1], [0, 0]])}, + "x must be a 2D tensor, got shape (3,).", + id="x_not_2d", + ), + pytest.param( + {"x": torch.randn(3, 2), "hyperedge_index": torch.tensor([0, 1])}, + "hyperedge_index must have shape (2, num_incidences), got (2,).", + id="hyperedge_index_not_2d", + ), + pytest.param( + {"x": torch.randn(3, 2), "hyperedge_index": torch.tensor([[0, 1, 2]])}, + "hyperedge_index must have shape (2, num_incidences), got (1, 3).", + id="hyperedge_index_wrong_rows", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0.0, 1.0], [0.0, 0.0]]), + }, + "hyperedge_index must have dtype torch.long, got torch.float32.", + id="hyperedge_index_not_long", + ), + pytest.param( + {"x": torch.randn(3, 2), "hyperedge_index": torch.tensor([[-1, 1], [0, 0]])}, + "hyperedge_index cannot contain negative node or hyperedge IDs.", + id="hyperedge_index_negative_id", + ), + pytest.param( + { + "x": torch.randn(2, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + }, + ( + "x must have one feature row per node, or be 'torch.empty((0, 0))' if there are no " + "nodes. Got x.shape=(2, 2) but num_nodes=3." + ), + id="x_rows_do_not_match_num_nodes", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + "global_node_ids": torch.tensor([[0, 1, 2]]), + }, + "global_node_ids must be a 1D tensor, got shape (1, 3).", + id="global_node_ids_not_1d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + "global_node_ids": torch.tensor([0.0, 1.0, 2.0]), + }, + "global_node_ids must have dtype torch.long, got torch.float32.", + id="global_node_ids_not_long", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 0]]), + "global_node_ids": torch.tensor([0, 1]), + }, + "global_node_ids must have one entry per node. Got size=2 but num_nodes=3.", + id="global_node_ids_wrong_length", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "y": torch.tensor([[1.0, 0.0]]), + }, + "y must be a 1D tensor, got shape (1, 2).", + id="y_not_1d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "y": torch.tensor([1.0]), + }, + "y must have one entry per hyperedge. Got 1 entries but num_hyperedges=2.", + id="y_wrong_length", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_weights": torch.tensor([[0.25, 0.75]]), + }, + "hyperedge_weights must be a 1D tensor, got shape (1, 2).", + id="hyperedge_weights_not_1d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_weights": torch.tensor([0.25]), + }, + ( + "hyperedge_weights must have one entry per hyperedge. " + "Got size=1 but num_hyperedges=2." + ), + id="hyperedge_weights_wrong_length", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_attr": torch.tensor([1.0, 2.0]), + }, + "hyperedge_attr must be a 2D tensor, got shape (2,).", + id="hyperedge_attr_not_2d", + ), + pytest.param( + { + "x": torch.randn(3, 2), + "hyperedge_index": torch.tensor([[0, 1, 2], [0, 0, 1]]), + "hyperedge_attr": torch.randn(1, 4), + }, + "hyperedge_attr must have one row per hyperedge. Got size=1 but num_hyperedges=2.", + id="hyperedge_attr_wrong_rows", + ), + ], +) +def test_init_validates_input_values(kwargs, expected_message): + with pytest.raises(ValueError, match=re.escape(expected_message)): + HData(**kwargs) + + +@pytest.mark.parametrize( + "kwargs, expected_message", + [ + pytest.param( + {"x": cast(Any, [[1.0], [2.0]]), "hyperedge_index": torch.tensor([[0], [0]])}, + "x must be a torch.Tensor.", + id="x_not_tensor", + ), + pytest.param( + {"x": torch.randn(2, 1), "hyperedge_index": cast(Any, [[0], [0]])}, + "hyperedge_index must be a torch.Tensor.", + id="hyperedge_index_not_tensor", + ), + pytest.param( + { + "x": torch.randn(2, 1), + "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), + "hyperedge_attr": cast(Any, [[1.0]]), + }, + "hyperedge_attr must be a torch.Tensor.", + id="hyperedge_attr_not_tensor", + ), + pytest.param( + { + "x": torch.randn(2, 1), + "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), + "hyperedge_weights": cast(Any, [1.0]), + }, + "hyperedge_weights must be a torch.Tensor.", + id="hyperedge_weights_not_tensor", + ), + pytest.param( + { + "x": torch.randn(2, 1), + "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), + "global_node_ids": cast(Any, [0, 1]), + }, + "global_node_ids must be a torch.Tensor.", + id="global_node_ids_not_tensor", + ), + pytest.param( + { + "x": torch.randn(2, 1), + "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), + "y": cast(Any, [1.0]), + }, + "y must be a torch.Tensor.", + id="y_not_tensor", + ), + pytest.param( + { + "x": torch.randn(2, 1), + "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), + "num_nodes": cast(Any, 2.0), + }, + "num_nodes must be an int.", + id="num_nodes_not_int", + ), + pytest.param( + { + "x": torch.randn(2, 1), + "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), + "num_hyperedges": cast(Any, True), + }, + "num_hyperedges must be an int.", + id="num_hyperedges_bool", + ), + ], +) +def test_init_validates_runtime_types(kwargs, expected_message): + with pytest.raises(TypeError, match=re.escape(expected_message)): + HData(**kwargs) + + +@pytest.mark.parametrize( + "kwargs, expected_message", + [ + pytest.param( + { + "x": torch.empty((0, 1)), + "hyperedge_index": torch.empty((2, 0), dtype=torch.long), + "num_nodes": -1, + }, + "num_nodes must be non-negative, got -1.", + id="negative_num_nodes", + ), + pytest.param( + { + "x": torch.empty((0, 1)), + "hyperedge_index": torch.empty((2, 0), dtype=torch.long), + "num_hyperedges": -1, + }, + "num_hyperedges must be non-negative, got -1.", + id="negative_num_hyperedges", + ), + ], +) +def test_init_validates_non_negative_number_of_nodes_and_hyperedges(kwargs, expected_message): + with pytest.raises(ValueError, match=re.escape(expected_message)): + HData(**kwargs) + + def test_repr_contains_class_name_and_fields(mock_hdata): r = repr(mock_hdata) @@ -441,25 +701,85 @@ def test_cat_same_node_space_concatenates_labels(): def test_cat_same_node_space_uses_largest_x_when_not_provided(): x_large = torch.randn(3, 1) x_small = torch.randn(2, 1) - hdata1 = HData(x=x_large, hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 0]])) + global_node_ids = torch.tensor([10, 20, 30]) + hdata1 = HData( + x=x_large, + hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 0]]), + global_node_ids=global_node_ids, + ) hdata2 = HData(x=x_small, hyperedge_index=torch.tensor([[0, 2], [1, 1]])) expected_hyperedge_index = torch.tensor([[0, 1, 2, 0, 2], [0, 0, 0, 1, 1]]) result = HData.cat_same_node_space([hdata1, hdata2]) assert torch.equal(result.x, x_large) + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, global_node_ids) assert torch.equal(result.hyperedge_index, expected_hyperedge_index) -def test_cat_same_node_space_uses_provided_x(): +def test_cat_same_node_space_uses_provided_x_and_global_node_ids(): x = torch.randn(2, 4) hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) custom_x = torch.randn(4, 4) - result = HData.cat_same_node_space([hdata1, hdata2], x=custom_x) + custom_global_node_ids = torch.tensor([10, 20, 30, 40]) + result = HData.cat_same_node_space( + hdatas=[hdata1, hdata2], + x=custom_x, + global_node_ids=custom_global_node_ids, + ) assert torch.equal(result.x, custom_x) + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, custom_global_node_ids) + + +def test_cat_same_node_space_raises_when_only_x_is_provided(): + x = torch.randn(2, 4) + hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) + hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) + + with pytest.raises( + ValueError, + match=re.escape( + "If x is provided, global_node_ids must also be provided to ensure consistency." + ), + ): + HData.cat_same_node_space([hdata1, hdata2], x=torch.randn(4, 4)) + + +def test_cat_same_node_space_raises_when_only_global_node_ids_are_provided(): + x = torch.randn(2, 4) + hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) + hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) + + with pytest.raises( + ValueError, + match=re.escape( + "If global_node_ids is provided, x must also be provided to ensure consistency." + ), + ): + HData.cat_same_node_space([hdata1, hdata2], global_node_ids=torch.arange(4)) + + +def test_cat_same_node_space_validates_global_node_ids_alignment(): + x = torch.randn(2, 4) + hdata1 = HData(x=x, hyperedge_index=torch.tensor([[0, 1], [0, 0]])) + hdata2 = HData(x=x, hyperedge_index=torch.tensor([[2, 3], [1, 1]])) + + with pytest.raises( + ValueError, + match=re.escape( + "global_node_ids must have one entry per node. Got size=3 but num_nodes=4." + ), + ): + HData.cat_same_node_space( + [hdata1, hdata2], + x=torch.randn(4, 4), + global_node_ids=torch.arange(3), + ) def test_cat_same_node_space_concatenates_hyperedge_attr(): @@ -1223,6 +1543,29 @@ def test_enrich_hyperedge_weights_concatenate(): assert torch.equal(utils.to_non_empty_edgeattr(result.hyperedge_weights), enriched_weights) +def test_enrich_hyperedge_weights_concatenate_after_hyperedge_index_expansion(): + x = torch.tensor([[1.0], [2.0], [3.0]]) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_weights=torch.tensor([0.1, 0.2]), + ) + hdata.hyperedge_index = torch.tensor([[0, 1, 2, 0], [0, 0, 1, 2]]) + hdata.num_hyperedges = 3 + hdata.y = torch.ones(3, dtype=torch.float) + + enricher = MagicMock(spec=HyperedgeEnricher) + enricher.enrich.return_value = torch.tensor([0.7]) + + result = hdata.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") + + enricher.enrich.assert_called_once_with(hdata.hyperedge_index) + assert torch.equal( + utils.to_non_empty_edgeattr(result.hyperedge_weights), torch.tensor([0.1, 0.2, 0.7]) + ) + + def test_enrich_hyperedge_attr_replace(): x = torch.tensor([[1.0], [2.0], [3.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) diff --git a/hyperbench/tests/utils/data_utils_test.py b/hyperbench/tests/utils/data_utils_test.py index 94226147..7e5c5082 100644 --- a/hyperbench/tests/utils/data_utils_test.py +++ b/hyperbench/tests/utils/data_utils_test.py @@ -37,35 +37,39 @@ def test_clone_optional_tensor_with_tensor_does_not_share_storage(): assert tensor[0, 0] == 1.0 -def test_empty_edgeindex(): +def test_empty_hyperedgeindex(): result = empty_hyperedgeindex() assert result.shape == (2, 0) - assert result.dtype == torch.float32 + assert result.dtype == torch.long def test_empty_edgeattr_zero_edges(): result = empty_edgeattr(num_edges=0) assert result.shape == (0, 0) + assert result.dtype == torch.float def test_empty_edgeattr_single_edge(): result = empty_edgeattr(num_edges=1) assert result.shape == (1, 0) + assert result.dtype == torch.float def test_empty_edgeattr_with_edges(): result = empty_edgeattr(num_edges=5) assert result.shape == (5, 0) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_none(): result = to_non_empty_edgeattr(edge_attr=None) assert result.shape == (0, 0) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_tensor(): @@ -74,6 +78,7 @@ def test_to_non_empty_edgeattr_with_tensor(): assert torch.equal(result, edge_attr) assert result.shape == (3, 1) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_empty_tensor(): @@ -82,6 +87,7 @@ def test_to_non_empty_edgeattr_with_empty_tensor(): assert torch.equal(result, edge_attr) assert result.shape == (0, 3) + assert result.dtype == torch.float def test_to_non_empty_edgeattr_with_multi_dimensional(): @@ -90,12 +96,14 @@ def test_to_non_empty_edgeattr_with_multi_dimensional(): assert torch.equal(result, edge_attr) assert result.shape == (2, 3) + assert result.dtype == torch.float def test_empty_nodefeatures(): result = empty_nodefeatures() assert result.shape == (0, 0) + assert result.dtype == torch.float @pytest.mark.parametrize( diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 17137482..9b88187b 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -63,26 +63,24 @@ def __init__( y: Tensor | None = None, ): self.x: Tensor = x - self.hyperedge_index: Tensor = hyperedge_index + self.__validate_required_tensors_type_and_dim() self.hyperedge_weights: Tensor | None = hyperedge_weights - self.hyperedge_attr: Tensor | None = hyperedge_attr hyperedge_index_wrapper = HyperedgeIndex(hyperedge_index) - self.num_nodes: int = ( num_nodes if num_nodes is not None # There should never be isolated nodes when HData is created by Dataset - # as each isolted node gets its own self-loop hyperedge + # as each isolated node gets its own self-loop hyperedge else hyperedge_index_wrapper.num_nodes_if_isolated_exist(num_nodes=x.size(0)) ) - self.num_hyperedges: int = ( num_hyperedges if num_hyperedges is not None else hyperedge_index_wrapper.num_hyperedges ) + self.__validate_number_of_nodes_and_hyperedges() self.global_node_ids: Tensor | None = ( # torch.arange is to handle isolated nodes, as they are already considered @@ -96,6 +94,8 @@ def __init__( else torch.ones((self.num_hyperedges,), dtype=torch.float, device=self.x.device) ) + self.__validate() + self.device = self.get_device_if_all_consistent() def __repr__(self) -> str: @@ -114,16 +114,28 @@ def __repr__(self) -> str: ) @classmethod - def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) -> HData: + def cat_same_node_space( + cls, + hdatas: Sequence[HData], + x: Tensor | None = None, + global_node_ids: Tensor | None = None, + ) -> HData: """ Concatenate `HData` instances that share the same node space, meaning nodes with the same ID in different instances are the same node. This is useful when combining positive and negative hyperedges that reference the same set of nodes. Notes: - - ``x`` is derived from the instance with the largest number of nodes, if not provided explicitly. If there are conflicting features for the same node ID across instances, the features from the instance with the largest number of nodes will be used. + - ``x`` is derived from the instance with the largest number of nodes, if not provided explicitly. + If there are conflicting features for the same node ID across instances, + the features from the instance with the largest number of nodes will be used. + If ``global_node_ids`` is provided explicitly, ``x`` must also be provided to ensure consistency. - ``hyperedge_index`` is the concatenation of all input hyperedge indices. - - ``hyperedge_weights`` is the concatenation of all input hyperedge weights, if present. If some instances have hyperedge weights and others do not, the resulting ``hyperedge_weights`` will be set to ``None``. - - ``hyperedge_attr`` is the concatenation of all input hyperedge attributes, if present. If some instances have hyperedge attributes and others do not, the resulting ``hyperedge_attr`` will be set to ``None``. + - ``hyperedge_weights`` is the concatenation of all input hyperedge weights, if present. + If some instances have hyperedge weights and others do not, the resulting ``hyperedge_weights`` will be set to ``None``. + - ``hyperedge_attr`` is the concatenation of all input hyperedge attributes, if present. + If some instances have hyperedge attributes and others do not, the resulting ``hyperedge_attr`` will be set to ``None``. + - ``global_node_ids`` is derived from the instance with the largest number of nodes, if not provided explicitly. + If ``x`` is provided explicitly, ``global_node_ids`` must be provided explicitly as well to ensure consistency. - ``y`` is the concatenation of all input labels. Examples: @@ -138,12 +150,20 @@ def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) - hdatas: One or more `HData` instances sharing the same node space. x: Optional node feature matrix to use for the resulting `HData`. If ``None``, the node features from the instance with the largest number of nodes will be used. + If ``global_node_ids`` is provided explicitly, ``x`` must also be provided to ensure consistency. + global_node_ids: Optional global node IDs for the resulting `HData`. + If ``None``, the global node IDs from the instance with the largest number of nodes will be used. + If ``x`` is provided explicitly, ``global_node_ids`` must also be provided to ensure consistency. + If ``x`` is provided and there is no need for ``global_node_ids`` to preserve access to the canonical node space, + it is recommended to use arbitrary global node IDs that are consistent with the feature rows of ``x``. + For example, ``global_node_ids=torch.arange(x.size(0))``). Returns: hdata: A new `HData` with shared nodes and concatenated hyperedges. Raises: - ValueError: If the node counts do not match across inputs. + ValueError: If no HData instances are provided, if there are overlapping hyperedge IDs across instances, + or if ``x`` and ``global_node_ids`` are not both provided when one of them is provided. """ if len(hdatas) < 1: raise ValueError("At least one instance is required.") @@ -155,8 +175,14 @@ def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) - "Overlapping hyperedge IDs found across instances. Ensure each instance uses distinct hyperedge IDs." ) + cls.__validate_can_perform_cat_same_node_space(hdatas, x, global_node_ids) hdata_with_largest_node_space = max(hdatas, key=lambda hdata: hdata.num_nodes) - new_x = (x if x is not None else hdata_with_largest_node_space.x).clone() + new_x = (x.clone() if x is not None else hdata_with_largest_node_space.x).clone() + new_global_node_ids = ( + global_node_ids.clone() + if global_node_ids is not None + else hdata_with_largest_node_space.global_node_ids.clone() + ) new_y = torch.cat([hdata.y for hdata in hdatas], dim=0) new_hyperedge_index = torch.cat([hdata.hyperedge_index for hdata in hdatas], dim=1) @@ -181,7 +207,7 @@ def cat_same_node_space(cls, hdatas: Sequence[HData], x: Tensor | None = None) - hyperedge_attr=new_hyperedge_attr, num_nodes=new_x.size(0), num_hyperedges=new_y.size(0), - global_node_ids=clone_optional_tensor(hdata_with_largest_node_space.global_node_ids), + global_node_ids=new_global_node_ids, y=new_y, ) @@ -213,7 +239,7 @@ def empty(cls) -> HData: hyperedge_attr=None, num_nodes=0, num_hyperedges=0, - global_node_ids=torch.empty(size=(0, 0), dtype=torch.long), + global_node_ids=torch.empty(size=(0,), dtype=torch.long), y=None, ) @@ -511,13 +537,17 @@ def enrich_hyperedge_weights( enricher: HyperedgeEnricher, enrichment_mode: EnrichmentMode | None = None, ) -> HData: - """Enrich hyperedge weights using the provided hyperedge weight enricher. + """ + Enrich hyperedge weights using the provided hyperedge weight enricher. Args: enricher: An instance of HyperedgeEnricher to generate hyperedge weights from hypergraph topology. enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. ``concatenate`` appends new weights to the existing 1D tensor. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. + + Returns: + hdata: A new `HData` with enriched hyperedge weights. """ enriched_weights = enricher.enrich(self.hyperedge_index) @@ -775,11 +805,21 @@ def with_y_to(self, value: float) -> HData: ) def with_y_ones(self) -> HData: - """Return a copy of this instance with a y attribute of all ones.""" + """ + Return a copy of this instance with a y attribute of all ones. + + Returns: + hdata: A new `HData` instance with the same attributes except for y, which is set to a tensor of ones. + """ return self.with_y_to(1.0) def with_y_zeros(self) -> HData: - """Return a copy of this instance with a y attribute of all zeros.""" + """ + Return a copy of this instance with a y attribute of all zeros. + + Returns: + hdata: A new `HData` instance with the same attributes except for y, which is set to a tensor of zeros. + """ return self.with_y_to(0.0) def stats(self) -> dict[str, Any]: @@ -883,6 +923,32 @@ def stats(self) -> dict[str, Any]: "distribution_hyperedge_size_hist": distribution_hyperedge_size_hist, } + @classmethod + def __validate_can_perform_cat_same_node_space( + cls, + hdatas: Sequence[HData], + x: Tensor | None, + global_node_ids: Tensor | None, + ) -> None: + if len(hdatas) < 1: + raise ValueError("At least one instance is required.") + + if x is not None and global_node_ids is None: + raise ValueError( + "If x is provided, global_node_ids must also be provided to ensure consistency." + ) + if x is None and global_node_ids is not None: + raise ValueError( + "If global_node_ids is provided, x must also be provided to ensure consistency." + ) + + joint_hyperedge_ids = torch.cat([hdata.hyperedge_index[1].unique() for hdata in hdatas]) + unique_joint_hyperedge_ids = joint_hyperedge_ids.unique() + if unique_joint_hyperedge_ids.size(0) != joint_hyperedge_ids.size(0): + raise ValueError( + "Overlapping hyperedge IDs found across instances. Ensure each instance uses distinct hyperedge IDs." + ) + def __to_fill_features( self, fill_value: NodeSpaceFiller | None, @@ -915,6 +981,110 @@ def __to_fill_features( ) return fill_features + def __validate(self) -> None: + self.__validate_node_features() + self.__validate_hyperedge_index() + self.__validate_hyperedge_attr() + self.__validate_hyperedge_weights() + self.__validate_global_node_ids() + self.__validate_labels() + + def __validate_hyperedge_attr(self) -> None: + if self.hyperedge_attr is None: + return + + if not isinstance(self.hyperedge_attr, Tensor): + raise TypeError("hyperedge_attr must be a torch.Tensor.") + + if self.hyperedge_attr.dim() != 2: + raise ValueError( + f"hyperedge_attr must be a 2D tensor, got shape {tuple(self.hyperedge_attr.shape)}." + ) + if self.hyperedge_attr.size(0) != self.num_hyperedges: + raise ValueError( + "hyperedge_attr must have one row per hyperedge. " + f"Got size={self.hyperedge_attr.size(0)} but num_hyperedges={self.num_hyperedges}." + ) + + def __validate_hyperedge_index(self) -> None: + if self.hyperedge_index.dtype != torch.long: + raise ValueError( + f"hyperedge_index must have dtype torch.long, got {self.hyperedge_index.dtype}." + ) + if self.hyperedge_index.numel() > 0 and bool((self.hyperedge_index < 0).any()): + raise ValueError("hyperedge_index cannot contain negative node or hyperedge IDs.") + + unique_node_count = self.hyperedge_index[0].unique().size(0) + if unique_node_count > self.num_nodes: + raise ValueError( + "num_nodes is too small for hyperedge_index. " + f"Got num_nodes={self.num_nodes}, but hyperedge_index contains " + f"{unique_node_count} unique node IDs." + ) + + unique_hyperedge_count = self.hyperedge_index[1].unique().size(0) + if unique_hyperedge_count > self.num_hyperedges: + raise ValueError( + "num_hyperedges is too small for hyperedge_index. " + f"Got num_hyperedges={self.num_hyperedges}, but hyperedge_index contains " + f"{unique_hyperedge_count} unique hyperedge IDs." + ) + + def __validate_hyperedge_weights(self) -> None: + if self.hyperedge_weights is None: + return + + if not isinstance(self.hyperedge_weights, Tensor): + raise TypeError("hyperedge_weights must be a torch.Tensor.") + + if self.hyperedge_weights.dim() != 1: + raise ValueError( + f"hyperedge_weights must be a 1D tensor, got shape {tuple(self.hyperedge_weights.shape)}." + ) + if self.hyperedge_weights.size(0) != self.num_hyperedges: + raise ValueError( + "hyperedge_weights must have one entry per hyperedge. " + f"Got size={self.hyperedge_weights.size(0)} but num_hyperedges={self.num_hyperedges}." + ) + + def __validate_global_node_ids(self) -> None: + if not isinstance(self.global_node_ids, Tensor): + raise TypeError("global_node_ids must be a torch.Tensor.") + + if self.global_node_ids.dim() != 1: + raise ValueError( + f"global_node_ids must be a 1D tensor, got shape {tuple(self.global_node_ids.shape)}." + ) + if self.global_node_ids.size(0) != self.num_nodes: + raise ValueError( + "global_node_ids must have one entry per node. " + f"Got size={self.global_node_ids.size(0)} but num_nodes={self.num_nodes}." + ) + + if self.global_node_ids.dtype != torch.long: + raise ValueError( + f"global_node_ids must have dtype torch.long, got {self.global_node_ids.dtype}." + ) + + def __validate_labels(self) -> None: + if not isinstance(self.y, Tensor): + raise TypeError("y must be a torch.Tensor.") + + if self.y.dim() != 1: + raise ValueError(f"y must be a 1D tensor, got shape {tuple(self.y.shape)}.") + if self.y.size(0) != self.num_hyperedges: + raise ValueError( + "y must have one entry per hyperedge. " + f"Got {self.y.size(0)} entries but num_hyperedges={self.num_hyperedges}." + ) + + def __validate_node_features(self) -> None: + if self.x.size(0) not in (0, self.num_nodes): + raise ValueError( + "x must have one feature row per node, or be 'torch.empty((0, 0))' if there are no nodes. " + f"Got x.shape={tuple(self.x.shape)} but num_nodes={self.num_nodes}." + ) + def __validate_node_space_setting( self, node_space_setting: NodeSpaceSetting, @@ -926,3 +1096,30 @@ def __validate_node_space_setting( ) if is_inductive_setting(node_space_setting) and fill_value is None: raise ValueError("fill_value must be provided when node_space_setting='inductive'.") + + def __validate_number_of_nodes_and_hyperedges(self) -> None: + # Check on bool as bool is a subclass of int + if not isinstance(self.num_nodes, int) or isinstance(self.num_nodes, bool): + raise TypeError("num_nodes must be an int.") + if self.num_nodes < 0: + raise ValueError(f"num_nodes must be non-negative, got {self.num_nodes}.") + + if not isinstance(self.num_hyperedges, int) or isinstance(self.num_hyperedges, bool): + raise TypeError("num_hyperedges must be an int.") + if self.num_hyperedges < 0: + raise ValueError(f"num_hyperedges must be non-negative, got {self.num_hyperedges}.") + + def __validate_required_tensors_type_and_dim(self) -> None: + if not isinstance(self.x, Tensor): + raise TypeError("x must be a torch.Tensor.") + if not isinstance(self.hyperedge_index, Tensor): + raise TypeError("hyperedge_index must be a torch.Tensor.") + + if self.x.dim() != 2: + raise ValueError(f"x must be a 2D tensor, got shape {tuple(self.x.shape)}.") + + if self.hyperedge_index.dim() != 2 or self.hyperedge_index.size(0) != 2: + raise ValueError( + "hyperedge_index must have shape (2, num_incidences), got " + f"{tuple(self.hyperedge_index.shape)}." + ) diff --git a/hyperbench/utils/data_utils.py b/hyperbench/utils/data_utils.py index 66c49ef5..e94fefaa 100644 --- a/hyperbench/utils/data_utils.py +++ b/hyperbench/utils/data_utils.py @@ -8,15 +8,15 @@ def clone_optional_tensor(tensor: Tensor | None) -> Tensor | None: def empty_nodefeatures() -> Tensor: - return torch.empty((0, 0)) + return torch.empty((0, 0), dtype=torch.float) def empty_hyperedgeindex() -> Tensor: - return torch.empty((2, 0)) + return torch.empty((2, 0), dtype=torch.long) def empty_edgeattr(num_edges: int) -> Tensor: - return torch.empty((num_edges, 0)) + return torch.empty((num_edges, 0), dtype=torch.float) def to_non_empty_edgeattr(edge_attr: Tensor | None) -> Tensor: From ac5fae1a5217da3ce609d46fffaedfab5e99dbfc Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 15:40:27 +0200 Subject: [PATCH 02/15] docs: fix docstring --- hyperbench/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 9203a38f..bef0fbdd 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -212,7 +212,7 @@ def enrich_hyperedge_weights( Args: enricher: An instance of HyperedgeEnricher to generate structural hyperedge weights from hypergraph topology. enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. - ``concatenate`` appends new weights to the existing 1D tensor. + ``concatenate`` appends new weights to the existing ones. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. """ self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode) From 7fb680acca69b940b7860624435b18021c7f21c2 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 15:40:57 +0200 Subject: [PATCH 03/15] docs: fix docstring --- hyperbench/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index bef0fbdd..9debcc4f 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -195,8 +195,8 @@ def enrich_hyperedge_attr( """Enrich hyperedge features using the provided hyperedge feature enricher. Args: - enricher: An instance of HyperedgeEnricher to generate structural hyperedge features from hypergraph topology. - enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``. + enricher: An instance of HyperedgeEnricher to generate structural hyperedge attributes from hypergraph topology. + enrichment_mode: How to combine generated attributes with existing ``hdata.hyperedge_attr``. ``concatenate`` appends new features as additional columns. ``replace`` substitutes ``hdata.hyperedge_attr`` entirely. """ From d64f56ca12e6c3ac46cdc0369e92fcc4473bdb07 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 15:41:54 +0200 Subject: [PATCH 04/15] docs: fix docstring --- hyperbench/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 9debcc4f..d549cef6 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -146,7 +146,7 @@ def enrich_node_features( Args: enricher: An instance of NodeEnricher to generate structural node features from hypergraph topology. enrichment_mode: How to combine generated features with existing ``hdata.x``. - ``concatenate`` appends new features as additional columns. + ``concatenate`` appends new features to the existing ones as additional columns. ``replace`` substitutes ``hdata.x`` entirely. """ self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode) @@ -197,7 +197,7 @@ def enrich_hyperedge_attr( Args: enricher: An instance of HyperedgeEnricher to generate structural hyperedge attributes from hypergraph topology. enrichment_mode: How to combine generated attributes with existing ``hdata.hyperedge_attr``. - ``concatenate`` appends new features as additional columns. + ``concatenate`` appends new attributes to the existing ones as additional columns. ``replace`` substitutes ``hdata.hyperedge_attr`` entirely. """ self.hdata = self.hdata.enrich_hyperedge_attr(enricher, enrichment_mode) @@ -212,7 +212,7 @@ def enrich_hyperedge_weights( Args: enricher: An instance of HyperedgeEnricher to generate structural hyperedge weights from hypergraph topology. enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. - ``concatenate`` appends new weights to the existing ones. + ``concatenate`` appends new weights to the existing ones as additional columns. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. """ self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode) From 2fc679b23f54fb2d9d10e85b913079b29effeb75 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 15:54:13 +0200 Subject: [PATCH 05/15] fix: add missing unit tests for Dataset --- hyperbench/tests/data/dataset_test.py | 117 ++++++++++++++++++++------ 1 file changed, 89 insertions(+), 28 deletions(-) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 547c0527..c0dbea22 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -75,13 +75,6 @@ def mock_hdata_transductive_split() -> HData: return HData(x=x, hyperedge_index=hyperedge_index) -@pytest.fixture -def mock_hdata_no_hyperedge_attr() -> HData: - x = torch.ones((2, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) - return HData(x=x, hyperedge_index=hyperedge_index) - - @pytest.fixture def mock_hdata_multiple_hyperedge_attrs() -> HData: x = torch.ones((4, 1), dtype=torch.float) @@ -168,7 +161,7 @@ def test_dataset_process_no_incidences(mock_hdata_isolated_hyperedges): assert dataset.hdata.hyperedge_attr is None -def test_dataset_process_with_edge_attributes(mock_hdata_two_hyperedge_attrs_weighted): +def test_dataset_process_with_hyperedge_attributes(mock_hdata_two_hyperedge_attrs_weighted): with patch.object( HIFLoader, "load_by_name", return_value=mock_hdata_two_hyperedge_attrs_weighted ): @@ -184,16 +177,32 @@ def test_dataset_process_with_edge_attributes(mock_hdata_two_hyperedge_attrs_wei assert torch.allclose(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 3.0])) -def test_dataset_process_without_edge_attributes(mock_hdata_no_hyperedge_attr): - with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_hyperedge_attr): +def test_dataset_process_without_hyperedge_attributes(mock_hdata): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): dataset = AlgebraDataset() assert dataset.hdata is not None - assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 2 assert dataset.hdata.hyperedge_attr is None +def test_dataset_process_with_hyperedge_weights(mock_hdata_with_hyperedge_weights): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_weights): + dataset = AlgebraDataset() + + assert dataset.hdata is not None + assert dataset.hdata.hyperedge_weights is not None + assert dataset.hdata.hyperedge_weights.shape == (2,) + assert torch.allclose(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 2.0])) + + +def test_dataset_process_without_hyperedge_weights(mock_hdata): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): + dataset = AlgebraDataset() + + assert dataset.hdata is not None + assert dataset.hdata.hyperedge_weights is None + + def test_dataset_process_hyperedge_index_in_correct_format(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() @@ -318,7 +327,7 @@ def test_getitem_when_list_index_provided( pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_with_edge_attr(mock_hdata_three_nodes_weighted, strategy): +def test_getitem_with_hyperedge_attr(mock_hdata_three_nodes_weighted, strategy): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_three_nodes_weighted): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -336,8 +345,8 @@ def test_getitem_with_edge_attr(mock_hdata_three_nodes_weighted, strategy): pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_without_edge_attr(mock_hdata_no_hyperedge_attr, strategy): - with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_hyperedge_attr): +def test_getitem_without_hyperedge_attr(mock_hdata, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] @@ -353,7 +362,7 @@ def test_getitem_without_edge_attr(mock_hdata_no_hyperedge_attr, strategy): pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), ], ) -def test_getitem_with_multiple_edges_attr(mock_hdata_multiple_hyperedge_attrs, strategy, index): +def test_getitem_with_multiple_hyperedge_attr(mock_hdata_multiple_hyperedge_attrs, strategy, index): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_multiple_hyperedge_attrs): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -609,10 +618,10 @@ def test_split_three_way(mock_hdata_multiple_hyperedge_attrs): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.25, 0.25], node_space_setting="inductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 3 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -631,10 +640,10 @@ def test_split_transductive_three_way( dataset = AlgebraDataset() splits = dataset.split([0.5, 0.25, 0.25], node_space_setting="transductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 3 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -736,10 +745,10 @@ def test_split_with_shuffle_when_no_seed_provided( dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5], shuffle=True, node_space_setting="inductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 2 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -758,10 +767,10 @@ def test_split_transductive_with_shuffle_when_no_seed_provided( dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5], shuffle=True, node_space_setting="transductive") - total_edges = sum(split.hdata.num_hyperedges for split in splits) + total_hyperedges = sum(split.hdata.num_hyperedges for split in splits) assert len(splits) == 2 - assert total_edges == dataset.hdata.num_hyperedges + assert total_hyperedges == dataset.hdata.num_hyperedges for split in splits: assert split.hdata.x is not None @@ -769,7 +778,7 @@ def test_split_transductive_with_shuffle_when_no_seed_provided( assert split.hdata.num_hyperedges > 0 -def test_split_preserves_edge_attr(mock_hdata_multiple_hyperedge_attrs): +def test_split_preserves_hyperedge_attr(mock_hdata_multiple_hyperedge_attrs): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_multiple_hyperedge_attrs): dataset = AlgebraDataset() @@ -780,7 +789,7 @@ def test_split_preserves_edge_attr(mock_hdata_multiple_hyperedge_attrs): assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges -def test_split_transductive_preserves_edge_attr( +def test_split_transductive_preserves_hyperedge_attr( mock_hdata_transductive_multiple_hyperedges_attrs, ): with patch.object( @@ -797,7 +806,7 @@ def test_split_transductive_preserves_edge_attr( assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges -def test_split_without_edge_attr(mock_hdata_four_nodes): +def test_split_without_hyperedge_attr(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() @@ -807,7 +816,7 @@ def test_split_without_edge_attr(mock_hdata_four_nodes): assert split.hdata.hyperedge_attr is None -def test_split_transductive_without_edge_attr(mock_hdata_transductive_split): +def test_split_transductive_without_hyperedge_attr(mock_hdata_transductive_split): with patch.object( HIFLoader, "load_by_name", @@ -821,6 +830,58 @@ def test_split_transductive_without_edge_attr(mock_hdata_transductive_split): assert split.hdata.hyperedge_attr is None +def test_split_preserves_hyperedge_weights(mock_hdata_multiple_hyperedge_attrs): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_multiple_hyperedge_attrs): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="inductive") + + for split in splits: + assert split.hdata.hyperedge_weights is not None + assert split.hdata.hyperedge_weights.shape[0] == split.hdata.num_hyperedges + + +def test_split_transductive_preserves_hyperedge_weights( + mock_hdata_transductive_multiple_hyperedges_attrs, +): + with patch.object( + HIFLoader, + "load_by_name", + return_value=mock_hdata_transductive_multiple_hyperedges_attrs, + ): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="transductive") + + for split in splits: + assert split.hdata.hyperedge_weights is not None + assert split.hdata.hyperedge_weights.shape[0] == split.hdata.num_hyperedges + + +def test_split_without_hyperedge_weights(mock_hdata_four_nodes): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="inductive") + + for split in splits: + assert split.hdata.hyperedge_weights is None + + +def test_split_transductive_without_hyperedge_weights(mock_hdata_transductive_split): + with patch.object( + HIFLoader, + "load_by_name", + return_value=mock_hdata_transductive_split, + ): + dataset = AlgebraDataset() + + splits = dataset.split([0.5, 0.5], node_space_setting="transductive") + + for split in splits: + assert split.hdata.hyperedge_weights is None + + def test_to_device(mock_hdata): device = torch.device("cpu") From af0f422a17fc6f4e45b318df1ecac89d5408f185 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 16:05:59 +0200 Subject: [PATCH 06/15] fix: clean unit tests for Dataset --- hyperbench/tests/data/dataset_test.py | 73 +++++++++++++++++---------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index c0dbea22..c6b22449 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -47,14 +47,6 @@ def mock_hdata_isolated_hyperedges() -> HData: return HData(x=x, hyperedge_index=hyperedge_index) -@pytest.fixture -def mock_hdata_three_nodes_weighted() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([1.0, 2.0], dtype=torch.float) - return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - - @pytest.fixture def mock_hdata_four_nodes() -> HData: x = torch.ones((4, 1), dtype=torch.float) @@ -109,14 +101,6 @@ def mock_hdata_transductive_multiple_hyperedges_attrs() -> HData: ) -@pytest.fixture -def mock_hdata_two_hyperedge_attrs_weighted() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([1.0, 3.0], dtype=torch.float) - return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - - @pytest.fixture def mock_negative_sampler() -> tuple[NegativeSampler, MagicMock]: sampler = new_mock_negative_sampler() @@ -161,20 +145,14 @@ def test_dataset_process_no_incidences(mock_hdata_isolated_hyperedges): assert dataset.hdata.hyperedge_attr is None -def test_dataset_process_with_hyperedge_attributes(mock_hdata_two_hyperedge_attrs_weighted): - with patch.object( - HIFLoader, "load_by_name", return_value=mock_hdata_two_hyperedge_attrs_weighted - ): +def test_dataset_process_with_hyperedge_attributes(mock_hdata_with_hyperedge_attr): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_attr): dataset = AlgebraDataset() assert dataset.hdata is not None - assert dataset.hdata.x.shape[0] == 3 + assert dataset.hdata.hyperedge_attr is not None assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 3 - assert dataset.hdata.hyperedge_attr is None - assert dataset.hdata.hyperedge_weights is not None - assert dataset.hdata.hyperedge_weights.shape == (2,) - assert torch.allclose(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 3.0])) + assert torch.allclose(dataset.hdata.hyperedge_attr, torch.ones((2, 1), dtype=torch.float)) def test_dataset_process_without_hyperedge_attributes(mock_hdata): @@ -327,14 +305,17 @@ def test_getitem_when_list_index_provided( pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_with_hyperedge_attr(mock_hdata_three_nodes_weighted, strategy): - with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_three_nodes_weighted): +def test_getitem_with_hyperedge_attr(mock_hdata_with_hyperedge_attr, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_attr): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] assert data.hyperedge_index.shape == (2, 2) assert data.num_hyperedges == 1 + + # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None + # as the hyperedge attributes are handled by the loader's collate function during batching assert data.hyperedge_attr is None @@ -374,6 +355,42 @@ def test_getitem_with_multiple_hyperedge_attr(mock_hdata_multiple_hyperedge_attr assert data.hyperedge_attr is None +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_with_hyperedge_weights(mock_hdata_with_hyperedge_weights, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_hyperedge_weights): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[0] + + assert data.hyperedge_index.shape == (2, 2) + assert data.num_hyperedges == 1 + + # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_weights as None + # as the hyperedge weights are handled by the loader's collate function during batching + assert data.hyperedge_weights is None + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_without_hyperedge_weights(mock_hdata, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[0] + assert data.hyperedge_weights is None + + @pytest.mark.parametrize( "strategy, expected_len", [ From e131219d25c1087c3c877a7d34a756c9618ba97b Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 16:13:07 +0200 Subject: [PATCH 07/15] fix: apply minor fixes --- hyperbench/tests/types/hdata_test.py | 2 ++ hyperbench/types/hdata.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 715d5d35..4dcc2e49 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -708,6 +708,7 @@ def test_cat_same_node_space_uses_largest_x_when_not_provided(): global_node_ids=global_node_ids, ) hdata2 = HData(x=x_small, hyperedge_index=torch.tensor([[0, 2], [1, 1]])) + expected_hyperedge_index = torch.tensor([[0, 1, 2, 0, 2], [0, 0, 0, 1, 1]]) result = HData.cat_same_node_space([hdata1, hdata2]) @@ -725,6 +726,7 @@ def test_cat_same_node_space_uses_provided_x_and_global_node_ids(): custom_x = torch.randn(4, 4) custom_global_node_ids = torch.tensor([10, 20, 30, 40]) + result = HData.cat_same_node_space( hdatas=[hdata1, hdata2], x=custom_x, diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 9b88187b..6b1b3a22 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -64,7 +64,7 @@ def __init__( ): self.x: Tensor = x self.hyperedge_index: Tensor = hyperedge_index - self.__validate_required_tensors_type_and_dim() + self.__validate_x_and_hyperedge_index_type_and_dim() self.hyperedge_weights: Tensor | None = hyperedge_weights self.hyperedge_attr: Tensor | None = hyperedge_attr @@ -982,7 +982,7 @@ def __to_fill_features( return fill_features def __validate(self) -> None: - self.__validate_node_features() + self.__validate_x() self.__validate_hyperedge_index() self.__validate_hyperedge_attr() self.__validate_hyperedge_weights() @@ -1078,7 +1078,7 @@ def __validate_labels(self) -> None: f"Got {self.y.size(0)} entries but num_hyperedges={self.num_hyperedges}." ) - def __validate_node_features(self) -> None: + def __validate_x(self) -> None: if self.x.size(0) not in (0, self.num_nodes): raise ValueError( "x must have one feature row per node, or be 'torch.empty((0, 0))' if there are no nodes. " @@ -1109,7 +1109,7 @@ def __validate_number_of_nodes_and_hyperedges(self) -> None: if self.num_hyperedges < 0: raise ValueError(f"num_hyperedges must be non-negative, got {self.num_hyperedges}.") - def __validate_required_tensors_type_and_dim(self) -> None: + def __validate_x_and_hyperedge_index_type_and_dim(self) -> None: if not isinstance(self.x, Tensor): raise TypeError("x must be a torch.Tensor.") if not isinstance(self.hyperedge_index, Tensor): From 87642be6d5cca5f3a727fe3be36321db6af38890 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 16:17:38 +0200 Subject: [PATCH 08/15] fix: HData.empty() returns None global_node_ids --- hyperbench/tests/types/hdata_test.py | 11 ++++++++--- hyperbench/types/hdata.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 4dcc2e49..49dca383 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -4,7 +4,6 @@ from unittest.mock import MagicMock from typing import Any, cast -from torch import Tensor from hyperbench import utils from hyperbench.data import HyperedgeEnricher, NegativeSampler, NodeEnricher, RandomNegativeSampler from hyperbench.types import HData @@ -452,17 +451,23 @@ def test_empty_returns_empty_hdata(): data = HData.empty() assert data.x is not None - assert isinstance(data.x, Tensor) assert data.x.shape == (0, 0) assert data.hyperedge_index is not None - assert isinstance(data.hyperedge_index, Tensor) assert data.hyperedge_index.shape == (2, 0) assert data.hyperedge_attr is None + assert data.hyperedge_weights is None + assert data.num_nodes == 0 assert data.num_hyperedges == 0 + assert data.global_node_ids is not None + assert data.global_node_ids.shape == (0,) + + assert data.y is not None + assert data.y.shape == (0,) + @pytest.mark.parametrize( "hyperedge_index, expected_num_nodes, expected_num_hyperedges", diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 6b1b3a22..b2526908 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -239,7 +239,7 @@ def empty(cls) -> HData: hyperedge_attr=None, num_nodes=0, num_hyperedges=0, - global_node_ids=torch.empty(size=(0,), dtype=torch.long), + global_node_ids=None, y=None, ) From fa0bcc23c6c213b5a55228fd85e5a92d5ec3b86b Mon Sep 17 00:00:00 2001 From: Tiziano Date: Mon, 25 May 2026 18:21:55 +0200 Subject: [PATCH 09/15] fix: remove None for global_node_ids --- hyperbench/data/dataset.py | 35 ++- hyperbench/data/loader.py | 5 +- .../data/negative_sampling_scheduler.py | 38 ++-- hyperbench/data/sampler.py | 18 +- hyperbench/hlp/common.py | 6 +- hyperbench/models/nhp.py | 4 +- hyperbench/tests/data/dataset_test.py | 61 +++--- hyperbench/tests/data/loader_test.py | 3 +- .../data/negative_sampling_scheduler_test.py | 62 +++++- hyperbench/tests/data/sampler_test.py | 45 ++++ hyperbench/tests/types/hdata_test.py | 200 ++++++++---------- hyperbench/tests/types/hypergraph_test.py | 11 + hyperbench/tests/utils/node_utils_test.py | 32 +++ hyperbench/types/hdata.py | 90 ++++---- hyperbench/types/hypergraph.py | 7 +- hyperbench/utils/__init__.py | 2 + hyperbench/utils/node_utils.py | 18 ++ 17 files changed, 412 insertions(+), 225 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index d549cef6..fb325ebc 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import torch from typing import TYPE_CHECKING, Any @@ -10,6 +11,7 @@ NodeSpaceFiller, NodeSpaceSetting, is_transductive_setting, + validate_node_space_setting, ) from hyperbench.data.hif import HIFLoader, HIFProcessor @@ -148,6 +150,7 @@ def enrich_node_features( enrichment_mode: How to combine generated features with existing ``hdata.x``. ``concatenate`` appends new features to the existing ones as additional columns. ``replace`` substitutes ``hdata.x`` entirely. + Defaults to ``replace`` if not provided. """ self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode) @@ -199,6 +202,7 @@ def enrich_hyperedge_attr( enrichment_mode: How to combine generated attributes with existing ``hdata.hyperedge_attr``. ``concatenate`` appends new attributes to the existing ones as additional columns. ``replace`` substitutes ``hdata.hyperedge_attr`` entirely. + Defaults to ``replace`` if not provided. """ self.hdata = self.hdata.enrich_hyperedge_attr(enricher, enrichment_mode) @@ -214,6 +218,7 @@ def enrich_hyperedge_weights( enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. ``concatenate`` appends new weights to the existing ones as additional columns. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. + Defaults to ``replace`` if not provided. """ self.hdata = self.hdata.enrich_hyperedge_weights(enricher, enrichment_mode) @@ -325,7 +330,8 @@ def split_with_ratios( seed: int | None = None, node_space_setting: NodeSpaceSetting = "transductive", ) -> tuple[list[Dataset], list[float]]: - """Split the dataset and return the final hyperedge ratios. + """ + Split the dataset and return the final hyperedge ratios. Final ratios are computed from split hyperedge counts after ratio boundaries and any transductive rebalancing have been applied. @@ -350,12 +356,8 @@ def split_with_ratios( hyperedges, or a transductive first split cannot cover the full node space. """ - # Allow small imprecision in sum of ratios, but raise error if it's significant - # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid) - # ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError) - # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) - if abs(sum(ratios) - 1.0) > 1e-6: - raise ValueError(f"Split ratios must sum to 1.0, got {sum(ratios)}.") + validate_node_space_setting(node_space_setting) + self.__validate_split_ratios(ratios) device = self.hdata.device hyperedge_splitter = HyperedgeIDSplitter(self.hdata) @@ -445,3 +447,22 @@ def stats(self) -> dict[str, Any]: """ return self.hdata.stats() + + @staticmethod + def __validate_split_ratios(ratios: list[float]) -> None: + if len(ratios) < 1: + raise ValueError("Split ratios cannot be empty.") + + for ratio in ratios: + if not math.isfinite(float(ratio)): + raise ValueError(f"Split ratios must be finite, got {ratio}.") + if ratio <= 0.0: + raise ValueError(f"Split ratios must be positive, got {ratio}.") + + # Allow small imprecision in sum of ratios, but raise error if it's significant + # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid) + # ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError) + # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) + ratio_sum = float(sum(ratios)) + if abs(ratio_sum - 1.0) > 1e-6: + raise ValueError(f"Split ratios must sum to 1.0, got {ratio_sum}.") diff --git a/hyperbench/data/loader.py b/hyperbench/data/loader.py index 9c705537..b99a5a97 100644 --- a/hyperbench/data/loader.py +++ b/hyperbench/data/loader.py @@ -83,10 +83,7 @@ def collate(self, batch: list[HData]) -> HData: collated_x = self.__cached_dataset_hdata.x[node_ids] collated_y = self.__cached_dataset_hdata.y[hyperedge_ids] - - collated_global_node_ids = None - if self.__cached_dataset_hdata.global_node_ids is not None: - collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids] + collated_global_node_ids = self.__cached_dataset_hdata.global_node_ids[node_ids] collated_hyperedge_attr = None if self.__cached_dataset_hdata.hyperedge_attr is not None: diff --git a/hyperbench/data/negative_sampling_scheduler.py b/hyperbench/data/negative_sampling_scheduler.py index d76d3526..68c0fa8d 100644 --- a/hyperbench/data/negative_sampling_scheduler.py +++ b/hyperbench/data/negative_sampling_scheduler.py @@ -1,15 +1,13 @@ -from enum import Enum -from typing import Any +from typing import Any, Literal, TypeAlias from hyperbench.types import HData from hyperbench.data import NegativeSampler -class NegativeSamplingSchedule(Enum): - """When to run negative sampling during training.""" - - FIRST_EPOCH = "first_epoch" # Only at epoch 0, cached for all subsequent epochs - EVERY_N_EPOCHS = "every_n_epochs" # Every N epochs (N provided separately) - EVERY_EPOCH = "every_epoch" # Negatives generated every epoch +NegativeSamplingSchedule: TypeAlias = Literal[ + "first_epoch", # Only at epoch 0, cached for all subsequent epochs + "every_n_epochs", # Every N epochs (N provided separately) + "every_epoch", # Negatives generated every epoch +] class NegativeSamplingScheduler: @@ -21,14 +19,15 @@ class NegativeSamplingScheduler: Args: negative_sampler: An instance of a ``NegativeSampler`` that defines how to sample negatives. - negative_sampling_schedule: An instance of ``NegativeSamplingSchedule`` that specifies the schedule for sampling negatives. - negative_sampling_every_n: An integer specifying the interval for sampling negatives when the schedule is set to ``EVERY_N_EPOCHS``. This parameter is ignored for other schedules. + negative_sampling_schedule: Literal string specifying the schedule for sampling negatives. + negative_sampling_every_n: An integer specifying the interval for sampling negatives + when the schedule is set to ``"every_n_epochs"``. This parameter is ignored for other schedules. """ def __init__( self, negative_sampler: NegativeSampler, - negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule: NegativeSamplingSchedule = "every_epoch", negative_sampling_every_n: int = 1, ) -> None: self.negative_sampler = negative_sampler @@ -56,13 +55,24 @@ def should_sample(self, epoch: int) -> bool: Returns: should_sample: True if negatives should be resampled for the current epoch, False otherwise. """ + if epoch < 0: + raise ValueError(f"Epoch must be non-negative, got {epoch}.") + match self.negative_sampling_schedule: - case NegativeSamplingSchedule.EVERY_N_EPOCHS: + case "every_n_epochs": + if self.negative_sampling_every_n <= 0: + raise ValueError( + f"negative_sampling_every_n must be positive, got {self.negative_sampling_every_n}." + ) return epoch % self.negative_sampling_every_n == 0 - case NegativeSamplingSchedule.FIRST_EPOCH: + case "first_epoch": return epoch == 0 - case _: # Defaults to NegativeSamplingSchedule.EVERY_EPOCH + case "every_epoch": return True + case _: + raise ValueError( + f"Unsupported negative sampling schedule: {self.negative_sampling_schedule!r}." + ) def sample(self, batch: HData, epoch: int) -> HData: """ diff --git a/hyperbench/data/sampler.py b/hyperbench/data/sampler.py index be496e98..e855b658 100644 --- a/hyperbench/data/sampler.py +++ b/hyperbench/data/sampler.py @@ -49,6 +49,7 @@ def _normalize_index(self, index: int | list[int], size: int) -> list[int]: Raises: ValueError: If the provided index is invalid (e.g., empty list or list length exceeds number of sampleable items). + TypeError: If the index is not an integer or a list of integers. """ if isinstance(index, list): if len(index) < 1: @@ -57,7 +58,15 @@ def _normalize_index(self, index: int | list[int], size: int) -> list[int]: raise ValueError( f"Index list length ({len(index)}) cannot exceed the number of sampleable items ({size})." ) + for id in index: + if not isinstance(id, int) or isinstance(id, bool): + raise TypeError("Index list must contain only integers.") + return list(set(index)) + + if not isinstance(index, int) or isinstance(index, bool): + raise TypeError("Index must be an integer or a list of integers.") + return [index] def _sample_hyperedge_index( @@ -244,10 +253,15 @@ def create_sampler_from_strategy(strategy: SamplingStrategy) -> BaseSampler: strategy: An instance of SamplingStrategy enum indicating which sampling strategy to use. Returns: - sampler: An instance of a subclass of BaseSampler corresponding to the specified strategy. If strategy is not recognized, defaults to ``HyperedgeSampler``. + sampler: An instance of a subclass of BaseSampler corresponding to the specified strategy. + + Raises: + ValueError: If ``strategy`` is not a supported `SamplingStrategy`. """ match strategy: case SamplingStrategy.NODE: return NodeSampler() - case _: + case SamplingStrategy.HYPEREDGE: return HyperedgeSampler() + case _: + raise ValueError(f"Unsupported sampling strategy: {strategy!r}.") diff --git a/hyperbench/hlp/common.py b/hyperbench/hlp/common.py index 08c7b345..870a222e 100644 --- a/hyperbench/hlp/common.py +++ b/hyperbench/hlp/common.py @@ -20,8 +20,8 @@ class HlpModule(L.LightningModule): metrics: Optional ``MetricCollection`` of torchmetrics to compute during evaluation. Cloned per stage (train, val, test) for independent state accumulation. negative_sampler: Optional negative sampler. If ``None``, no negative sampling is performed. - negative_sampling_schedule: When to perform negative sampling during training. Defaults to ``EVERY_EPOCH``. - negative_sampling_every_n: If using ``EVERY_N_EPOCHS`` schedule, how many epochs between negative sampling runs. Defaults to ``1``. + negative_sampling_schedule: When to perform negative sampling during training. Defaults to ``"every_epoch"``. + negative_sampling_every_n: If using ``"every_n_epochs"`` schedule, how many epochs between negative sampling runs. Defaults to ``1``. """ def __init__( @@ -31,7 +31,7 @@ def __init__( encoder: nn.Module | None = None, metrics: MetricCollection | None = None, negative_sampler: NegativeSampler | None = None, - negative_sampling_schedule: NegativeSamplingSchedule = NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule: NegativeSamplingSchedule = "every_epoch", negative_sampling_every_n: int = 1, ): super().__init__() diff --git a/hyperbench/models/nhp.py b/hyperbench/models/nhp.py index 549f41af..89de057e 100644 --- a/hyperbench/models/nhp.py +++ b/hyperbench/models/nhp.py @@ -188,8 +188,10 @@ def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: max_embeddings = incidence_aggregator.pool("max") min_embeddings = incidence_aggregator.pool("min") hyperedge_embeddings = max_embeddings - min_embeddings - case _: + case "mean": hyperedge_embeddings = incidence_aggregator.pool("mean") + case _: + raise ValueError(f"Invalid aggregation method: {self.aggregation}") # Decode: linear projection to scalar score per hyperedge # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index c6b22449..73598791 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -2,7 +2,7 @@ import torch import re -from typing import cast +from typing import Any, cast from unittest.mock import patch, MagicMock from hyperbench.types import HData from hyperbench.data import ( @@ -691,6 +691,41 @@ def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_nodes): dataset.split([0.8, 0.1, 0.05]) +@pytest.mark.parametrize( + "ratios, expected_exception, expected_message", + [ + pytest.param([], ValueError, "Split ratios cannot be empty.", id="empty"), + pytest.param( + [0.5, 0.0, 0.5], ValueError, "Split ratios must be positive, got 0.0.", id="zero" + ), + pytest.param( + [0.5, float("inf")], ValueError, "Split ratios must be finite, got inf.", id="infinite" + ), + ], +) +def test_split_validates_ratio_values( + mock_hdata_four_nodes, ratios, expected_exception, expected_message +): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): + dataset = AlgebraDataset() + + with pytest.raises(expected_exception, match=re.escape(expected_message)): + dataset.split(cast(Any, ratios)) + + +def test_split_raises_on_invalid_node_space_setting(mock_hdata_four_nodes): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): + dataset = AlgebraDataset() + + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', got 'semi'." + ), + ): + dataset.split([0.5, 0.5], node_space_setting=cast(Any, "semi")) + + def test_split_raises_when_a_split_has_zero_hyperedges(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() @@ -1155,30 +1190,6 @@ def test_enrich_node_features_from_dataset(): assert torch.equal(target_dataset.hdata.x, torch.tensor([[3.0, 30.0], [1.0, 10.0]])) -def test_enrich_node_features_from_propagates_hdata_validation_errors(): - source_dataset = Dataset.from_hdata( - HData( - x=torch.tensor([[1.0], [2.0]]), - hyperedge_index=torch.tensor([[0, 1], [0, 0]]), - global_node_ids=torch.tensor([10, 20]), - ) - ) - target_dataset = Dataset.from_hdata( - HData( - x=torch.tensor([[0.0]]), - hyperedge_index=torch.tensor([[0], [0]]), - global_node_ids=torch.tensor([10]), - ) - ) - target_dataset.hdata.global_node_ids = None - - with pytest.raises( - ValueError, - match=re.escape("Both HData instances must define global_node_ids to align node features."), - ): - target_dataset.enrich_node_features_from(source_dataset) - - def test_enrich_node_features_from_dataset_with_fill_value(): source_dataset = Dataset.from_hdata( HData( diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index 6aa9c79d..5e56c110 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -308,8 +308,7 @@ def test_collate_when_dataset_no_hyperedge_attr_presence(): def test_collate_when_dataset_has_no_global_node_ids(): x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - hdata.global_node_ids = None + hdata = HData(x=x, hyperedge_index=hyperedge_index, global_node_ids=None) sample0 = HData.from_hyperedge_index(torch.tensor([[0, 1], [0, 0]])) sample1 = HData.from_hyperedge_index(torch.tensor([[2], [1]])) diff --git a/hyperbench/tests/data/negative_sampling_scheduler_test.py b/hyperbench/tests/data/negative_sampling_scheduler_test.py index ffab739b..8c99f54b 100644 --- a/hyperbench/tests/data/negative_sampling_scheduler_test.py +++ b/hyperbench/tests/data/negative_sampling_scheduler_test.py @@ -1,6 +1,8 @@ import pytest import torch +import re +from typing import Any, cast from unittest.mock import MagicMock from hyperbench.data import NegativeSampler, NegativeSamplingSchedule, NegativeSamplingScheduler from hyperbench.types import HData @@ -34,7 +36,7 @@ def mock_sampler(mock_negative_hdata): def test_config_returns_scheduler_parameters(mock_sampler): - schedule = NegativeSamplingSchedule.EVERY_N_EPOCHS + schedule: NegativeSamplingSchedule = "every_n_epochs" scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, negative_sampling_schedule=schedule, @@ -50,7 +52,7 @@ def test_config_returns_scheduler_parameters(mock_sampler): def test_sample_caches_result_across_non_sampling_epochs(mock_sampler, mock_batch): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.FIRST_EPOCH, + negative_sampling_schedule="first_epoch", ) # Epoch 0: should sample @@ -65,7 +67,7 @@ def test_sample_caches_result_across_non_sampling_epochs(mock_sampler, mock_batc def test_sample_delegates_to_negative_sampler(mock_sampler, mock_batch, mock_negative_hdata): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule="every_epoch", ) result = scheduler.sample(mock_batch, epoch=0) @@ -77,7 +79,7 @@ def test_sample_delegates_to_negative_sampler(mock_sampler, mock_batch, mock_neg def test_sample_raises_when_cache_is_empty(mock_sampler, mock_batch): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_N_EPOCHS, + negative_sampling_schedule="every_n_epochs", negative_sampling_every_n=5, ) @@ -89,7 +91,7 @@ def test_sample_raises_when_cache_is_empty(mock_sampler, mock_batch): def test_sample_resamples_on_every_n_epoch(mock_sampler, mock_batch): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_N_EPOCHS, + negative_sampling_schedule="every_n_epochs", negative_sampling_every_n=3, ) @@ -114,7 +116,7 @@ def test_sample_resamples_on_every_n_epoch(mock_sampler, mock_batch): def test_should_sample_every_epoch(mock_sampler, epoch, expected_should_sample): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_EPOCH, + negative_sampling_schedule="every_epoch", ) assert scheduler.should_sample(epoch) == expected_should_sample @@ -137,7 +139,7 @@ def test_should_sample_every_epoch(mock_sampler, epoch, expected_should_sample): def test_should_sample_every_n_epochs(mock_sampler, epoch, every_n, expected_should_sample): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.EVERY_N_EPOCHS, + negative_sampling_schedule="every_n_epochs", negative_sampling_every_n=every_n, ) @@ -156,7 +158,51 @@ def test_should_sample_every_n_epochs(mock_sampler, epoch, every_n, expected_sho def test_should_sample_first_epoch(mock_sampler, epoch, expected_should_sample): scheduler = NegativeSamplingScheduler( negative_sampler=mock_sampler, - negative_sampling_schedule=NegativeSamplingSchedule.FIRST_EPOCH, + negative_sampling_schedule="first_epoch", ) assert scheduler.should_sample(epoch) == expected_should_sample + + +def test_should_sample_rejects_invalid_epoch(mock_sampler): + scheduler = NegativeSamplingScheduler(negative_sampler=mock_sampler) + + with pytest.raises(ValueError, match=re.escape("Epoch must be non-negative, got -1.")): + scheduler.should_sample(epoch=-1) + + +def test_should_sample_rejects_unsupported_schedule(mock_sampler): + scheduler = NegativeSamplingScheduler( + negative_sampler=mock_sampler, + negative_sampling_schedule=cast(Any, "sometimes"), + ) + + with pytest.raises( + ValueError, + match=re.escape("Unsupported negative sampling schedule: 'sometimes'."), + ): + scheduler.should_sample(epoch=0) + + +@pytest.mark.parametrize( + "every_n, expected_exception, expected_message", + [ + pytest.param( + 0, ValueError, "negative_sampling_every_n must be positive, got 0.", id="zero" + ), + pytest.param( + -1, ValueError, "negative_sampling_every_n must be positive, got -1.", id="negative" + ), + ], +) +def test_should_sample_rejects_invalid_every_n( + mock_sampler, every_n, expected_exception, expected_message +): + scheduler = NegativeSamplingScheduler( + negative_sampler=mock_sampler, + negative_sampling_schedule="every_n_epochs", + negative_sampling_every_n=cast(Any, every_n), + ) + + with pytest.raises(expected_exception, match=re.escape(expected_message)): + scheduler.should_sample(epoch=0) diff --git a/hyperbench/tests/data/sampler_test.py b/hyperbench/tests/data/sampler_test.py index 3ddc82f9..7afd1a09 100644 --- a/hyperbench/tests/data/sampler_test.py +++ b/hyperbench/tests/data/sampler_test.py @@ -2,6 +2,7 @@ import pytest import torch +from typing import Any, cast from hyperbench.data import ( BaseSampler, HyperedgeSampler, @@ -48,6 +49,14 @@ def test_create_sampler_from_strategy_node(): assert isinstance(sampler, NodeSampler) +def test_create_sampler_from_strategy_rejects_unsupported_strategy(): + with pytest.raises( + ValueError, + match=re.escape("Unsupported sampling strategy: 'edge'."), + ): + create_sampler_from_strategy(cast(Any, "edge")) + + def test_hyperedge_sampling_single_index(mock_four_node_two_hyperedge_hdata): sampler = HyperedgeSampler() result = sampler.sample(0, mock_four_node_two_hyperedge_hdata) @@ -123,6 +132,42 @@ def test_sample_empty_index_raises(mock_four_node_two_hyperedge_hdata, sampler): sampler.sample([], mock_four_node_two_hyperedge_hdata) +@pytest.mark.parametrize( + "index", + [ + pytest.param("0", id="string_index"), + pytest.param(True, id="bool_index"), + ], +) +def test_sample_rejects_non_integer_index(mock_four_node_two_hyperedge_hdata, index): + sampler = HyperedgeSampler() + + with pytest.raises( + TypeError, + match=re.escape("Index must be an integer or a list of integers."), + ): + sampler.sample(cast(Any, index), mock_four_node_two_hyperedge_hdata) + + +@pytest.mark.parametrize( + "index", + [ + pytest.param([0, "1"], id="list_with_string"), + pytest.param([0, False], id="list_with_bool"), + ], +) +def test_sample_rejects_index_list_with_non_integer_items( + mock_four_node_two_hyperedge_hdata, index +): + sampler = HyperedgeSampler() + + with pytest.raises( + TypeError, + match=re.escape("Index list must contain only integers."), + ): + sampler.sample(cast(Any, index), mock_four_node_two_hyperedge_hdata) + + @pytest.mark.parametrize( "sampler, label", [ diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 49dca383..0a9b6047 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -1,6 +1,6 @@ -import re import pytest import torch +import re from unittest.mock import MagicMock from typing import Any, cast @@ -327,80 +327,6 @@ def test_init_validates_input_values(kwargs, expected_message): HData(**kwargs) -@pytest.mark.parametrize( - "kwargs, expected_message", - [ - pytest.param( - {"x": cast(Any, [[1.0], [2.0]]), "hyperedge_index": torch.tensor([[0], [0]])}, - "x must be a torch.Tensor.", - id="x_not_tensor", - ), - pytest.param( - {"x": torch.randn(2, 1), "hyperedge_index": cast(Any, [[0], [0]])}, - "hyperedge_index must be a torch.Tensor.", - id="hyperedge_index_not_tensor", - ), - pytest.param( - { - "x": torch.randn(2, 1), - "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), - "hyperedge_attr": cast(Any, [[1.0]]), - }, - "hyperedge_attr must be a torch.Tensor.", - id="hyperedge_attr_not_tensor", - ), - pytest.param( - { - "x": torch.randn(2, 1), - "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), - "hyperedge_weights": cast(Any, [1.0]), - }, - "hyperedge_weights must be a torch.Tensor.", - id="hyperedge_weights_not_tensor", - ), - pytest.param( - { - "x": torch.randn(2, 1), - "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), - "global_node_ids": cast(Any, [0, 1]), - }, - "global_node_ids must be a torch.Tensor.", - id="global_node_ids_not_tensor", - ), - pytest.param( - { - "x": torch.randn(2, 1), - "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), - "y": cast(Any, [1.0]), - }, - "y must be a torch.Tensor.", - id="y_not_tensor", - ), - pytest.param( - { - "x": torch.randn(2, 1), - "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), - "num_nodes": cast(Any, 2.0), - }, - "num_nodes must be an int.", - id="num_nodes_not_int", - ), - pytest.param( - { - "x": torch.randn(2, 1), - "hyperedge_index": torch.tensor([[0, 1], [0, 0]]), - "num_hyperedges": cast(Any, True), - }, - "num_hyperedges must be an int.", - id="num_hyperedges_bool", - ), - ], -) -def test_init_validates_runtime_types(kwargs, expected_message): - with pytest.raises(TypeError, match=re.escape(expected_message)): - HData(**kwargs) - - @pytest.mark.parametrize( "kwargs, expected_message", [ @@ -1009,6 +935,25 @@ def test_split_transductive_counts( assert torch.equal(result.hyperedge_index, expected_hyperedge_index) +def test_split_raises_on_invalid_node_space_setting(): + hdata = HData( + x=torch.randn(2, 1), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + ) + + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', got 'semi'." + ), + ): + HData.split( + hdata, + split_hyperedge_ids=torch.tensor([0]), + node_space_setting=cast(Any, "semi"), + ) + + def test_split_inductive_subsets_node_features(): x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) hyperedge_index = torch.tensor([[0, 1, 3, 4], [0, 0, 1, 1]]) @@ -1050,7 +995,7 @@ def test_split_subsets_labels(): pytest.param( "inductive", torch.tensor([1]), - torch.arange(2), + torch.tensor([2, 3]), id="inductive", ), ], @@ -1060,8 +1005,7 @@ def test_split_handles_none_global_node_ids( ): x = torch.tensor([[10.0], [20.0], [30.0], [40.0]]) hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - hdata.global_node_ids = None + hdata = HData(x=x, hyperedge_index=hyperedge_index, global_node_ids=None) result = HData.split( hdata, @@ -1101,8 +1045,12 @@ def test_split_transductive_keeps_full_x_and_global_node_ids(): def test_split_transductive_handles_none_global_node_ids(): x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) hyperedge_index = torch.tensor([[0, 2, 3, 4], [0, 0, 1, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index, y=torch.tensor([1.0, 0.0])) - hdata.global_node_ids = None + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + y=torch.tensor([1.0, 0.0]), + global_node_ids=None, + ) result = HData.split( hdata, @@ -1238,12 +1186,34 @@ def test_enrich_node_features_concatenate(mock_hdata): assert result.x.shape == (5, 7) # 4 original + 3 enriched +@pytest.mark.parametrize( + "enrich_method", + [ + pytest.param("enrich_node_features", id="node_features"), + pytest.param("enrich_hyperedge_weights", id="hyperedge_weights"), + pytest.param("enrich_hyperedge_attr", id="hyperedge_attr"), + ], +) +def test_enrich_rejects_invalid_enrichment_mode(mock_hdata, enrich_method): + enricher_spec = NodeEnricher if enrich_method == "enrich_node_features" else HyperedgeEnricher + enricher = MagicMock(spec=enricher_spec) + + with pytest.raises( + ValueError, + match=re.escape( + "enrichment_mode must be one of 'replace', 'concatenate', or None, got 'append'." + ), + ): + getattr(mock_hdata, enrich_method)(enricher, enrichment_mode=cast(Any, "append")) + + enricher.enrich.assert_not_called() + + @pytest.mark.parametrize( "enrichment_mode", [ pytest.param("replace", id="replace"), pytest.param("concatenate", id="concatenate"), - pytest.param(None, id="none_enrichment_mode_defaults_to_replace"), ], ) def test_enrich_node_features_replace_preserves_global_node_ids(mock_hdata, enrichment_mode): @@ -1285,37 +1255,6 @@ def test_enrich_node_features_from_aligns_by_global_node_ids(): assert torch.equal(result.y, target_hdata.y) -@pytest.mark.parametrize( - "missing_side", - [ - pytest.param("source", id="source_missing_global_node_ids"), - pytest.param("target", id="target_missing_global_node_ids"), - ], -) -def test_enrich_node_features_from_raises_without_global_node_ids(missing_side): - source_hdata = HData( - x=torch.tensor([[1.0], [2.0]]), - hyperedge_index=torch.tensor([[0, 1], [0, 0]]), - global_node_ids=torch.tensor([10, 20]), - ) - target_hdata = HData( - x=torch.tensor([[0.0]]), - hyperedge_index=torch.tensor([[0], [0]]), - global_node_ids=torch.tensor([10]), - ) - - if missing_side == "source": - source_hdata.global_node_ids = None - else: - target_hdata.global_node_ids = None - - with pytest.raises( - ValueError, - match=re.escape("Both HData instances must define global_node_ids to align node features."), - ): - target_hdata.enrich_node_features_from(source_hdata) - - def test_enrich_node_features_from_raises_when_source_rows_do_not_match_global_node_ids(): source_hdata = HData( x=torch.empty((0, 0)), @@ -1504,6 +1443,30 @@ def test_enrich_methods_do_not_share_mutable_storage_with_source(hdata_with_all_ ) +def test_enrich_node_features_from_raises_on_invalid_node_space_setting(): + source_hdata = HData( + x=torch.tensor([[1.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + target_hdata = HData( + x=torch.tensor([[0.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', got 'semi'." + ), + ): + target_hdata.enrich_node_features_from( + source_hdata, + node_space_setting=cast(Any, "semi"), + ) + + def test_enrich_hyperedge_weights_replace(): x = torch.tensor([[1.0], [2.0], [3.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) @@ -1656,8 +1619,12 @@ def test_get_device_if_all_consistent_handles_none_global_node_ids(): x = torch.randn(3, 2) hyperedge_index = torch.tensor([[0, 1], [0, 0]]) hyperedge_attr = torch.randn(1, 4) - hdata = HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) - hdata.global_node_ids = None + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_attr=hyperedge_attr, + global_node_ids=None, + ) assert hdata.get_device_if_all_consistent() == torch.device("cpu") @@ -1996,13 +1963,12 @@ def test_remove_hyperedges_with_fewer_than_k_nodes_keeps_none_hyperedge_attr(): def test_remove_hyperedges_with_fewer_than_k_nodes_handles_none_global_node_ids(): x = torch.randn(5, 2) hyperedge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 1, 1]]) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - hdata.global_node_ids = None + hdata = HData(x=x, hyperedge_index=hyperedge_index, global_node_ids=None) result = hdata.remove_hyperedges_with_fewer_than_k_nodes(k=3) assert result.global_node_ids is not None - assert torch.equal(result.global_node_ids, torch.arange(result.num_nodes)) + assert torch.equal(result.global_node_ids, torch.tensor([2, 3, 4])) def test_remove_hyperedges_with_fewer_than_k_nodes_subsets_hyperedge_weights(): diff --git a/hyperbench/tests/types/hypergraph_test.py b/hyperbench/tests/types/hypergraph_test.py index 77a5b181..d63b24a8 100644 --- a/hyperbench/tests/types/hypergraph_test.py +++ b/hyperbench/tests/types/hypergraph_test.py @@ -4,6 +4,7 @@ import pytest import torch +from typing import Any, cast from unittest.mock import patch from hyperbench.types import HIFHypergraph, Hypergraph, HyperedgeIndex from hyperbench.tests import MOCK_BASE_PATH @@ -572,6 +573,16 @@ def test_reduce_with_clique_expansion_matches_specialized_reducer(): assert torch.equal(result, expected) +def test_reduce_rejects_unsupported_strategy(): + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + + with pytest.raises( + ValueError, + match=re.escape("Unsupported reduction strategy: fancy_expansion. "), + ): + HyperedgeIndex(hyperedge_index).reduce(cast(Any, "fancy_expansion")) + + @pytest.mark.parametrize( "hyperedge_index_tensor, expected_num_edges", [ diff --git a/hyperbench/tests/utils/node_utils_test.py b/hyperbench/tests/utils/node_utils_test.py index 76b03add..b3cee375 100644 --- a/hyperbench/tests/utils/node_utils_test.py +++ b/hyperbench/tests/utils/node_utils_test.py @@ -1,10 +1,13 @@ import pytest import torch +import re +from typing import Any, cast from hyperbench.utils import ( assign_hyperedge_label_to_nodes, is_inductive_setting, is_transductive_setting, + validate_node_space_setting, ) @@ -89,3 +92,32 @@ def test_is_inductive_setting(node_space_setting, expected): ) def test_is_transductive_setting(node_space_setting, expected): assert is_transductive_setting(node_space_setting) == expected + + +@pytest.mark.parametrize( + "node_space_setting", + [ + pytest.param("inductive", id="inductive"), + pytest.param("transductive", id="transductive"), + ], +) +def test_validate_node_space_setting_accepts_supported_values(node_space_setting): + validate_node_space_setting(node_space_setting) + + +@pytest.mark.parametrize( + "node_space_setting", + [ + pytest.param("semi", id="unsupported_string"), + pytest.param(None, id="none"), + ], +) +def test_validate_node_space_setting_rejects_unsupported_values(node_space_setting): + with pytest.raises( + ValueError, + match=re.escape( + "node_space_setting must be one of 'transductive' or 'inductive', " + f"got {node_space_setting!r}." + ), + ): + validate_node_space_setting(cast(Any, node_space_setting)) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index b2526908..92d28254 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -14,6 +14,7 @@ is_inductive_setting, is_transductive_setting, to_0based_ids, + validate_node_space_setting, ) from hyperbench.types.hypergraph import HyperedgeIndex @@ -82,7 +83,7 @@ def __init__( ) self.__validate_number_of_nodes_and_hyperedges() - self.global_node_ids: Tensor | None = ( + self.global_node_ids = ( # torch.arange is to handle isolated nodes, as they are already considered # when computing self.num_nodes via num_nodes_if_isolated_exist global_node_ids if global_node_ids is not None else torch.arange(self.num_nodes) @@ -308,7 +309,12 @@ def split( Returns: hdata: The splitted instance with remapped node and hyperedge IDs. + + Raises: + ValueError: If ``node_space_setting`` is not ``"transductive"`` or ``"inductive"``. """ + cls.__validate_node_space_setting_value(node_space_setting) + # Mask to keep only incidences belonging to selected hyperedges # Example: hyperedge_index = [[0, 0, 1, 2, 3, 4], # [0, 0, 0, 1, 2, 2]] @@ -369,12 +375,10 @@ def split( .item ) - split_global_node_ids = None - if hdata.global_node_ids is not None: - split_global_node_ids = hdata.global_node_ids[split_unique_node_ids] - + split_x = hdata.x[split_unique_node_ids] + split_global_node_ids = hdata.global_node_ids[split_unique_node_ids] return cls( - x=hdata.x[split_unique_node_ids], + x=split_x, hyperedge_index=split_hyperedge_index.clone(), hyperedge_weights=split_hyperedge_weights, hyperedge_attr=split_hyperedge_attr, @@ -387,7 +391,7 @@ def split( def enrich_node_features( self, enricher: NodeEnricher, - enrichment_mode: EnrichmentMode | None = None, + enrichment_mode: EnrichmentMode | None = "replace", ) -> HData: """ Enrich node features using the provided node feature enricher. @@ -397,7 +401,9 @@ def enrich_node_features( enrichment_mode: How to combine generated features with existing ``hdata.x``. ``concatenate`` appends new features as additional columns. ``replace`` substitutes ``hdata.x`` entirely. + Defaults to ``replace`` if not provided. """ + self.__validate_enrichment_mode(enrichment_mode) enriched_features = enricher.enrich(self.hyperedge_index) match enrichment_mode: @@ -461,10 +467,6 @@ def enrich_node_features_from( """ source_global_node_ids = hdata_with_features.global_node_ids source_x = hdata_with_features.x - if self.global_node_ids is None or source_global_node_ids is None: - raise ValueError( - "Both HData instances must define global_node_ids to align node features." - ) if source_x.size(0) != source_global_node_ids.size(0): raise ValueError( "Expected hdata_with_features.x rows to align with hdata_with_features.global_node_ids." @@ -535,7 +537,7 @@ def enrich_node_features_from( def enrich_hyperedge_weights( self, enricher: HyperedgeEnricher, - enrichment_mode: EnrichmentMode | None = None, + enrichment_mode: EnrichmentMode | None = "replace", ) -> HData: """ Enrich hyperedge weights using the provided hyperedge weight enricher. @@ -545,10 +547,12 @@ def enrich_hyperedge_weights( enrichment_mode: How to combine generated weights with existing ``hdata.hyperedge_weights``. ``concatenate`` appends new weights to the existing 1D tensor. ``replace`` substitutes ``hdata.hyperedge_weights`` entirely. + Defaults to ``replace`` if not provided. Returns: hdata: A new `HData` with enriched hyperedge weights. """ + self.__validate_enrichment_mode(enrichment_mode) enriched_weights = enricher.enrich(self.hyperedge_index) match enrichment_mode: @@ -575,7 +579,7 @@ def enrich_hyperedge_weights( def enrich_hyperedge_attr( self, enricher: HyperedgeEnricher, - enrichment_mode: EnrichmentMode | None = None, + enrichment_mode: EnrichmentMode | None = "replace", ) -> HData: """ Enrich hyperedge features using the provided hyperedge feature enricher. @@ -585,7 +589,9 @@ def enrich_hyperedge_attr( enrichment_mode: How to combine generated features with existing ``hdata.hyperedge_attr``. ``concatenate`` appends new features as additional columns. ``replace`` substitutes ``hdata.hyperedge_attr`` entirely. + Defaults to ``replace`` if not provided. """ + self.__validate_enrichment_mode(enrichment_mode) enriched_features = enricher.enrich(self.hyperedge_index) match enrichment_mode: @@ -620,13 +626,18 @@ def get_device_if_all_consistent(self) -> torch.device: Raises: ValueError: If tensors are on different devices. """ - devices = {self.x.device, self.hyperedge_index.device, self.y.device} - if self.global_node_ids is not None: - devices.add(self.global_node_ids.device) + devices = { + self.x.device, + self.hyperedge_index.device, + self.global_node_ids.device, + self.y.device, + } + if self.hyperedge_attr is not None: devices.add(self.hyperedge_attr.device) if self.hyperedge_weights is not None: devices.add(self.hyperedge_weights.device) + if len(devices) > 1: raise ValueError(f"Inconsistent device placement: {devices}") @@ -638,9 +649,7 @@ def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HData: ).remove_hyperedges_with_fewer_than_k_nodes(k) x = self.x[hyperedge_index_wrapper.node_ids] - global_node_ids = None - if self.global_node_ids is not None: - global_node_ids = self.global_node_ids[hyperedge_index_wrapper.node_ids] + global_node_ids = self.global_node_ids[hyperedge_index_wrapper.node_ids] y = self.y[hyperedge_index_wrapper.hyperedge_ids] hyperedge_attr = None @@ -993,9 +1002,6 @@ def __validate_hyperedge_attr(self) -> None: if self.hyperedge_attr is None: return - if not isinstance(self.hyperedge_attr, Tensor): - raise TypeError("hyperedge_attr must be a torch.Tensor.") - if self.hyperedge_attr.dim() != 2: raise ValueError( f"hyperedge_attr must be a 2D tensor, got shape {tuple(self.hyperedge_attr.shape)}." @@ -1034,9 +1040,6 @@ def __validate_hyperedge_weights(self) -> None: if self.hyperedge_weights is None: return - if not isinstance(self.hyperedge_weights, Tensor): - raise TypeError("hyperedge_weights must be a torch.Tensor.") - if self.hyperedge_weights.dim() != 1: raise ValueError( f"hyperedge_weights must be a 1D tensor, got shape {tuple(self.hyperedge_weights.shape)}." @@ -1048,9 +1051,6 @@ def __validate_hyperedge_weights(self) -> None: ) def __validate_global_node_ids(self) -> None: - if not isinstance(self.global_node_ids, Tensor): - raise TypeError("global_node_ids must be a torch.Tensor.") - if self.global_node_ids.dim() != 1: raise ValueError( f"global_node_ids must be a 1D tensor, got shape {tuple(self.global_node_ids.shape)}." @@ -1067,9 +1067,6 @@ def __validate_global_node_ids(self) -> None: ) def __validate_labels(self) -> None: - if not isinstance(self.y, Tensor): - raise TypeError("y must be a torch.Tensor.") - if self.y.dim() != 1: raise ValueError(f"y must be a 1D tensor, got shape {tuple(self.y.shape)}.") if self.y.size(0) != self.num_hyperedges: @@ -1090,6 +1087,8 @@ def __validate_node_space_setting( node_space_setting: NodeSpaceSetting, fill_value: NodeSpaceFiller | None, ) -> None: + validate_node_space_setting(node_space_setting) + if is_transductive_setting(node_space_setting) and fill_value is not None: raise ValueError( "fill_value cannot be provided when node_space_setting='transductive'." @@ -1097,24 +1096,33 @@ def __validate_node_space_setting( if is_inductive_setting(node_space_setting) and fill_value is None: raise ValueError("fill_value must be provided when node_space_setting='inductive'.") + @staticmethod + def __validate_enrichment_mode(enrichment_mode: EnrichmentMode | None) -> None: + if enrichment_mode is None or enrichment_mode in ("replace", "concatenate"): + return + + raise ValueError( + f"enrichment_mode must be one of 'replace', 'concatenate', or None, got {enrichment_mode!r}." + ) + + @staticmethod + def __validate_node_space_setting_value(node_space_setting: NodeSpaceSetting) -> None: + if is_transductive_setting(node_space_setting) or is_inductive_setting(node_space_setting): + return + + raise ValueError( + "node_space_setting must be one of 'transductive' or 'inductive', " + f"got {node_space_setting!r}." + ) + def __validate_number_of_nodes_and_hyperedges(self) -> None: # Check on bool as bool is a subclass of int - if not isinstance(self.num_nodes, int) or isinstance(self.num_nodes, bool): - raise TypeError("num_nodes must be an int.") if self.num_nodes < 0: raise ValueError(f"num_nodes must be non-negative, got {self.num_nodes}.") - - if not isinstance(self.num_hyperedges, int) or isinstance(self.num_hyperedges, bool): - raise TypeError("num_hyperedges must be an int.") if self.num_hyperedges < 0: raise ValueError(f"num_hyperedges must be non-negative, got {self.num_hyperedges}.") def __validate_x_and_hyperedge_index_type_and_dim(self) -> None: - if not isinstance(self.x, Tensor): - raise TypeError("x must be a torch.Tensor.") - if not isinstance(self.hyperedge_index, Tensor): - raise TypeError("hyperedge_index must be a torch.Tensor.") - if self.x.dim() != 2: raise ValueError(f"x must be a 2D tensor, got shape {tuple(self.x.shape)}.") diff --git a/hyperbench/types/hypergraph.py b/hyperbench/types/hypergraph.py index 39abd2e1..3cd3133d 100644 --- a/hyperbench/types/hypergraph.py +++ b/hyperbench/types/hypergraph.py @@ -788,8 +788,13 @@ def reduce(self, strategy: Literal["clique_expansion"], **kwargs) -> Tensor: edge_index: The edge index of the reduced graph. Size ``(2, num_edges)``. """ match strategy: - case _: + case "clique_expansion": return self.reduce_to_edge_index_on_clique_expansion(**kwargs) + case _: + raise ValueError( + f"Unsupported reduction strategy: {strategy}. " + "Supported strategies: ['clique_expansion']" + ) def reduce_to_edge_index_on_clique_expansion( self, diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 9e17e994..58711c04 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -32,6 +32,7 @@ assign_hyperedge_label_to_nodes, is_inductive_setting, is_transductive_setting, + validate_node_space_setting, ) from .random_utils import create_seeded_torch_generator @@ -87,6 +88,7 @@ "validate_hif_data", "validate_hif_json", "validate_http_url", + "validate_node_space_setting", "write_dataset_to_disk_as_zst", "write_zst_file_to_disk", ] diff --git a/hyperbench/utils/node_utils.py b/hyperbench/utils/node_utils.py index ac3b50a3..9e1203fb 100644 --- a/hyperbench/utils/node_utils.py +++ b/hyperbench/utils/node_utils.py @@ -26,3 +26,21 @@ def is_inductive_setting(node_space_setting: NodeSpaceSetting | None) -> bool: def is_transductive_setting(node_space_setting: NodeSpaceSetting | None) -> bool: return node_space_setting == "transductive" + + +def validate_node_space_setting(node_space_setting: NodeSpaceSetting) -> None: + """ + Validate that the node space setting is one of the supported values. + + Args: + node_space_setting: The node space setting to validate, which should be either "inductive" or "transductive". + + Raises: + ValueError: If the node space setting is not one of the supported values. + """ + if is_transductive_setting(node_space_setting) or is_inductive_setting(node_space_setting): + return + + raise ValueError( + f"node_space_setting must be one of 'transductive' or 'inductive', got {node_space_setting!r}." + ) From 9d4075ae57d9c22d9b3886b9a55f774bcce54c56 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Tue, 26 May 2026 11:55:25 +0200 Subject: [PATCH 10/15] fix: rebase --- hyperbench/tests/data/dataset_test.py | 5 ++++- hyperbench/tests/types/hdata_test.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 73598791..58b88417 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -535,7 +535,10 @@ def test_enrich_hyperedge_weights_concatenate( dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") - enricher.enrich.assert_called_once_with(dataset.hdata.hyperedge_index) + enricher.enrich.assert_called_once() + enriched_hyperedge_index = enricher.enrich.call_args.args[0] + + assert torch.equal(enriched_hyperedge_index, dataset.hdata.hyperedge_index) assert dataset.hdata.hyperedge_weights is not None assert torch.equal(dataset.hdata.hyperedge_weights, torch.tensor([1.0, 2.0, 3.0])) assert dataset.hdata.hyperedge_weights.shape == (3,) diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index 0a9b6047..a519e3ae 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -784,15 +784,20 @@ def test_cat_same_node_space_does_not_share_mutable_storage_with_inputs( __assert_mutating_result_keeps_source_tensors_unchanged(other_hdata, result) -def test_cat_same_node_space_clones_provided_x(hdata_with_all_mutable_tensors): - hdata = hdata_with_all_mutable_tensors - custom_x = torch.full_like(hdata.x, 9.0) +def test_cat_same_node_space_clones_provided_x_and_global_node_ids(hdata_with_all_mutable_tensors): + custom_x = torch.full_like(hdata_with_all_mutable_tensors.x, 9.0) + custom_global_node_ids = torch.arange(hdata_with_all_mutable_tensors.num_nodes) + 100 + original_custom_x = custom_x.clone() + original_custom_global_node_ids = custom_global_node_ids.clone() - result = HData.cat_same_node_space([hdata], x=custom_x) + result = HData.cat_same_node_space( + hdatas=[hdata_with_all_mutable_tensors], x=custom_x, global_node_ids=custom_global_node_ids + ) result.x.flatten()[0].add_(1) assert torch.equal(custom_x, original_custom_x) + assert torch.equal(custom_global_node_ids, original_custom_global_node_ids) def test_add_negative_samples_combines_positive_and_negative_hyperedges(mock_negative_sampler): From adfb7af017b509698f2b7ee1cb59e832cce13a29 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Tue, 26 May 2026 12:03:24 +0200 Subject: [PATCH 11/15] fix: rebase --- hyperbench/types/hypergraph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hyperbench/types/hypergraph.py b/hyperbench/types/hypergraph.py index 3cd3133d..6706ba0a 100644 --- a/hyperbench/types/hypergraph.py +++ b/hyperbench/types/hypergraph.py @@ -527,8 +527,7 @@ def get_sparse_incidence_matrix( if max_hyperedge_id >= num_hyperedges: raise ValueError( "num_hyperedges is too small for the hyperedge index. " - f"Got num_hyperedges={num_hyperedges}, " - f"but max hyperedge id is {max_hyperedge_id}." + f"Got num_hyperedges={num_hyperedges}, but max hyperedge id is {max_hyperedge_id}." ) incidence_values = torch.ones(self.num_incidences, dtype=torch.float, device=device) From 4a06e94ebe57c64bb7ce0caa404ed72b8cc573f7 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Tue, 26 May 2026 13:31:56 +0200 Subject: [PATCH 12/15] fix: add more validation --- hyperbench/data/dataset.py | 24 +- hyperbench/data/enricher.py | 72 ++++- hyperbench/data/hif.py | 8 + hyperbench/data/splitter.py | 10 + hyperbench/models/node2vec.py | 2 +- hyperbench/tests/data/dataset_test.py | 10 +- hyperbench/tests/data/enricher_test.py | 276 +++++++++++++++++- hyperbench/tests/data/hif_test.py | 24 ++ .../tests/data/negative_sampler_test.py | 61 +++- hyperbench/tests/data/splitter_test.py | 40 +++ hyperbench/tests/train/latex_logger_test.py | 27 +- .../tests/train/markdown_logger_test.py | 14 +- hyperbench/tests/train/trainer_test.py | 61 ++-- hyperbench/tests/types/graph_test.py | 67 +++++ hyperbench/tests/types/hdata_test.py | 15 +- hyperbench/tests/types/hypergraph_test.py | 108 ++++++- hyperbench/train/__init__.py | 4 +- hyperbench/train/latex_logger.py | 9 +- hyperbench/train/markdown_logger.py | 6 +- hyperbench/train/trainer.py | 19 +- hyperbench/types/graph.py | 27 +- hyperbench/types/hdata.py | 30 +- hyperbench/types/hypergraph.py | 92 ++++-- hyperbench/utils/__init__.py | 14 + hyperbench/utils/data_utils.py | 53 ++++ 25 files changed, 918 insertions(+), 155 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index fb325ebc..33e25f1c 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -1,6 +1,5 @@ from __future__ import annotations -import math import torch from typing import TYPE_CHECKING, Any @@ -12,6 +11,7 @@ NodeSpaceSetting, is_transductive_setting, validate_node_space_setting, + validate_split_ratios, ) from hyperbench.data.hif import HIFLoader, HIFProcessor @@ -357,7 +357,8 @@ def split_with_ratios( node space. """ validate_node_space_setting(node_space_setting) - self.__validate_split_ratios(ratios) + validate_split_ratios(ratios) + device = self.hdata.device hyperedge_splitter = HyperedgeIDSplitter(self.hdata) @@ -447,22 +448,3 @@ def stats(self) -> dict[str, Any]: """ return self.hdata.stats() - - @staticmethod - def __validate_split_ratios(ratios: list[float]) -> None: - if len(ratios) < 1: - raise ValueError("Split ratios cannot be empty.") - - for ratio in ratios: - if not math.isfinite(float(ratio)): - raise ValueError(f"Split ratios must be finite, got {ratio}.") - if ratio <= 0.0: - raise ValueError(f"Split ratios must be positive, got {ratio}.") - - # Allow small imprecision in sum of ratios, but raise error if it's significant - # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid) - # ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError) - # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) - ratio_sum = float(sum(ratios)) - if abs(ratio_sum - 1.0) > 1e-6: - raise ValueError(f"Split ratios must sum to 1.0, got {ratio_sum}.") diff --git a/hyperbench/data/enricher.py b/hyperbench/data/enricher.py index 5f888023..c6308022 100644 --- a/hyperbench/data/enricher.py +++ b/hyperbench/data/enricher.py @@ -1,6 +1,6 @@ -import warnings import random import torch +import warnings from abc import ABC, abstractmethod from torch import Tensor, optim @@ -8,6 +8,13 @@ from torch_geometric.nn import Node2Vec as PyGNode2Vec from hyperbench.types import EdgeIndex, HyperedgeIndex from hyperbench.models import VilLain +from hyperbench.utils import ( + validate_is_between, + validate_is_finite, + validate_is_finite_when_provided, + validate_is_non_negative, + validate_is_positive, +) EnrichmentMode: TypeAlias = Literal["concatenate", "replace"] @@ -63,6 +70,8 @@ def __init__( self.weight_decay = weight_decay self.verbose = verbose + self.__validate() + def _empty_features(self, hyperedge_index: Tensor) -> Tensor: """ Return an empty feature matrix on the same device as ``hyperedge_index``. @@ -147,6 +156,28 @@ def _train(self, hyperedge_index: Tensor): return model + def __validate(self) -> None: + validate_is_positive("num_features", self.embedding_dim) + validate_is_non_negative("num_nodes", self.num_nodes) + validate_is_non_negative("num_hyperedges", self.num_hyperedges) + + if self.labels_per_subspace < 2: + raise ValueError( + f"'labels_per_subspace' must be at least 2, got {self.labels_per_subspace}." + ) + + validate_is_positive("training_steps", self.training_steps) + validate_is_positive("generation_steps", self.generation_steps) + validate_is_finite("tau", self.tau) + validate_is_positive("tau", self.tau) + validate_is_finite("eps", self.eps) + validate_is_positive("eps", self.eps) + validate_is_positive("num_epochs", self.num_epochs) + validate_is_positive("learning_rate", self.learning_rate) + validate_is_non_negative("weight_decay", self.weight_decay) + validate_is_finite("learning_rate", self.learning_rate) + validate_is_finite("weight_decay", self.weight_decay) + class Enricher(ABC): """ @@ -322,8 +353,9 @@ def __init__( beta: float | None = None, ): super().__init__(cache_dir=cache_dir) - if alpha < 0.0 or alpha > 1.0: - raise ValueError("Alpha must be between 0.0 and 1.0.") + + validate_is_between("alpha", alpha, 0.0, 1.0) + validate_is_finite_when_provided("beta", beta) self.alpha = alpha self.beta = beta @@ -407,12 +439,6 @@ def __init__( verbose: bool = False, ): super().__init__(cache_dir=cache_dir) - if walk_length < context_size: - raise ValueError( - f"Expected walk_length >= context_size, got " - f"walk_length={walk_length}, context_size={context_size}." - ) - self.embedding_dim = num_features self.walk_length = walk_length self.context_size = context_size @@ -428,6 +454,8 @@ def __init__( self.sparse = sparse self.verbose = verbose + self.__validate() + def enrich(self, hyperedge_index: Tensor) -> Tensor: """ Compute Node2Vec embeddings from the clique expansion of the hypergraph. @@ -519,6 +547,28 @@ def enrich(self, hyperedge_index: Tensor) -> Tensor: # Detach node embeddings from computation graph and return them return x.detach().to(device) + def __validate(self) -> None: + validate_is_positive("num_features", self.embedding_dim) + validate_is_positive("walk_length", self.walk_length) + validate_is_positive("context_size", self.context_size) + if self.walk_length < self.context_size: + raise ValueError( + "Expected walk_length >= context_size, got " + f"walk_length={self.walk_length}, context_size={self.context_size}." + ) + + validate_is_positive("num_walks_per_node", self.num_walks_per_node) + validate_is_finite("p", self.p) + validate_is_positive("p", self.p) + validate_is_finite("q", self.q) + validate_is_positive("q", self.q) + validate_is_positive("num_negative_samples", self.num_negative_samples) + validate_is_non_negative("num_nodes", self.num_nodes) + validate_is_positive("num_epochs", self.num_epochs) + validate_is_finite("learning_rate", self.learning_rate) + validate_is_positive("learning_rate", self.learning_rate) + validate_is_positive("batch_size", self.batch_size) + class LaplacianPositionalEncodingEnricher(NodeEnricher): """ @@ -540,6 +590,10 @@ def __init__( cache_dir: str | None = None, ): super().__init__(cache_dir=cache_dir) + + validate_is_positive("num_features", num_features) + validate_is_non_negative("num_nodes", num_nodes) + self.num_features = num_features self.num_nodes = num_nodes diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 3b5312f8..276842d9 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -74,6 +74,9 @@ def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData: # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} + if len(node_id_to_idx) != num_nodes: + raise ValueError("HIF node IDs must be unique.") + # Initialize edge_set only with edges that have incidences, so that # we avoid inflating edge count due to isolated nodes/missing incidences hyperedge_id_to_idx: dict[Any, int] = {} @@ -84,6 +87,11 @@ def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData: for incidence in hypergraph.incidences: node_id = incidence.get("node", 0) hyperedge_id = incidence.get("edge", 0) + if node_id not in node_id_to_idx: + raise ValueError( + f"Incidence references unknown node id {node_id!r}; " + "all incidence nodes must be declared in the HIF nodes list." + ) if hyperedge_id not in hyperedge_id_to_idx: # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences diff --git a/hyperbench/data/splitter.py b/hyperbench/data/splitter.py index 0bd26dcd..6b453a61 100644 --- a/hyperbench/data/splitter.py +++ b/hyperbench/data/splitter.py @@ -4,6 +4,7 @@ from typing import cast from torch import Tensor from hyperbench.types import HData +from hyperbench.utils import validate_is_non_empty, validate_split_ratios class Splitter(ABC): @@ -60,6 +61,13 @@ def ensure_split_covers_all_nodes( Raises: ValueError: If one or more nodes do not appear in any hyperedge of the source hypergraph. """ + validate_is_non_empty("hyperedge_ids_by_split", hyperedge_ids_by_split) + if split_idx < 0 or split_idx >= len(hyperedge_ids_by_split): + raise ValueError( + f"split_idx must reference an existing split, got {split_idx} " + f"for {len(hyperedge_ids_by_split)} splits." + ) + required_node_ids = torch.arange(self.hdata.num_nodes, device=self.hdata.device) available_node_ids = self.hdata.hyperedge_index[0].unique() missing_from_hypergraph_mask = torch.logical_not( @@ -185,6 +193,8 @@ def split(self, to_split: Tensor, ratios: list[float]) -> tuple[list[Tensor], li hyperedge_ids_by_split: The updated hyperedge IDs for each split. ratios: The final ratios of hyperedges in each split after rebalancing. """ + validate_split_ratios(ratios) + # Cumulative floor boundaries keep early splits from over-consuming hyperedges. # The last split absorbs any rounding remainder. num_hyperedges = int(to_split.size(0)) diff --git a/hyperbench/models/node2vec.py b/hyperbench/models/node2vec.py index 52ef52b8..b7d40d88 100644 --- a/hyperbench/models/node2vec.py +++ b/hyperbench/models/node2vec.py @@ -56,7 +56,7 @@ def __init__( super().__init__() if walk_length < context_size: raise ValueError( - f"Expected walk_length >= context_size, got " + "Expected walk_length >= context_size, got " f"walk_length={walk_length}, context_size={context_size}." ) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 58b88417..3a23a769 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -690,19 +690,17 @@ def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_nodes): with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_nodes): dataset = AlgebraDataset() - with pytest.raises(ValueError, match=re.escape("Split ratios must sum to 1.0")): + with pytest.raises(ValueError, match=re.escape("'ratios' must sum to 1.0")): dataset.split([0.8, 0.1, 0.05]) @pytest.mark.parametrize( "ratios, expected_exception, expected_message", [ - pytest.param([], ValueError, "Split ratios cannot be empty.", id="empty"), + pytest.param([], ValueError, "'ratios' cannot be empty.", id="empty"), + pytest.param([0.5, 0.0, 0.5], ValueError, "'ratios' must be positive, got 0.0.", id="zero"), pytest.param( - [0.5, 0.0, 0.5], ValueError, "Split ratios must be positive, got 0.0.", id="zero" - ), - pytest.param( - [0.5, float("inf")], ValueError, "Split ratios must be finite, got inf.", id="infinite" + [0.5, float("inf")], ValueError, "'ratios' must be finite, got inf.", id="infinite" ), ], ) diff --git a/hyperbench/tests/data/enricher_test.py b/hyperbench/tests/data/enricher_test.py index 40800e85..8ee7f5e4 100644 --- a/hyperbench/tests/data/enricher_test.py +++ b/hyperbench/tests/data/enricher_test.py @@ -3,6 +3,7 @@ import re from pathlib import Path +from collections.abc import Callable from unittest.mock import patch from torch import Tensor from hyperbench.data import ( @@ -17,7 +18,6 @@ VilLainEnricher, VilLainHyperedgeAttrsEnricher, ) - from hyperbench.data.enricher import Enricher, _VilLainTrainer from hyperbench.tests.mock.mock import new_mock_pyg_node2vec, new_mock_villain @@ -92,10 +92,15 @@ def test_fill_value_hyperedge_attrs_enricher_returns_empty_attrs_for_empty_input ], ) def test_ab_hyperedge_weights_enricher_rejects_invalid_alpha(alpha: float) -> None: - with pytest.raises(ValueError, match=re.escape("Alpha must be between 0.0 and 1.0.")): + with pytest.raises(ValueError, match=re.escape("'alpha' must be between 0.0 and 1.0")): ABHyperedgeWeightsEnricher(alpha=alpha) +def test_ab_hyperedge_weights_enricher_rejects_non_finite_beta() -> None: + with pytest.raises(ValueError, match=re.escape("'beta' must be finite when provided")): + ABHyperedgeWeightsEnricher(beta=float("inf")) + + def test_ab_hyperedge_weights_enricher_counts_nodes_per_hyperedge( mock_two_hyperedge_index: Tensor, ) -> None: @@ -137,6 +142,74 @@ def test_node2vec_enricher_rejects_context_larger_than_walk_length() -> None: Node2VecEnricher(num_features=4, walk_length=2, context_size=3) +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: Node2VecEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, walk_length=0), + "'walk_length' must be positive", + id="walk_length", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, context_size=0), + "'context_size' must be positive", + id="context_size", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_walks_per_node=0), + "'num_walks_per_node' must be positive", + id="walks_per_node", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, p=0.0), + "'p' must be positive", + id="p", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, q=0.0), + "'q' must be positive", + id="q", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_negative_samples=0), + "'num_negative_samples' must be positive", + id="negative_samples", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, num_epochs=0), + "'num_epochs' must be positive", + id="epochs", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, learning_rate=float("nan")), + "'learning_rate' must be finite", + id="learning_rate_finite", + ), + pytest.param( + lambda: Node2VecEnricher(num_features=3, batch_size=0), + "'batch_size' must be positive", + id="batch_size", + ), + ], +) +def test_node2vec_enricher_rejects_invalid_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_node2vec_enricher_returns_empty_features_when_no_nodes() -> None: hyperedge_index = torch.zeros((2, 0), dtype=torch.long) enricher = Node2VecEnricher(num_features=3) @@ -283,6 +356,29 @@ def test_laplacian_positional_encoding_enricher_uses_explicit_num_nodes() -> Non assert result.shape == (4, 2) +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: LaplacianPositionalEncodingEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: LaplacianPositionalEncodingEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + ], +) +def test_laplacian_positional_encoding_enricher_rejects_invalid_semantic_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_villain_trainer_resolves_explicit_and_inferred_counts( mock_two_hyperedge_index: Tensor, ) -> None: @@ -302,6 +398,94 @@ def test_villain_trainer_falls_back_to_inferred_counts( assert trainer._num_hyperedges(mock_two_hyperedge_index) == 2 +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: VilLainEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, num_hyperedges=-1), + "'num_hyperedges' must be non-negative", + id="num_hyperedges", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, labels_per_subspace=1), + "'labels_per_subspace' must be at least 2", + id="labels_per_subspace", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, training_steps=0), + "'training_steps' must be positive", + id="training_steps", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, generation_steps=0), + "'generation_steps' must be positive", + id="generation_steps", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, tau=float("nan")), + "'tau' must be finite", + id="tau_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, tau=0.0), + "'tau' must be positive", + id="tau_positive", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, eps=float("nan")), + "'eps' must be finite", + id="eps_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, eps=0.0), + "'eps' must be positive", + id="eps_positive", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, num_epochs=0), + "'num_epochs' must be positive", + id="num_epochs", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, learning_rate=float("nan")), + "'learning_rate' must be finite", + id="learning_rate_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, learning_rate=0.0), + "'learning_rate' must be positive", + id="learning_rate_positive", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, weight_decay=float("nan")), + "'weight_decay' must be finite", + id="weight_decay_finite", + ), + pytest.param( + lambda: VilLainEnricher(num_features=3, weight_decay=-0.1), + "'weight_decay' must be non-negative", + id="weight_decay_non_negative", + ), + ], +) +def test_villain_node_enricher_rejects_invalid_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_villain_node_enricher_returns_empty_features_when_no_nodes() -> None: hyperedge_index = torch.zeros((2, 0), dtype=torch.long) enricher = VilLainEnricher(num_features=3) @@ -313,6 +497,94 @@ def test_villain_node_enricher_returns_empty_features_when_no_nodes() -> None: assert result.device == hyperedge_index.device +@pytest.mark.parametrize( + ("build_invalid_enricher", "expected_message"), + [ + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=0), + "'num_features' must be positive", + id="features", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, num_nodes=-1), + "'num_nodes' must be non-negative", + id="num_nodes", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, num_hyperedges=-1), + "'num_hyperedges' must be non-negative", + id="num_hyperedges", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, labels_per_subspace=1), + "'labels_per_subspace' must be at least 2", + id="labels_per_subspace", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, training_steps=0), + "'training_steps' must be positive", + id="training_steps", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, generation_steps=0), + "'generation_steps' must be positive", + id="generation_steps", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, tau=float("nan")), + "'tau' must be finite", + id="tau_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, tau=0.0), + "'tau' must be positive", + id="tau_positive", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, eps=float("nan")), + "'eps' must be finite", + id="eps_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, eps=0.0), + "'eps' must be positive", + id="eps_positive", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, num_epochs=0), + "'num_epochs' must be positive", + id="num_epochs", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, learning_rate=float("nan")), + "'learning_rate' must be finite", + id="learning_rate_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, learning_rate=0.0), + "'learning_rate' must be positive", + id="learning_rate_positive", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, weight_decay=float("nan")), + "'weight_decay' must be finite", + id="weight_decay_finite", + ), + pytest.param( + lambda: VilLainHyperedgeAttrsEnricher(num_features=3, weight_decay=-0.1), + "'weight_decay' must be non-negative", + id="weight_decay_non_negative", + ), + ], +) +def test_villain_hyperedge_attrs_enricher_rejects_invalid_params( + build_invalid_enricher: Callable[[], object], + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=re.escape(expected_message)): + build_invalid_enricher() + + def test_villain_hyperedge_attrs_enricher_returns_empty_attrs_when_no_hyperedges() -> None: hyperedge_index = torch.zeros((2, 0), dtype=torch.long) enricher = VilLainHyperedgeAttrsEnricher(num_features=3) diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index 07f77691..6658d888 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -202,6 +202,30 @@ def test_transform_node_attrs_adds_padding_zero_when_attr_keys_padding(): assert torch.allclose(result, torch.tensor([0.0, 2.5])) +def test_process_hypergraph_rejects_duplicate_node_ids(): + hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "0"}], + hyperedges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + with pytest.raises(ValueError, match="HIF node IDs must be unique"): + HIFProcessor.process_hypergraph(hypergraph) + + +def test_process_hypergraph_rejects_incidences_with_unknown_node_ids(): + hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}], + hyperedges=[{"edge": "0"}], + incidences=[{"node": "missing", "edge": "0"}], + ) + + with pytest.raises(ValueError, match="Incidence references unknown node id"): + HIFProcessor.process_hypergraph(hypergraph) + + def test_load_from_url_rejects_invalid_url(): with pytest.raises(ValueError, match="Invalid URL"): HIFLoader.load_from_url("not-a-url") diff --git a/hyperbench/tests/data/negative_sampler_test.py b/hyperbench/tests/data/negative_sampler_test.py index ee777c25..06e68ba4 100644 --- a/hyperbench/tests/data/negative_sampler_test.py +++ b/hyperbench/tests/data/negative_sampler_test.py @@ -33,10 +33,32 @@ def mock_hdata_no_attr() -> HData: hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), hyperedge_attr=None, num_nodes=3, + num_hyperedges=2, + ) + + +@pytest.fixture +def mock_hdata_with_weights() -> HData: + return HData( + x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + hyperedge_index=torch.tensor([[0, 1, 2], [0, 1, 2]]), + hyperedge_weights=torch.tensor([0.5, 0.7, 0.9]), + num_nodes=3, num_hyperedges=3, ) +@pytest.fixture +def mock_hdata_no_weights() -> HData: + return HData( + x=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), + hyperedge_weights=None, + num_nodes=3, + num_hyperedges=2, + ) + + @pytest.fixture def mock_clique_hdata() -> HData: return HData( @@ -101,7 +123,7 @@ def test_random_negative_sampler_sample_too_many_nodes(mock_hdata_with_attr): sampler.sample(mock_hdata_with_attr) -def test_random_negative_sampler_with_edge_attr(mock_hdata_with_attr): +def test_random_negative_sampler_with_hyperedge_attr(mock_hdata_with_attr): sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=2) result = sampler.sample(mock_hdata_with_attr) @@ -118,7 +140,7 @@ def test_random_negative_sampler_with_edge_attr(mock_hdata_with_attr): assert result.hyperedge_attr.shape[0] == 2 -def test_random_negative_sampler_sample_no_edge_attr(mock_hdata_no_attr): +def test_random_negative_sampler_sample_no_hyperedge_attr(mock_hdata_no_attr): sampler = RandomNegativeSampler(num_negative_samples=1, num_nodes_per_sample=2) result = sampler.sample(mock_hdata_no_attr) @@ -128,10 +150,43 @@ def test_random_negative_sampler_sample_no_edge_attr(mock_hdata_no_attr): assert ( result.hyperedge_index.shape[1] == 2 ) # 1 negative hyperedge * 2 nodes per negative hyperedge - assert 3 in result.hyperedge_index[1] # New hyperedge ID (3) should be present + assert 2 in result.hyperedge_index[1] # New hyperedge ID (2) should be present assert result.hyperedge_attr is None +def test_random_negative_sampler_with_hyperedge_weights(mock_hdata_with_weights): + sampler = RandomNegativeSampler(num_negative_samples=2, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_with_weights) + + assert result.num_hyperedges == 2 + assert result.x.shape[0] <= mock_hdata_with_weights.x.shape[0] + assert result.hyperedge_index.shape[0] == 2 + assert ( + result.hyperedge_index.shape[1] == 4 + ) # 2 negative hyperedges * 2 nodes per negative hyperedge + assert ( + 3 in result.hyperedge_index[1] and 4 in result.hyperedge_index[1] + ) # New hyperedge IDs (3, 4) should be present + + # Weights for new hyperedges are added only if the sampler is configured + # with a hyperedge_weights_enricher, otherwise they default to None + assert result.hyperedge_weights is None + + +def test_random_negative_sampler_sample_no_hyperedge_weights(mock_hdata_no_weights): + sampler = RandomNegativeSampler(num_negative_samples=1, num_nodes_per_sample=2) + result = sampler.sample(mock_hdata_no_weights) + + assert result.num_hyperedges == 1 + assert result.x.shape[0] <= mock_hdata_no_weights.x.shape[0] + assert result.hyperedge_index.shape[0] == 2 + assert ( + result.hyperedge_index.shape[1] == 2 + ) # 1 negative hyperedge * 2 nodes per negative hyperedge + assert 2 in result.hyperedge_index[1] # New hyperedge ID (2) should be present + assert result.hyperedge_weights is None + + def test_random_negative_sampler_sample_with_seed_is_reproducible(mock_hdata_with_attr): sampler = RandomNegativeSampler(num_negative_samples=3, num_nodes_per_sample=2) diff --git a/hyperbench/tests/data/splitter_test.py b/hyperbench/tests/data/splitter_test.py index 1310c42a..d8633650 100644 --- a/hyperbench/tests/data/splitter_test.py +++ b/hyperbench/tests/data/splitter_test.py @@ -2,6 +2,7 @@ import pytest import re +from typing import Any, cast from hyperbench.data import HyperedgeIDSplitter from hyperbench.types import HData @@ -63,6 +64,26 @@ def test_split_uses_cumulative_floor_boundaries_and_last_split_absorbs_remainder assert final_ratios == [0.4, 0.2, 0.4] +@pytest.mark.parametrize( + "ratios, expected_exception, expected_message", + [ + pytest.param([], ValueError, "'ratios' cannot be empty.", id="empty"), + pytest.param([0.5, 0.0, 0.5], ValueError, "'ratios' must be positive, got 0.0.", id="zero"), + pytest.param( + [0.5, float("inf")], ValueError, "'ratios' must be finite, got inf.", id="infinite" + ), + ], +) +def test_split_validates_ratio_values( + mock_hdata_five_hyperedges, ratios, expected_exception, expected_message +): + hyperedge_ids = torch.arange(5) + splitter = HyperedgeIDSplitter(mock_hdata_five_hyperedges) + + with pytest.raises(expected_exception, match=re.escape(expected_message)): + splitter.split(hyperedge_ids, cast(Any, ratios)) + + def test_ensure_split_covers_all_nodes_moves_best_covering_hyperedge_into_first_split(): x = torch.ones((4, 1), dtype=torch.float) hyperedge_index = torch.tensor( @@ -89,6 +110,25 @@ def test_ensure_split_covers_all_nodes_moves_best_covering_hyperedge_into_first_ assert final_ratios == [0.6, 0.4] +def test_ensure_split_covers_all_nodes_rejects_empty_splits(mock_hdata_five_hyperedges): + splitter = HyperedgeIDSplitter(mock_hdata_five_hyperedges) + + with pytest.raises(ValueError, match="'hyperedge_ids_by_split' cannot be empty"): + splitter.ensure_split_covers_all_nodes([]) + + +def test_ensure_split_covers_all_nodes_rejects_invalid_split_idx( + mock_hdata_five_hyperedges, +): + splitter = HyperedgeIDSplitter(mock_hdata_five_hyperedges) + + with pytest.raises(ValueError, match="split_idx must reference an existing split"): + splitter.ensure_split_covers_all_nodes( + [torch.tensor([0], dtype=torch.long)], + split_idx=1, + ) + + def test_ensure_split_covers_all_nodes_raises_when_a_node_is_missing_from_hypergraph(): x = torch.ones((4, 1), dtype=torch.float) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) diff --git a/hyperbench/tests/train/latex_logger_test.py b/hyperbench/tests/train/latex_logger_test.py index 06b9efae..367f3fb9 100644 --- a/hyperbench/tests/train/latex_logger_test.py +++ b/hyperbench/tests/train/latex_logger_test.py @@ -1,8 +1,7 @@ import pytest -from textwrap import dedent - -from hyperbench.train.latex_logger import ( +from textwrap import dedent +from hyperbench.train import ( LaTexTableConfig, LaTexTableLogger, colorize_metric_value, @@ -35,6 +34,17 @@ def test_latex_logger_basics(tmp_path, mock_option_configs): assert logger.experiment_name == "exp1" +def test_latex_logger_rejects_negative_precision(tmp_path, mock_option_configs): + with pytest.raises(ValueError, match="'precision' must be non-negative"): + LaTexTableLogger( + save_dir=str(tmp_path), + model_name="model_a", + experiment_name="negative_precision", + precision=-1, + options=mock_option_configs, + ) + + def test_latex_logger_log_hyperparams_is_noop(tmp_path, mock_option_configs): logger = LaTexTableLogger( @@ -122,6 +132,17 @@ def test_save_comparison_tables_no_val_results(tmp_path, mock_option_configs): assert (tmp_path / "comparison" / "test.tex").exists() +def test_colorize_metric_value_rejects_invalid_sort_order(): + with pytest.raises(ValueError, match="'sort_order' must be 'asc' or 'des'"): + colorize_metric_value( + metric="test_auc", + value=0.8, + text="0.8000", + metric_bounds={"test_auc": (0.1, 0.9)}, + sort_order="invalid", + ) + + def test_save_comparison_tables_only_val_results(tmp_path, mock_option_configs): logger = LaTexTableLogger( save_dir=str(tmp_path), diff --git a/hyperbench/tests/train/markdown_logger_test.py b/hyperbench/tests/train/markdown_logger_test.py index d619b9f1..f73c71ed 100644 --- a/hyperbench/tests/train/markdown_logger_test.py +++ b/hyperbench/tests/train/markdown_logger_test.py @@ -1,7 +1,7 @@ -from textwrap import dedent import pytest -from hyperbench.train.markdown_logger import MarkdownTableLogger +from textwrap import dedent +from hyperbench.train import MarkdownTableLogger def test_markdown_table_logger_basic_functions(tmp_path): @@ -21,6 +21,16 @@ def test_markdown_table_logger_basic_functions(tmp_path): assert logger.experiment_name == experiment_name +def test_markdown_table_logger_rejects_negative_precision(tmp_path): + with pytest.raises(ValueError, match="'precision' must be non-negative"): + MarkdownTableLogger( + save_dir=str(tmp_path), + model_name="model_a", + experiment_name="negative_precision", + precision=-1, + ) + + def test_log_hyperparams_is_noop(tmp_path): experiment_name = "exp_hparams" diff --git a/hyperbench/tests/train/trainer_test.py b/hyperbench/tests/train/trainer_test.py index 6d79627f..f7d2bb7d 100644 --- a/hyperbench/tests/train/trainer_test.py +++ b/hyperbench/tests/train/trainer_test.py @@ -1,5 +1,5 @@ -import re import pytest +import re from unittest.mock import MagicMock, patch from hyperbench.train import MultiModelTrainer @@ -50,6 +50,11 @@ def test_trainer_initialization( assert config.trainer is not None +def test_trainer_initialization_rejects_invalid_tensorboard_port(mock_model_configs): + with pytest.raises(ValueError, match="'tensorboard_port' must be non-negative"): + MultiModelTrainer(mock_model_configs, tensorboard_port=-1) + + @patch("hyperbench.train.trainer.L.Trainer") @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -70,6 +75,20 @@ def test_trainer_initialization_with_initialized_trainer( assert config.trainer is not None +@patch("hyperbench.train.trainer.L.Trainer") +@patch("hyperbench.train.trainer.CSVLogger") +@patch("hyperbench.train.trainer.MarkdownTableLogger") +@patch("hyperbench.train.trainer.LaTexTableLogger") +def test_trainer_initialization_with_no_models( + mock_latex_logger_cls, + mock_md_logger_cls, + mock_csv_logger_cls, + mock_trainer_cls, +): + with pytest.raises(ValueError, match=re.escape("'model_configs' cannot be empty.")): + MultiModelTrainer(model_configs=[]) + + @patch("hyperbench.train.trainer.L.Trainer") @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -152,14 +171,6 @@ def test_models_property_returns_models( assert len(models) == len(mock_model_configs) -@patch("hyperbench.train.trainer.L.Trainer") -def test_models_property_returns_empty_when_no_models(mock_trainer_cls): - multi_model_trainer = MultiModelTrainer([]) - models = multi_model_trainer.models - - assert len(models) == 0 - - @patch("hyperbench.train.trainer.L.Trainer") @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -273,22 +284,6 @@ def test_fit_all_calls_fit( config.trainer.fit.assert_called_once() -@patch("hyperbench.train.trainer.L.Trainer") -@patch("hyperbench.train.trainer.CSVLogger") -@patch("hyperbench.train.trainer.MarkdownTableLogger") -@patch("hyperbench.train.trainer.LaTexTableLogger") -def test_fit_all_with_no_models( - mock_latex_logger_cls, - mock_md_logger_cls, - mock_csv_logger_cls, - mock_trainer_cls, -): - multi_model_trainer = MultiModelTrainer([]) - - with pytest.raises(ValueError, match=re.escape("No models to fit.")): - multi_model_trainer.fit_all(verbose=False) - - @patch("hyperbench.train.trainer.L.Trainer", return_value=None) @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") @@ -424,22 +419,6 @@ def test_test_all_calls_test_and_returns_results( config.trainer.test.assert_called_once() -@patch("hyperbench.train.trainer.L.Trainer") -@patch("hyperbench.train.trainer.CSVLogger") -@patch("hyperbench.train.trainer.MarkdownTableLogger") -@patch("hyperbench.train.trainer.LaTexTableLogger") -def test_test_all_with_no_models( - mock_latex_logger_cls, - mock_md_logger_cls, - mock_csv_logger_cls, - mock_trainer_cls, -): - multi_model_trainer = MultiModelTrainer([]) - - with pytest.raises(ValueError, match=re.escape("No models to test.")): - multi_model_trainer.test_all(verbose=False) - - @patch("hyperbench.train.trainer.L.Trainer", return_value=None) @patch("hyperbench.train.trainer.CSVLogger") @patch("hyperbench.train.trainer.MarkdownTableLogger") diff --git a/hyperbench/tests/types/graph_test.py b/hyperbench/tests/types/graph_test.py index 7d9e9b7a..8463eafa 100644 --- a/hyperbench/tests/types/graph_test.py +++ b/hyperbench/tests/types/graph_test.py @@ -1,5 +1,7 @@ import re +from collections.abc import Callable + import pytest import torch @@ -696,6 +698,71 @@ def test_get_sparse_adjacency_matrix_empty_edge_index(): assert torch.all(dense_adj_matrix == 0) +@pytest.mark.parametrize( + "call_with_num_nodes", + [ + pytest.param( + lambda edge_index, num_nodes: edge_index.add_selfloops(num_nodes=num_nodes), + id="add_selfloops", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_adjacency_matrix( + num_nodes=num_nodes + ), + id="get_sparse_adjacency_matrix", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_identity_matrix( + num_nodes=num_nodes + ), + id="get_sparse_identity_matrix", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_normalized_degree_matrix( + num_nodes=num_nodes + ), + id="get_sparse_normalized_degree_matrix", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_normalized_laplacian( + num_nodes=num_nodes + ), + id="get_sparse_normalized_laplacian", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.get_sparse_normalized_gcn_laplacian( + num_nodes=num_nodes + ), + id="get_sparse_normalized_gcn_laplacian", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.remove_duplicate_edges(num_nodes=num_nodes), + id="remove_duplicate_edges", + ), + pytest.param( + lambda edge_index, num_nodes: edge_index.to_undirected(num_nodes=num_nodes), + id="to_undirected", + ), + ], +) +@pytest.mark.parametrize( + ("num_nodes", "expected_message"), + [ + pytest.param(-1, "'num_nodes' must be non-negative", id="negative"), + pytest.param(1, "'num_nodes' is too small for the edge index", id="too_small"), + ], +) +def test_edge_index_methods_reject_invalid_num_nodes( + call_with_num_nodes: Callable[[EdgeIndex, int], object], + num_nodes: int, + expected_message: str, +): + edge_index = EdgeIndex(torch.tensor([[0, 1], [1, 2]], dtype=torch.long)) + + with pytest.raises(ValueError, match=expected_message): + call_with_num_nodes(edge_index, num_nodes) + + @pytest.mark.parametrize( "edge_index, num_nodes, expected_entries", [ diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index a519e3ae..dd8d5ca2 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -336,7 +336,7 @@ def test_init_validates_input_values(kwargs, expected_message): "hyperedge_index": torch.empty((2, 0), dtype=torch.long), "num_nodes": -1, }, - "num_nodes must be non-negative, got -1.", + "'num_nodes' must be non-negative, got -1.", id="negative_num_nodes", ), pytest.param( @@ -345,7 +345,7 @@ def test_init_validates_input_values(kwargs, expected_message): "hyperedge_index": torch.empty((2, 0), dtype=torch.long), "num_hyperedges": -1, }, - "num_hyperedges must be non-negative, got -1.", + "'num_hyperedges' must be non-negative, got -1.", id="negative_num_hyperedges", ), ], @@ -567,7 +567,7 @@ def test_hdata_to_mps_handles_none_hyperedge_attr(mock_hdata): def test_cat_same_node_space_raises_on_empty_list(): - with pytest.raises(ValueError, match=re.escape("At least one instance is required.")): + with pytest.raises(ValueError, match=re.escape("'hdatas' cannot be empty.")): HData.cat_same_node_space([]) @@ -1965,6 +1965,15 @@ def test_remove_hyperedges_with_fewer_than_k_nodes_keeps_none_hyperedge_attr(): assert result.hyperedge_attr is None +def test_remove_hyperedges_with_fewer_than_k_nodes_rejects_invalid_k(): + x = torch.randn(2, 1) + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + hdata = HData(x=x, hyperedge_index=hyperedge_index) + + with pytest.raises(ValueError, match="'k' must be positive"): + hdata.remove_hyperedges_with_fewer_than_k_nodes(k=0) + + def test_remove_hyperedges_with_fewer_than_k_nodes_handles_none_global_node_ids(): x = torch.randn(5, 2) hyperedge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 1, 1]]) diff --git a/hyperbench/tests/types/hypergraph_test.py b/hyperbench/tests/types/hypergraph_test.py index d63b24a8..be8b7725 100644 --- a/hyperbench/tests/types/hypergraph_test.py +++ b/hyperbench/tests/types/hypergraph_test.py @@ -1,8 +1,7 @@ -import json -import re - import pytest import torch +import json +import re from typing import Any, cast from unittest.mock import patch @@ -669,6 +668,13 @@ def test_get_clique_expansion_adjacency(hyperedge_index_tensor, num_nodes, expec assert result == expected_adjacency +def test_get_clique_expansion_adjacency_rejects_num_nodes_too_small(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 2], [0, 0]], dtype=torch.long)) + + with pytest.raises(ValueError, match="'num_nodes' is too small for the hyperedge index"): + hyperedge_index.get_clique_expansion_adjacency_list(num_nodes=2) + + @pytest.mark.parametrize( "x, hyperedge_index, with_mediators, expected_num_edges", [ @@ -1037,6 +1043,13 @@ def test_remove_hyperedges_with_fewer_than_k_nodes_returns_self(): assert result is hyperedge_index +def test_remove_hyperedges_with_fewer_than_k_nodes_rejects_invalid_k(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + + with pytest.raises(ValueError, match="'k' must be positive"): + hyperedge_index.remove_hyperedges_with_fewer_than_k_nodes(0) + + @pytest.mark.parametrize( "dropout", [ @@ -1294,18 +1307,42 @@ def test_get_sparse_incidence_matrix_sums_duplicate_incidences_when_coalesced(): assert torch.allclose(incidence_matrix.to_dense(), expected_incidence_matrix, atol=1e-6) -def test_get_sparse_incidence_matrix_rejects_explicit_num_nodes_too_small(): - hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2], [0, 0, 0]], dtype=torch.long)) +@pytest.mark.parametrize( + ("kwargs", "expected_message"), + [ + pytest.param({"num_nodes": 3}, "'num_nodes' is too small", id="nodes"), + pytest.param( + {"num_hyperedges": 1}, + "'num_hyperedges' is too small", + id="hyperedges", + ), + ], +) +def test_get_sparse_incidence_matrix_rejects_explicit_dimensions_too_small( + kwargs, expected_message +): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long)) - with pytest.raises(ValueError, match="num_nodes is too small"): - hyperedge_index.get_sparse_incidence_matrix(num_nodes=2) + with pytest.raises(ValueError, match=expected_message): + hyperedge_index.get_sparse_incidence_matrix(**kwargs) -def test_get_sparse_incidence_matrix_rejects_explicit_num_hyperedges_too_small(): - hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long)) +@pytest.mark.parametrize( + ("kwargs", "expected_message"), + [ + pytest.param({"num_nodes": -1}, "'num_nodes' must be non-negative", id="nodes"), + pytest.param( + {"num_hyperedges": -1}, + "'num_hyperedges' must be non-negative", + id="hyperedges", + ), + ], +) +def test_get_sparse_incidence_matrix_rejects_negative_dimensions(kwargs, expected_message): + hyperedge_index = HyperedgeIndex(torch.zeros((2, 0), dtype=torch.long)) - with pytest.raises(ValueError, match="num_hyperedges is too small"): - hyperedge_index.get_sparse_incidence_matrix(num_hyperedges=1) + with pytest.raises(ValueError, match=expected_message): + hyperedge_index.get_sparse_incidence_matrix(**kwargs) def test_get_sparse_symnormalized_node_degree_matrix_is_expected_diagonal(): @@ -1384,6 +1421,30 @@ def test_get_sparse_normalized_node_degree_matrix_zeroes_isolated_nodes_for_nega assert torch.allclose(node_degree_matrix.to_dense(), torch.diag(expected_diagonal), atol=1e-6) +def test_get_sparse_normalized_node_degree_matrix_rejects_mismatched_num_nodes(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix(num_nodes=3) + + with pytest.raises(ValueError, match="'num_nodes' must match the incidence matrix dimension"): + hyperedge_index.get_sparse_normalized_node_degree_matrix( + incidence_matrix, + power=-1, + num_nodes=2, + ) + + +def test_get_sparse_normalized_node_degree_matrix_rejects_negative_num_nodes(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix() + + with pytest.raises(ValueError, match="'num_nodes' must be non-negative"): + hyperedge_index.get_sparse_normalized_node_degree_matrix( + incidence_matrix, + power=-1, + num_nodes=-1, + ) + + def test_get_sparse_rownormalized_node_degree_matrix_is_expected_diagonal(): hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 0, 2], [0, 0, 1, 1]])) incidence_matrix = hyperedge_index.get_sparse_incidence_matrix() # shape (3,2) @@ -1438,3 +1499,28 @@ def test_get_sparse_normalized_hyperedge_degree_matrix_infers_num_hyperedges(): ) assert hyperedge_degree_matrix.shape == (2, 2) + + +def test_get_sparse_normalized_hyperedge_degree_matrix_rejects_mismatched_num_hyperedges(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix(num_hyperedges=2) + + with pytest.raises( + ValueError, + match="'num_hyperedges' must match the incidence matrix dimension", + ): + hyperedge_index.get_sparse_normalized_hyperedge_degree_matrix( + incidence_matrix, + num_hyperedges=1, + ) + + +def test_get_sparse_normalized_hyperedge_degree_matrix_rejects_negative_num_hyperedges(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1], [0, 0]], dtype=torch.long)) + incidence_matrix = hyperedge_index.get_sparse_incidence_matrix() + + with pytest.raises(ValueError, match="'num_hyperedges' must be non-negative"): + hyperedge_index.get_sparse_normalized_hyperedge_degree_matrix( + incidence_matrix, + num_hyperedges=-1, + ) diff --git a/hyperbench/train/__init__.py b/hyperbench/train/__init__.py index 390b13d9..30f3a973 100644 --- a/hyperbench/train/__init__.py +++ b/hyperbench/train/__init__.py @@ -1,6 +1,6 @@ import logging -from .latex_logger import LaTexTableLogger +from .latex_logger import LaTexTableConfig, LaTexTableLogger, colorize_metric_value from .markdown_logger import MarkdownTableLogger @@ -9,7 +9,9 @@ logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) __all__ = [ + "LaTexTableConfig", "LaTexTableLogger", "MarkdownTableLogger", "MultiModelTrainer", + "colorize_metric_value", ] diff --git a/hyperbench/train/latex_logger.py b/hyperbench/train/latex_logger.py index b2999aa9..13582d53 100644 --- a/hyperbench/train/latex_logger.py +++ b/hyperbench/train/latex_logger.py @@ -2,6 +2,7 @@ from typing import Any, ClassVar, TypedDict from collections.abc import Mapping from typing_extensions import NotRequired +from hyperbench.utils import validate_is_non_negative from lightning.pytorch.loggers import Logger @@ -35,6 +36,10 @@ def colorize_metric_value( if bounds is None: return text + normalized_sort_order = sort_order.lower() + if normalized_sort_order not in ("asc", "des"): + raise ValueError(f"'sort_order' must be 'asc' or 'des', got {sort_order!r}.") + min_metric_value, max_metric_value = bounds if max_metric_value == min_metric_value: quality = 1.0 @@ -44,7 +49,7 @@ def colorize_metric_value( ) # 0..1, low->high quality = ( (1.0 - normalized_metric_value) - if sort_order.lower() == "asc" + if normalized_sort_order == "asc" else normalized_metric_value ) @@ -107,6 +112,8 @@ def __init__( options: LaTexTableConfig | None = None, ) -> None: super().__init__() + validate_is_non_negative("precision", precision) + self.__save_dir = save_dir self.__model_name = model_name self.__experiment_name = experiment_name diff --git a/hyperbench/train/markdown_logger.py b/hyperbench/train/markdown_logger.py index ed699463..3e5e0565 100644 --- a/hyperbench/train/markdown_logger.py +++ b/hyperbench/train/markdown_logger.py @@ -4,6 +4,7 @@ from typing import Any, ClassVar from lightning.pytorch.loggers import Logger from collections.abc import Mapping +from hyperbench.utils import validate_is_non_negative class MarkdownTableLogger(Logger): @@ -35,6 +36,8 @@ def __init__( precision: int = 4, ) -> None: super().__init__() + validate_is_non_negative("precision", precision) + self.__save_dir = save_dir self.__model_name = model_name self.__experiment_name = experiment_name @@ -161,7 +164,8 @@ def __build_comparison_table( results: Mapping[str, Mapping[str, float]], precision: int = 4, ) -> str: - """Build a markdown comparison table from model results. + """ + Build a markdown comparison table from model results. Examples: Input: diff --git a/hyperbench/train/trainer.py b/hyperbench/train/trainer.py index 95987cb7..5a5031dd 100644 --- a/hyperbench/train/trainer.py +++ b/hyperbench/train/trainer.py @@ -9,8 +9,7 @@ from pathlib import Path from typing import Any -from collections.abc import Mapping -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.callbacks import Callback from lightning.pytorch.loggers import CSVLogger, Logger @@ -18,6 +17,7 @@ from lightning.pytorch.strategies import Strategy from hyperbench.data import DataLoader from hyperbench.types import CkptStrategy, ModelConfig, TestResult +from hyperbench.utils import validate_is_non_empty, validate_is_non_negative from hyperbench.train.markdown_logger import MarkdownTableLogger from hyperbench.train.latex_logger import LaTexTableLogger @@ -165,13 +165,17 @@ def __init__( auto_wait: bool = False, **kwargs, ) -> None: + self.auto_wait = auto_wait + self.__tensorboard_process: subprocess.Popen | None = None + validate_is_non_negative("tensorboard_port", tensorboard_port) + self.model_configs = model_configs + validate_is_non_empty("model_configs", self.model_configs) + self.log_dir = self.__setup_logdir(default_root_dir, experiment_name) self.auto_start_tensorboard = auto_start_tensorboard - self.auto_wait = auto_wait self.tensorboard_port = tensorboard_port - self.__tensorboard_process: subprocess.Popen | None = None for model_config in model_configs: if model_config.trainer is None: @@ -238,9 +242,6 @@ def fit_all( ckpt_path: CkptStrategy | None = None, verbose: bool = True, ) -> None: - if len(self.model_configs) < 1: - raise ValueError("No models to fit.") - for i, config in enumerate(self.model_configs): if not config.is_trainable: if verbose: @@ -281,11 +282,7 @@ def test_all( verbose: bool = True, verbose_loop: bool = True, ) -> Mapping[str, TestResult]: - if len(self.model_configs) < 1: - raise ValueError("No models to test.") - test_results: dict[str, TestResult] = {} - for i, config in enumerate(self.model_configs): if config.trainer is None: raise ValueError(f"Trainer not defined for model {config.full_model_name()}.") diff --git a/hyperbench/types/graph.py b/hyperbench/types/graph.py index def2f277..8d8b2568 100644 --- a/hyperbench/types/graph.py +++ b/hyperbench/types/graph.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from hyperbench import utils +from hyperbench.utils import validate_is_non_negative, sparse_dropout class Graph: @@ -124,7 +124,7 @@ def smoothing_with_laplacian_matrix( x: The smoothed feature matrix. Size ``(num_nodes, C)``. """ if drop_rate > 0.0: - laplacian_matrix = utils.sparse_dropout(laplacian_matrix, drop_rate) + laplacian_matrix = sparse_dropout(laplacian_matrix, drop_rate) return laplacian_matrix.matmul(x) @@ -227,6 +227,7 @@ def add_selfloops( """ if self.__edge_index.size(1) < 1: raise ValueError("Edge index must have at least one edge to add self-loops.") + self.__validate_num_nodes(num_nodes) device = self.__edge_index.device src, dest = self.__edge_index[0], self.__edge_index[1] @@ -302,6 +303,7 @@ def get_sparse_adjacency_matrix( Returns: adjacency: The sparse adjacency matrix of shape ``(num_nodes, num_nodes)``. """ + self.__validate_num_nodes(num_nodes) device = self.__edge_index.device src, dest = self.__edge_index num_nodes = self.num_nodes if num_nodes is None else num_nodes @@ -350,6 +352,7 @@ def get_sparse_identity_matrix(self, num_nodes: int | None = None) -> Tensor: Returns: identity: The sparse identity matrix I of shape ``(num_nodes, num_nodes)``. """ + self.__validate_num_nodes(num_nodes) device = self.__edge_index.device num_nodes = self.num_nodes if num_nodes is None else num_nodes @@ -387,6 +390,7 @@ def get_sparse_normalized_degree_matrix( Returns: degree_matrix: The sparse normalized degree matrix D^-1/2 of shape ``(num_nodes, num_nodes)``. """ + self.__validate_num_nodes(num_nodes) device = self.__edge_index.device num_nodes = self.num_nodes if num_nodes is None else num_nodes @@ -440,6 +444,7 @@ def get_sparse_normalized_laplacian( Returns: laplacian: The sparse symmetric normalized Laplacian matrix of shape ``(num_nodes, num_nodes)``. """ + self.__validate_num_nodes(num_nodes) self.to_undirected(with_selfloops=False) num_nodes = self.num_nodes if num_nodes is None else num_nodes @@ -479,6 +484,7 @@ def get_sparse_normalized_gcn_laplacian( Returns: laplacian: The sparse symmetrically normalized Laplacian matrix of shape ``(num_nodes, num_nodes)``. """ + self.__validate_num_nodes(num_nodes) self.to_undirected(with_selfloops=True, num_nodes=num_nodes) num_nodes = self.num_nodes if num_nodes is None else num_nodes @@ -522,6 +528,7 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: Returns: edge_index: This `EdgeIndex` instance with duplicate edges removed. """ + self.__validate_num_nodes(num_nodes) # Example: edge_index = [[0, 1, 2, 2, 0, 3, 2], # [1, 0, 3, 2, 1, 2, 2]], shape (2, |E| = 7) # -> after torch.unique(..., dim=1): @@ -576,6 +583,7 @@ def to_undirected( Returns: edge_index: This `EdgeIndex` instance converted to undirected. """ + self.__validate_num_nodes(num_nodes) device = self.__edge_index.device num_nodes = self.num_nodes if num_nodes is None else num_nodes @@ -659,3 +667,18 @@ def __validate_edge_weights(self, edge_weights: Tensor | None) -> None: "edge_weights must have the same number of entries as edge_index columns. " f"Got {edge_weights.size(0)} edge weights but {self.__edge_index.size(1)} edge columns." ) + + def __validate_num_nodes(self, num_nodes: int | None) -> None: + if num_nodes is None: + return + + validate_is_non_negative("num_nodes", num_nodes) + + if self.num_edges < 1: + return + + if self.max_node_id >= num_nodes: + raise ValueError( + "'num_nodes' is too small for the edge index. " + f"Got num_nodes={num_nodes}, but max node id is {self.max_node_id}." + ) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index 92d28254..dd4e31d2 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -14,6 +14,9 @@ is_inductive_setting, is_transductive_setting, to_0based_ids, + validate_is_non_empty, + validate_is_non_negative, + validate_is_positive, validate_node_space_setting, ) @@ -78,10 +81,12 @@ def __init__( # as each isolated node gets its own self-loop hyperedge else hyperedge_index_wrapper.num_nodes_if_isolated_exist(num_nodes=x.size(0)) ) + validate_is_non_negative("num_nodes", self.num_nodes) + self.num_hyperedges: int = ( num_hyperedges if num_hyperedges is not None else hyperedge_index_wrapper.num_hyperedges ) - self.__validate_number_of_nodes_and_hyperedges() + validate_is_non_negative("num_hyperedges", self.num_hyperedges) self.global_node_ids = ( # torch.arange is to handle isolated nodes, as they are already considered @@ -166,17 +171,8 @@ def cat_same_node_space( ValueError: If no HData instances are provided, if there are overlapping hyperedge IDs across instances, or if ``x`` and ``global_node_ids`` are not both provided when one of them is provided. """ - if len(hdatas) < 1: - raise ValueError("At least one instance is required.") - - joint_hyperedge_ids = torch.cat([hdata.hyperedge_index[1].unique() for hdata in hdatas]) - unique_joint_hyperedge_ids = joint_hyperedge_ids.unique() - if unique_joint_hyperedge_ids.size(0) != joint_hyperedge_ids.size(0): - raise ValueError( - "Overlapping hyperedge IDs found across instances. Ensure each instance uses distinct hyperedge IDs." - ) - cls.__validate_can_perform_cat_same_node_space(hdatas, x, global_node_ids) + hdata_with_largest_node_space = max(hdatas, key=lambda hdata: hdata.num_nodes) new_x = (x.clone() if x is not None else hdata_with_largest_node_space.x).clone() new_global_node_ids = ( @@ -644,6 +640,8 @@ def get_device_if_all_consistent(self) -> torch.device: return devices.pop() if len(devices) == 1 else torch.device("cpu") def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HData: + validate_is_positive("k", k) + hyperedge_index_wrapper = HyperedgeIndex( self.hyperedge_index.clone() ).remove_hyperedges_with_fewer_than_k_nodes(k) @@ -939,8 +937,7 @@ def __validate_can_perform_cat_same_node_space( x: Tensor | None, global_node_ids: Tensor | None, ) -> None: - if len(hdatas) < 1: - raise ValueError("At least one instance is required.") + validate_is_non_empty("hdatas", hdatas) if x is not None and global_node_ids is None: raise ValueError( @@ -1115,13 +1112,6 @@ def __validate_node_space_setting_value(node_space_setting: NodeSpaceSetting) -> f"got {node_space_setting!r}." ) - def __validate_number_of_nodes_and_hyperedges(self) -> None: - # Check on bool as bool is a subclass of int - if self.num_nodes < 0: - raise ValueError(f"num_nodes must be non-negative, got {self.num_nodes}.") - if self.num_hyperedges < 0: - raise ValueError(f"num_hyperedges must be non-negative, got {self.num_hyperedges}.") - def __validate_x_and_hyperedge_index_type_and_dim(self) -> None: if self.x.dim() != 2: raise ValueError(f"x must be a 2D tensor, got shape {tuple(self.x.shape)}.") diff --git a/hyperbench/types/hypergraph.py b/hyperbench/types/hypergraph.py index 6706ba0a..8ad70059 100644 --- a/hyperbench/types/hypergraph.py +++ b/hyperbench/types/hypergraph.py @@ -5,7 +5,13 @@ from itertools import combinations from torch import Tensor from typing import Any, Literal, TypeAlias -from hyperbench.utils import sparse_dropout, to_0based_ids, create_seeded_torch_generator +from hyperbench.utils import ( + create_seeded_torch_generator, + sparse_dropout, + to_0based_ids, + validate_is_non_negative, + validate_is_positive, +) from hyperbench.types.graph import EdgeIndex, Graph @@ -474,6 +480,8 @@ def get_clique_expansion_adjacency_list(self, num_nodes: int | None = None) -> l adjacency: A list where ``adjacency[node_id]`` is the set of nodes adjacent to ``node_id``. """ num_nodes = num_nodes if num_nodes is not None else self.num_nodes + self.__validate_num_nodes(num_nodes) + adjacency_list: list[set[int]] = [set() for _ in range(num_nodes)] for hyperedge_id in self.hyperedge_ids.tolist(): @@ -513,23 +521,12 @@ def get_sparse_incidence_matrix( Raises: ValueError: If the provided dimensions cannot contain the raw node or hyperedge IDs. """ - device = self.__hyperedge_index.device num_nodes = num_nodes if num_nodes is not None else self.num_nodes num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges - if self.num_incidences > 0: - max_node_id = int(self.all_node_ids.max().item()) - if max_node_id >= num_nodes: - raise ValueError( - "num_nodes is too small for the hyperedge index. " - f"Got num_nodes={num_nodes}, but max node id is {max_node_id}." - ) - max_hyperedge_id = int(self.all_hyperedge_ids.max().item()) - if max_hyperedge_id >= num_hyperedges: - raise ValueError( - "num_hyperedges is too small for the hyperedge index. " - f"Got num_hyperedges={num_hyperedges}, but max hyperedge id is {max_hyperedge_id}." - ) + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) + device = self.__hyperedge_index.device incidence_values = torch.ones(self.num_incidences, dtype=torch.float, device=device) incidence_indices = torch.stack([self.all_node_ids, self.all_hyperedge_ids], dim=0) incidence_matrix = torch.sparse_coo_tensor( @@ -557,7 +554,13 @@ def get_sparse_normalized_node_degree_matrix( degree_matrix: The sparse diagonal matrix of shape ``(num_nodes, num_nodes)``. """ device = self.__hyperedge_index.device - num_nodes = num_nodes if num_nodes is not None else self.num_nodes + num_nodes = num_nodes if num_nodes is not None else int(incidence_matrix.size(0)) + self.__validate_num_nodes(num_nodes) + self.__validate_degree_matrix_dimension( + name="num_nodes", + value=num_nodes, + expected=int(incidence_matrix.size(0)), + ) degrees = torch.sparse.sum(incidence_matrix, dim=1).to_dense() normalized_degrees = degrees.pow(power) @@ -654,7 +657,15 @@ def get_sparse_normalized_hyperedge_degree_matrix( degree_matrix: The sparse diagonal matrix D_e^-1 of shape ``(num_hyperedges, num_hyperedges)``. """ device = self.__hyperedge_index.device - num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges + num_hyperedges = ( + num_hyperedges if num_hyperedges is not None else int(incidence_matrix.size(1)) + ) + self.__validate_num_hyperedges(num_hyperedges) + self.__validate_degree_matrix_dimension( + name="num_hyperedges", + value=num_hyperedges, + expected=int(incidence_matrix.size(1)), + ) # Example: hyperedge_index = [[0, 1, 2, 0], # [0, 0, 0, 1]] @@ -709,6 +720,8 @@ def get_sparse_hgnn_smoothing_matrix( """ num_nodes = num_nodes if num_nodes is not None else self.num_nodes num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) incidence_matrix = self.get_sparse_incidence_matrix(num_nodes, num_hyperedges) node_degree_matrix = self.get_sparse_symnormalized_node_degree_matrix( @@ -755,6 +768,8 @@ def get_sparse_hgnnp_smoothing_matrix( """ num_nodes = num_nodes if num_nodes is not None else self.num_nodes num_hyperedges = num_hyperedges if num_hyperedges is not None else self.num_hyperedges + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) incidence_matrix = self.get_sparse_incidence_matrix(num_nodes, num_hyperedges) node_degree_matrix = self.get_sparse_rownormalized_node_degree_matrix( @@ -815,6 +830,9 @@ def reduce_to_edge_index_on_clique_expansion( Returns: edge_index: The edge index of the clique-expanded graph. Size ``(2, |E'|)``. """ + self.__validate_num_nodes(num_nodes) + self.__validate_num_hyperedges(num_hyperedges) + incidence_matrix = self.get_sparse_incidence_matrix( num_nodes=num_nodes, num_hyperedges=num_hyperedges, @@ -961,6 +979,8 @@ def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HyperedgeIndex: Returns: hyperedge_index: A new `HyperedgeIndex` instance with hyperedges containing fewer than k nodes. """ + validate_is_positive("k", k) + _, idx_to_hyperedge_id, num_nodes_per_hyperedge = torch.unique( self.all_hyperedge_ids, return_inverse=True, @@ -1002,3 +1022,41 @@ def to_0based( self.__hyperedge_index[1] = to_0based_ids(self.all_hyperedge_ids, hyperedge_ids_to_rebase) return self + + def __validate_num_hyperedges(self, num_hyperedges: int | None) -> None: + if num_hyperedges is None: + return + validate_is_non_negative("num_hyperedges", num_hyperedges) + + if self.all_hyperedge_ids.numel() < 1: + return + + max_hyperedge_id = int(self.all_hyperedge_ids.max().item()) + if max_hyperedge_id >= num_hyperedges: + raise ValueError( + f"'num_hyperedges' is too small for the hyperedge index. " + f"Got num_hyperedges={num_hyperedges}, but max hyperedge id is {max_hyperedge_id}." + ) + + def __validate_num_nodes(self, num_nodes: int | None) -> None: + if num_nodes is None: + return + validate_is_non_negative("num_nodes", num_nodes) + + if self.all_node_ids.numel() < 1: + return + + max_node_id = int(self.all_node_ids.max().item()) + if max_node_id >= num_nodes: + raise ValueError( + f"'num_nodes' is too small for the hyperedge index. " + f"Got num_nodes={num_nodes}, but max node id is {max_node_id}." + ) + + def __validate_degree_matrix_dimension(self, name: str, value: int, expected: int) -> None: + validate_is_non_negative(name, value) + if value != expected: + raise ValueError( + f"'{name}' must match the incidence matrix dimension. " + f"Got {name}={value}, but expected {expected}." + ) diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 58711c04..ba40dea8 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -5,6 +5,13 @@ empty_nodefeatures, to_non_empty_edgeattr, to_0based_ids, + validate_is_between, + validate_is_finite, + validate_is_finite_when_provided, + validate_is_non_empty, + validate_is_non_negative, + validate_is_positive, + validate_split_ratios, ) from .hif_utils import ( @@ -88,7 +95,14 @@ "validate_hif_data", "validate_hif_json", "validate_http_url", + "validate_is_between", + "validate_is_finite", + "validate_is_finite_when_provided", + "validate_is_non_empty", + "validate_is_non_negative", + "validate_is_positive", "validate_node_space_setting", + "validate_split_ratios", "write_dataset_to_disk_as_zst", "write_zst_file_to_disk", ] diff --git a/hyperbench/utils/data_utils.py b/hyperbench/utils/data_utils.py index e94fefaa..a929453c 100644 --- a/hyperbench/utils/data_utils.py +++ b/hyperbench/utils/data_utils.py @@ -1,5 +1,7 @@ +import math import torch +from collections.abc import Sequence from torch import Tensor @@ -53,3 +55,54 @@ def to_0based_ids(original_ids: Tensor, ids_to_rebase: Tensor | None = None) -> ids_to_keep = original_ids[keep_mask] sorted_unique_ids_to_rebase = ids_to_rebase.unique(sorted=True) return torch.searchsorted(sorted_unique_ids_to_rebase, ids_to_keep) + + +def validate_is_between( + name: str, + value: int | float, + min_value: int | float, + max_value: int | float, +) -> None: + if not math.isfinite(value) or value < min_value or value > max_value: + raise ValueError(f"{name!r} must be between {min_value} and {max_value}, got {value}.") + + +def validate_is_finite(name: str, value: int | float) -> None: + if not math.isfinite(value): + raise ValueError(f"{name!r} must be finite, got {value}.") + + +def validate_is_finite_when_provided(name: str, value: int | float | None) -> None: + if value is not None and not math.isfinite(value): + raise ValueError(f"{name!r} must be finite when provided, got {value}.") + + +def validate_is_non_negative(name: str, value: int | float) -> None: + if value < 0: + raise ValueError(f"{name!r} must be non-negative, got {value}.") + + +def validate_is_positive(name: str, value: int | float) -> None: + if value <= 0: + raise ValueError(f"{name!r} must be positive, got {value}.") + + +def validate_is_non_empty(name: str, value: Sequence) -> None: + if len(value) < 1: + raise ValueError(f"{name!r} cannot be empty.") + + +def validate_split_ratios(ratios: list[int | float]) -> None: + validate_is_non_empty("ratios", ratios) + + for ratio in ratios: + validate_is_finite("ratios", ratio) + validate_is_positive("ratios", ratio) + + # Allow small imprecision in sum of ratios, but raise error if it's significant + # Example: ratios = [0.8, 0.1, 0.1] -> sum = 1.0 (valid) + # ratios = [0.8, 0.1, 0.05] -> sum = 0.95 (invalid, raises ValueError) + # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) + ratio_sum = float(sum(ratios)) + if abs(ratio_sum - 1.0) > 1e-6: + raise ValueError(f"'ratios' must sum to 1.0, got {ratio_sum}.") From 3004edb88ed943a3262703a82d83a572bf14266a Mon Sep 17 00:00:00 2001 From: Tiziano Date: Wed, 27 May 2026 14:39:06 +0200 Subject: [PATCH 13/15] fix: add remaining validation checks --- AGENTS.md | 1 + hyperbench/data/dataset.py | 6 ++-- hyperbench/data/splitter.py | 10 +++--- hyperbench/nn/aggregator.py | 1 - hyperbench/tests/types/hypergraph_test.py | 16 ++++++++-- hyperbench/types/graph.py | 39 ++++++++++++++--------- hyperbench/types/hdata.py | 6 ++-- hyperbench/types/hypergraph.py | 26 ++++++++++++--- 8 files changed, 71 insertions(+), 34 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 61b76386..2d0bc161 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -59,6 +59,7 @@ Write `HyperBench` when referring to the project, repository, organization, or p - `make check` - **Typing:** Add and preserve type annotations. Run `make typecheck` for changes that touch typed code. - **Imports:** Keep imports at module top level unless a delayed import is necessary. Use `TYPE_CHECKING` guards for type-only or heavyweight imports. +- **Validation:** Use explicit validation functions for argument checks. Avoid `assert` and do not validate types at runtime. - **Runtime checks:** Do not use `assert` for library-facing validation. Raise explicit exceptions instead. - **Public APIs:** Avoid changing public signatures without a clear reason and matching docs/tests updates. - **Scope:** Keep changes narrow. Do not mix behavioral edits with unrelated refactors. diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 33e25f1c..677609f0 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -195,7 +195,8 @@ def enrich_hyperedge_attr( enricher: HyperedgeEnricher, enrichment_mode: EnrichmentMode | None = None, ) -> None: - """Enrich hyperedge features using the provided hyperedge feature enricher. + """ + Enrich hyperedge features using the provided hyperedge feature enricher. Args: enricher: An instance of HyperedgeEnricher to generate structural hyperedge attributes from hypergraph topology. @@ -211,7 +212,8 @@ def enrich_hyperedge_weights( enricher: HyperedgeEnricher, enrichment_mode: EnrichmentMode | None = None, ) -> None: - """Enrich hyperedge weights using the provided hyperedge weight enricher. + """ + Enrich hyperedge weights using the provided hyperedge weight enricher. Args: enricher: An instance of HyperedgeEnricher to generate structural hyperedge weights from hypergraph topology. diff --git a/hyperbench/data/splitter.py b/hyperbench/data/splitter.py index 6b453a61..d4386dd1 100644 --- a/hyperbench/data/splitter.py +++ b/hyperbench/data/splitter.py @@ -4,7 +4,11 @@ from typing import cast from torch import Tensor from hyperbench.types import HData -from hyperbench.utils import validate_is_non_empty, validate_split_ratios +from hyperbench.utils import ( + create_seeded_torch_generator, + validate_is_non_empty, + validate_split_ratios, +) class Splitter(ABC): @@ -142,9 +146,7 @@ def get_hyperedge_ids_permutation(self, shuffle: bool | None, seed: int | None) # Shuffle hyperedge IDs if shuffle is requested, otherwise keep original order for deterministic splits if shuffle: - generator = torch.Generator(device=device) - if seed is not None: - generator.manual_seed(seed) + generator = create_seeded_torch_generator(device=device, seed=seed) random_hyperedge_ids_permutation = torch.randperm( n=num_hyperedges, diff --git a/hyperbench/nn/aggregator.py b/hyperbench/nn/aggregator.py index 7bfb6177..8be38081 100644 --- a/hyperbench/nn/aggregator.py +++ b/hyperbench/nn/aggregator.py @@ -1,7 +1,6 @@ from torch import Tensor from typing import Literal from torch_geometric.utils import scatter - from hyperbench.types import HyperedgeIndex from hyperbench.utils import maxmin_scatter diff --git a/hyperbench/tests/types/hypergraph_test.py b/hyperbench/tests/types/hypergraph_test.py index be8b7725..63cc25e5 100644 --- a/hyperbench/tests/types/hypergraph_test.py +++ b/hyperbench/tests/types/hypergraph_test.py @@ -503,6 +503,12 @@ def test_hifhypergraph_stats_returns_correct_statistics(): [2, 3], id="second_of_two_hyperedges", ), + pytest.param( + torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]), + 10, + [], + id="non_existent_hyperedge_id", + ), ], ) def test_hyperedge_index_nodes_in(hyperedge_index_tensor, hyperedge_id, expected_nodes): @@ -510,8 +516,10 @@ def test_hyperedge_index_nodes_in(hyperedge_index_tensor, hyperedge_id, expected assert hyperedge_index.nodes_in(hyperedge_id) == expected_nodes -def _edge_index_to_edge_set(edge_index: torch.Tensor) -> set[tuple[int, int]]: - return set(zip(edge_index[0].tolist(), edge_index[1].tolist(), strict=True)) +def test_hyperedge_index_nodes_in_raises_when_hyperedge_id_is_negative(): + hyperedge_index = HyperedgeIndex(torch.tensor([[0, 1, 2], [0, 0, 0]], dtype=torch.long)) + with pytest.raises(ValueError, match=re.escape("'hyperedge_id' must be non-negative, got -1")): + hyperedge_index.nodes_in(-1) @pytest.mark.parametrize( @@ -559,7 +567,9 @@ def test_reduce_with_clique_expansion_returns_expected_edges( assert result.dtype == torch.long assert result.shape[0] == 2 - assert _edge_index_to_edge_set(result) == expected_edges + + edges = set(zip(result[0].tolist(), result[1].tolist(), strict=True)) + assert edges == expected_edges def test_reduce_with_clique_expansion_matches_specialized_reducer(): diff --git a/hyperbench/types/graph.py b/hyperbench/types/graph.py index 8d8b2568..2b6890f4 100644 --- a/hyperbench/types/graph.py +++ b/hyperbench/types/graph.py @@ -3,6 +3,7 @@ import torch from torch import Tensor +from typing import cast from hyperbench.utils import validate_is_non_negative, sparse_dropout @@ -69,7 +70,8 @@ def remove_selfloops(self) -> Graph: # -> edges without self-loops = [[0, 1], # [2, 3]] no_selfloop_mask = edges_tensor[:, 0] != edges_tensor[:, 1] - self.edges = edges_tensor[no_selfloop_mask].tolist() + edges_without_selfloops: list[list[int]] = edges_tensor[no_selfloop_mask].tolist() + self.edges = edges_without_selfloops # Example: edge_weights = [0.5, 1.0, 0.8], no_selfloop_mask = [True, False, True] # -> edge_weights without self-loops = [0.5, 0.8] @@ -303,10 +305,11 @@ def get_sparse_adjacency_matrix( Returns: adjacency: The sparse adjacency matrix of shape ``(num_nodes, num_nodes)``. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) + device = self.__edge_index.device src, dest = self.__edge_index - num_nodes = self.num_nodes if num_nodes is None else num_nodes # Example: edge_index = [[0, 1, 2, 3], # [1, 0, 3, 2]] @@ -352,9 +355,10 @@ def get_sparse_identity_matrix(self, num_nodes: int | None = None) -> Tensor: Returns: identity: The sparse identity matrix I of shape ``(num_nodes, num_nodes)``. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) + device = self.__edge_index.device - num_nodes = self.num_nodes if num_nodes is None else num_nodes # Example: num_nodes = 3 # -> identity_indices = [[0, 1, 2], @@ -390,10 +394,10 @@ def get_sparse_normalized_degree_matrix( Returns: degree_matrix: The sparse normalized degree matrix D^-1/2 of shape ``(num_nodes, num_nodes)``. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) - device = self.__edge_index.device - num_nodes = self.num_nodes if num_nodes is None else num_nodes + device = self.__edge_index.device adj_matrix = self.get_sparse_adjacency_matrix( num_nodes=num_nodes, use_edge_weights=use_edge_weights @@ -444,10 +448,10 @@ def get_sparse_normalized_laplacian( Returns: laplacian: The sparse symmetric normalized Laplacian matrix of shape ``(num_nodes, num_nodes)``. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) - self.to_undirected(with_selfloops=False) - num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.to_undirected(with_selfloops=False) degree_matrix = self.get_sparse_normalized_degree_matrix(num_nodes) adj_matrix = self.get_sparse_adjacency_matrix(num_nodes) @@ -484,10 +488,10 @@ def get_sparse_normalized_gcn_laplacian( Returns: laplacian: The sparse symmetrically normalized Laplacian matrix of shape ``(num_nodes, num_nodes)``. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) - self.to_undirected(with_selfloops=True, num_nodes=num_nodes) - num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.to_undirected(with_selfloops=True, num_nodes=num_nodes) degree_matrix = self.get_sparse_normalized_degree_matrix( num_nodes=num_nodes, use_edge_weights=use_edge_weights @@ -511,7 +515,8 @@ def remove_selfloops(self) -> EdgeIndex: # -> edge_index = [[0, 2, 3], # [1, 3, 2]], shape (2, |E'| = 3) keep_mask = self.__edge_index[0] != self.__edge_index[1] - self.__edge_index = self.__edge_index[:, keep_mask] + edge_index_without_selfloops: Tensor = self.__edge_index[:, keep_mask] + self.__edge_index = edge_index_without_selfloops if self.__edge_weights is not None: self.__edge_weights = self.__edge_weights[keep_mask] return self @@ -528,7 +533,9 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: Returns: edge_index: This `EdgeIndex` instance with duplicate edges removed. """ + num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) + # Example: edge_index = [[0, 1, 2, 2, 0, 3, 2], # [1, 0, 3, 2, 1, 2, 2]], shape (2, |E| = 7) # -> after torch.unique(..., dim=1): @@ -537,7 +544,10 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: # Note: we call contiguous() to ensure that the resulting tensor is contiguous in memory, # which can improve performance for subsequent operations that require contiguous tensors. if self.__edge_weights is None: - self.__edge_index = torch.unique(self.__edge_index, dim=1).contiguous() + edge_index_without_duplicate_edges = cast( + Tensor, torch.unique(self.__edge_index, dim=1) + ) + self.__edge_index = edge_index_without_duplicate_edges.contiguous() return self # No edges to process, just ensure tensors are contiguous @@ -556,7 +566,6 @@ def remove_duplicate_edges(self, num_nodes: int | None = None) -> EdgeIndex: # -> edge_index = [[0, 1], # [1, 2]] # -> edge_weights = [3.0, 3.0] (weights of duplicate edges are summed) - num_nodes = self.num_nodes if num_nodes is None else num_nodes coalesced = torch.sparse_coo_tensor( self.__edge_index, self.__edge_weights, @@ -583,10 +592,10 @@ def to_undirected( Returns: edge_index: This `EdgeIndex` instance converted to undirected. """ - self.__validate_num_nodes(num_nodes) - device = self.__edge_index.device num_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_nodes) + device = self.__edge_index.device orig_src, orig_dest = self.__edge_index[0], self.__edge_index[1] # Encode each directed edge (u, v) as a unique scalar key u * num_nodes + v. @@ -649,7 +658,7 @@ def to_undirected( # In this way, we don't do the duplicate edge removal twice, which would be redundant and inefficient self.add_selfloops(num_nodes=num_nodes, with_duplicate_removal=False) - self.remove_duplicate_edges() + self.remove_duplicate_edges(num_nodes=num_nodes) return self diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index dd4e31d2..f09373ce 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -9,6 +9,7 @@ NodeSpaceFiller, NodeSpaceSetting, clone_optional_tensor, + create_seeded_torch_generator, empty_hyperedgeindex, empty_nodefeatures, is_inductive_setting, @@ -694,10 +695,7 @@ def shuffle(self, seed: int | None = None) -> HData: Returns: hdata: A new `HData` instance with hyperedge IDs, ``y``, and ``hyperedge_attr`` permuted. """ - generator = torch.Generator(device=self.device) - if seed is not None: - generator.manual_seed(seed) - + generator = create_seeded_torch_generator(device=self.device, seed=seed) permutation = torch.randperm(self.num_hyperedges, generator=generator, device=self.device) # permutation[new_id] = old_id, so y[permutation] puts old labels into new slots diff --git a/hyperbench/types/hypergraph.py b/hyperbench/types/hypergraph.py index 8ad70059..64c938ec 100644 --- a/hyperbench/types/hypergraph.py +++ b/hyperbench/types/hypergraph.py @@ -4,7 +4,7 @@ from itertools import combinations from torch import Tensor -from typing import Any, Literal, TypeAlias +from typing import Any, Literal, TypeAlias, cast from hyperbench.utils import ( create_seeded_torch_generator, sparse_dropout, @@ -230,6 +230,8 @@ def neighbors_of(self, node: int) -> Neighborhood: Returns: neighbors: A set of neighbor node IDs (excluding the node itself). """ + validate_is_non_negative("node", node) + neighbors: Neighborhood = set() for hyperedge in self.hyperedges: if node in hyperedge: @@ -450,7 +452,16 @@ def num_incidences(self) -> int: return self.__hyperedge_index.size(1) def nodes_in(self, hyperedge_id: int) -> list[int]: - """Return the list of node IDs that belong to the given hyperedge.""" + """ + Return the list of node IDs that belong to the given hyperedge. + + Args: + hyperedge_id: The ID of the hyperedge to query. + + Returns: + node_ids: A list of node IDs that belong to the specified hyperedge. + """ + validate_is_non_negative("hyperedge_id", hyperedge_id) return self.__hyperedge_index[0, self.__hyperedge_index[1] == hyperedge_id].tolist() def num_nodes_if_isolated_exist(self, num_nodes: int) -> int: @@ -553,7 +564,6 @@ def get_sparse_normalized_node_degree_matrix( Returns: degree_matrix: The sparse diagonal matrix of shape ``(num_nodes, num_nodes)``. """ - device = self.__hyperedge_index.device num_nodes = num_nodes if num_nodes is not None else int(incidence_matrix.size(0)) self.__validate_num_nodes(num_nodes) self.__validate_degree_matrix_dimension( @@ -562,6 +572,8 @@ def get_sparse_normalized_node_degree_matrix( expected=int(incidence_matrix.size(0)), ) + device = self.__hyperedge_index.device + degrees = torch.sparse.sum(incidence_matrix, dim=1).to_dense() normalized_degrees = degrees.pow(power) normalized_degrees[normalized_degrees == float("inf")] = 0 @@ -656,7 +668,6 @@ def get_sparse_normalized_hyperedge_degree_matrix( Returns: degree_matrix: The sparse diagonal matrix D_e^-1 of shape ``(num_hyperedges, num_hyperedges)``. """ - device = self.__hyperedge_index.device num_hyperedges = ( num_hyperedges if num_hyperedges is not None else int(incidence_matrix.size(1)) ) @@ -667,6 +678,8 @@ def get_sparse_normalized_hyperedge_degree_matrix( expected=int(incidence_matrix.size(1)), ) + device = self.__hyperedge_index.device + # Example: hyperedge_index = [[0, 1, 2, 0], # [0, 0, 0, 1]] # hyperedges 0 1 @@ -948,7 +961,10 @@ def remove_duplicate_edges(self) -> HyperedgeIndex: # Note: we need to call contiguous() after torch.unique() to ensure # the resulting tensor is contiguous in memory, which is important for efficient indexing # and further operations (e.g., searchsorted) - self.__hyperedge_index = torch.unique(self.__hyperedge_index, dim=1).contiguous() + hyperedge_index_without_duplicates = cast( + Tensor, torch.unique(self.__hyperedge_index, dim=1) + ) + self.__hyperedge_index = hyperedge_index_without_duplicates.contiguous() return self def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> HyperedgeIndex: From 20747297afdc9ceebc9b7daa68a5dd656a130c64 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Fri, 29 May 2026 16:23:18 +0200 Subject: [PATCH 14/15] fix: fix integration tests after adding validation --- .../data/enricher_integration_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hyperbench/integration_tests/data/enricher_integration_test.py b/hyperbench/integration_tests/data/enricher_integration_test.py index 6f0f7577..095e2b4d 100644 --- a/hyperbench/integration_tests/data/enricher_integration_test.py +++ b/hyperbench/integration_tests/data/enricher_integration_test.py @@ -3,15 +3,15 @@ from hyperbench.data import ( list_datasets, get_dataset_by_name, - Node2VecEnricher, FillValueHyperedgeAttrsEnricher, + LaplacianPositionalEncodingEnricher, + Node2VecEnricher, VilLainHyperedgeAttrsEnricher, VilLainEnricher, ) from hyperbench.integration_tests.common import ( split_dataset, ) -from hyperbench.data import LaplacianPositionalEncodingEnricher excluded_dataset = [ @@ -78,7 +78,7 @@ def test_n2v_node_enricher(dataset_name): walk_length=5, num_walks_per_node=2, num_negative_samples=1, - num_nodes=dataset.hdata.num_nodes, + num_nodes=to_enrich_dataset.hdata.num_nodes, num_epochs=3, learning_rate=0.01, batch_size=128, @@ -104,8 +104,8 @@ def test_villain_node_enricher(dataset_name): to_enrich_dataset.enrich_node_features( enricher=VilLainEnricher( num_features=NUM_FEATURES, - num_nodes=dataset.hdata.num_nodes, - num_hyperedges=dataset.hdata.num_hyperedges, + num_nodes=to_enrich_dataset.hdata.num_nodes, + num_hyperedges=to_enrich_dataset.hdata.num_hyperedges, labels_per_subspace=2, training_steps=2, generation_steps=4, @@ -150,8 +150,8 @@ def test_villain_hyperedge_attrs_enricher(dataset_name): to_enrich_dataset.enrich_hyperedge_attr( enricher=VilLainHyperedgeAttrsEnricher( num_features=NUM_FEATURES, - num_nodes=dataset.hdata.num_nodes, - num_hyperedges=dataset.hdata.num_hyperedges, + num_nodes=to_enrich_dataset.hdata.num_nodes, + num_hyperedges=to_enrich_dataset.hdata.num_hyperedges, labels_per_subspace=2, training_steps=2, generation_steps=4, From 74e786d95080f1aaed106f6a31e554cd4178b102 Mon Sep 17 00:00:00 2001 From: Tiziano Date: Fri, 29 May 2026 18:11:38 +0200 Subject: [PATCH 15/15] fix: remove incorrect raise in EdgeIndex.add_selfloop --- hyperbench/integration_tests/common.py | 30 ++++++-------------------- hyperbench/tests/types/graph_test.py | 13 ++--------- hyperbench/types/graph.py | 15 +++++-------- 3 files changed, 14 insertions(+), 44 deletions(-) diff --git a/hyperbench/integration_tests/common.py b/hyperbench/integration_tests/common.py index 72cc7f77..6c470164 100644 --- a/hyperbench/integration_tests/common.py +++ b/hyperbench/integration_tests/common.py @@ -1,10 +1,10 @@ -from typing import Literal -from collections.abc import Sequence import lightning as L import torch +from collections.abc import Sequence from functools import cache from pathlib import Path +from typing import Literal from torchmetrics import MetricCollection from torchmetrics.classification import ( BinaryAUROC, @@ -13,7 +13,6 @@ BinaryPrecision, BinaryRecall, ) - from hyperbench.data import ( Dataset, DataLoader, @@ -24,27 +23,12 @@ ) from hyperbench.train import MultiModelTrainer from hyperbench.types import ModelConfig, HData -from torch import Generator +from hyperbench.utils import create_seeded_torch_generator SEED = 42 -def __create_seeded_torch_generator( - device: torch.device, - seed: int | None, -) -> Generator | None: - - if seed is None: - return None - generator = Generator(device=device) - generator.manual_seed(seed) - return generator - - -generator = __create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED) - - @cache def _cached_split_dataset( sampling_strategy: SamplingStrategy, @@ -52,6 +36,7 @@ def _cached_split_dataset( node_space_setting: Literal["transductive", "inductive"] = "transductive", ) -> tuple[Dataset, Dataset, Dataset]: if dataset is None: + generator = create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED) x = torch.randn((100, 4), generator=generator) # 100 nodes with 4 features each hyperedge_index = torch.cat( # 200 hyperedges, each connecting 5 nodes [ @@ -166,27 +151,26 @@ def loaders( batch_size: int = 1, sample_full_hypergraph: bool = False, ) -> tuple[DataLoader, DataLoader, DataLoader]: - train_loader = DataLoader( train_dataset, batch_size=batch_size, sample_full_hypergraph=sample_full_hypergraph, shuffle=False, - generator=generator, + generator=create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED), ) val_loader = DataLoader( val_dataset, batch_size=batch_size, sample_full_hypergraph=sample_full_hypergraph, shuffle=False, - generator=generator, + generator=create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED), ) tests_loader = DataLoader( test_dataset, batch_size=batch_size, sample_full_hypergraph=sample_full_hypergraph, shuffle=False, - generator=generator, + generator=create_seeded_torch_generator(device=torch.device("cpu"), seed=SEED), ) return train_loader, val_loader, tests_loader diff --git a/hyperbench/tests/types/graph_test.py b/hyperbench/tests/types/graph_test.py index 8463eafa..326e38a3 100644 --- a/hyperbench/tests/types/graph_test.py +++ b/hyperbench/tests/types/graph_test.py @@ -1,10 +1,8 @@ -import re - -from collections.abc import Callable - import pytest import torch +import re +from collections.abc import Callable from hyperbench.types import EdgeIndex, Graph @@ -608,13 +606,6 @@ def test_add_selfloops_adds_unit_edge_weights(): assert weighted_edges == expected_weighted_edges -def test_add_selfloops_raises_on_empty_edge_index(): - edge_index = EdgeIndex(torch.tensor([[], []], dtype=torch.long)) - - with pytest.raises(ValueError, match="Edge index must have at least one edge"): - edge_index.add_selfloops() - - def test_add_selfloops_does_not_duplicate_selfloops(): edge_index = EdgeIndex(torch.tensor([[0, 1, 1], [1, 2, 1]])) edge_index.add_selfloops() diff --git a/hyperbench/types/graph.py b/hyperbench/types/graph.py index 2b6890f4..4e2ba40d 100644 --- a/hyperbench/types/graph.py +++ b/hyperbench/types/graph.py @@ -4,7 +4,7 @@ from torch import Tensor from typing import cast -from hyperbench.utils import validate_is_non_negative, sparse_dropout +from hyperbench.utils import sparse_dropout, validate_is_non_negative class Graph: @@ -227,9 +227,8 @@ def add_selfloops( Raises: ValueError: If the input edge index has no edges (i.e., ``shape (2, 0)``). """ - if self.__edge_index.size(1) < 1: - raise ValueError("Edge index must have at least one edge to add self-loops.") - self.__validate_num_nodes(num_nodes) + num_selfloop_nodes = self.num_nodes if num_nodes is None else num_nodes + self.__validate_num_nodes(num_selfloop_nodes) device = self.__edge_index.device src, dest = self.__edge_index[0], self.__edge_index[1] @@ -251,7 +250,6 @@ def add_selfloops( # -> dest = [1, 0, 3, 0, 1, 2, 3, 4, 5] # -> edge_index_with_selfloops = [[0, 1, 2, 0, 1, 2, 3, 4, 5], # [1, 0, 3, 0, 1, 2, 3, 4, 5]] - num_selfloop_nodes = self.num_nodes if num_nodes is None else num_nodes selfloop_indices = torch.arange(num_selfloop_nodes, device=device) src = torch.cat([src, selfloop_indices]) dest = torch.cat([dest, selfloop_indices]) @@ -451,7 +449,7 @@ def get_sparse_normalized_laplacian( num_nodes = self.num_nodes if num_nodes is None else num_nodes self.__validate_num_nodes(num_nodes) - self.to_undirected(with_selfloops=False) + self.to_undirected(with_selfloops=False, num_nodes=num_nodes) degree_matrix = self.get_sparse_normalized_degree_matrix(num_nodes) adj_matrix = self.get_sparse_adjacency_matrix(num_nodes) @@ -677,10 +675,7 @@ def __validate_edge_weights(self, edge_weights: Tensor | None) -> None: f"Got {edge_weights.size(0)} edge weights but {self.__edge_index.size(1)} edge columns." ) - def __validate_num_nodes(self, num_nodes: int | None) -> None: - if num_nodes is None: - return - + def __validate_num_nodes(self, num_nodes: int) -> None: validate_is_non_negative("num_nodes", num_nodes) if self.num_edges < 1: