diff --git a/cuda_bindings/tests/nvml/test_device.py b/cuda_bindings/tests/nvml/test_device.py index c9cb967efb8..4b5f6a1e160 100644 --- a/cuda_bindings/tests/nvml/test_device.py +++ b/cuda_bindings/tests/nvml/test_device.py @@ -4,6 +4,7 @@ from functools import cache +import numpy as np import pytest from cuda.bindings import nvml @@ -78,7 +79,7 @@ def test_get_nv_link_supported_bw_modes(all_devices): assert not hasattr(modes, "total_bw_modes") for mode in modes.bw_modes: - assert isinstance(mode, int) + assert isinstance(mode, np.uint8) def test_device_get_pdi(all_devices):