From 87e2bfd401dd8251ff52c372f05bbd974780bbae Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 18:40:21 +0000 Subject: [PATCH 01/16] Add C++ infrastructure: build, format, lint, and unit test tooling --- .clang-format | 31 ++++++++++++ .clang-tidy | 37 ++++++++++++++ Makefile | 30 +++++++++-- requirements/install_cpp_deps.sh | 40 +++++++++++++++ scripts/generate_compile_commands.py | 69 ++++++++++++++++++++++++++ setup.py | 42 ++++++++++++++++ tests/unit/cpp/CMakeLists.txt | 44 ++++++++++++++++ tests/unit/cpp/infrastructure_test.cpp | 14 ++++++ 8 files changed, 302 insertions(+), 5 deletions(-) create mode 100644 .clang-format create mode 100644 .clang-tidy create mode 100644 requirements/install_cpp_deps.sh create mode 100644 scripts/generate_compile_commands.py create mode 100644 setup.py create mode 100644 tests/unit/cpp/CMakeLists.txt create mode 100644 tests/unit/cpp/infrastructure_test.cpp diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..f702e9dce --- /dev/null +++ b/.clang-format @@ -0,0 +1,31 @@ +# clang-format configuration for GiGL C++ sources. +# Run: clang-format -i (format in-place) +# clang-format --dry-run --Werror (check only) +BasedOnStyle: Google + +# Match the 100-column limit used for Python throughout the codebase. +ColumnLimit: 100 + +# 4-space indentation (Google style defaults to 2; override to match existing code). +IndentWidth: 4 +ContinuationIndentWidth: 4 + +# Align consecutive = signs in variable declarations for readability. +AlignConsecutiveAssignments: + Enabled: true + AcrossEmptyLines: false + AcrossComments: false + +# Keep short functions on one line only when they are truly trivial (getters/setters). +AllowShortFunctionsOnASingleLine: Inline + +# Never put if/else/loop bodies on the same line as the condition. +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false + +# Always break before the opening brace of a namespace, class, or function body. +BreakBeforeBraces: Attach + +# Sort #include blocks: standard library first, then project headers. +SortIncludes: CaseSensitive +IncludeBlocks: Regroup diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 000000000..2c746e6b2 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,37 @@ +# clang-tidy configuration for GiGL C++ sources. +# Run: clang-tidy -p build/compile_commands.json +# +# Checks are opt-in by prefix. Each line below enables (*) or disables (-) +# a category. More specific patterns override broader ones. + +Checks: > + clang-analyzer-*, + bugprone-use-after-move, + bugprone-incorrect-roundings, + bugprone-integer-division, + bugprone-signed-char-misuse, + bugprone-suspicious-memset-usage, + modernize-use-nullptr, + modernize-loop-convert, + modernize-use-override, + performance-unnecessary-value-param, + performance-unnecessary-copy-initialization, + readability-const-return-type, + -bugprone-easily-swappable-parameters, + -clang-analyzer-optin.cplusplus.UninitializedObject, + -modernize-use-trailing-return-type + +# Treat every enabled warning as an error so the build fails loudly. +WarningsAsErrors: "*" + +# Only apply header-file checks to GiGL-owned headers, not third-party +# (torch, pybind11) headers pulled in by includes. +HeaderFilterRegex: "gigl/.*" + +# Per-check configuration options. +CheckOptions: + # Enforce lower_case naming for local variables (matches existing style). + - key: readability-identifier-naming.LocalVariableCase + value: lower_case + - key: readability-identifier-naming.ParameterCase + value: lower_case diff --git a/Makefile b/Makefile index e15a063f3..11d8ba848 100644 --- a/Makefile +++ b/Makefile @@ -22,6 +22,7 @@ DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG?=${DOCKER_IMAGE_MAIN_CPU_NAME}:${DATE} DOCKER_IMAGE_DEV_WORKBENCH_NAME_WITH_TAG?=${DOCKER_IMAGE_DEV_WORKBENCH_NAME}:${DATE} PYTHON_DIRS:=.github/scripts examples gigl tests snapchat scripts +CPP_SOURCES:=$(shell find gigl -name "*.cpp") PY_TEST_FILES?="*_test.py" # You can override GIGL_TEST_DEFAULT_RESOURCE_CONFIG by setting it in your environment i.e. # adding `export GIGL_TEST_DEFAULT_RESOURCE_CONFIG=your_resource_config` to your shell config (~/.bashrc, ~/.zshrc, etc.) @@ -49,6 +50,7 @@ install_dev_deps: check_if_valid_env gcloud auth configure-docker us-central1-docker.pkg.dev bash ./requirements/install_py_deps.sh --dev bash ./requirements/install_scala_deps.sh + bash ./requirements/install_cpp_deps.sh uv pip install -e . uv run pre-commit install --hook-type pre-commit --hook-type pre-push @@ -94,7 +96,12 @@ unit_test_scala: clean_build_files_scala # Eventually, we should look into splitting these up. # We run `make check_format` separately instead of as a dependent make rule so that it always runs after the actual testing. # We don't want to fail the tests due to non-conformant formatting during development. -unit_test: precondition_tests unit_test_py unit_test_scala +unit_test_cpp: + cmake -S tests/unit/cpp -B build/cpp_tests + cmake --build build/cpp_tests --parallel + ctest --test-dir build/cpp_tests --output-on-failure + +unit_test: precondition_tests unit_test_py unit_test_scala unit_test_cpp check_format_py: uv run autoflake --check --config pyproject.toml ${PYTHON_DIRS} @@ -109,7 +116,10 @@ check_format_md: @echo "Checking markdown files..." uv run mdformat --check ${MD_FILES} -check_format: check_format_py check_format_scala check_format_md +check_format_cpp: + clang-format --dry-run --Werror --style=file $(CPP_SOURCES) + +check_format: check_format_py check_format_scala check_format_md check_format_cpp # Set PY_TEST_FILES= to test a specifc file. # Ex. `make integration_test PY_TEST_FILES="dataflow_test.py"` @@ -143,12 +153,19 @@ format_md: @echo "Formatting markdown files..." uv run mdformat ${MD_FILES} -format: format_py format_scala format_md +format_cpp: + clang-format -i --style=file $(CPP_SOURCES) + +format: format_py format_scala format_md format_cpp type_check: uv run mypy ${PYTHON_DIRS} --check-untyped-defs -lint_test: check_format assert_yaml_configs_parse +lint_cpp: build_cpp_extensions + uv run python scripts/generate_compile_commands.py + clang-tidy -p build/compile_commands.json $(CPP_SOURCES) + +lint_test: check_format assert_yaml_configs_parse lint_cpp @echo "Lint checks pass!" # compiles current working state of scala projects to local jars @@ -313,7 +330,10 @@ clean_build_files_scala: ( cd scala; sbt clean; find . -type d -name "target" -prune -exec rm -rf {} \; ) ( cd scala_spark35; sbt clean; find . -type d -name "target" -prune -exec rm -rf {} \; ) -clean_build_files: clean_build_files_py clean_build_files_scala +clean_build_files_cpp: + rm -rf build/ + +clean_build_files: clean_build_files_py clean_build_files_scala clean_build_files_cpp # Call to generate new proto definitions if any of the .proto files have been changed. # We intentionally rebuild *all* protos with one commmand as they should all be in sync. diff --git a/requirements/install_cpp_deps.sh b/requirements/install_cpp_deps.sh new file mode 100644 index 000000000..9a3aacbd7 --- /dev/null +++ b/requirements/install_cpp_deps.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Install C++ development tools: clang-format, clang-tidy, cmake. +# +# Usage: +# bash requirements/install_cpp_deps.sh +# +# Called by `make install_dev_deps` alongside install_py_deps.sh and +# install_scala_deps.sh. + +set -e +set -x + +is_running_on_mac() { + [ "$(uname)" == "Darwin" ] + return $? +} + +if is_running_on_mac; then + # `brew install llvm` provides clang-format and clang-tidy. + # Homebrew does not add llvm to PATH by default to avoid shadowing Apple's + # clang, so we print an instruction for the developer to do it manually. + brew install llvm cmake + LLVM_PREFIX=$(brew --prefix llvm) + set +x + echo "" + echo "NOTE: Add the LLVM bin directory to your PATH to use clang-format and clang-tidy:" + echo " export PATH=\"${LLVM_PREFIX}/bin:\$PATH\"" + echo " (Add this to your ~/.zshrc or ~/.bashrc to make it permanent.)" + echo "" + set -x +else + # Ubuntu / Debian — clang 15 is the highest version available on Ubuntu 22.04. + apt-get install -y clang-format-15 clang-tidy-15 cmake + # Register versioned binaries as the default so bare `clang-format` and + # `clang-tidy` resolve to them without callers specifying the version suffix. + update-alternatives --install /usr/bin/clang-format clang-format /usr/bin/clang-format-15 100 + update-alternatives --install /usr/bin/clang-tidy clang-tidy /usr/bin/clang-tidy-15 100 +fi + +echo "Finished installing C++ tooling" diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py new file mode 100644 index 000000000..8146f2c18 --- /dev/null +++ b/scripts/generate_compile_commands.py @@ -0,0 +1,69 @@ +"""Generate build/compile_commands.json for clang-tidy analysis of GiGL C++ extensions. + +clang-tidy needs a compilation database to resolve include paths and compiler flags. +This script derives those paths directly from the installed torch and pybind11 packages, +avoiding the need for `bear` or a separate CMake build of the extension. + +Usage:: + + uv run python scripts/generate_compile_commands.py + +Output: ``build/compile_commands.json`` (created or overwritten). +""" + +import json +import sys +import sysconfig +from pathlib import Path + + +def main() -> None: + try: + import pybind11 + import torch # noqa: F401 — imported to verify it is installed + from torch.utils.cpp_extension import include_paths as torch_include_paths + except ImportError as exc: + print( + f"Error: {exc}\n" + "Run `make build_cpp_extensions` first to ensure torch and pybind11 are available.", + file=sys.stderr, + ) + sys.exit(1) + + repo_root = Path(__file__).parent.parent.resolve() + + # Collect all include directories needed to compile the extension. + include_flags: list[str] = [] + for path in torch_include_paths(): + include_flags.append(f"-I{path}") + include_flags.append(f"-I{pybind11.get_include()}") + # Python C API headers (e.g. Python.h) required by pybind11. + include_flags.append(f"-I{sysconfig.get_path('include')}") + + cpp_sources = sorted((repo_root / "gigl").rglob("*.cpp")) + if not cpp_sources: + print("Warning: no .cpp files found under gigl/", file=sys.stderr) + + # Each entry in compile_commands.json describes how one source file is compiled. + # clang-tidy reads this to reproduce the exact compilation environment. + commands: list[dict[str, str]] = [ + { + "directory": str(repo_root), + "file": str(source), + "command": ( + f"c++ -std=c++17 -Wall -Wextra " + f"{' '.join(include_flags)} " + f"-c {source}" + ), + } + for source in cpp_sources + ] + + output = repo_root / "build" / "compile_commands.json" + output.parent.mkdir(exist_ok=True) + output.write_text(json.dumps(commands, indent=2)) + print(f"Wrote {len(commands)} entr{'y' if len(commands) == 1 else 'ies'} to {output}") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..a9678da1f --- /dev/null +++ b/setup.py @@ -0,0 +1,42 @@ +from pathlib import Path + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CppExtension + + +def find_cpp_extensions() -> list[CppExtension]: + """Auto-discover pybind11 extension modules. + + Any .cpp file anywhere under ``gigl/`` is compiled as a Python C++ + extension. The module name is derived from the file path, so the + extension is importable at the same location as its Python neighbours. + + Example:: + + gigl/distributed/cpp_extensions/ppr_forward_push.cpp + → importable as ``gigl.distributed.cpp_extensions.ppr_forward_push`` + + To add a new extension, drop a .cpp file anywhere under ``gigl/`` — + no changes to this file required. + """ + extensions = [] + for cpp_file in sorted(Path("gigl").rglob("*.cpp")): + # Convert path separators to dots and strip the .cpp suffix to get the + # fully-qualified Python module name, e.g.: + # gigl/distributed/cpp_extensions/ppr_forward_push.cpp + # → "gigl.distributed.cpp_extensions.ppr_forward_push" + module_name = ".".join(cpp_file.with_suffix("").parts) + extensions.append( + CppExtension( + name=module_name, + sources=[str(cpp_file)], + extra_compile_args=["-O3", "-std=c++17", "-Wall", "-Wextra"], + ) + ) + return extensions + + +setup( + ext_modules=find_cpp_extensions(), + cmdclass={"build_ext": BuildExtension}, +) diff --git a/tests/unit/cpp/CMakeLists.txt b/tests/unit/cpp/CMakeLists.txt new file mode 100644 index 000000000..fd55255b2 --- /dev/null +++ b/tests/unit/cpp/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.18) +project(GiGLCppTests CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# --------------------------------------------------------------------------- +# GoogleTest via FetchContent +# --------------------------------------------------------------------------- +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz + # TODO: pin URL_HASH once infrastructure is validated, e.g.: + # URL_HASH SHA256= +) +# Prevent GoogleTest from overriding the compiler's runtime on Windows +# (no-op on Linux/Mac, but required for portable CMake config). +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +# Required for add_test() to register tests with CTest. +enable_testing() + +# --------------------------------------------------------------------------- +# Auto-discover test targets +# --------------------------------------------------------------------------- +# Any file named *_test.cpp in this directory (or subdirectories) is +# automatically compiled into its own test binary and registered with CTest. +# To add a new test suite, drop a *_test.cpp file here — no changes to this +# file required. This matches the *_test.py convention used for Python tests. +file(GLOB_RECURSE TEST_SOURCES "*_test.cpp") + +foreach(test_source ${TEST_SOURCES}) + # Derive the binary name from the filename, e.g.: + # ppr_forward_push_test.cpp → ppr_forward_push_test + get_filename_component(test_name ${test_source} NAME_WE) + add_executable(${test_name} ${test_source}) + target_link_libraries(${test_name} GTest::gtest_main) + # add_test registers the binary with CTest. Each *_test binary is one + # CTest entry; GoogleTest itself reports individual TEST() results inside it. + add_test(NAME ${test_name} COMMAND ${test_name}) +endforeach() diff --git a/tests/unit/cpp/infrastructure_test.cpp b/tests/unit/cpp/infrastructure_test.cpp new file mode 100644 index 000000000..d8d252f8e --- /dev/null +++ b/tests/unit/cpp/infrastructure_test.cpp @@ -0,0 +1,14 @@ +// Placeholder C++ unit test. +// +// This file exists to verify that the GoogleTest infrastructure compiles and +// runs end-to-end. Replace or supplement it with tests for actual GiGL C++ +// code (e.g. PPRForwardPushState) as those components are added. + +#include + +// A trivial sanity-check test — if this fails, something is very wrong with +// the build environment itself. +TEST(PlaceholderTest, BasicArithmetic) { + EXPECT_EQ(1 + 1, 2); + EXPECT_NE(1 + 1, 3); +} From a18fb47e53865408a0f3828ae82d5da130228565 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 18:48:13 +0000 Subject: [PATCH 02/16] =?UTF-8?q?Remove=20pybind11=20import=20from=20gener?= =?UTF-8?q?ate=5Fcompile=5Fcommands=20=E2=80=94=20bundled=20in=20torch=20i?= =?UTF-8?q?ncludes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/generate_compile_commands.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py index 8146f2c18..05873a486 100644 --- a/scripts/generate_compile_commands.py +++ b/scripts/generate_compile_commands.py @@ -19,13 +19,12 @@ def main() -> None: try: - import pybind11 import torch # noqa: F401 — imported to verify it is installed from torch.utils.cpp_extension import include_paths as torch_include_paths except ImportError as exc: print( f"Error: {exc}\n" - "Run `make build_cpp_extensions` first to ensure torch and pybind11 are available.", + "Run `make build_cpp_extensions` first to ensure torch is available.", file=sys.stderr, ) sys.exit(1) @@ -33,10 +32,11 @@ def main() -> None: repo_root = Path(__file__).parent.parent.resolve() # Collect all include directories needed to compile the extension. + # torch_include_paths() returns the torch headers, which already bundle + # pybind11 under torch/include/pybind11/ — no separate pybind11 import needed. include_flags: list[str] = [] for path in torch_include_paths(): include_flags.append(f"-I{path}") - include_flags.append(f"-I{pybind11.get_include()}") # Python C API headers (e.g. Python.h) required by pybind11. include_flags.append(f"-I{sysconfig.get_path('include')}") From c127f2b86568b69cf7c4fac4100bdcfaf8620e04 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Wed, 25 Mar 2026 12:13:46 -0700 Subject: [PATCH 03/16] ppr seq --- .../graph_transformer/graph_transformer.py | 235 ++++++++-- gigl/transforms/graph_transformer.py | 431 +++++++++++++++--- .../graph_transformer_test.py | 205 ++++++++- .../unit/transforms/graph_transformer_test.py | 163 ++++++- 4 files changed, 918 insertions(+), 116 deletions(-) diff --git a/gigl/src/common/models/graph_transformer/graph_transformer.py b/gigl/src/common/models/graph_transformer/graph_transformer.py index 2d07b849d..4308b495f 100644 --- a/gigl/src/common/models/graph_transformer/graph_transformer.py +++ b/gigl/src/common/models/graph_transformer/graph_transformer.py @@ -11,6 +11,7 @@ replacement as the encoder in ``LinkPredictionGNN``. """ +import math from typing import Callable, Literal, Optional, cast import torch @@ -20,7 +21,10 @@ from torch import Tensor from gigl.src.common.types.graph_data import EdgeType, NodeType -from gigl.transforms.graph_transformer import heterodata_to_graph_transformer_input +from gigl.transforms.graph_transformer import ( + PPR_WEIGHT_FEATURE_NAME, + heterodata_to_graph_transformer_input, +) def _get_node_type_positional_encodings( @@ -54,6 +58,25 @@ def _get_node_type_positional_encodings( return torch.cat(pe_parts, dim=-1) +def _build_sinusoidal_sequence_position_table( + max_seq_len: int, + hid_dim: int, +) -> Tensor: + """Build a standard sinusoidal absolute position table.""" + positions = torch.arange(max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, hid_dim, 2, dtype=torch.float) * (-math.log(10000.0) / hid_dim) + ) + + position_table = torch.zeros(max_seq_len, hid_dim, dtype=torch.float) + position_table[:, 0::2] = torch.sin(positions * div_term) + if hid_dim > 1: + position_table[:, 1::2] = torch.cos( + positions * div_term[: position_table[:, 1::2].shape[1]] + ) + return position_table + + # Supported activation functions for FeedForwardNetwork _ACTIVATION_FNS = { "gelu": nn.GELU, @@ -364,7 +387,16 @@ class GraphTransformerEncoder(nn.Module): max_seq_len: Maximum sequence length for the graph-to-sequence transform. Neighborhoods are truncated to this length. hop_distance: Number of hops for neighborhood extraction in the - graph-to-sequence transform. + graph-to-sequence transform when using ``"khop"`` sequence construction. + sequence_construction_method: Sequence builder used to create tokens for + each anchor. ``"khop"`` expands the sampled graph by hop distance, + while ``"ppr"`` consumes outgoing ``"ppr"`` edges sorted by weight. + sequence_positional_encoding_type: Optional sequence-level positional + encoding applied after sequence construction. Supported values are + ``None`` and ``"sinusoidal"``. Lower-cost future extensions could + add learned absolute position embeddings here, while attention-level + options like RoPE or ALiBi would require changes inside the + attention block. dropout_rate: Dropout probability for feed-forward layers. attention_dropout_rate: Dropout probability for attention weights. should_l2_normalize_embedding_layer_output: Whether to L2 normalize @@ -373,14 +405,18 @@ class GraphTransformerEncoder(nn.Module): In ``"concat"`` mode these are concatenated to sequence features. In ``"add"`` mode they are projected to ``hid_dim`` and added to node features before sequence construction. - anchor_based_pe_attr_names: List of relative-encoding attribute names - containing sparse (N x N) matrices for anchor-relative positional - encodings. - These are used as additive attention bias for sequence keys. - pairwise_pe_attr_names: List of relative-encoding attribute names - containing sparse (N x N) matrices for pairwise relative encodings - between sequence nodes. These are used as additive attention bias - and can be combined with anchor-relative bias in the same model. + anchor_based_attention_bias_attr_names: List of anchor-relative feature + names used as additive attention bias for sequence keys. Sparse + graph-level attributes are looked up from ``data`` and the reserved + name ``"ppr_weight"`` resolves to PPR edge weights in PPR mode. + anchor_based_input_attr_names: List of anchor-relative attribute names + used as token-aligned input features. Sparse graph-level attributes + are looked up from ``data`` and ``"ppr_weight"`` resolves to PPR + edge weights in PPR mode. These are projected to ``hid_dim`` and + added to the sequence tokens after sequence construction. + pairwise_attention_bias_attr_names: List of pairwise feature names used + as additive attention bias. These must correspond to sparse + graph-level attributes on ``data``. feature_embedding_layer_dict: Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None) @@ -436,12 +472,15 @@ def __init__( num_heads: int = 2, max_seq_len: int = 128, hop_distance: int = 2, + sequence_construction_method: Literal["khop", "ppr"] = "khop", + sequence_positional_encoding_type: Optional[str] = None, dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, should_l2_normalize_embedding_layer_output: bool = False, pe_attr_names: Optional[list[str]] = None, - anchor_based_pe_attr_names: Optional[list[str]] = None, - pairwise_pe_attr_names: Optional[list[str]] = None, + anchor_based_attention_bias_attr_names: Optional[list[str]] = None, + anchor_based_input_attr_names: Optional[list[str]] = None, + pairwise_attention_bias_attr_names: Optional[list[str]] = None, feature_embedding_layer_dict: Optional[nn.ModuleDict] = None, pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", @@ -461,15 +500,68 @@ def __init__( self._out_dim = out_dim self._max_seq_len = max_seq_len self._hop_distance = hop_distance + if sequence_construction_method not in {"khop", "ppr"}: + raise ValueError( + "sequence_construction_method must be one of {'khop', 'ppr'}, " + f"got '{sequence_construction_method}'" + ) + if sequence_positional_encoding_type is not None: + sequence_positional_encoding_type = ( + sequence_positional_encoding_type.lower() + ) + if sequence_positional_encoding_type == "none": + sequence_positional_encoding_type = None + if sequence_positional_encoding_type not in {None, "sinusoidal"}: + raise ValueError( + "sequence_positional_encoding_type must be one of " + "{None, 'sinusoidal'}, " + f"got '{sequence_positional_encoding_type}'" + ) + anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] + anchor_input_attr_names = anchor_based_input_attr_names or [] + pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] + if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names: + raise ValueError( + f"'{PPR_WEIGHT_FEATURE_NAME}' is an anchor-relative feature and " + "cannot be used as pairwise attention bias." + ) + if ( + PPR_WEIGHT_FEATURE_NAME in anchor_bias_attr_names + anchor_input_attr_names + and sequence_construction_method != "ppr" + ): + raise ValueError( + "The reserved anchor-relative feature 'ppr_weight' requires " + "sequence_construction_method='ppr'." + ) + self._sequence_construction_method = sequence_construction_method + self._sequence_positional_encoding_type = sequence_positional_encoding_type self._should_l2_normalize_embedding_layer_output = ( should_l2_normalize_embedding_layer_output ) self._pe_attr_names = pe_attr_names - self._anchor_based_pe_attr_names = anchor_based_pe_attr_names - self._pairwise_pe_attr_names = pairwise_pe_attr_names + self._anchor_based_attention_bias_attr_names = ( + anchor_based_attention_bias_attr_names + ) + self._anchor_based_input_attr_names = anchor_based_input_attr_names + self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + if self._sequence_positional_encoding_type == "sinusoidal": + self.register_buffer( + "_sequence_positional_encoding_table", + _build_sinusoidal_sequence_position_table( + max_seq_len=max_seq_len, + hid_dim=hid_dim, + ), + persistent=False, + ) + else: + self.register_buffer( + "_sequence_positional_encoding_table", + None, + persistent=False, + ) # Per-node-type input projection to hid_dim (like HGT's lin_dict) self._node_projection_dict = nn.ModuleDict( @@ -483,25 +575,31 @@ def __init__( # In "concat" mode: projects [node_features || PE] → hid_dim # In "add" mode: projects PE → hid_dim, then adds to node features self._concat_pe_fusion_projection: Optional[nn.Module] = None - if pe_integration_mode == "concat" and pe_attr_names: + has_node_level_pe = bool(pe_attr_names) + if pe_integration_mode == "concat" and has_node_level_pe: self._concat_pe_fusion_projection = nn.LazyLinear(hid_dim) self._pe_projection: Optional[nn.Module] = None - if pe_integration_mode == "add" and pe_attr_names: + if pe_integration_mode == "add" and has_node_level_pe: self._pe_projection = nn.LazyLinear(hid_dim, bias=False) + self._token_input_projection: Optional[nn.Module] = None + if self._anchor_based_input_attr_names: + self._token_input_projection = nn.LazyLinear(hid_dim, bias=False) + self._anchor_pe_attention_bias_projection: Optional[nn.Linear] = None - if anchor_based_pe_attr_names: + num_anchor_bias_attrs = len(self._anchor_based_attention_bias_attr_names or []) + if num_anchor_bias_attrs > 0: self._anchor_pe_attention_bias_projection = nn.Linear( - len(anchor_based_pe_attr_names), + num_anchor_bias_attrs, num_heads, bias=False, ) self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None - if pairwise_pe_attr_names: + if self._pairwise_attention_bias_attr_names: self._pairwise_pe_attention_bias_projection = nn.Linear( - len(pairwise_pe_attr_names), + len(self._pairwise_attention_bias_attr_names), num_heads, bias=False, ) @@ -572,21 +670,29 @@ def forward( projected_x_dict: dict[NodeType, torch.Tensor] = {} for node_type, x in data.x_dict.items(): x_processed = x.to(device) + feature_embedding_layer = None + if ( + self._feature_embedding_layer_dict is not None + and node_type in self._feature_embedding_layer_dict + ): + feature_embedding_layer = self._feature_embedding_layer_dict[node_type] # Apply feature embedding if available for this node type - if self._feature_embedding_layer_dict is not None: - if node_type in self._feature_embedding_layer_dict: - x_processed = self._feature_embedding_layer_dict[node_type]( - x_processed - ) + if feature_embedding_layer is not None: + x_processed = feature_embedding_layer(x_processed) # Project to hid_dim x_projected = self._node_projection_dict[str(node_type)](x_processed) + node_pe_parts = [] if self._pe_attr_names: - node_pe = _get_node_type_positional_encodings( - data=data, - node_type=node_type, - pe_attr_names=self._pe_attr_names, - device=device, + node_pe_parts.append( + _get_node_type_positional_encodings( + data=data, + node_type=node_type, + pe_attr_names=self._pe_attr_names, + device=device, + ) ) + if node_pe_parts: + node_pe = torch.cat(node_pe_parts, dim=-1) if self._pe_integration_mode == "add": if self._pe_projection is None: raise ValueError("PE projection layer is not initialized.") @@ -610,9 +716,17 @@ def forward( projected_data[node_type].batch_size = data[node_type].batch_size for edge_type in data.edge_types: projected_data[edge_type].edge_index = data[edge_type].edge_index + if hasattr(data[edge_type], "edge_attr"): + projected_data[edge_type].edge_attr = data[edge_type].edge_attr # Copy relative-encoding attributes (e.g., hop_distance stored as sparse matrix) - relative_pe_attr_names = set(self._anchor_based_pe_attr_names or []) - relative_pe_attr_names.update(self._pairwise_pe_attr_names or []) + relative_pe_attr_names = { + attr_name + for attr_name in (self._anchor_based_attention_bias_attr_names or []) + if attr_name != PPR_WEIGHT_FEATURE_NAME + } + relative_pe_attr_names.update(self._anchor_based_input_attr_names or []) + relative_pe_attr_names.update(self._pairwise_attention_bias_attr_names or []) + relative_pe_attr_names.discard(PPR_WEIGHT_FEATURE_NAME) if relative_pe_attr_names: for attr_name in sorted(relative_pe_attr_names): if hasattr(data, attr_name): @@ -632,7 +746,7 @@ def forward( ( sequences, valid_mask, - attention_bias_data, + sequence_auxiliary_data, ) = heterodata_to_graph_transformer_input( data=projected_data, batch_size=num_anchor_nodes, @@ -640,8 +754,10 @@ def forward( anchor_node_type=anchor_node_type, anchor_node_ids=anchor_node_ids, hop_distance=self._hop_distance, - anchor_based_pe_attr_names=self._anchor_based_pe_attr_names, - pairwise_pe_attr_names=self._pairwise_pe_attr_names, + sequence_construction_method=self._sequence_construction_method, + anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, + anchor_based_input_attr_names=self._anchor_based_input_attr_names, + pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, ) # Free memory after sequences are built @@ -653,10 +769,25 @@ def forward( f"got {sequences.size(-1)}." ) + token_input_features = sequence_auxiliary_data.get("token_input") + if token_input_features is not None: + if self._token_input_projection is None: + raise ValueError("Token-input projection is not initialized.") + sequences = sequences + self._token_input_projection( + token_input_features.to(sequences.dtype) + ) + + sequence_positional_encoding = self._get_sequence_positional_encoding( + valid_mask=valid_mask, + sequences=sequences, + ) + if sequence_positional_encoding is not None: + sequences = sequences + sequence_positional_encoding + attn_bias = self._build_attention_bias( valid_mask=valid_mask, sequences=sequences, - attention_bias_data=attention_bias_data, + attention_bias_data=sequence_auxiliary_data, ) embeddings = self._encode_and_readout( @@ -671,6 +802,38 @@ def forward( return embeddings + def _get_sequence_positional_encoding( + self, + valid_mask: Tensor, + sequences: Tensor, + ) -> Optional[Tensor]: + if self._sequence_positional_encoding_type is None: + return None + if self._sequence_positional_encoding_type != "sinusoidal": + raise ValueError( + "Unsupported sequence_positional_encoding_type " + f"'{self._sequence_positional_encoding_type}'." + ) + if self._sequence_positional_encoding_table is None: + raise ValueError("Sequence positional encoding table is not initialized.") + + seq_len = sequences.size(1) + if seq_len > self._sequence_positional_encoding_table.size(0): + raise ValueError( + f"Sequence length {seq_len} exceeds configured max_seq_len " + f"{self._sequence_positional_encoding_table.size(0)}." + ) + + position_encoding = self._sequence_positional_encoding_table[:seq_len] + position_encoding = position_encoding.to( + device=sequences.device, + dtype=sequences.dtype, + ) + position_encoding = position_encoding.unsqueeze(0).expand( + sequences.size(0), -1, -1 + ) + return position_encoding * valid_mask.unsqueeze(-1).to(sequences.dtype) + def _build_attention_bias( self, valid_mask: Tensor, diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 1b4ca7fa4..f99f8b264 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -1,4 +1,3 @@ -# TODO: support RW sampling, edges from data.ppr_edge, data.ppr_weights """ Transform HeteroData to Graph Transformer sequence input. @@ -52,21 +51,22 @@ ... batch_size=32, ... max_seq_len=128, ... anchor_node_type='user', - ... anchor_based_pe_attr_names=['hop_distance'], # Anchor-based PE (N×N sparse CSR) + ... anchor_based_attention_bias_attr_names=['hop_distance'], ... ) >>> # sequences: (batch_size, max_seq_len, feature_dim) >>> # attention_bias_data['anchor_bias']: (batch_size, max_seq_len, 1) """ -from typing import Optional +from typing import Literal, Optional import torch from torch import Tensor -from torch_geometric.data import HeteroData +from torch_geometric.data import Data, HeteroData from torch_geometric.typing import NodeType from torch_geometric.utils import to_torch_sparse_tensor -AttentionBiasData = dict[str, Optional[Tensor]] +SequenceAuxiliaryData = dict[str, Optional[Tensor]] +PPR_WEIGHT_FEATURE_NAME = "ppr_weight" def heterodata_to_graph_transformer_input( @@ -76,11 +76,13 @@ def heterodata_to_graph_transformer_input( anchor_node_type: NodeType, anchor_node_ids: Optional[Tensor] = None, hop_distance: int = 2, + sequence_construction_method: Literal["khop", "ppr"] = "khop", include_anchor_first: bool = True, padding_value: float = 0.0, - anchor_based_pe_attr_names: Optional[list[str]] = None, - pairwise_pe_attr_names: Optional[list[str]] = None, -) -> tuple[Tensor, Tensor, AttentionBiasData]: + anchor_based_attention_bias_attr_names: Optional[list[str]] = None, + anchor_based_input_attr_names: Optional[list[str]] = None, + pairwise_attention_bias_attr_names: Optional[list[str]] = None, +) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]: """ Transform a HeteroData object to Graph Transformer sequence input. @@ -100,20 +102,27 @@ def heterodata_to_graph_transformer_input( anchor_node_type: The node type of anchor nodes. anchor_node_ids: Optional tensor of local node indices within anchor_node_type to use as anchors. If None, uses first batch_size nodes. (default: None) - hop_distance: Number of hops to consider for neighborhood (default: 2). + hop_distance: Number of hops to consider for neighborhood when + ``sequence_construction_method="khop"``. (default: 2) + sequence_construction_method: Strategy used to build per-anchor sequences. + ``"khop"`` performs the existing k-hop expansion over the sampled graph. + ``"ppr"`` uses outgoing ``(anchor_type, "ppr", neighbor_type)`` edges, + sorted by descending PPR weight from ``edge_attr``. (default: ``"khop"``) include_anchor_first: If True, anchor node is always first in sequence. padding_value: Value to use for padding (default: 0.0). - anchor_based_pe_attr_names: List of relative-encoding attribute names - containing sparse (N x N) matrices for anchor-relative positional - encodings. - For each node in the sequence, the value PE[anchor_idx, node_idx] is - looked up and returned as attention-bias features. - Examples: ['hop_distance'] (from AddHeteroHopDistanceEncoding) - If None, no anchor-based PEs are attached. (default: None) - pairwise_pe_attr_names: List of relative-encoding attribute names - containing sparse (N x N) matrices for pairwise relative encodings. - For each pair of sequence nodes (i, j), the value PE[node_i, node_j] - is looked up and returned as attention-bias features. (default: None) + anchor_based_attention_bias_attr_names: List of anchor-relative feature + names used as attention bias. Sparse graph-level attributes are + looked up from ``data`` and the reserved name ``"ppr_weight"`` + resolves to PPR edge weights in PPR sequence mode. + Example: ['hop_distance', 'ppr_weight']. + anchor_based_input_attr_names: List of anchor-relative attribute names + returned as token-aligned model-input features. Sparse graph-level + attributes are looked up from ``data`` and ``"ppr_weight"`` resolves + to PPR edge weights in PPR sequence mode. + Example: ['hop_distance', 'ppr_weight']. + pairwise_attention_bias_attr_names: List of pairwise feature names used + as attention bias. These must correspond to sparse graph-level + attributes on ``data``. Example: ['pairwise_distance']. Returns: (sequences, valid_mask, attention_bias_data), where: @@ -121,10 +130,13 @@ def heterodata_to_graph_transformer_input( taken directly from ``data[node_type].x`` in homogeneous order. valid_mask: (batch_size, max_seq_len) bool tensor indicating which sequence positions correspond to real nodes. - attention_bias_data: dictionary of raw attention-bias features with: + sequence_auxiliary_data: dictionary of raw token-aligned and + attention-bias features with: ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"token_input"`` shaped + ``(batch, seq, num_token_input_attrs)`` or None Raises: ValueError: If node types have different feature dimensions. @@ -160,73 +172,115 @@ def heterodata_to_graph_transformer_input( f"Found different dimensions: {feature_dims}" ) + anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] + anchor_input_attr_names = anchor_based_input_attr_names or [] + pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] + + if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names: + raise ValueError( + f"'{PPR_WEIGHT_FEATURE_NAME}' is an anchor-relative feature and cannot " + "be used as pairwise attention bias." + ) + + if ( + PPR_WEIGHT_FEATURE_NAME in anchor_bias_attr_names + anchor_input_attr_names + and sequence_construction_method != "ppr" + ): + raise ValueError( + "The reserved anchor-relative feature 'ppr_weight' requires " + "sequence_construction_method='ppr'." + ) + + if sequence_construction_method == "ppr": + _validate_ppr_sequence_input(data) + device = data[anchor_node_type].x.device # Convert to homogeneous for easier neighborhood extraction homo_data = data.to_homogeneous() homo_x = homo_data.x # (total_nodes, feature_dim) - homo_edge_index = homo_data.edge_index # (2, num_edges) num_nodes = homo_data.num_nodes - # Get node type to index mapping (sorted alphabetically) - sorted_node_types = sorted(data.node_types) + # Match the node-type ordering used by to_homogeneous() so homogeneous + # indices line up with homo_x / homo_edge_index. + node_type_order = list(getattr(homo_data, "_node_type_names", data.node_types)) + node_type_offsets = _get_node_type_offsets( + data=data, node_type_order=node_type_order + ) # Find offset for anchor_node_type in homogeneous graph - # Nodes are ordered by node_type (alphabetically), then by original index - offset = 0 - for nt in sorted_node_types: - if nt == anchor_node_type: - break - offset += data[nt].num_nodes + # Nodes are ordered by the homogeneous node-type order, then by original index. + offset = node_type_offsets[anchor_node_type] # Determine anchor indices in homogeneous graph if anchor_node_ids is not None: # Use provided local indices, convert to homogeneous indices - anchor_indices = offset + anchor_node_ids.to(device) + anchor_local_indices = anchor_node_ids.to(device) else: # Default: first batch_size nodes of anchor_node_type - anchor_indices = torch.arange(offset, offset + batch_size, device=device) + anchor_local_indices = torch.arange(batch_size, device=device) + anchor_indices = offset + anchor_local_indices + + ppr_weight_sequences: Optional[Tensor] = None + if sequence_construction_method == "khop": + homo_edge_index = homo_data.edge_index # (2, num_edges) + # Use sparse matrix operations for efficient k-hop neighbor extraction + # Returns: (batch_size, num_nodes) sparse matrix where non-zero entries are reachable + reachable = _get_k_hop_neighbors_sparse( + anchor_indices=anchor_indices, + edge_index=homo_edge_index, + num_nodes=num_nodes, + k=hop_distance, + device=device, + ) + node_index_sequences, valid_mask = _build_sequence_layout_from_sparse_neighbors( + reachable=reachable, + anchor_indices=anchor_indices, + max_seq_len=max_seq_len, + include_anchor_first=include_anchor_first, + device=device, + ) + elif sequence_construction_method == "ppr": + ( + node_index_sequences, + valid_mask, + ppr_weight_sequences, + ) = _build_sequence_layout_from_ppr_edges( + homo_data=homo_data, + anchor_indices=anchor_indices, + max_seq_len=max_seq_len, + include_anchor_first=include_anchor_first, + num_nodes=num_nodes, + device=device, + return_edge_weights=( + PPR_WEIGHT_FEATURE_NAME + in anchor_bias_attr_names + anchor_input_attr_names + ), + ) + else: + raise ValueError( + "sequence_construction_method must be one of ['khop', 'ppr'], " + f"got '{sequence_construction_method}'." + ) - # Use sparse matrix operations for efficient k-hop neighbor extraction - # Returns: (batch_size, num_nodes) sparse matrix where non-zero entries are reachable - reachable = _get_k_hop_neighbors_sparse( - anchor_indices=anchor_indices, - edge_index=homo_edge_index, - num_nodes=num_nodes, - k=hop_distance, - device=device, + anchor_matrix_attr_names = list( + { + attr_name + for attr_name in (anchor_bias_attr_names + anchor_input_attr_names) + if attr_name != PPR_WEIGHT_FEATURE_NAME + } + ) + anchor_based_matrices = _get_sparse_feature_matrices( + data=data, + attr_names=anchor_matrix_attr_names, + missing_attr_error_prefix="Anchor-based attribute", ) - # Get anchor-based PE matrices if specified - anchor_based_pe_matrices = [] - if anchor_based_pe_attr_names: - for attr_name in anchor_based_pe_attr_names: - if hasattr(data, attr_name): - anchor_based_pe_matrices.append(getattr(data, attr_name)) - else: - raise ValueError( - f"Anchor-based PE attribute '{attr_name}' not found in data. " - f"Make sure to apply the corresponding transform first." - ) - - pairwise_pe_matrices = [] - if pairwise_pe_attr_names: - for attr_name in pairwise_pe_attr_names: - if hasattr(data, attr_name): - pairwise_pe_matrices.append(getattr(data, attr_name)) - else: - raise ValueError( - f"Pairwise PE attribute '{attr_name}' not found in data. " - f"Make sure to apply the corresponding transform first." - ) - - node_index_sequences, valid_mask = _build_sequence_layout_from_sparse_neighbors( - reachable=reachable, - anchor_indices=anchor_indices, - max_seq_len=max_seq_len, - include_anchor_first=include_anchor_first, - device=device, + pairwise_pe_matrices = _get_sparse_feature_matrices( + data=data, + attr_names=pairwise_bias_attr_names, + missing_attr_error_prefix="Pairwise PE attribute", ) node_feature_sequences = _gather_sequences_from_node_indices( @@ -240,7 +294,7 @@ def heterodata_to_graph_transformer_input( anchor_indices=anchor_indices, node_index_sequences=node_index_sequences, valid_mask=valid_mask, - csr_matrices=anchor_based_pe_matrices if anchor_based_pe_matrices else None, + csr_matrices=anchor_based_matrices if anchor_based_matrices else None, device=device, ) @@ -251,16 +305,118 @@ def heterodata_to_graph_transformer_input( device=device, ) + anchor_bias_features = _compose_anchor_feature_tensor( + anchor_relative_feature_sequences=anchor_relative_feature_sequences, + available_anchor_attr_names=anchor_matrix_attr_names, + requested_anchor_attr_names=anchor_bias_attr_names, + ppr_weight_sequences=ppr_weight_sequences, + ) + token_input_features = _compose_anchor_feature_tensor( + anchor_relative_feature_sequences=anchor_relative_feature_sequences, + available_anchor_attr_names=anchor_matrix_attr_names, + requested_anchor_attr_names=anchor_input_attr_names, + ppr_weight_sequences=ppr_weight_sequences, + ) + return ( node_feature_sequences, valid_mask, { - "anchor_bias": anchor_relative_feature_sequences, + "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "token_input": token_input_features, }, ) +def _get_node_type_offsets( + data: HeteroData, + node_type_order: list[NodeType], +) -> dict[NodeType, int]: + offsets: dict[NodeType, int] = {} + offset = 0 + for node_type in node_type_order: + offsets[node_type] = offset + offset += data[node_type].num_nodes + return offsets + + +def _validate_ppr_sequence_input(data: HeteroData) -> None: + if not data.edge_types: + raise ValueError( + "sequence_construction_method='ppr' requires at least one PPR edge type." + ) + + if any(edge_type[1] != "ppr" for edge_type in data.edge_types): + raise ValueError( + "sequence_construction_method='ppr' expects the hetero batch to contain " + f"only PPR edges, got edge types: {data.edge_types}." + ) + + for edge_type in data.edge_types: + edge_store = data[edge_type] + if not hasattr(edge_store, "edge_attr") or edge_store.edge_attr is None: + raise ValueError( + "sequence_construction_method='ppr' requires every PPR edge type to " + f"have edge_attr weights, but {edge_type} is missing them." + ) + + +def _get_sparse_feature_matrices( + data: HeteroData, + attr_names: Optional[list[str]], + missing_attr_error_prefix: str, +) -> list[Tensor]: + matrices: list[Tensor] = [] + for attr_name in attr_names or []: + if not hasattr(data, attr_name): + raise ValueError( + f"{missing_attr_error_prefix} '{attr_name}' not found in data. " + "Make sure to apply the corresponding transform first." + ) + matrices.append(getattr(data, attr_name)) + return matrices + + +def _compose_anchor_feature_tensor( + anchor_relative_feature_sequences: Optional[Tensor], + available_anchor_attr_names: list[str], + requested_anchor_attr_names: list[str], + ppr_weight_sequences: Optional[Tensor], +) -> Optional[Tensor]: + if not requested_anchor_attr_names: + return None + + feature_parts: list[Tensor] = [] + feature_index_by_name = { + attr_name: idx for idx, attr_name in enumerate(available_anchor_attr_names) + } + + for attr_name in requested_anchor_attr_names: + if attr_name == PPR_WEIGHT_FEATURE_NAME: + if ppr_weight_sequences is None: + raise ValueError( + f"Requested '{PPR_WEIGHT_FEATURE_NAME}' but it was not computed." + ) + feature_parts.append(ppr_weight_sequences) + continue + + if anchor_relative_feature_sequences is None: + raise ValueError( + "Anchor-relative features were requested but not computed." + ) + if attr_name not in feature_index_by_name: + raise ValueError( + f"Anchor-relative feature '{attr_name}' was requested but not found." + ) + feature_idx = feature_index_by_name[attr_name] + feature_parts.append( + anchor_relative_feature_sequences[..., feature_idx : feature_idx + 1] + ) + + return torch.cat(feature_parts, dim=-1) + + def _build_sequence_layout_from_sparse_neighbors( reachable: Tensor, anchor_indices: Tensor, @@ -357,6 +513,139 @@ def _build_sequence_layout_from_sparse_neighbors( return node_index_sequences, valid_mask +def _build_sequence_layout_from_ppr_edges( + homo_data: Data, + anchor_indices: Tensor, + max_seq_len: int, + include_anchor_first: bool, + num_nodes: int, + device: torch.device, + return_edge_weights: bool = False, +) -> tuple[Tensor, Tensor, Optional[Tensor]]: + """Build sequences directly from outgoing PPR edges for each anchor. + + The sequence order is: + 1. Anchor node first, when ``include_anchor_first`` is True. + 2. Destination nodes reachable by outgoing ``"ppr"`` edges from that anchor, + sorted by descending PPR weight. + """ + batch_size = anchor_indices.size(0) + node_index_sequences = torch.full( + (batch_size, max_seq_len), + fill_value=-1, + dtype=torch.long, + device=device, + ) + valid_mask = torch.zeros( + (batch_size, max_seq_len), + dtype=torch.bool, + device=device, + ) + ppr_weight_sequences = None + if return_edge_weights: + ppr_weight_sequences = torch.zeros( + (batch_size, max_seq_len, 1), + dtype=torch.float, + device=device, + ) + + if include_anchor_first and max_seq_len > 0: + node_index_sequences[:, 0] = anchor_indices + valid_mask[:, 0] = True + start_pos = 1 + else: + start_pos = 0 + + if start_pos >= max_seq_len: + return node_index_sequences, valid_mask, ppr_weight_sequences + + if not hasattr(homo_data, "edge_attr") or homo_data.edge_attr is None: + raise ValueError( + "sequence_construction_method='ppr' requires homogeneous edge_attr weights." + ) + + edge_weights = homo_data.edge_attr + if edge_weights.dim() == 2: + if edge_weights.size(1) != 1: + raise ValueError( + "PPR edge weights must be 1D or shape [N, 1], " + f"got {tuple(edge_weights.shape)}." + ) + edge_weights = edge_weights.squeeze(1) + elif edge_weights.dim() != 1: + raise ValueError( + "PPR edge weights must be 1D or shape [N, 1], " + f"got {tuple(edge_weights.shape)}." + ) + + anchor_batch_index_by_homo_idx = torch.full( + (num_nodes,), + fill_value=-1, + dtype=torch.long, + device=device, + ) + anchor_batch_index_by_homo_idx[anchor_indices] = torch.arange( + batch_size, device=device + ) + + src_idx = homo_data.edge_index[0] + dst_idx = homo_data.edge_index[1] + anchor_batch_idx = anchor_batch_index_by_homo_idx[src_idx] + keep = anchor_batch_idx >= 0 + if not keep.any(): + return node_index_sequences, valid_mask, ppr_weight_sequences + + all_anchor_batch_idx = anchor_batch_idx[keep] + all_dst_idx = dst_idx[keep] + all_weights = edge_weights[keep] + + if include_anchor_first: + keep = all_dst_idx != anchor_indices[all_anchor_batch_idx] + if not keep.any(): + return node_index_sequences, valid_mask, ppr_weight_sequences + all_anchor_batch_idx = all_anchor_batch_idx[keep] + all_dst_idx = all_dst_idx[keep] + all_weights = all_weights[keep] + + # Flattened COO edges can be laid out in one pass by sorting first on weight + # and then stably on anchor batch id, which preserves descending-weight order + # within each anchor group without a Python loop. + weight_order = torch.argsort(all_weights, descending=True, stable=True) + all_anchor_batch_idx = all_anchor_batch_idx[weight_order] + all_dst_idx = all_dst_idx[weight_order] + all_weights = all_weights[weight_order] + + batch_order = torch.argsort(all_anchor_batch_idx, stable=True) + sorted_batch_idx = all_anchor_batch_idx[batch_order] + sorted_dst_idx = all_dst_idx[batch_order] + sorted_weights = all_weights[batch_order] + + n = sorted_batch_idx.size(0) + is_group_start = torch.zeros(n, dtype=torch.long, device=device) + is_group_start[0] = 1 + if n > 1: + is_group_start[1:] = (sorted_batch_idx[1:] != sorted_batch_idx[:-1]).long() + + group_id = is_group_start.cumsum(0) - 1 + group_starts = torch.nonzero(is_group_start, as_tuple=True)[0] + positions = torch.arange(n, device=device) - group_starts[group_id] + start_pos + + valid = positions < max_seq_len + valid_batch_idx = sorted_batch_idx[valid] + valid_positions = positions[valid] + valid_dst_idx = sorted_dst_idx[valid] + valid_weights = sorted_weights[valid] + + node_index_sequences[valid_batch_idx, valid_positions] = valid_dst_idx + valid_mask[valid_batch_idx, valid_positions] = True + if ppr_weight_sequences is not None: + ppr_weight_sequences[ + valid_batch_idx, valid_positions, 0 + ] = valid_weights.float() + + return node_index_sequences, valid_mask, ppr_weight_sequences + + def _gather_sequences_from_node_indices( node_index_sequences: Tensor, node_features: Tensor, diff --git a/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py b/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py index 3a3c558bd..103237c72 100644 --- a/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py +++ b/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py @@ -257,6 +257,20 @@ def _create_user_graph_with_pe() -> HeteroData: return data +def _create_user_graph_with_ppr_edges() -> HeteroData: + data = HeteroData() + + data["user"].x = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]] + ) + data["user", "ppr", "user"].edge_index = torch.tensor([[0, 0, 1], [1, 2, 2]]) + data["user", "ppr", "user"].edge_attr = torch.tensor([0.9, 0.4, 0.7]) + hop_distance = torch.tensor([[0.0, 1.0, 2.0], [1.0, 0.0, 1.0], [2.0, 1.0, 0.0]]) + data.hop_distance = hop_distance.to_sparse_csr() + + return data + + class TestGraphTransformerEncoderPEModes(TestCase): def setUp(self) -> None: self._node_type = NodeType("user") @@ -328,8 +342,8 @@ def test_forward_accepts_pairwise_attention_bias(self) -> None: encoder = self._create_encoder( pe_attr_names=["random_walk_pe"], - anchor_based_pe_attr_names=["hop_distance"], - pairwise_pe_attr_names=["pairwise_distance"], + anchor_based_attention_bias_attr_names=["hop_distance"], + pairwise_attention_bias_attr_names=["pairwise_distance"], pe_integration_mode="add", ) encoder.eval() @@ -365,8 +379,8 @@ def test_concat_mode_infers_sequence_width_without_explicit_pe_dim(self) -> None def test_attention_bias_features_are_projected_per_head(self) -> None: encoder = self._create_encoder( - anchor_based_pe_attr_names=["hop_distance"], - pairwise_pe_attr_names=["pairwise_distance"], + anchor_based_attention_bias_attr_names=["hop_distance"], + pairwise_attention_bias_attr_names=["pairwise_distance"], ) assert encoder._anchor_pe_attention_bias_projection is not None @@ -403,6 +417,189 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: self.assertEqual(attn_bias[0, 0, 2, 2].item(), 27.0) self.assertEqual(attn_bias[0, 1, 2, 2].item(), 38.0) + def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( + self, + ) -> None: + encoder = self._create_encoder( + edge_type_to_feat_dim_map={ + EdgeType(self._node_type, Relation("ppr"), self._node_type): 0 + }, + sequence_construction_method="ppr", + anchor_based_attention_bias_attr_names=["hop_distance", "ppr_weight"], + ) + + assert encoder._anchor_pe_attention_bias_projection is not None + + with torch.no_grad(): + encoder._anchor_pe_attention_bias_projection.weight.copy_( + torch.tensor([[1.0, 10.0], [2.0, 20.0]]) + ) + + attn_bias = encoder._build_attention_bias( + valid_mask=torch.ones((1, 3), dtype=torch.bool), + sequences=torch.zeros((1, 3, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": torch.tensor( + [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] + ), + "pairwise_bias": None, + "token_input": None, + }, + ) + + self.assertEqual(attn_bias.shape, (1, 2, 3, 3)) + self.assertEqual(attn_bias[0, 0, 0, 1].item(), 4.5) + self.assertEqual(attn_bias[0, 1, 0, 1].item(), 9.0) + self.assertEqual(attn_bias[0, 0, 2, 2].item(), 4.25) + self.assertEqual(attn_bias[0, 1, 2, 2].item(), 8.5) + + def test_sinusoidal_sequence_positional_encoding_masks_padding(self) -> None: + encoder = self._create_encoder( + sequence_positional_encoding_type="sinusoidal", + ) + + sequence_positional_encoding = encoder._get_sequence_positional_encoding( + valid_mask=torch.tensor([[True, True, False, False]]), + sequences=torch.zeros((1, 4, 8), dtype=torch.float), + ) + + assert sequence_positional_encoding is not None + expected_position_zero = torch.tensor( + [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], + dtype=torch.float, + ) + self.assertEqual(sequence_positional_encoding.shape, (1, 4, 8)) + self.assertTrue( + torch.allclose(sequence_positional_encoding[0, 0], expected_position_zero) + ) + self.assertFalse( + torch.allclose( + sequence_positional_encoding[0, 1], + torch.zeros(8, dtype=torch.float), + ) + ) + self.assertTrue( + torch.allclose( + sequence_positional_encoding[0, 2:], + torch.zeros((2, 8), dtype=torch.float), + ) + ) + + def test_forward_supports_ppr_sequence_construction(self) -> None: + data = _create_user_graph_with_ppr_edges() + + encoder = self._create_encoder( + edge_type_to_feat_dim_map={ + EdgeType(self._node_type, Relation("ppr"), self._node_type): 0 + }, + sequence_construction_method="ppr", + ) + encoder.eval() + + with torch.no_grad(): + embeddings = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertEqual(embeddings.shape, (3, 6)) + self.assertFalse(torch.isnan(embeddings).any()) + + def test_forward_supports_sinusoidal_sequence_position_encoding_in_ppr_mode( + self, + ) -> None: + data = _create_user_graph_with_ppr_edges() + encoder = self._create_encoder( + edge_type_to_feat_dim_map={ + EdgeType(self._node_type, Relation("ppr"), self._node_type): 0 + }, + sequence_construction_method="ppr", + sequence_positional_encoding_type="sinusoidal", + anchor_based_attention_bias_attr_names=["hop_distance", "ppr_weight"], + anchor_based_input_attr_names=["hop_distance", "ppr_weight"], + ) + encoder.eval() + + with torch.no_grad(): + _ = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + assert encoder._sequence_positional_encoding_table is not None + original_position_table = ( + encoder._sequence_positional_encoding_table.detach().clone() + ) + + embeddings_with_position_encoding = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + encoder._sequence_positional_encoding_table.zero_() + embeddings_without_position_encoding = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + encoder._sequence_positional_encoding_table.copy_(original_position_table) + + self.assertEqual(embeddings_with_position_encoding.shape, (3, 6)) + self.assertFalse(torch.isnan(embeddings_with_position_encoding).any()) + self.assertFalse( + torch.allclose( + embeddings_with_position_encoding, + embeddings_without_position_encoding, + ) + ) + + def test_forward_supports_anchor_relative_and_ppr_token_input_features( + self, + ) -> None: + data = _create_user_graph_with_ppr_edges() + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + sequence_construction_method="ppr", + ) + augmented_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + sequence_construction_method="ppr", + anchor_based_input_attr_names=["hop_distance", "ppr_weight"], + ) + augmented_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + + base_encoder.eval() + augmented_encoder.eval() + + with torch.no_grad(): + _ = augmented_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + assert augmented_encoder._token_input_projection is not None + assert isinstance(augmented_encoder._token_input_projection, nn.Linear) + augmented_encoder._token_input_projection.weight.data.zero_() + + base_embeddings = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + augmented_embeddings = augmented_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertEqual(augmented_embeddings.shape, (3, 6)) + self.assertTrue( + torch.allclose(base_embeddings, augmented_embeddings, atol=1e-6) + ) + class TestFeedForwardNetwork(TestCase): """Tests for FeedForwardNetwork with various activations.""" diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 14382d680..fffac1805 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -68,6 +68,60 @@ def create_larger_hetero_data(num_users: int = 10, num_items: int = 5) -> Hetero return data +def create_ppr_sequence_hetero_data() -> HeteroData: + """Create a graph with explicit PPR edges for sequence-construction tests.""" + data = HeteroData() + + data["user"].x = torch.tensor([[10.0, 0.0], [11.0, 0.0]]) + data["item"].x = torch.tensor([[0.0, 20.0], [0.0, 21.0]]) + + data["user", "ppr", "item"].edge_index = torch.tensor( + [ + [0, 0, 1], + [1, 0, 0], + ] + ) + data["user", "ppr", "item"].edge_attr = torch.tensor([0.9, 0.6, 0.8]) + + data["user", "ppr", "user"].edge_index = torch.tensor( + [ + [0, 1], + [1, 0], + ] + ) + data["user", "ppr", "user"].edge_attr = torch.tensor([0.4, 0.3]) + + homo_data = data.to_homogeneous() + node_type_order = list(getattr(homo_data, "_node_type_names", data.node_types)) + offsets = {} + offset = 0 + for node_type in node_type_order: + offsets[node_type] = offset + offset += data[node_type].num_nodes + + total_nodes = homo_data.num_nodes + hop_distance = torch.zeros((total_nodes, total_nodes), dtype=torch.float) + + user0_idx = offsets["user"] + 0 + user1_idx = offsets["user"] + 1 + item0_idx = offsets["item"] + 0 + item1_idx = offsets["item"] + 1 + + hop_distance[user0_idx, user0_idx] = 0.0 + hop_distance[user0_idx, user1_idx] = 1.0 + hop_distance[user0_idx, item0_idx] = 2.0 + hop_distance[user0_idx, item1_idx] = 3.0 + + hop_distance[user1_idx, user0_idx] = 1.0 + hop_distance[user1_idx, user1_idx] = 0.0 + hop_distance[user1_idx, item0_idx] = 4.0 + hop_distance[user1_idx, item1_idx] = 5.0 + + data.hop_distance = hop_distance.to_sparse_csr() + + return data + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -232,8 +286,13 @@ def test_anchor_first(self): # Get anchor node feature homo_data = data.to_homogeneous() - # 'item' comes before 'user' alphabetically, so user offset = num_items = 2 - anchor_feature = homo_data.x[2] # First user node + node_type_order = list(getattr(homo_data, "_node_type_names", data.node_types)) + user_offset = 0 + for node_type in node_type_order: + if node_type == "user": + break + user_offset += data[node_type].num_nodes + anchor_feature = homo_data.x[user_offset] # First user node sequences, _, _ = heterodata_to_graph_transformer_input( data=data, @@ -327,6 +386,100 @@ def test_custom_padding_value(self): expected_padding = torch.full_like(padding_features, -1.0) self.assertTrue(torch.allclose(padding_features, expected_padding)) + def test_ppr_sequence_construction_sorts_tokens_by_weight(self): + """Test that PPR mode uses outgoing PPR edges ordered by descending weight.""" + data = create_ppr_sequence_hetero_data() + + sequences, valid_mask, _ = heterodata_to_graph_transformer_input( + data=data, + batch_size=2, + max_seq_len=4, + anchor_node_type="user", + sequence_construction_method="ppr", + ) + + expected_anchor_0 = torch.tensor( + [ + [10.0, 0.0], # anchor user0 + [0.0, 21.0], # item1, weight 0.9 + [0.0, 20.0], # item0, weight 0.6 + [11.0, 0.0], # user1, weight 0.4 + ] + ) + expected_anchor_1 = torch.tensor( + [ + [11.0, 0.0], # anchor user1 + [0.0, 20.0], # item0, weight 0.8 + [10.0, 0.0], # user0, weight 0.3 + [0.0, 0.0], # padding + ] + ) + + self.assertTrue(torch.allclose(sequences[0], expected_anchor_0)) + self.assertTrue(torch.allclose(sequences[1], expected_anchor_1)) + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, True])) + ) + self.assertTrue( + torch.equal(valid_mask[1], torch.tensor([True, True, True, False])) + ) + + def test_ppr_sequence_construction_requires_only_ppr_relations(self): + data = create_simple_hetero_data() + + with self.assertRaisesRegex(ValueError, "contain only PPR edges"): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + sequence_construction_method="ppr", + ) + + def test_ppr_sequence_can_return_token_input_and_attention_bias_features(self): + data = create_ppr_sequence_hetero_data() + + _, valid_mask, sequence_auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=2, + max_seq_len=4, + anchor_node_type="user", + sequence_construction_method="ppr", + anchor_based_attention_bias_attr_names=["hop_distance", "ppr_weight"], + anchor_based_input_attr_names=["hop_distance", "ppr_weight"], + ) + + anchor_bias = sequence_auxiliary_data["anchor_bias"] + token_input = sequence_auxiliary_data["token_input"] + assert anchor_bias is not None + assert token_input is not None + + expected_anchor_0 = torch.tensor( + [ + [0.0, 0.0], + [3.0, 0.9], + [2.0, 0.6], + [1.0, 0.4], + ] + ) + expected_anchor_1 = torch.tensor( + [ + [0.0, 0.0], + [4.0, 0.8], + [1.0, 0.3], + [0.0, 0.0], + ] + ) + + self.assertEqual(anchor_bias.shape, (2, 4, 2)) + self.assertEqual(token_input.shape, (2, 4, 2)) + self.assertTrue(torch.allclose(anchor_bias[0], expected_anchor_0)) + self.assertTrue(torch.allclose(anchor_bias[1], expected_anchor_1)) + self.assertTrue(torch.allclose(token_input, anchor_bias)) + self.assertTrue( + torch.equal(valid_mask[1], torch.tensor([True, True, True, False])) + ) + class TestPyTorchTransformerIntegration(TestCase): """Tests for integration with PyTorch TransformerEncoderLayer.""" @@ -620,7 +773,7 @@ def test_transform_returns_base_sequences_and_anchor_relative_bias(self) -> None max_seq_len=4, anchor_node_type="user", hop_distance=2, - anchor_based_pe_attr_names=["hop_distance"], + anchor_based_attention_bias_attr_names=["hop_distance"], ) self.assertEqual(sequences.shape, (1, 4, 4)) @@ -645,8 +798,8 @@ def test_attention_bias_outputs_include_valid_mask_and_relative_features( max_seq_len=4, anchor_node_type="user", hop_distance=2, - anchor_based_pe_attr_names=["hop_distance"], - pairwise_pe_attr_names=["pairwise_distance"], + anchor_based_attention_bias_attr_names=["hop_distance"], + pairwise_attention_bias_attr_names=["pairwise_distance"], ) self.assertEqual(sequences.shape, (1, 4, 4)) From b5b802726ffa51fa90efdb1be2fd8ced45105405 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 19:51:55 +0000 Subject: [PATCH 04/16] Auto-build C++ extensions in post_install; auto-add LLVM to PATH on Mac --- Makefile | 2 +- gigl/scripts/post_install.py | 40 +++++++++++++++++++--------- requirements/install_cpp_deps.sh | 22 ++++++++------- scripts/generate_compile_commands.py | 26 +++++++++--------- 4 files changed, 56 insertions(+), 34 deletions(-) diff --git a/Makefile b/Makefile index 11d8ba848..f13be1c18 100644 --- a/Makefile +++ b/Makefile @@ -161,7 +161,7 @@ format: format_py format_scala format_md format_cpp type_check: uv run mypy ${PYTHON_DIRS} --check-untyped-defs -lint_cpp: build_cpp_extensions +lint_cpp: uv run python scripts/generate_compile_commands.py clang-tidy -p build/compile_commands.json $(CPP_SOURCES) diff --git a/gigl/scripts/post_install.py b/gigl/scripts/post_install.py index cd13c4b0b..918d62fb5 100644 --- a/gigl/scripts/post_install.py +++ b/gigl/scripts/post_install.py @@ -3,8 +3,9 @@ Once GiGL is installed w/ `pip install gigl`, this script can be executed by running: `gigl-post-install` -This script is used to install the dependencies for GIGL. -- Currently, it installs GLT by running install_glt.sh. +This script is used to install the dependencies for GIGL: +- Installs GLT by running install_glt.sh. +- Builds pybind11 C++ extensions in-place so they are importable without a separate build step. """ import subprocess @@ -38,24 +39,19 @@ def main(): """Main entry point for the post-install script.""" print("Running GIGL post-install script...") - # Get the directory where this script is located script_dir = Path(__file__).parent + repo_root = script_dir.parent.parent - # Path to the install_glt.sh script + # Step 1: Install GLT install_glt_script = script_dir / "install_glt.sh" - if not install_glt_script.exists(): print(f"Error: install_glt.sh not found at {install_glt_script}") sys.exit(1) - cmd = f"bash {install_glt_script}" - try: - print(f"Executing {cmd}...") - result = run_command_and_stream_stdout(cmd) - print("Post-install script finished running, with return code: ", result) - return result - + print(f"Executing bash {install_glt_script}...") + result = run_command_and_stream_stdout(f"bash {install_glt_script}") + print("GLT install finished with return code:", result) except subprocess.CalledProcessError as e: print(f"Error running install_glt.sh: {e}") sys.exit(1) @@ -63,6 +59,26 @@ def main(): print(f"Unexpected error: {e}") sys.exit(1) + # Step 2: Build pybind11 C++ extensions in-place so they are importable + # without requiring a separate `make build_cpp_extensions` call. + setup_py = repo_root / "setup.py" + if setup_py.exists(): + cmd = f"cd {repo_root} && {sys.executable} setup.py build_ext --inplace" + try: + print("Building C++ extensions...") + result = run_command_and_stream_stdout(cmd) + print("C++ extension build finished with return code:", result) + except subprocess.CalledProcessError as e: + print(f"Error building C++ extensions: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error building C++ extensions: {e}") + sys.exit(1) + else: + print(f"Warning: {setup_py} not found, skipping C++ extension build") + + return result + if __name__ == "__main__": main() diff --git a/requirements/install_cpp_deps.sh b/requirements/install_cpp_deps.sh index 9a3aacbd7..7675a0396 100644 --- a/requirements/install_cpp_deps.sh +++ b/requirements/install_cpp_deps.sh @@ -18,16 +18,20 @@ is_running_on_mac() { if is_running_on_mac; then # `brew install llvm` provides clang-format and clang-tidy. # Homebrew does not add llvm to PATH by default to avoid shadowing Apple's - # clang, so we print an instruction for the developer to do it manually. + # clang, so we add it ourselves. brew install llvm cmake - LLVM_PREFIX=$(brew --prefix llvm) - set +x - echo "" - echo "NOTE: Add the LLVM bin directory to your PATH to use clang-format and clang-tidy:" - echo " export PATH=\"${LLVM_PREFIX}/bin:\$PATH\"" - echo " (Add this to your ~/.zshrc or ~/.bashrc to make it permanent.)" - echo "" - set -x + LLVM_BIN="$(brew --prefix llvm)/bin" + + # Append to any shell rc files that exist and don't already include it. + for rc_file in ~/.zshrc ~/.bashrc; do + if [ -f "$rc_file" ] && ! grep -qF "$LLVM_BIN" "$rc_file"; then + printf '\n# Added by GiGL install_cpp_deps.sh\nexport PATH="%s:$PATH"\n' "$LLVM_BIN" >> "$rc_file" + echo "Added LLVM bin to PATH in $rc_file" + fi + done + + # Export for the current shell session so make targets work immediately. + export PATH="$LLVM_BIN:$PATH" else # Ubuntu / Debian — clang 15 is the highest version available on Ubuntu 22.04. apt-get install -y clang-format-15 clang-tidy-15 cmake diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py index 05873a486..7aea7582b 100644 --- a/scripts/generate_compile_commands.py +++ b/scripts/generate_compile_commands.py @@ -12,25 +12,25 @@ """ import json +import subprocess import sys import sysconfig from pathlib import Path +from torch.utils.cpp_extension import include_paths as torch_include_paths -def main() -> None: - try: - import torch # noqa: F401 — imported to verify it is installed - from torch.utils.cpp_extension import include_paths as torch_include_paths - except ImportError as exc: - print( - f"Error: {exc}\n" - "Run `make build_cpp_extensions` first to ensure torch is available.", - file=sys.stderr, - ) - sys.exit(1) +def main() -> None: repo_root = Path(__file__).parent.parent.resolve() + # Always rebuild C++ extensions before generating compile_commands.json so + # the database reflects the current state of the code. + subprocess.run( + [sys.executable, "setup.py", "build_ext", "--inplace"], + cwd=repo_root, + check=True, + ) + # Collect all include directories needed to compile the extension. # torch_include_paths() returns the torch headers, which already bundle # pybind11 under torch/include/pybind11/ — no separate pybind11 import needed. @@ -62,7 +62,9 @@ def main() -> None: output = repo_root / "build" / "compile_commands.json" output.parent.mkdir(exist_ok=True) output.write_text(json.dumps(commands, indent=2)) - print(f"Wrote {len(commands)} entr{'y' if len(commands) == 1 else 'ies'} to {output}") + print( + f"Wrote {len(commands)} entr{'y' if len(commands) == 1 else 'ies'} to {output}" + ) if __name__ == "__main__": From 9d3c8df496755b36f81e1fde551fd0238c3938be Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 20:00:20 +0000 Subject: [PATCH 05/16] Rename setup.py to build_cpp_extensions.py; add build_cpp_extensions make target --- Makefile | 3 +++ setup.py => build_cpp_extensions.py | 9 --------- gigl/scripts/post_install.py | 4 ++-- scripts/generate_compile_commands.py | 2 +- 4 files changed, 6 insertions(+), 12 deletions(-) rename setup.py => build_cpp_extensions.py (71%) diff --git a/Makefile b/Makefile index f13be1c18..7e215c650 100644 --- a/Makefile +++ b/Makefile @@ -161,6 +161,9 @@ format: format_py format_scala format_md format_cpp type_check: uv run mypy ${PYTHON_DIRS} --check-untyped-defs +build_cpp_extensions: + uv run --no-sync python build_cpp_extensions.py build_ext --inplace + lint_cpp: uv run python scripts/generate_compile_commands.py clang-tidy -p build/compile_commands.json $(CPP_SOURCES) diff --git a/setup.py b/build_cpp_extensions.py similarity index 71% rename from setup.py rename to build_cpp_extensions.py index a9678da1f..99ca11a6a 100644 --- a/setup.py +++ b/build_cpp_extensions.py @@ -11,20 +11,11 @@ def find_cpp_extensions() -> list[CppExtension]: extension. The module name is derived from the file path, so the extension is importable at the same location as its Python neighbours. - Example:: - - gigl/distributed/cpp_extensions/ppr_forward_push.cpp - → importable as ``gigl.distributed.cpp_extensions.ppr_forward_push`` - To add a new extension, drop a .cpp file anywhere under ``gigl/`` — no changes to this file required. """ extensions = [] for cpp_file in sorted(Path("gigl").rglob("*.cpp")): - # Convert path separators to dots and strip the .cpp suffix to get the - # fully-qualified Python module name, e.g.: - # gigl/distributed/cpp_extensions/ppr_forward_push.cpp - # → "gigl.distributed.cpp_extensions.ppr_forward_push" module_name = ".".join(cpp_file.with_suffix("").parts) extensions.append( CppExtension( diff --git a/gigl/scripts/post_install.py b/gigl/scripts/post_install.py index 918d62fb5..93d517f31 100644 --- a/gigl/scripts/post_install.py +++ b/gigl/scripts/post_install.py @@ -61,9 +61,9 @@ def main(): # Step 2: Build pybind11 C++ extensions in-place so they are importable # without requiring a separate `make build_cpp_extensions` call. - setup_py = repo_root / "setup.py" + setup_py = repo_root / "build_cpp_extensions.py" if setup_py.exists(): - cmd = f"cd {repo_root} && {sys.executable} setup.py build_ext --inplace" + cmd = f"cd {repo_root} && {sys.executable} build_cpp_extensions.py build_ext --inplace" try: print("Building C++ extensions...") result = run_command_and_stream_stdout(cmd) diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py index 7aea7582b..a3f77ef5f 100644 --- a/scripts/generate_compile_commands.py +++ b/scripts/generate_compile_commands.py @@ -26,7 +26,7 @@ def main() -> None: # Always rebuild C++ extensions before generating compile_commands.json so # the database reflects the current state of the code. subprocess.run( - [sys.executable, "setup.py", "build_ext", "--inplace"], + [sys.executable, "build_cpp_extensions.py", "build_ext", "--inplace"], cwd=repo_root, check=True, ) From ace71264cc046e4a9cef2ba43eb2c9b8276e79cd Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 20:02:34 +0000 Subject: [PATCH 06/16] Scope C++ extension discovery to gigl/cpp_extensions/ --- Makefile | 2 +- build_cpp_extensions.py | 12 ++++-------- scripts/generate_compile_commands.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 7e215c650..e0f05494c 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG?=${DOCKER_IMAGE_MAIN_CPU_NAME}:${DATE} DOCKER_IMAGE_DEV_WORKBENCH_NAME_WITH_TAG?=${DOCKER_IMAGE_DEV_WORKBENCH_NAME}:${DATE} PYTHON_DIRS:=.github/scripts examples gigl tests snapchat scripts -CPP_SOURCES:=$(shell find gigl -name "*.cpp") +CPP_SOURCES:=$(shell find gigl/cpp_extensions -name "*.cpp" 2>/dev/null) PY_TEST_FILES?="*_test.py" # You can override GIGL_TEST_DEFAULT_RESOURCE_CONFIG by setting it in your environment i.e. # adding `export GIGL_TEST_DEFAULT_RESOURCE_CONFIG=your_resource_config` to your shell config (~/.bashrc, ~/.zshrc, etc.) diff --git a/build_cpp_extensions.py b/build_cpp_extensions.py index 99ca11a6a..406f02a6e 100644 --- a/build_cpp_extensions.py +++ b/build_cpp_extensions.py @@ -5,17 +5,13 @@ def find_cpp_extensions() -> list[CppExtension]: - """Auto-discover pybind11 extension modules. + """Auto-discover pybind11 extension modules under ``gigl/cpp_extensions/``. - Any .cpp file anywhere under ``gigl/`` is compiled as a Python C++ - extension. The module name is derived from the file path, so the - extension is importable at the same location as its Python neighbours. - - To add a new extension, drop a .cpp file anywhere under ``gigl/`` — - no changes to this file required. + The module name is derived from the file path, so each extension is + importable at the Python path corresponding to its location. """ extensions = [] - for cpp_file in sorted(Path("gigl").rglob("*.cpp")): + for cpp_file in sorted(Path("gigl/cpp_extensions").rglob("*.cpp")): module_name = ".".join(cpp_file.with_suffix("").parts) extensions.append( CppExtension( diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py index a3f77ef5f..630febfd7 100644 --- a/scripts/generate_compile_commands.py +++ b/scripts/generate_compile_commands.py @@ -40,7 +40,7 @@ def main() -> None: # Python C API headers (e.g. Python.h) required by pybind11. include_flags.append(f"-I{sysconfig.get_path('include')}") - cpp_sources = sorted((repo_root / "gigl").rglob("*.cpp")) + cpp_sources = sorted((repo_root / "gigl" / "cpp_extensions").rglob("*.cpp")) if not cpp_sources: print("Warning: no .cpp files found under gigl/", file=sys.stderr) From 48be4ccbbcddea57f80e33664ef3ea6fea7fce75 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 20:05:53 +0000 Subject: [PATCH 07/16] Remove unnecessary existence check for build_cpp_extensions.py in post_install --- gigl/scripts/post_install.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/gigl/scripts/post_install.py b/gigl/scripts/post_install.py index 93d517f31..f7ee5c1cd 100644 --- a/gigl/scripts/post_install.py +++ b/gigl/scripts/post_install.py @@ -61,21 +61,17 @@ def main(): # Step 2: Build pybind11 C++ extensions in-place so they are importable # without requiring a separate `make build_cpp_extensions` call. - setup_py = repo_root / "build_cpp_extensions.py" - if setup_py.exists(): - cmd = f"cd {repo_root} && {sys.executable} build_cpp_extensions.py build_ext --inplace" - try: - print("Building C++ extensions...") - result = run_command_and_stream_stdout(cmd) - print("C++ extension build finished with return code:", result) - except subprocess.CalledProcessError as e: - print(f"Error building C++ extensions: {e}") - sys.exit(1) - except Exception as e: - print(f"Unexpected error building C++ extensions: {e}") - sys.exit(1) - else: - print(f"Warning: {setup_py} not found, skipping C++ extension build") + cmd = f"cd {repo_root} && {sys.executable} build_cpp_extensions.py build_ext --inplace" + try: + print("Building C++ extensions...") + result = run_command_and_stream_stdout(cmd) + print("C++ extension build finished with return code:", result) + except subprocess.CalledProcessError as e: + print(f"Error building C++ extensions: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error building C++ extensions: {e}") + sys.exit(1) return result From 1b153b7fdab0a3e7460664560f803c70ec31c6f0 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 20:29:30 +0000 Subject: [PATCH 08/16] Review fixes + adopt PyTorch csrc conventions for C++ layout --- .clang-tidy | 8 ------- Makefile | 6 ++--- build_cpp_extensions.py | 33 ++++++++++++++++++++++---- gigl/scripts/post_install.py | 16 ++++++------- requirements/install_cpp_deps.sh | 4 +++- scripts/generate_compile_commands.py | 5 ++-- tests/unit/cpp/CMakeLists.txt | 3 +-- tests/unit/cpp/infrastructure_test.cpp | 3 +-- 8 files changed, 47 insertions(+), 31 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 2c746e6b2..fe02e7100 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -27,11 +27,3 @@ WarningsAsErrors: "*" # Only apply header-file checks to GiGL-owned headers, not third-party # (torch, pybind11) headers pulled in by includes. HeaderFilterRegex: "gigl/.*" - -# Per-check configuration options. -CheckOptions: - # Enforce lower_case naming for local variables (matches existing style). - - key: readability-identifier-naming.LocalVariableCase - value: lower_case - - key: readability-identifier-naming.ParameterCase - value: lower_case diff --git a/Makefile b/Makefile index e0f05494c..86fea095e 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG?=${DOCKER_IMAGE_MAIN_CPU_NAME}:${DATE} DOCKER_IMAGE_DEV_WORKBENCH_NAME_WITH_TAG?=${DOCKER_IMAGE_DEV_WORKBENCH_NAME}:${DATE} PYTHON_DIRS:=.github/scripts examples gigl tests snapchat scripts -CPP_SOURCES:=$(shell find gigl/cpp_extensions -name "*.cpp" 2>/dev/null) +CPP_SOURCES:=$(shell find gigl/csrc -name "*.cpp" 2>/dev/null) PY_TEST_FILES?="*_test.py" # You can override GIGL_TEST_DEFAULT_RESOURCE_CONFIG by setting it in your environment i.e. # adding `export GIGL_TEST_DEFAULT_RESOURCE_CONFIG=your_resource_config` to your shell config (~/.bashrc, ~/.zshrc, etc.) @@ -117,7 +117,7 @@ check_format_md: uv run mdformat --check ${MD_FILES} check_format_cpp: - clang-format --dry-run --Werror --style=file $(CPP_SOURCES) + $(if $(CPP_SOURCES), clang-format --dry-run --Werror --style=file $(CPP_SOURCES)) check_format: check_format_py check_format_scala check_format_md check_format_cpp @@ -154,7 +154,7 @@ format_md: uv run mdformat ${MD_FILES} format_cpp: - clang-format -i --style=file $(CPP_SOURCES) + $(if $(CPP_SOURCES), clang-format -i --style=file $(CPP_SOURCES)) format: format_py format_scala format_md format_cpp diff --git a/build_cpp_extensions.py b/build_cpp_extensions.py index 406f02a6e..53c823333 100644 --- a/build_cpp_extensions.py +++ b/build_cpp_extensions.py @@ -1,18 +1,41 @@ +"""Build script for GiGL pybind11 C++ extensions. + +Invoked by ``make build_cpp_extensions`` and automatically during ``make install_dev_deps`` +via ``post_install.py``. Not a general-purpose setup.py — only builds C++ extensions. + +Usage:: + + python build_cpp_extensions.py build_ext --inplace +""" + from pathlib import Path from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension +_CSRC_DIR = Path("gigl/csrc") + def find_cpp_extensions() -> list[CppExtension]: - """Auto-discover pybind11 extension modules under ``gigl/cpp_extensions/``. + """Auto-discover pybind11 extension modules under ``gigl/csrc/``. + + Following PyTorch's csrc convention, only files named ``python_*.cpp`` are + compiled as Python extension modules. Pure C++ files (without the + ``python_`` prefix) are used only in C++ unit tests. + + The module name is derived from the file path with the ``python_`` prefix + stripped, so ``gigl/csrc/distributed/python_ppr_forward_push.cpp`` is + importable as ``gigl.csrc.distributed.ppr_forward_push``. - The module name is derived from the file path, so each extension is - importable at the Python path corresponding to its location. + Returns an empty list if ``gigl/csrc/`` does not yet exist. """ + if not _CSRC_DIR.exists(): + return [] extensions = [] - for cpp_file in sorted(Path("gigl/cpp_extensions").rglob("*.cpp")): - module_name = ".".join(cpp_file.with_suffix("").parts) + for cpp_file in sorted(_CSRC_DIR.rglob("python_*.cpp")): + parts = list(cpp_file.with_suffix("").parts) + parts[-1] = parts[-1].removeprefix("python_") + module_name = ".".join(parts) extensions.append( CppExtension( name=module_name, diff --git a/gigl/scripts/post_install.py b/gigl/scripts/post_install.py index f7ee5c1cd..677e45a6c 100644 --- a/gigl/scripts/post_install.py +++ b/gigl/scripts/post_install.py @@ -52,20 +52,22 @@ def main(): print(f"Executing bash {install_glt_script}...") result = run_command_and_stream_stdout(f"bash {install_glt_script}") print("GLT install finished with return code:", result) - except subprocess.CalledProcessError as e: - print(f"Error running install_glt.sh: {e}") - sys.exit(1) except Exception as e: print(f"Unexpected error: {e}") sys.exit(1) # Step 2: Build pybind11 C++ extensions in-place so they are importable # without requiring a separate `make build_cpp_extensions` call. - cmd = f"cd {repo_root} && {sys.executable} build_cpp_extensions.py build_ext --inplace" + # subprocess.run streams stdout/stderr to the terminal and raises + # CalledProcessError on a non-zero exit code. try: print("Building C++ extensions...") - result = run_command_and_stream_stdout(cmd) - print("C++ extension build finished with return code:", result) + subprocess.run( + [sys.executable, "build_cpp_extensions.py", "build_ext", "--inplace"], + cwd=repo_root, + check=True, + ) + print("C++ extension build finished.") except subprocess.CalledProcessError as e: print(f"Error building C++ extensions: {e}") sys.exit(1) @@ -73,8 +75,6 @@ def main(): print(f"Unexpected error building C++ extensions: {e}") sys.exit(1) - return result - if __name__ == "__main__": main() diff --git a/requirements/install_cpp_deps.sh b/requirements/install_cpp_deps.sh index 7675a0396..6b9f19113 100644 --- a/requirements/install_cpp_deps.sh +++ b/requirements/install_cpp_deps.sh @@ -30,7 +30,9 @@ if is_running_on_mac; then fi done - # Export for the current shell session so make targets work immediately. + # NOTE: this export only affects subprocesses of this script, not the calling + # shell or make process. Open a new terminal (or run `source ~/.zshrc`) after + # install_dev_deps to pick up the PATH change. export PATH="$LLVM_BIN:$PATH" else # Ubuntu / Debian — clang 15 is the highest version available on Ubuntu 22.04. diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py index 630febfd7..767a4860b 100644 --- a/scripts/generate_compile_commands.py +++ b/scripts/generate_compile_commands.py @@ -40,9 +40,10 @@ def main() -> None: # Python C API headers (e.g. Python.h) required by pybind11. include_flags.append(f"-I{sysconfig.get_path('include')}") - cpp_sources = sorted((repo_root / "gigl" / "cpp_extensions").rglob("*.cpp")) + cpp_dir = repo_root / "gigl" / "csrc" + cpp_sources = sorted(cpp_dir.rglob("*.cpp")) if cpp_dir.exists() else [] if not cpp_sources: - print("Warning: no .cpp files found under gigl/", file=sys.stderr) + print("Warning: no .cpp files found under gigl/csrc/", file=sys.stderr) # Each entry in compile_commands.json describes how one source file is compiled. # clang-tidy reads this to reproduce the exact compilation environment. diff --git a/tests/unit/cpp/CMakeLists.txt b/tests/unit/cpp/CMakeLists.txt index fd55255b2..514625587 100644 --- a/tests/unit/cpp/CMakeLists.txt +++ b/tests/unit/cpp/CMakeLists.txt @@ -12,8 +12,7 @@ include(FetchContent) FetchContent_Declare( googletest URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz - # TODO: pin URL_HASH once infrastructure is validated, e.g.: - # URL_HASH SHA256= + URL_HASH SHA256=8372520bab45e12a97cd2f49dad36d07e55ddb89c0e39fad4a1a64cab0bdf35d ) # Prevent GoogleTest from overriding the compiler's runtime on Windows # (no-op on Linux/Mac, but required for portable CMake config). diff --git a/tests/unit/cpp/infrastructure_test.cpp b/tests/unit/cpp/infrastructure_test.cpp index d8d252f8e..af85a6660 100644 --- a/tests/unit/cpp/infrastructure_test.cpp +++ b/tests/unit/cpp/infrastructure_test.cpp @@ -1,8 +1,7 @@ // Placeholder C++ unit test. // // This file exists to verify that the GoogleTest infrastructure compiles and -// runs end-to-end. Replace or supplement it with tests for actual GiGL C++ -// code (e.g. PPRForwardPushState) as those components are added. +// runs end-to-end. #include From 03ed8c458c3478f981e9c8b85805af79c8c607ff Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 21:52:34 +0000 Subject: [PATCH 09/16] Add multi-source C++ ext support, gigl/csrc package init, and .so gitignore --- .gitignore | 3 +++ build_cpp_extensions.py | 10 +++++++++- gigl/csrc/__init__.py | 0 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 gigl/csrc/__init__.py diff --git a/.gitignore b/.gitignore index e06a57bc8..da3c883c5 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,6 @@ fossa*.zip # https://github.com/google-github-actions/auth/issues/497 gha-creds-*.json + +# Compiled C++ extension modules +gigl/csrc/**/*.so diff --git a/build_cpp_extensions.py b/build_cpp_extensions.py index 53c823333..a0095464a 100644 --- a/build_cpp_extensions.py +++ b/build_cpp_extensions.py @@ -27,6 +27,10 @@ def find_cpp_extensions() -> list[CppExtension]: stripped, so ``gigl/csrc/distributed/python_ppr_forward_push.cpp`` is importable as ``gigl.csrc.distributed.ppr_forward_push``. + If a matching implementation file exists alongside the binding file (e.g. + ``ppr_forward_push.cpp`` next to ``python_ppr_forward_push.cpp``), it is + compiled into the same extension module. + Returns an empty list if ``gigl/csrc/`` does not yet exist. """ if not _CSRC_DIR.exists(): @@ -36,10 +40,14 @@ def find_cpp_extensions() -> list[CppExtension]: parts = list(cpp_file.with_suffix("").parts) parts[-1] = parts[-1].removeprefix("python_") module_name = ".".join(parts) + impl_file = cpp_file.parent / (parts[-1] + ".cpp") + sources = [str(cpp_file)] + if impl_file.exists(): + sources.append(str(impl_file)) extensions.append( CppExtension( name=module_name, - sources=[str(cpp_file)], + sources=sources, extra_compile_args=["-O3", "-std=c++17", "-Wall", "-Wextra"], ) ) diff --git a/gigl/csrc/__init__.py b/gigl/csrc/__init__.py new file mode 100644 index 000000000..e69de29bb From 638e667f2c346286ffe64f5f2c3e55cefa14c3d0 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 25 Mar 2026 22:31:53 +0000 Subject: [PATCH 10/16] Move build_cpp_extensions.py to scripts/ and wire into relevant make targets --- Makefile | 10 +++++----- gigl/scripts/post_install.py | 2 +- .../build_cpp_extensions.py | 11 +---------- scripts/generate_compile_commands.py | 2 +- 4 files changed, 8 insertions(+), 17 deletions(-) rename build_cpp_extensions.py => scripts/build_cpp_extensions.py (73%) diff --git a/Makefile b/Makefile index 86fea095e..89416cfa7 100644 --- a/Makefile +++ b/Makefile @@ -77,7 +77,7 @@ assert_yaml_configs_parse: # Ex. `make unit_test_py PY_TEST_FILES="eval_metrics_test.py"` # By default, runs all tests under tests/unit. # See the help text for "--test_file_pattern" in tests/test_args.py for more details. -unit_test_py: clean_build_files_py type_check +unit_test_py: clean_build_files_py type_check build_cpp_extensions uv run python -m tests.unit.main \ --env=test \ --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ @@ -125,7 +125,7 @@ check_format: check_format_py check_format_scala check_format_md check_format_cp # Ex. `make integration_test PY_TEST_FILES="dataflow_test.py"` # By default, runs all tests under tests/integration. # See the help text for "--test_file_pattern" in tests/test_args.py for more details. -integration_test: +integration_test: build_cpp_extensions uv run python -m tests.integration.main \ --env=test \ --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ @@ -162,7 +162,7 @@ type_check: uv run mypy ${PYTHON_DIRS} --check-untyped-defs build_cpp_extensions: - uv run --no-sync python build_cpp_extensions.py build_ext --inplace + uv run --no-sync python scripts/build_cpp_extensions.py build_ext --inplace lint_cpp: uv run python scripts/generate_compile_commands.py @@ -282,7 +282,7 @@ run_all_e2e_tests: # Example: # `make compiled_pipeline_path="/tmp/gigl/my_pipeline.yaml" compile_gigl_kubeflow_pipeline` # Can be a GCS URI as well -compile_gigl_kubeflow_pipeline: compile_jars push_new_docker_images +compile_gigl_kubeflow_pipeline: build_cpp_extensions compile_jars push_new_docker_images uv run python -m gigl.orchestration.kubeflow.runner \ --action=compile \ --container_image_cuda=${DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG} \ @@ -308,7 +308,7 @@ _skip_build_deps: # job_name=... \ , and other params # compiled_pipeline_path="/tmp/gigl/my_pipeline.yaml" \ # run_dev_gnn_kubeflow_pipeline -run_dev_gnn_kubeflow_pipeline: $(if $(compiled_pipeline_path), _skip_build_deps, compile_jars push_new_docker_images) +run_dev_gnn_kubeflow_pipeline: $(if $(compiled_pipeline_path), _skip_build_deps, build_cpp_extensions compile_jars push_new_docker_images) uv run python -m gigl.orchestration.kubeflow.runner \ $(if $(compiled_pipeline_path),,--container_image_cuda=${DOCKER_IMAGE_MAIN_CUDA_NAME_WITH_TAG}) \ $(if $(compiled_pipeline_path),,--container_image_cpu=${DOCKER_IMAGE_MAIN_CPU_NAME_WITH_TAG}) \ diff --git a/gigl/scripts/post_install.py b/gigl/scripts/post_install.py index 677e45a6c..eed1528f6 100644 --- a/gigl/scripts/post_install.py +++ b/gigl/scripts/post_install.py @@ -63,7 +63,7 @@ def main(): try: print("Building C++ extensions...") subprocess.run( - [sys.executable, "build_cpp_extensions.py", "build_ext", "--inplace"], + [sys.executable, "scripts/build_cpp_extensions.py", "build_ext", "--inplace"], cwd=repo_root, check=True, ) diff --git a/build_cpp_extensions.py b/scripts/build_cpp_extensions.py similarity index 73% rename from build_cpp_extensions.py rename to scripts/build_cpp_extensions.py index a0095464a..d05740860 100644 --- a/build_cpp_extensions.py +++ b/scripts/build_cpp_extensions.py @@ -20,16 +20,7 @@ def find_cpp_extensions() -> list[CppExtension]: """Auto-discover pybind11 extension modules under ``gigl/csrc/``. Following PyTorch's csrc convention, only files named ``python_*.cpp`` are - compiled as Python extension modules. Pure C++ files (without the - ``python_`` prefix) are used only in C++ unit tests. - - The module name is derived from the file path with the ``python_`` prefix - stripped, so ``gigl/csrc/distributed/python_ppr_forward_push.cpp`` is - importable as ``gigl.csrc.distributed.ppr_forward_push``. - - If a matching implementation file exists alongside the binding file (e.g. - ``ppr_forward_push.cpp`` next to ``python_ppr_forward_push.cpp``), it is - compiled into the same extension module. + compiled as Python extension modules. Returns an empty list if ``gigl/csrc/`` does not yet exist. """ diff --git a/scripts/generate_compile_commands.py b/scripts/generate_compile_commands.py index 767a4860b..eec176848 100644 --- a/scripts/generate_compile_commands.py +++ b/scripts/generate_compile_commands.py @@ -26,7 +26,7 @@ def main() -> None: # Always rebuild C++ extensions before generating compile_commands.json so # the database reflects the current state of the code. subprocess.run( - [sys.executable, "build_cpp_extensions.py", "build_ext", "--inplace"], + [sys.executable, "scripts/build_cpp_extensions.py", "build_ext", "--inplace"], cwd=repo_root, check=True, ) From 416f6b4dd01079667d82660eac072eab0b247f30 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 30 Mar 2026 22:33:43 +0000 Subject: [PATCH 11/16] Initial commit --- gigl/distributed/dist_dataset.py | 24 +-- gigl/distributed/dist_ppr_sampler.py | 74 +++------ gigl/distributed/dist_sampling_producer.py | 8 +- gigl/distributed/sampler_options.py | 6 +- gigl/distributed/utils/degree.py | 147 +++++++++++++----- .../graph_transformer/graph_transformer.py | 3 + tests/unit/distributed/utils/degree_test.py | 58 ++++--- 7 files changed, 198 insertions(+), 122 deletions(-) diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index f495a6d96..20436f4bc 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -78,7 +78,7 @@ def __init__( Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ] = None, degree_tensor: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = None, ) -> None: """ @@ -104,7 +104,7 @@ def __init__( Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case. edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. - degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. + degree_tensor: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. """ self._rank: int = rank self._world_size: int = world_size @@ -141,7 +141,7 @@ def __init__( self._edge_feature_info = edge_feature_info self._degree_tensor: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = degree_tensor # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear @@ -300,7 +300,7 @@ def edge_feature_info( @property def degree_tensor( self, - ) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """ Lazily compute and return the degree tensor for the graph. @@ -308,15 +308,19 @@ def degree_tensor( all-reduce to aggregate across all machines. Requires torch.distributed to be initialized. + For heterogeneous graphs, degrees are summed across all edge types sharing + the same anchor node type (determined by ``self._edge_dir``), yielding one + ``int16`` tensor per node type rather than one per edge type. + Over-counting correction (for processes sharing the same data on the same machine) is handled automatically by detecting the distributed topology. The result is cached for subsequent accesses. Returns: - Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensor. + Union[torch.Tensor, dict[NodeType, torch.Tensor]]: The aggregated degree tensor. - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping EdgeType to degree tensors. + - For heterogeneous graphs: A dict mapping NodeType to ``int16`` degree tensors. Raises: RuntimeError: If torch.distributed is not initialized. @@ -326,7 +330,9 @@ def degree_tensor( if self.graph is None: raise ValueError("Dataset graph is None. Cannot compute degrees.") - self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) + self._degree_tensor = compute_and_broadcast_degree_tensor( + self.graph, edge_dir=self._edge_dir + ) return self._degree_tensor @property @@ -857,7 +863,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]], Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]], Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], ]: """ Serializes the member variables of the DistDatasetClass @@ -879,7 +885,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous. Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous - Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous + Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 17673a72d..34cb1661b 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -103,7 +103,7 @@ def __init__( max_ppr_nodes: int = 50, num_neighbors_per_hop: int = 100_000, total_degree_dtype: torch.dtype = torch.int32, - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], **kwargs, ): super().__init__(*args, **kwargs) @@ -144,69 +144,45 @@ def __init__( ] self._is_homogeneous = True - # Precompute total degree per node type: the sum of degrees across all - # edge types traversable from that node type. This is a graph-level - # property used on every PPR iteration, so computing it once at init - # avoids per-node summation and cache lookups in the hot loop. - # TODO (mkolodner-sc): This trades memory for throughput — we - # materialize a tensor per node type to avoid recomputing total degree - # on every neighbor during sampling. Computing it here (rather than in - # the dataset) also keeps the door open for edge-specific degree - # strategies. If memory becomes a bottleneck, revisit this. + # The dataset pre-aggregates degrees per node type as int16. + # _build_total_degree_tensors reindexes by node type with no further casting. self._node_type_to_total_degree: dict[ NodeType, torch.Tensor - ] = self._build_total_degree_tensors(degree_tensors, total_degree_dtype) + ] = self._build_total_degree_tensors(degree_tensors) def _build_total_degree_tensors( self, - degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], - dtype: torch.dtype, + degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], ) -> dict[NodeType, torch.Tensor]: - """Build total-degree tensors by summing per-edge-type degrees for each node type. + """Reindex pre-aggregated per-node-type degree tensors for use in the PPR loop. - For homogeneous graphs, the total degree is just the single degree tensor. - For heterogeneous graphs, it sums degree tensors across all edge types - traversable from each node type, padding shorter tensors with zeros. + The dataset provides degrees already summed across edge types per node type + (as ``int16``). This method maps those tensors into the internal node-type + keying used by the sampler, with no further summation or dtype casting. Args: - degree_tensors: Per-edge-type degree tensors from the dataset. - dtype: Dtype for the output tensors. + degree_tensors: Per-node-type degree tensors from the dataset, already + aggregated across edge types. Returns: - Dict mapping node type to a 1-D tensor of total degrees. - """ - result: dict[NodeType, torch.Tensor] = {} + Dict mapping node type to a 1-D degree tensor. + Raises: + ValueError: If a required node type is absent from ``degree_tensors``. + """ if self._is_homogeneous: assert isinstance(degree_tensors, torch.Tensor) - # Single edge type: degree values fit directly in the target dtype. - result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) - else: - assert isinstance(degree_tensors, dict) - dtype_max = torch.iinfo(dtype).max - for node_type, edge_types in self._node_type_to_edge_types.items(): - max_len = 0 - for et in edge_types: - if et not in degree_tensors: - raise ValueError( - f"Edge type {et} not found in degree tensors. " - f"Available: {list(degree_tensors.keys())}" - ) - max_len = max(max_len, len(degree_tensors[et])) - - # Each degree tensor is indexed by node ID (derived from CSR - # indptr), so index i in every edge type's tensor refers to - # the same node. Element-wise summation gives the total degree - # per node across all edge types. Shorter tensors are padded - # implicitly (only the first len(et_degrees) entries are added). - # Sum in int64: aggregate degrees are bounded by partition size - # and fit comfortably within int64 range in practice. - summed = torch.zeros(max_len, dtype=torch.int64) - for et in edge_types: - et_degrees = degree_tensors[et] - summed[: len(et_degrees)] += et_degrees.to(torch.int64) - result[node_type] = summed.clamp(max=dtype_max).to(dtype) + return {_PPR_HOMOGENEOUS_NODE_TYPE: degree_tensors} + assert isinstance(degree_tensors, dict) + result: dict[NodeType, torch.Tensor] = {} + for node_type in self._node_type_to_edge_types: + if node_type not in degree_tensors: + raise ValueError( + f"Node type '{node_type}' not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + result[node_type] = degree_tensors[node_type] return result def _get_total_degree(self, node_id: int, node_type: NodeType) -> int: diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index e9c8cc0b8..3f630183e 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -29,7 +29,7 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType +from graphlearn_torch.typing import EdgeType, NodeType from graphlearn_torch.utils import seed_everything from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader @@ -60,7 +60,7 @@ def _sampling_worker_loop( sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], ): dist_sampler = None try: @@ -224,14 +224,14 @@ def init(self): # where torch.distributed IS initialized — lets the tensor be shared # to workers via IPC. degree_tensors: Optional[ - Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + Union[torch.Tensor, dict[NodeType, torch.Tensor]] ] = None if isinstance(self._sampler_options, PPRSamplerOptions): assert isinstance(self.data, GiglDistDataset) degree_tensors = self.data.degree_tensor if isinstance(degree_tensors, dict): logger.info( - f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} edge types." + f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} node types." ) else: logger.info( diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index d87a83d52..b3165b2af 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -56,9 +56,9 @@ class PPRSamplerOptions: num_neighbors_per_hop: Maximum number of neighbors fetched per node per edge type during PPR traversal. Set large to approximate fetching all neighbors. - total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults - to ``torch.int32``, which supports total degrees up to ~2 billion. - Use a larger dtype if nodes have exceptionally high aggregate degrees. + total_degree_dtype: Retained for backwards compatibility; currently unused. + Degree tensors are stored as ``int16`` and no dtype conversion is + applied in the sampler. """ alpha: float = 0.5 diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 0785a68a7..188dfa4de 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -27,38 +27,46 @@ import torch from graphlearn_torch.data import Graph -from torch_geometric.typing import EdgeType +from torch_geometric.typing import EdgeType, NodeType from gigl.common.logger import Logger from gigl.distributed.utils.device import get_device_from_process_group from gigl.distributed.utils.networking import get_internal_ip_from_all_ranks +from gigl.types.graph import is_label_edge_type logger = Logger() def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], -) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + edge_dir: str, +) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """ Compute node degrees from a graph and aggregate across all machines. - Computes degrees from the CSR row pointers (indptr) and performs all-reduce - to aggregate across ranks. + For heterogeneous graphs, degrees are summed across all edge types sharing + the same anchor node type (source for ``edge_dir="out"``, destination for + ``edge_dir="in"``), then all-reduced across ranks. The result is one + ``int16`` tensor per anchor node type rather than one per edge type, + reducing both stored memory and the number of all-reduce calls. Over-counting correction (for processes sharing the same data) is handled automatically by detecting the distributed topology. Args: graph: A Graph (homogeneous) or dict[EdgeType, Graph] (heterogeneous). + edge_dir: ``"out"`` to group by source node type (out-degree); + ``"in"`` to group by destination node type (in-degree). Returns: - Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensors. + Union[torch.Tensor, dict[NodeType, torch.Tensor]]: The aggregated degree tensors. - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping EdgeType to degree tensors. + - For heterogeneous graphs: A dict mapping NodeType to ``int16`` degree tensors, + where each entry is the total degree across all edge types for that anchor node type. Raises: RuntimeError: If torch.distributed is not initialized. - ValueError: If topology is unavailable. + ValueError: If topology is unavailable for a homogeneous graph. """ if not torch.distributed.is_initialized(): raise RuntimeError( @@ -71,20 +79,10 @@ def compute_and_broadcast_degree_tensor( if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") local_degrees: Union[ - torch.Tensor, dict[EdgeType, torch.Tensor] + torch.Tensor, dict[NodeType, torch.Tensor] ] = _compute_degrees_from_indptr(topo.indptr) else: - local_dict: dict[EdgeType, torch.Tensor] = {} - for edge_type, edge_graph in graph.items(): - topo = edge_graph.topo - if topo is None or topo.indptr is None: - logger.warning( - f"Topology/indptr not available for edge type {edge_type}, using empty tensor." - ) - local_dict[edge_type] = torch.empty(0, dtype=torch.int16) - else: - local_dict[edge_type] = _compute_degrees_from_indptr(topo.indptr) - local_degrees = local_dict + local_degrees = _compute_hetero_degrees_by_node_type(graph, edge_dir) # All-reduce across ranks (over-counting correction handled internally) result = _all_reduce_degrees(local_degrees) @@ -98,19 +96,90 @@ def compute_and_broadcast_degree_tensor( else: logger.info("Graph contained 0 nodes when computing degrees") else: - for edge_type, degrees in result.items(): + for node_type, degrees in result.items(): if degrees.numel() > 0: logger.info( - f"{edge_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" + f"{node_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" ) else: logger.info( - f"Graph contained 0 nodes for edge type {edge_type} when computing degrees" + f"Graph contained 0 nodes for node type {node_type} when computing degrees" ) return result +def _compute_hetero_degrees_by_node_type( + graph: dict[EdgeType, Graph], + edge_dir: str, +) -> dict[NodeType, torch.Tensor]: + """Sum per-edge-type degrees into per-anchor-node-type totals. + + Label edge types (ABLP supervision edges) are excluded — they should not + contribute to degree counts used in PPR, matching the sampler's own exclusion. + + Uses a two-pass approach to avoid dynamic accumulator resizing: + + - Pass 1: compute ``int32`` degrees per edge type and record the maximum + node count per anchor type (to pre-size the accumulator). + - Pass 2: allocate one ``int32`` accumulator per anchor type and sum. + + Returns ``int32`` tensors; the caller is responsible for clamping to + ``int16`` after the all-reduce (see ``_all_reduce_degrees``). + + All anchor node types derived from non-label ``graph.keys()`` are present in + the output (with size-0 tensors for types whose ``indptr`` is unavailable), + ensuring symmetric participation in the subsequent all-reduce. + + Args: + graph: Heterogeneous graph mapping EdgeType to Graph. + edge_dir: ``"out"`` to anchor on source node type; ``"in"`` to anchor + on destination node type. + + Returns: + dict[NodeType, torch.Tensor]: ``int32`` total-degree tensor per anchor node type. + """ + # All anchor node types must appear in the output so the all_reduce is + # symmetric across ranks, even if this rank has no edges for a given type. + # Label edge types are excluded — they represent ABLP supervision edges, not + # structural neighbors, and should not contribute to degree counts used in PPR. + anchor_ntypes: set[NodeType] = { + etype[-1] if edge_dir == "in" else etype[0] + for etype in graph.keys() + if not is_label_edge_type(etype) + } + + # Pass 1: compute int32 degrees per valid edge type and track max node + # count per anchor type. int32 is sufficient: per-node degrees are always + # well below int32 max for any realistic graph. + et_degrees_i32: dict[EdgeType, torch.Tensor] = {} + max_sizes: dict[NodeType, int] = {ntype: 0 for ntype in anchor_ntypes} + for edge_type, edge_graph in graph.items(): + if is_label_edge_type(edge_type): + continue + anchor_ntype: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] + topo = edge_graph.topo + if topo is None or topo.indptr is None: + logger.warning( + f"Topology/indptr not available for edge type {edge_type}, skipping." + ) + continue + degrees_i32 = (topo.indptr[1:] - topo.indptr[:-1]).to(torch.int32) + et_degrees_i32[edge_type] = degrees_i32 + max_sizes[anchor_ntype] = max(max_sizes[anchor_ntype], len(degrees_i32)) + + # Pass 2: allocate accumulators once (sized from pass 1) and sum. + accumulators: dict[NodeType, torch.Tensor] = { + ntype: torch.zeros(size, dtype=torch.int32) + for ntype, size in max_sizes.items() + } + for edge_type, degrees_i32 in et_degrees_i32.items(): + anchor_ntype = edge_type[-1] if edge_dir == "in" else edge_type[0] + accumulators[anchor_ntype][: len(degrees_i32)] += degrees_i32 + + return accumulators + + def _pad_to_size(tensor: torch.Tensor, target_size: int) -> torch.Tensor: """Pad tensor with zeros to reach target_size.""" if tensor.size(0) >= target_size: @@ -135,12 +204,12 @@ def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: def _all_reduce_degrees( - local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], -) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: + local_degrees: Union[torch.Tensor, dict[NodeType, torch.Tensor]], +) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """All-reduce degree tensors across ranks, handling both homogeneous and heterogeneous cases. - For heterogeneous graphs, iterates over the edge types in local_degrees. All partitions - are expected to have entries for all edge types (even if some have empty tensors). + For heterogeneous graphs, iterates over the node types in local_degrees. All partitions + are expected to have entries for all node types (even if some have empty tensors). Moves tensors to GPU for the all-reduce if using NCCL backend (which requires CUDA), otherwise keeps tensors on CPU (for Gloo backend). @@ -164,12 +233,12 @@ def _all_reduce_degrees( IP addresses, then divides by that count to correct the over-counting. Args: - local_degrees: Either a single tensor (homogeneous) or dict mapping EdgeType + local_degrees: Either a single tensor (homogeneous) or dict mapping NodeType to tensors (heterogeneous). For heterogeneous graphs, all partitions must - have entries for all edge types. + have entries for all node types. Returns: - Aggregated degree tensors in the same format as input. + Aggregated degree tensors in the same format as input, stored as ``int16``. Raises: RuntimeError: If torch.distributed is not initialized. @@ -195,22 +264,24 @@ def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) max_size = int(local_size.item()) - # Pad, convert to int64 (all_reduce doesn't support int16), move to device - padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) + # Pad and convert to int32 for the all-reduce (int16 is not supported). + # int32 is sufficient: the raw sum across W ranks is at most W * max_degree, + # which fits comfortably in int32 for any realistic world size and degree. + padded = _pad_to_size(tensor, max_size).to(torch.int32).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - # Correct for over-counting, move back to CPU, and clamp to int16 - # TODO (mkolodner-sc): Potentially want to paramaterize this in the future if we want degrees higher than the int16 max. + # Correct for over-counting, move back to CPU, and clamp to int16. + # TODO (mkolodner-sc): Potentially want to parameterize this in the future if we want degrees higher than the int16 max. return _clamp_to_int16((padded // local_world_size).cpu()) # Homogeneous case if isinstance(local_degrees, torch.Tensor): return reduce_tensor(local_degrees) - # Heterogeneous case: all-reduce each edge type - # Sort edge types for deterministic ordering across ranks - result: dict[EdgeType, torch.Tensor] = {} - for edge_type in sorted(local_degrees.keys()): - result[edge_type] = reduce_tensor(local_degrees[edge_type]) + # Heterogeneous case: all-reduce each node type. + # Sort node types for deterministic ordering across ranks. + result: dict[NodeType, torch.Tensor] = {} + for node_type in sorted(local_degrees.keys()): + result[node_type] = reduce_tensor(local_degrees[node_type]) return result diff --git a/gigl/src/common/models/graph_transformer/graph_transformer.py b/gigl/src/common/models/graph_transformer/graph_transformer.py index 4308b495f..d8e1a7370 100644 --- a/gigl/src/common/models/graph_transformer/graph_transformer.py +++ b/gigl/src/common/models/graph_transformer/graph_transformer.py @@ -547,6 +547,9 @@ def __init__( self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + # Explicit annotation so mypy can narrow past the None check below; + # register_buffer sets the value at runtime. + self._sequence_positional_encoding_table: Optional[torch.Tensor] if self._sequence_positional_encoding_type == "sinusoidal": self.register_buffer( "_sequence_positional_encoding_table", diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 6836e4581..16f334509 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -60,7 +60,7 @@ def test_homogeneous_graph(self): dataset = create_homogeneous_dataset(edge_index=edge_index) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") assert isinstance(result, torch.Tensor) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) @@ -68,20 +68,28 @@ def test_homogeneous_graph(self): self.assert_tensor_equality(result, expected) def test_heterogeneous_graph(self): - """Test degree computation for a heterogeneous graph.""" + """Test degree computation for a heterogeneous graph. + + Result is keyed by anchor node type (source for edge_dir="out"), not edge type. + Degrees are summed across all edge types sharing the same anchor node type. + """ edge_indices = DEFAULT_HETEROGENEOUS_EDGE_INDICES dataset = create_heterogeneous_dataset(edge_indices=edge_indices) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") assert isinstance(result, dict) - self.assertEqual(set(result.keys()), set(edge_indices.keys())) + expected_node_types = {etype[0] for etype in edge_indices.keys()} + self.assertEqual(set(result.keys()), expected_node_types) + # For the default test data each source node type maps to exactly one edge type, + # so per-node-type degrees equal per-edge-type degrees. for edge_type, edge_index in edge_indices.items(): + src_node_type = edge_type[0] num_nodes = int(edge_index[0].max().item() + 1) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[edge_type], expected) + self.assert_tensor_equality(result[src_node_type], expected) def test_heterogeneous_graph_with_missing_topology(self): """Test that edge types with missing topology get empty tensors. @@ -113,16 +121,18 @@ def test_heterogeneous_graph_with_missing_topology(self): # Manually set one graph's topology to None to test the edge case dataset.graph[edge_type_without_topo].topo = None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") assert isinstance(result, dict) - self.assertEqual(set(result.keys()), set(edge_types)) + # Keys are source node types (anchor for edge_dir="out"), not edge types. + expected_node_types = {etype[0] for etype in edge_types} + self.assertEqual(set(result.keys()), expected_node_types) - # Edge type with topology should have computed degrees - self.assert_tensor_equality(result[edge_type_with_topo], expected_degrees) + # Edge type with topology: anchor node type should have computed degrees. + self.assert_tensor_equality(result[edge_type_with_topo[0]], expected_degrees) - # Edge type without topology should have empty tensor - self.assertEqual(result[edge_type_without_topo].numel(), 0) + # Edge type without topology: anchor node type should have empty tensor. + self.assertEqual(result[edge_type_without_topo[0]].numel(), 0) def _run_local_world_size_correction_homogeneous( @@ -142,7 +152,7 @@ def _run_local_world_size_correction_homogeneous( try: dataset = create_homogeneous_dataset(edge_index=edge_index) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") assert isinstance(result, torch.Tensor) assert_tensor_equality(result, expected_degrees) @@ -157,7 +167,10 @@ def _run_local_world_size_correction_heterogeneous( edge_indices: dict, expected_degrees: dict, ) -> None: - """Worker function for multi-process local_world_size correction test (heterogeneous).""" + """Worker function for multi-process local_world_size correction test (heterogeneous). + + expected_degrees must be keyed by node type (anchor for edge_dir="out"). + """ dist.init_process_group( backend="gloo", init_method=init_method, @@ -167,12 +180,12 @@ def _run_local_world_size_correction_heterogeneous( try: dataset = create_heterogeneous_dataset(edge_indices=edge_indices) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph) + result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") assert isinstance(result, dict) assert set(result.keys()) == set(expected_degrees.keys()) - for edge_type, expected in expected_degrees.items(): - assert_tensor_equality(result[edge_type], expected) + for node_type, expected in expected_degrees.items(): + assert_tensor_equality(result[node_type], expected) finally: dist.destroy_process_group() @@ -204,13 +217,16 @@ def test_local_world_size_correction_heterogeneous(self): """Test over-counting correction for heterogeneous graphs with 2 processes.""" edge_indices = DEFAULT_HETEROGENEOUS_EDGE_INDICES + # Build expected degrees keyed by source node type (anchor for edge_dir="out"). + # For the default test data each source node type maps to exactly one edge type. expected_degrees = {} for edge_type, edge_index in edge_indices.items(): + src_node_type = edge_type[0] num_nodes = int(edge_index[0].max().item() + 1) raw_degrees = _compute_expected_degrees_from_edge_index( edge_index, num_nodes ) - expected_degrees[edge_type] = raw_degrees + expected_degrees[src_node_type] = raw_degrees init_method = get_process_group_init_method() mp.spawn( @@ -263,12 +279,16 @@ def test_degree_tensor_heterogeneous(self): result = dataset.degree_tensor assert isinstance(result, dict) - self.assertEqual(set(result.keys()), set(edge_indices.keys())) + # degree_tensor is keyed by node type (anchor for dataset's edge_dir="out"), not edge type. + expected_node_types = {etype[0] for etype in edge_indices.keys()} + self.assertEqual(set(result.keys()), expected_node_types) + # For the default test data each source node type maps to exactly one edge type. for edge_type, edge_index in edge_indices.items(): + src_node_type = edge_type[0] num_nodes = int(edge_index[0].max().item() + 1) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[edge_type], expected) + self.assert_tensor_equality(result[src_node_type], expected) class TestHelperFunctions(TestCase): From 91d99d3a6ce6e4c4e0cfadf4e796356cec3ed101 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 31 Mar 2026 16:59:07 +0000 Subject: [PATCH 12/16] Update --- gigl/distributed/base_dist_loader.py | 7 +++++++ gigl/distributed/dist_sampling_producer.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index ed0783ef6..4823813f1 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -566,6 +566,13 @@ def _init_colocated_connections( self._channel = producer.output_channel self._mp_producer = producer + # Run distributed collectives (e.g. degree tensor all_reduce for PPR) + # BEFORE the staggered sleep. Collectives require all ranks to + # participate simultaneously; sleeping first would cause all ranks to + # exit the collective together and then proceed to worker spawning at + # the same time, nullifying the stagger. + producer.pre_init_collectives() + # Staggered init — sleep proportional to local_rank to avoid # concurrent initialization spikes that cause CPU memory OOM. logger.info( diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 3f630183e..1294d2d4d 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -216,6 +216,24 @@ def __init__( super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options + def pre_init_collectives(self) -> None: + """Run distributed collectives that must complete before the staggered init sleep. + + ``init()`` calls ``self.data.degree_tensor``, which triggers a + ``torch.distributed.all_reduce``. Collectives require all ranks to + participate simultaneously. If that collective runs inside ``init()`` + (after the staggered sleep), every rank exits the collective at the same + instant and proceeds to worker spawning together, negating the stagger. + + Call this before the staggered sleep so all ranks complete the collective + together, then the sleep staggering applies only to the expensive worker- + spawn step. The result is cached on the dataset; ``init()`` reuses it + without re-running the collective. + """ + if isinstance(self._sampler_options, PPRSamplerOptions): + assert isinstance(self.data, GiglDistDataset) + _ = self.data.degree_tensor + def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" # Extract degree tensors before spawning workers. Worker subprocesses From 3d41dc31bf2c89ae0c125b68abc0d81989bb3f03 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 31 Mar 2026 17:44:33 +0000 Subject: [PATCH 13/16] Update --- gigl/distributed/base_dist_loader.py | 7 - gigl/distributed/dist_ablp_neighborloader.py | 17 ++ gigl/distributed/dist_dataset.py | 24 ++- gigl/distributed/dist_ppr_sampler.py | 74 ++++++--- gigl/distributed/dist_sampling_producer.py | 49 +----- .../distributed/distributed_neighborloader.py | 17 ++ gigl/distributed/sampler_options.py | 6 +- gigl/distributed/utils/degree.py | 153 ++++++------------ tests/unit/distributed/utils/degree_test.py | 58 +++---- 9 files changed, 166 insertions(+), 239 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 4823813f1..ed0783ef6 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -566,13 +566,6 @@ def _init_colocated_connections( self._channel = producer.output_channel self._mp_producer = producer - # Run distributed collectives (e.g. degree tensor all_reduce for PPR) - # BEFORE the staggered sleep. Collectives require all ranks to - # participate simultaneously; sleeping first would cause all ranks to - # exit the collective together and then proceed to worker spawning at - # the same time, nullifying the stagger. - producer.pre_init_collectives() - # Staggered init — sleep proportional to local_rank to avoid # concurrent initialization spikes that cause CPU memory OOM. logger.info( diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index f96fd3bc0..8c86c1c3d 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -359,6 +359,22 @@ def __init__( assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) + # Compute degree tensors here, before the producer is constructed and + # before the staggered sleep in _init_colocated_connections. The + # all_reduce requires all ranks to participate simultaneously; deferring + # it to after the sleep would negate the stagger on worker spawning. + if isinstance(sampler_options, PPRSamplerOptions): + degree_tensors = dataset.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} edge types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with {degree_tensors.size(0)} nodes." + ) + else: + degree_tensors = None producer: Union[ DistSamplingProducer, Callable[..., int] ] = DistSamplingProducer( @@ -368,6 +384,7 @@ def __init__( worker_options=worker_options, channel=channel, sampler_options=sampler_options, + degree_tensors=degree_tensors, ) else: producer = DistServer.create_sampling_producer diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 20436f4bc..f495a6d96 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -78,7 +78,7 @@ def __init__( Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ] = None, degree_tensor: Optional[ - Union[torch.Tensor, dict[NodeType, torch.Tensor]] + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, ) -> None: """ @@ -104,7 +104,7 @@ def __init__( Note this will be None in the homogeneous case if the data has no node features, or will only contain node types with node features in the heterogeneous case. edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. - degree_tensor: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. + degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. """ self._rank: int = rank self._world_size: int = world_size @@ -141,7 +141,7 @@ def __init__( self._edge_feature_info = edge_feature_info self._degree_tensor: Optional[ - Union[torch.Tensor, dict[NodeType, torch.Tensor]] + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = degree_tensor # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear @@ -300,7 +300,7 @@ def edge_feature_info( @property def degree_tensor( self, - ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: + ) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: """ Lazily compute and return the degree tensor for the graph. @@ -308,19 +308,15 @@ def degree_tensor( all-reduce to aggregate across all machines. Requires torch.distributed to be initialized. - For heterogeneous graphs, degrees are summed across all edge types sharing - the same anchor node type (determined by ``self._edge_dir``), yielding one - ``int16`` tensor per node type rather than one per edge type. - Over-counting correction (for processes sharing the same data on the same machine) is handled automatically by detecting the distributed topology. The result is cached for subsequent accesses. Returns: - Union[torch.Tensor, dict[NodeType, torch.Tensor]]: The aggregated degree tensor. + Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensor. - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping NodeType to ``int16`` degree tensors. + - For heterogeneous graphs: A dict mapping EdgeType to degree tensors. Raises: RuntimeError: If torch.distributed is not initialized. @@ -330,9 +326,7 @@ def degree_tensor( if self.graph is None: raise ValueError("Dataset graph is None. Cannot compute degrees.") - self._degree_tensor = compute_and_broadcast_degree_tensor( - self.graph, edge_dir=self._edge_dir - ) + self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) return self._degree_tensor @property @@ -863,7 +857,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]], Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]], Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], - Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], + Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ]: """ Serializes the member variables of the DistDatasetClass @@ -885,7 +879,7 @@ def share_ipc( Optional[Union[int, dict[NodeType, int]]]: Number of test nodes on the current machine. Will be a dict if heterogeneous. Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous - Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous + Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 34cb1661b..17673a72d 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -103,7 +103,7 @@ def __init__( max_ppr_nodes: int = 50, num_neighbors_per_hop: int = 100_000, total_degree_dtype: torch.dtype = torch.int32, - degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], **kwargs, ): super().__init__(*args, **kwargs) @@ -144,45 +144,69 @@ def __init__( ] self._is_homogeneous = True - # The dataset pre-aggregates degrees per node type as int16. - # _build_total_degree_tensors reindexes by node type with no further casting. + # Precompute total degree per node type: the sum of degrees across all + # edge types traversable from that node type. This is a graph-level + # property used on every PPR iteration, so computing it once at init + # avoids per-node summation and cache lookups in the hot loop. + # TODO (mkolodner-sc): This trades memory for throughput — we + # materialize a tensor per node type to avoid recomputing total degree + # on every neighbor during sampling. Computing it here (rather than in + # the dataset) also keeps the door open for edge-specific degree + # strategies. If memory becomes a bottleneck, revisit this. self._node_type_to_total_degree: dict[ NodeType, torch.Tensor - ] = self._build_total_degree_tensors(degree_tensors) + ] = self._build_total_degree_tensors(degree_tensors, total_degree_dtype) def _build_total_degree_tensors( self, - degree_tensors: Union[torch.Tensor, dict[NodeType, torch.Tensor]], + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + dtype: torch.dtype, ) -> dict[NodeType, torch.Tensor]: - """Reindex pre-aggregated per-node-type degree tensors for use in the PPR loop. + """Build total-degree tensors by summing per-edge-type degrees for each node type. - The dataset provides degrees already summed across edge types per node type - (as ``int16``). This method maps those tensors into the internal node-type - keying used by the sampler, with no further summation or dtype casting. + For homogeneous graphs, the total degree is just the single degree tensor. + For heterogeneous graphs, it sums degree tensors across all edge types + traversable from each node type, padding shorter tensors with zeros. Args: - degree_tensors: Per-node-type degree tensors from the dataset, already - aggregated across edge types. + degree_tensors: Per-edge-type degree tensors from the dataset. + dtype: Dtype for the output tensors. Returns: - Dict mapping node type to a 1-D degree tensor. - - Raises: - ValueError: If a required node type is absent from ``degree_tensors``. + Dict mapping node type to a 1-D tensor of total degrees. """ + result: dict[NodeType, torch.Tensor] = {} + if self._is_homogeneous: assert isinstance(degree_tensors, torch.Tensor) - return {_PPR_HOMOGENEOUS_NODE_TYPE: degree_tensors} + # Single edge type: degree values fit directly in the target dtype. + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + else: + assert isinstance(degree_tensors, dict) + dtype_max = torch.iinfo(dtype).max + for node_type, edge_types in self._node_type_to_edge_types.items(): + max_len = 0 + for et in edge_types: + if et not in degree_tensors: + raise ValueError( + f"Edge type {et} not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + max_len = max(max_len, len(degree_tensors[et])) + + # Each degree tensor is indexed by node ID (derived from CSR + # indptr), so index i in every edge type's tensor refers to + # the same node. Element-wise summation gives the total degree + # per node across all edge types. Shorter tensors are padded + # implicitly (only the first len(et_degrees) entries are added). + # Sum in int64: aggregate degrees are bounded by partition size + # and fit comfortably within int64 range in practice. + summed = torch.zeros(max_len, dtype=torch.int64) + for et in edge_types: + et_degrees = degree_tensors[et] + summed[: len(et_degrees)] += et_degrees.to(torch.int64) + result[node_type] = summed.clamp(max=dtype_max).to(dtype) - assert isinstance(degree_tensors, dict) - result: dict[NodeType, torch.Tensor] = {} - for node_type in self._node_type_to_edge_types: - if node_type not in degree_tensors: - raise ValueError( - f"Node type '{node_type}' not found in degree tensors. " - f"Available: {list(degree_tensors.keys())}" - ) - result[node_type] = degree_tensors[node_type] return result def _get_total_degree(self, node_id: int, node_type: NodeType) -> int: diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 1294d2d4d..78aaeb388 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -29,14 +29,13 @@ SamplingConfig, SamplingType, ) -from graphlearn_torch.typing import EdgeType, NodeType +from graphlearn_torch.typing import EdgeType from graphlearn_torch.utils import seed_everything from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from gigl.common.logger import Logger -from gigl.distributed.dist_dataset import DistDataset as GiglDistDataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler from gigl.distributed.sampler_options import ( @@ -60,7 +59,7 @@ def _sampling_worker_loop( sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, sampler_options: SamplerOptions, - degree_tensors: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], + degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], ): dist_sampler = None try: @@ -212,50 +211,16 @@ def __init__( worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, sampler_options: SamplerOptions, + degree_tensors: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = None, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options - - def pre_init_collectives(self) -> None: - """Run distributed collectives that must complete before the staggered init sleep. - - ``init()`` calls ``self.data.degree_tensor``, which triggers a - ``torch.distributed.all_reduce``. Collectives require all ranks to - participate simultaneously. If that collective runs inside ``init()`` - (after the staggered sleep), every rank exits the collective at the same - instant and proceeds to worker spawning together, negating the stagger. - - Call this before the staggered sleep so all ranks complete the collective - together, then the sleep staggering applies only to the expensive worker- - spawn step. The result is cached on the dataset; ``init()`` reuses it - without re-running the collective. - """ - if isinstance(self._sampler_options, PPRSamplerOptions): - assert isinstance(self.data, GiglDistDataset) - _ = self.data.degree_tensor + self._degree_tensors = degree_tensors def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" - # Extract degree tensors before spawning workers. Worker subprocesses - # only initialize RPC (not torch.distributed), so the lazy degree - # computation on GiglDistDataset would fail there. Computing here — - # where torch.distributed IS initialized — lets the tensor be shared - # to workers via IPC. - degree_tensors: Optional[ - Union[torch.Tensor, dict[NodeType, torch.Tensor]] - ] = None - if isinstance(self._sampler_options, PPRSamplerOptions): - assert isinstance(self.data, GiglDistDataset) - degree_tensors = self.data.degree_tensor - if isinstance(degree_tensors, dict): - logger.info( - f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} node types." - ) - else: - logger.info( - f"Pre-computed degree tensor for PPR sampling with {degree_tensors.size(0)} nodes." - ) - if self.sampling_config.seed is not None: seed_everything(self.sampling_config.seed) if not self.sampling_config.shuffle: @@ -284,7 +249,7 @@ def init(self): self.sampling_completed_worker_count, barrier, self._sampler_options, - degree_tensors, + self._degree_tensors, ), ) w.daemon = True diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index e47075343..2d3b82e9a 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -268,6 +268,22 @@ def __init__( assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) + # Compute degree tensors here, before the producer is constructed and + # before the staggered sleep in _init_colocated_connections. The + # all_reduce requires all ranks to participate simultaneously; deferring + # it to after the sleep would negate the stagger on worker spawning. + if isinstance(sampler_options, PPRSamplerOptions): + degree_tensors = dataset.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} edge types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with {degree_tensors.size(0)} nodes." + ) + else: + degree_tensors = None producer: Union[ DistSamplingProducer, Callable[..., int] ] = DistSamplingProducer( @@ -277,6 +293,7 @@ def __init__( worker_options=worker_options, channel=channel, sampler_options=sampler_options, + degree_tensors=degree_tensors, ) else: producer = GiglDistServer.create_sampling_producer diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index b3165b2af..d87a83d52 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -56,9 +56,9 @@ class PPRSamplerOptions: num_neighbors_per_hop: Maximum number of neighbors fetched per node per edge type during PPR traversal. Set large to approximate fetching all neighbors. - total_degree_dtype: Retained for backwards compatibility; currently unused. - Degree tensors are stored as ``int16`` and no dtype conversion is - applied in the sampler. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults + to ``torch.int32``, which supports total degrees up to ~2 billion. + Use a larger dtype if nodes have exceptionally high aggregate degrees. """ alpha: float = 0.5 diff --git a/gigl/distributed/utils/degree.py b/gigl/distributed/utils/degree.py index 188dfa4de..e5cb29cb9 100644 --- a/gigl/distributed/utils/degree.py +++ b/gigl/distributed/utils/degree.py @@ -27,7 +27,7 @@ import torch from graphlearn_torch.data import Graph -from torch_geometric.typing import EdgeType, NodeType +from torch_geometric.typing import EdgeType from gigl.common.logger import Logger from gigl.distributed.utils.device import get_device_from_process_group @@ -39,34 +39,30 @@ def compute_and_broadcast_degree_tensor( graph: Union[Graph, dict[EdgeType, Graph]], - edge_dir: str, -) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: +) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: """ Compute node degrees from a graph and aggregate across all machines. - For heterogeneous graphs, degrees are summed across all edge types sharing - the same anchor node type (source for ``edge_dir="out"``, destination for - ``edge_dir="in"``), then all-reduced across ranks. The result is one - ``int16`` tensor per anchor node type rather than one per edge type, - reducing both stored memory and the number of all-reduce calls. + Computes degrees from the CSR row pointers (indptr) and performs all-reduce + to aggregate across ranks. Over-counting correction (for processes sharing the same data) is handled automatically by detecting the distributed topology. Args: graph: A Graph (homogeneous) or dict[EdgeType, Graph] (heterogeneous). - edge_dir: ``"out"`` to group by source node type (out-degree); - ``"in"`` to group by destination node type (in-degree). + For heterogeneous graphs, label edge types are automatically excluded + from the computation — they are supervision edges and should not + contribute to node degree for graph traversal algorithms like PPR. Returns: - Union[torch.Tensor, dict[NodeType, torch.Tensor]]: The aggregated degree tensors. + Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: The aggregated degree tensors. - For homogeneous graphs: A tensor of shape [num_nodes]. - - For heterogeneous graphs: A dict mapping NodeType to ``int16`` degree tensors, - where each entry is the total degree across all edge types for that anchor node type. + - For heterogeneous graphs: A dict mapping non-label EdgeType to degree tensors. Raises: RuntimeError: If torch.distributed is not initialized. - ValueError: If topology is unavailable for a homogeneous graph. + ValueError: If topology is unavailable. """ if not torch.distributed.is_initialized(): raise RuntimeError( @@ -79,10 +75,24 @@ def compute_and_broadcast_degree_tensor( if topo is None or topo.indptr is None: raise ValueError("Topology/indptr not available for graph.") local_degrees: Union[ - torch.Tensor, dict[NodeType, torch.Tensor] + torch.Tensor, dict[EdgeType, torch.Tensor] ] = _compute_degrees_from_indptr(topo.indptr) else: - local_degrees = _compute_hetero_degrees_by_node_type(graph, edge_dir) + local_dict: dict[EdgeType, torch.Tensor] = {} + for edge_type, edge_graph in graph.items(): + # Label edge types are supervision edges and should not contribute + # to node degree for graph traversal algorithms like PPR. + if is_label_edge_type(edge_type): + continue + topo = edge_graph.topo + if topo is None or topo.indptr is None: + logger.warning( + f"Topology/indptr not available for edge type {edge_type}, using empty tensor." + ) + local_dict[edge_type] = torch.empty(0, dtype=torch.int16) + else: + local_dict[edge_type] = _compute_degrees_from_indptr(topo.indptr) + local_degrees = local_dict # All-reduce across ranks (over-counting correction handled internally) result = _all_reduce_degrees(local_degrees) @@ -96,90 +106,19 @@ def compute_and_broadcast_degree_tensor( else: logger.info("Graph contained 0 nodes when computing degrees") else: - for node_type, degrees in result.items(): + for edge_type, degrees in result.items(): if degrees.numel() > 0: logger.info( - f"{node_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" + f"{edge_type}: {degrees.size(0)} nodes, max={degrees.max().item()}, min={degrees.min().item()}" ) else: logger.info( - f"Graph contained 0 nodes for node type {node_type} when computing degrees" + f"Graph contained 0 nodes for edge type {edge_type} when computing degrees" ) return result -def _compute_hetero_degrees_by_node_type( - graph: dict[EdgeType, Graph], - edge_dir: str, -) -> dict[NodeType, torch.Tensor]: - """Sum per-edge-type degrees into per-anchor-node-type totals. - - Label edge types (ABLP supervision edges) are excluded — they should not - contribute to degree counts used in PPR, matching the sampler's own exclusion. - - Uses a two-pass approach to avoid dynamic accumulator resizing: - - - Pass 1: compute ``int32`` degrees per edge type and record the maximum - node count per anchor type (to pre-size the accumulator). - - Pass 2: allocate one ``int32`` accumulator per anchor type and sum. - - Returns ``int32`` tensors; the caller is responsible for clamping to - ``int16`` after the all-reduce (see ``_all_reduce_degrees``). - - All anchor node types derived from non-label ``graph.keys()`` are present in - the output (with size-0 tensors for types whose ``indptr`` is unavailable), - ensuring symmetric participation in the subsequent all-reduce. - - Args: - graph: Heterogeneous graph mapping EdgeType to Graph. - edge_dir: ``"out"`` to anchor on source node type; ``"in"`` to anchor - on destination node type. - - Returns: - dict[NodeType, torch.Tensor]: ``int32`` total-degree tensor per anchor node type. - """ - # All anchor node types must appear in the output so the all_reduce is - # symmetric across ranks, even if this rank has no edges for a given type. - # Label edge types are excluded — they represent ABLP supervision edges, not - # structural neighbors, and should not contribute to degree counts used in PPR. - anchor_ntypes: set[NodeType] = { - etype[-1] if edge_dir == "in" else etype[0] - for etype in graph.keys() - if not is_label_edge_type(etype) - } - - # Pass 1: compute int32 degrees per valid edge type and track max node - # count per anchor type. int32 is sufficient: per-node degrees are always - # well below int32 max for any realistic graph. - et_degrees_i32: dict[EdgeType, torch.Tensor] = {} - max_sizes: dict[NodeType, int] = {ntype: 0 for ntype in anchor_ntypes} - for edge_type, edge_graph in graph.items(): - if is_label_edge_type(edge_type): - continue - anchor_ntype: NodeType = edge_type[-1] if edge_dir == "in" else edge_type[0] - topo = edge_graph.topo - if topo is None or topo.indptr is None: - logger.warning( - f"Topology/indptr not available for edge type {edge_type}, skipping." - ) - continue - degrees_i32 = (topo.indptr[1:] - topo.indptr[:-1]).to(torch.int32) - et_degrees_i32[edge_type] = degrees_i32 - max_sizes[anchor_ntype] = max(max_sizes[anchor_ntype], len(degrees_i32)) - - # Pass 2: allocate accumulators once (sized from pass 1) and sum. - accumulators: dict[NodeType, torch.Tensor] = { - ntype: torch.zeros(size, dtype=torch.int32) - for ntype, size in max_sizes.items() - } - for edge_type, degrees_i32 in et_degrees_i32.items(): - anchor_ntype = edge_type[-1] if edge_dir == "in" else edge_type[0] - accumulators[anchor_ntype][: len(degrees_i32)] += degrees_i32 - - return accumulators - - def _pad_to_size(tensor: torch.Tensor, target_size: int) -> torch.Tensor: """Pad tensor with zeros to reach target_size.""" if tensor.size(0) >= target_size: @@ -204,12 +143,12 @@ def _compute_degrees_from_indptr(indptr: torch.Tensor) -> torch.Tensor: def _all_reduce_degrees( - local_degrees: Union[torch.Tensor, dict[NodeType, torch.Tensor]], -) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: + local_degrees: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], +) -> Union[torch.Tensor, dict[EdgeType, torch.Tensor]]: """All-reduce degree tensors across ranks, handling both homogeneous and heterogeneous cases. - For heterogeneous graphs, iterates over the node types in local_degrees. All partitions - are expected to have entries for all node types (even if some have empty tensors). + For heterogeneous graphs, iterates over the edge types in local_degrees. All partitions + are expected to have entries for all edge types (even if some have empty tensors). Moves tensors to GPU for the all-reduce if using NCCL backend (which requires CUDA), otherwise keeps tensors on CPU (for Gloo backend). @@ -233,12 +172,12 @@ def _all_reduce_degrees( IP addresses, then divides by that count to correct the over-counting. Args: - local_degrees: Either a single tensor (homogeneous) or dict mapping NodeType + local_degrees: Either a single tensor (homogeneous) or dict mapping EdgeType to tensors (heterogeneous). For heterogeneous graphs, all partitions must - have entries for all node types. + have entries for all edge types. Returns: - Aggregated degree tensors in the same format as input, stored as ``int16``. + Aggregated degree tensors in the same format as input. Raises: RuntimeError: If torch.distributed is not initialized. @@ -264,24 +203,22 @@ def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(local_size, op=torch.distributed.ReduceOp.MAX) max_size = int(local_size.item()) - # Pad and convert to int32 for the all-reduce (int16 is not supported). - # int32 is sufficient: the raw sum across W ranks is at most W * max_degree, - # which fits comfortably in int32 for any realistic world size and degree. - padded = _pad_to_size(tensor, max_size).to(torch.int32).to(device) + # Pad, convert to int64 (all_reduce doesn't support int16), move to device + padded = _pad_to_size(tensor, max_size).to(torch.int64).to(device) torch.distributed.all_reduce(padded, op=torch.distributed.ReduceOp.SUM) - # Correct for over-counting, move back to CPU, and clamp to int16. - # TODO (mkolodner-sc): Potentially want to parameterize this in the future if we want degrees higher than the int16 max. + # Correct for over-counting, move back to CPU, and clamp to int16 + # TODO (mkolodner-sc): Potentially want to paramaterize this in the future if we want degrees higher than the int16 max. return _clamp_to_int16((padded // local_world_size).cpu()) # Homogeneous case if isinstance(local_degrees, torch.Tensor): return reduce_tensor(local_degrees) - # Heterogeneous case: all-reduce each node type. - # Sort node types for deterministic ordering across ranks. - result: dict[NodeType, torch.Tensor] = {} - for node_type in sorted(local_degrees.keys()): - result[node_type] = reduce_tensor(local_degrees[node_type]) + # Heterogeneous case: all-reduce each edge type + # Sort edge types for deterministic ordering across ranks + result: dict[EdgeType, torch.Tensor] = {} + for edge_type in sorted(local_degrees.keys()): + result[edge_type] = reduce_tensor(local_degrees[edge_type]) return result diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 16f334509..6836e4581 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -60,7 +60,7 @@ def test_homogeneous_graph(self): dataset = create_homogeneous_dataset(edge_index=edge_index) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") + result = compute_and_broadcast_degree_tensor(dataset.graph) assert isinstance(result, torch.Tensor) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) @@ -68,28 +68,20 @@ def test_homogeneous_graph(self): self.assert_tensor_equality(result, expected) def test_heterogeneous_graph(self): - """Test degree computation for a heterogeneous graph. - - Result is keyed by anchor node type (source for edge_dir="out"), not edge type. - Degrees are summed across all edge types sharing the same anchor node type. - """ + """Test degree computation for a heterogeneous graph.""" edge_indices = DEFAULT_HETEROGENEOUS_EDGE_INDICES dataset = create_heterogeneous_dataset(edge_indices=edge_indices) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") + result = compute_and_broadcast_degree_tensor(dataset.graph) assert isinstance(result, dict) - expected_node_types = {etype[0] for etype in edge_indices.keys()} - self.assertEqual(set(result.keys()), expected_node_types) + self.assertEqual(set(result.keys()), set(edge_indices.keys())) - # For the default test data each source node type maps to exactly one edge type, - # so per-node-type degrees equal per-edge-type degrees. for edge_type, edge_index in edge_indices.items(): - src_node_type = edge_type[0] num_nodes = int(edge_index[0].max().item() + 1) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[src_node_type], expected) + self.assert_tensor_equality(result[edge_type], expected) def test_heterogeneous_graph_with_missing_topology(self): """Test that edge types with missing topology get empty tensors. @@ -121,18 +113,16 @@ def test_heterogeneous_graph_with_missing_topology(self): # Manually set one graph's topology to None to test the edge case dataset.graph[edge_type_without_topo].topo = None - result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") + result = compute_and_broadcast_degree_tensor(dataset.graph) assert isinstance(result, dict) - # Keys are source node types (anchor for edge_dir="out"), not edge types. - expected_node_types = {etype[0] for etype in edge_types} - self.assertEqual(set(result.keys()), expected_node_types) + self.assertEqual(set(result.keys()), set(edge_types)) - # Edge type with topology: anchor node type should have computed degrees. - self.assert_tensor_equality(result[edge_type_with_topo[0]], expected_degrees) + # Edge type with topology should have computed degrees + self.assert_tensor_equality(result[edge_type_with_topo], expected_degrees) - # Edge type without topology: anchor node type should have empty tensor. - self.assertEqual(result[edge_type_without_topo[0]].numel(), 0) + # Edge type without topology should have empty tensor + self.assertEqual(result[edge_type_without_topo].numel(), 0) def _run_local_world_size_correction_homogeneous( @@ -152,7 +142,7 @@ def _run_local_world_size_correction_homogeneous( try: dataset = create_homogeneous_dataset(edge_index=edge_index) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") + result = compute_and_broadcast_degree_tensor(dataset.graph) assert isinstance(result, torch.Tensor) assert_tensor_equality(result, expected_degrees) @@ -167,10 +157,7 @@ def _run_local_world_size_correction_heterogeneous( edge_indices: dict, expected_degrees: dict, ) -> None: - """Worker function for multi-process local_world_size correction test (heterogeneous). - - expected_degrees must be keyed by node type (anchor for edge_dir="out"). - """ + """Worker function for multi-process local_world_size correction test (heterogeneous).""" dist.init_process_group( backend="gloo", init_method=init_method, @@ -180,12 +167,12 @@ def _run_local_world_size_correction_heterogeneous( try: dataset = create_heterogeneous_dataset(edge_indices=edge_indices) assert dataset.graph is not None - result = compute_and_broadcast_degree_tensor(dataset.graph, edge_dir="out") + result = compute_and_broadcast_degree_tensor(dataset.graph) assert isinstance(result, dict) assert set(result.keys()) == set(expected_degrees.keys()) - for node_type, expected in expected_degrees.items(): - assert_tensor_equality(result[node_type], expected) + for edge_type, expected in expected_degrees.items(): + assert_tensor_equality(result[edge_type], expected) finally: dist.destroy_process_group() @@ -217,16 +204,13 @@ def test_local_world_size_correction_heterogeneous(self): """Test over-counting correction for heterogeneous graphs with 2 processes.""" edge_indices = DEFAULT_HETEROGENEOUS_EDGE_INDICES - # Build expected degrees keyed by source node type (anchor for edge_dir="out"). - # For the default test data each source node type maps to exactly one edge type. expected_degrees = {} for edge_type, edge_index in edge_indices.items(): - src_node_type = edge_type[0] num_nodes = int(edge_index[0].max().item() + 1) raw_degrees = _compute_expected_degrees_from_edge_index( edge_index, num_nodes ) - expected_degrees[src_node_type] = raw_degrees + expected_degrees[edge_type] = raw_degrees init_method = get_process_group_init_method() mp.spawn( @@ -279,16 +263,12 @@ def test_degree_tensor_heterogeneous(self): result = dataset.degree_tensor assert isinstance(result, dict) - # degree_tensor is keyed by node type (anchor for dataset's edge_dir="out"), not edge type. - expected_node_types = {etype[0] for etype in edge_indices.keys()} - self.assertEqual(set(result.keys()), expected_node_types) + self.assertEqual(set(result.keys()), set(edge_indices.keys())) - # For the default test data each source node type maps to exactly one edge type. for edge_type, edge_index in edge_indices.items(): - src_node_type = edge_type[0] num_nodes = int(edge_index[0].max().item() + 1) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[src_node_type], expected) + self.assert_tensor_equality(result[edge_type], expected) class TestHelperFunctions(TestCase): From 04cee0c96aeb426fa8d05829daef209aa5cf348a Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 31 Mar 2026 18:22:30 +0000 Subject: [PATCH 14/16] Fix for gs mode --- gigl/distributed/graph_store/dist_server.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 2fba2d93d..bd991c5eb 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -39,7 +39,7 @@ FetchNodesRequest, ) from gigl.distributed.sampler import ABLPNodeSamplerInput -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import FeatureInfo, select_label_edge_types @@ -485,6 +485,16 @@ def create_sampling_producer( buffer = ShmChannel( worker_options.buffer_capacity, worker_options.buffer_size ) + # Degree tensors for PPR must be computed before constructing + # the producer. The all_reduce inside degree_tensor requires + # all ranks to participate simultaneously and cannot run inside + # worker subprocesses (which only initialize RPC, not + # torch.distributed). + degree_tensors = ( + self.dataset.degree_tensor + if isinstance(sampler_options, PPRSamplerOptions) + else None + ) producer = DistSamplingProducer( data=self.dataset, sampler_input=sampler_input, @@ -492,6 +502,7 @@ def create_sampling_producer( worker_options=worker_options, channel=buffer, sampler_options=sampler_options, + degree_tensors=degree_tensors, ) producer_start_time = time.monotonic() producer.init() From 99fdef2a1a3b8ec8c1d65d7f05946bd70787800f Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 31 Mar 2026 20:40:58 +0000 Subject: [PATCH 15/16] Address comment --- gigl/distributed/base_dist_loader.py | 56 ++++++++++++++++++- gigl/distributed/dist_ablp_neighborloader.py | 23 +------- .../distributed/distributed_neighborloader.py | 23 +------- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index ed0783ef6..b0ce59afd 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -26,6 +26,7 @@ from graphlearn_torch.distributed.dist_client import async_request_server from graphlearn_torch.distributed.rpc import rpc_is_initialized from graphlearn_torch.sampler import ( + EdgeSamplerInput, NodeSamplerInput, RemoteSamplerInput, SamplingConfig, @@ -42,7 +43,7 @@ from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions from gigl.distributed.utils.neighborloader import ( DatasetSchema, patch_fanout_for_sampling, @@ -388,6 +389,59 @@ def create_colocated_channel( channel.pin_memory() return channel + @staticmethod + def create_mp_producer( + dataset: DistDataset, + sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], + sampling_config: SamplingConfig, + worker_options: MpDistSamplingWorkerOptions, + sampler_options: SamplerOptions, + ) -> DistSamplingProducer: + """Create a colocated-mode DistSamplingProducer with pre-computed degree tensors. + + Creates the shared-memory channel and, for PPR sampling, pre-computes + degree tensors via all-reduce before constructing the producer. The + all-reduce must happen here — before the staggered sleep in + ``_init_colocated_connections`` — so that all ranks complete the + collective together and the stagger applies only to worker spawning. + + Args: + dataset: The local DistDataset for this rank. + sampler_input: Node or edge sampler input (ABLPNodeSamplerInput is + also accepted as it extends NodeSamplerInput). + sampling_config: Sampling configuration. + worker_options: Colocated worker options (must be fully configured). + sampler_options: Controls which sampler class is instantiated. + + Returns: + A fully constructed DistSamplingProducer, ready to be passed to + ``_init_colocated_connections``. + """ + channel = BaseDistLoader.create_colocated_channel(worker_options) + if isinstance(sampler_options, PPRSamplerOptions): + degree_tensors = dataset.degree_tensor + if isinstance(degree_tensors, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across " + f"{len(degree_tensors)} edge types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with " + f"{degree_tensors.size(0)} nodes." + ) + else: + degree_tensors = None + return DistSamplingProducer( + data=dataset, + sampler_input=sampler_input, + sampling_config=sampling_config, + worker_options=worker_options, + channel=channel, + sampler_options=sampler_options, + degree_tensors=degree_tensors, + ) + @staticmethod def initialize_colocated_sampling_worker( *, diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 8c86c1c3d..f9efd01a0 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -358,33 +358,14 @@ def __init__( if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) - channel = BaseDistLoader.create_colocated_channel(worker_options) - # Compute degree tensors here, before the producer is constructed and - # before the staggered sleep in _init_colocated_connections. The - # all_reduce requires all ranks to participate simultaneously; deferring - # it to after the sleep would negate the stagger on worker spawning. - if isinstance(sampler_options, PPRSamplerOptions): - degree_tensors = dataset.degree_tensor - if isinstance(degree_tensors, dict): - logger.info( - f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} edge types." - ) - else: - logger.info( - f"Pre-computed degree tensor for PPR sampling with {degree_tensors.size(0)} nodes." - ) - else: - degree_tensors = None producer: Union[ DistSamplingProducer, Callable[..., int] - ] = DistSamplingProducer( - data=dataset, + ] = BaseDistLoader.create_mp_producer( + dataset=dataset, sampler_input=sampler_input, sampling_config=sampling_config, worker_options=worker_options, - channel=channel, sampler_options=sampler_options, - degree_tensors=degree_tensors, ) else: producer = DistServer.create_sampling_producer diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 2d3b82e9a..1cd12cdc9 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -267,33 +267,14 @@ def __init__( if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) - channel = BaseDistLoader.create_colocated_channel(worker_options) - # Compute degree tensors here, before the producer is constructed and - # before the staggered sleep in _init_colocated_connections. The - # all_reduce requires all ranks to participate simultaneously; deferring - # it to after the sleep would negate the stagger on worker spawning. - if isinstance(sampler_options, PPRSamplerOptions): - degree_tensors = dataset.degree_tensor - if isinstance(degree_tensors, dict): - logger.info( - f"Pre-computed degree tensors for PPR sampling across {len(degree_tensors)} edge types." - ) - else: - logger.info( - f"Pre-computed degree tensor for PPR sampling with {degree_tensors.size(0)} nodes." - ) - else: - degree_tensors = None producer: Union[ DistSamplingProducer, Callable[..., int] - ] = DistSamplingProducer( - data=dataset, + ] = BaseDistLoader.create_mp_producer( + dataset=dataset, sampler_input=input_data, sampling_config=sampling_config, worker_options=worker_options, - channel=channel, sampler_options=sampler_options, - degree_tensors=degree_tensors, ) else: producer = GiglDistServer.create_sampling_producer From 210c1ddc7be8d0e542d7a65cbab25a4c1690b90b Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 1 Apr 2026 22:16:42 +0000 Subject: [PATCH 16/16] Upate --- containers/Dockerfile.src | 2 ++ gigl/scripts/post_install.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/containers/Dockerfile.src b/containers/Dockerfile.src index b80295962..7b105084a 100644 --- a/containers/Dockerfile.src +++ b/containers/Dockerfile.src @@ -15,8 +15,10 @@ COPY uv.lock uv.lock COPY gigl/dep_vars.env gigl/dep_vars.env COPY deployment deployment COPY gigl gigl +COPY scripts scripts COPY snapchat snapchat COPY tests tests COPY examples examples RUN uv pip install -e . +RUN uv run python scripts/build_cpp_extensions.py build_ext --inplace diff --git a/gigl/scripts/post_install.py b/gigl/scripts/post_install.py index eed1528f6..d31b4244a 100644 --- a/gigl/scripts/post_install.py +++ b/gigl/scripts/post_install.py @@ -63,7 +63,12 @@ def main(): try: print("Building C++ extensions...") subprocess.run( - [sys.executable, "scripts/build_cpp_extensions.py", "build_ext", "--inplace"], + [ + sys.executable, + "scripts/build_cpp_extensions.py", + "build_ext", + "--inplace", + ], cwd=repo_root, check=True, )