Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,23 @@ def __contains__(self, other):
)


# complex128 if available else complex64
widest_complex_dtype = max(
[(dt, dtype_nbits[dt]) for dt in complex_dtypes], key=lambda x: x[1]
)[0]


# float64 if available else float32
widest_real_dtype = max(
[(dt, dtype_nbits[dt]) for dt in real_float_dtypes], key=lambda x: x[1]
)[0]


dtype_components = _make_dtype_mapping_from_names(
{"complex64": xp.float32, "complex128": xp.float64}
)


def as_real_dtype(dtype):
"""
Return the corresponding real dtype for a given floating-point dtype.
Expand Down
31 changes: 13 additions & 18 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import xps
from ._array_module import _UndefinedStub
from ._array_module import bool as bool_dtype
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
from ._array_module import broadcast_to, eye, full
from .stubs import category_to_funcs
from .pytest_helpers import nargs
from .typing import Array, DataType, Scalar, Shape
Expand Down Expand Up @@ -465,26 +465,21 @@ def scalars(draw, dtypes, finite=False, **kwds):
m, M = dh.dtype_ranges[dtype]
min_value = kwds.get('min_value', m)
max_value = kwds.get('max_value', M)

return draw(integers(min_value, max_value))

elif dtype == bool_dtype:
return draw(booleans())
elif dtype == float64:
if finite:
return draw(floats(allow_nan=False, allow_infinity=False, **kwds))
return draw(floats(), **kwds)
elif dtype == float32:
if finite:
return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds))
return draw(floats(width=32, **kwds))
elif dtype == complex64:
if finite:
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
return draw(complex_numbers(width=32))
elif dtype == complex128:
if finite:
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
return draw(complex_numbers())

elif dtype in dh.real_float_dtypes:
f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict()
width = dh.dtype_nbits[dtype] # 32 or 64
return draw(floats(width=width, **f_kwds, **kwds))

elif dtype in dh.complex_dtypes:
f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict()
width = dh.dtype_nbits[dtype] # 64 or 128
return draw(complex_numbers(width=width, **f_kwds, **kwds))

else:
raise ValueError(f"Unrecognized dtype {dtype}")

Expand Down
16 changes: 4 additions & 12 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,8 @@ def assert_dtype(
def assert_float_to_complex_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType
):
if in_dtype == xp.float32:
expected = xp.complex64
else:
assert in_dtype == xp.float64 # sanity check
expected = xp.complex128
assert in_dtype in dh.real_float_dtypes # sanity check
expected = dh.complex_dtype_for(in_dtype)
assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
)
Expand All @@ -174,13 +171,8 @@ def assert_float_to_complex_dtype(
def assert_complex_to_float_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
):
if in_dtype == xp.complex64:
expected = xp.float32
elif in_dtype == xp.complex128:
expected = xp.float64
else:
assert in_dtype in (xp.float32, xp.float64) # sanity check
expected = in_dtype
assert in_dtype in dh.all_float_dtypes
expected = dh.real_dtype_for(in_dtype)
assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
)
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_arange(dtype, data):
), f"out[0]={out[0]}, but should be {_start} {f_func}"
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise
raise


@given(shape=hh.shapes(min_side=1), data=st.data())
Expand Down
7 changes: 1 addition & 6 deletions array_api_tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,7 @@ def test_finfo_dtype(dtype):
try:
out = xp.finfo(dtype)

if dtype == xp.complex64:
assert out.dtype == xp.float32
elif dtype == xp.complex128:
assert out.dtype == xp.float64
else:
assert out.dtype == dtype
assert out.dtype == dh.real_dtype_for(dtype)

# Guard vs. numpy.dtype.__eq__ lax comparison
assert not isinstance(out.dtype, str)
Expand Down
8 changes: 4 additions & 4 deletions array_api_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
{
"stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
"iinfo": {"type": "xp.int64"},
"finfo": {"type": "xp.float64"},
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
"finfo": {"type": "xp.float32"},
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float32)"},
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float32)"},
"solve": {
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float32)" for a in ["x1", "x2"]
},
"outer": {"x1": "xp.ones((5,))", "x2": "xp.ones((5,))"},
},
Expand Down
Loading
Loading