diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index 4a8601f857..96ac65effc 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -32,7 +32,7 @@ from cuda.core._utils.clear_error_support import ( raise_code_path_meant_to_be_unreachable, ) from cuda.core._utils.cuda_utils cimport HANDLE_RETURN -from cuda.core._utils.version cimport cy_driver_version +from cuda.core._utils.version cimport cy_binding_version, cy_driver_version from cuda.core._utils.cuda_utils import driver from cuda.bindings cimport cydriver @@ -463,6 +463,11 @@ cdef class Kernel: "Driver version 12.4 or newer is required for this function. " f"Using driver version {'.'.join(map(str, cy_driver_version()))}" ) + if cy_binding_version() < (12, 4, 0): + raise NotImplementedError( + "cuda.bindings 12.4 or newer is required for this function. " + f"Using binding version {'.'.join(map(str, cy_binding_version()))}" + ) cdef size_t arg_pos = 0 cdef list param_info_data = [] cdef cydriver.CUkernel cu_kernel = as_cu(self._h_kernel) diff --git a/cuda_core/cuda/core/graph/_subclasses.pyx b/cuda_core/cuda/core/graph/_subclasses.pyx index 6d15ebc3ff..3550e993fe 100644 --- a/cuda_core/cuda/core/graph/_subclasses.pyx +++ b/cuda_core/cuda/core/graph/_subclasses.pyx @@ -60,8 +60,10 @@ cdef bint _version_checked = False cdef bint _check_node_get_params(): global _has_cuGraphNodeGetParams, _version_checked if not _version_checked: - from cuda.core._utils.version import driver_version - _has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0) + from cuda.core._utils.version import binding_version, driver_version + _has_cuGraphNodeGetParams = ( + driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0) + ) _version_checked = True return _has_cuGraphNodeGetParams diff --git a/cuda_core/tests/graph/test_graph_definition.py b/cuda_core/tests/graph/test_graph_definition.py index 7f70c74aa3..f9d10c766e 100644 --- a/cuda_core/tests/graph/test_graph_definition.py +++ b/cuda_core/tests/graph/test_graph_definition.py @@ -48,13 +48,13 @@ def _skip_if_no_managed_mempool(): pytest.skip("Device does not support managed memory pool operations") -def _driver_has_node_get_params(): - from cuda.core._utils.version import driver_version +def _has_node_get_params(): + from cuda.core._utils.version import binding_version, driver_version - return driver_version() >= (13, 2, 0) + return driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0) -_HAS_NODE_GET_PARAMS = _driver_has_node_get_params() +_HAS_NODE_GET_PARAMS = _has_node_get_params() def _bindings_major_version():