Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cuda_core/cuda/core/_module.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ from cuda.core._utils.clear_error_support import (
raise_code_path_meant_to_be_unreachable,
)
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core._utils.version cimport cy_driver_version
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
from cuda.core._utils.cuda_utils import driver
from cuda.bindings cimport cydriver

Expand Down Expand Up @@ -463,6 +463,11 @@ cdef class Kernel:
"Driver version 12.4 or newer is required for this function. "
f"Using driver version {'.'.join(map(str, cy_driver_version()))}"
)
if cy_binding_version() < (12, 4, 0):
raise NotImplementedError(
"cuda.bindings 12.4 or newer is required for this function. "
f"Using binding version {'.'.join(map(str, cy_binding_version()))}"
)
cdef size_t arg_pos = 0
cdef list param_info_data = []
cdef cydriver.CUkernel cu_kernel = as_cu(self._h_kernel)
Expand Down
6 changes: 4 additions & 2 deletions cuda_core/cuda/core/graph/_subclasses.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ cdef bint _version_checked = False
cdef bint _check_node_get_params():
global _has_cuGraphNodeGetParams, _version_checked
if not _version_checked:
from cuda.core._utils.version import driver_version
_has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0)
from cuda.core._utils.version import binding_version, driver_version
_has_cuGraphNodeGetParams = (
driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0)
)
_version_checked = True
return _has_cuGraphNodeGetParams

Expand Down
8 changes: 4 additions & 4 deletions cuda_core/tests/graph/test_graph_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def _skip_if_no_managed_mempool():
pytest.skip("Device does not support managed memory pool operations")


def _driver_has_node_get_params():
from cuda.core._utils.version import driver_version
def _has_node_get_params():
from cuda.core._utils.version import binding_version, driver_version

return driver_version() >= (13, 2, 0)
return driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0)


_HAS_NODE_GET_PARAMS = _driver_has_node_get_params()
_HAS_NODE_GET_PARAMS = _has_node_get_params()


def _bindings_major_version():
Expand Down
Loading