From 1d0f975d2e784d0e60fd259677e28b00d35819d1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 20 Mar 2026 22:38:52 +0100 Subject: [PATCH 1/7] MAINT: remove explicit xp.complex128 from scalars() strategy NB: while at it, note that complex64 cases used to generate values of 32 bits for *both real and imag parts*: so in a range of complex32 / two float16. --- array_api_tests/hypothesis_helpers.py | 31 +++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index e870a61c..4669421c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -20,7 +20,7 @@ from . import xps from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype -from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128 +from ._array_module import broadcast_to, eye, full from .stubs import category_to_funcs from .pytest_helpers import nargs from .typing import Array, DataType, Scalar, Shape @@ -465,26 +465,21 @@ def scalars(draw, dtypes, finite=False, **kwds): m, M = dh.dtype_ranges[dtype] min_value = kwds.get('min_value', m) max_value = kwds.get('max_value', M) - return draw(integers(min_value, max_value)) + elif dtype == bool_dtype: return draw(booleans()) - elif dtype == float64: - if finite: - return draw(floats(allow_nan=False, allow_infinity=False, **kwds)) - return draw(floats(), **kwds) - elif dtype == float32: - if finite: - return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds)) - return draw(floats(width=32, **kwds)) - elif dtype == complex64: - if finite: - return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False)) - return draw(complex_numbers(width=32)) - elif dtype == complex128: - if finite: - return draw(complex_numbers(allow_nan=False, allow_infinity=False)) - return draw(complex_numbers()) + + elif dtype in dh.real_float_dtypes: + f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict() + width = dh.dtype_nbits[dtype] # 32 or 64 + return draw(floats(width=width, **f_kwds, **kwds)) + + elif dtype in dh.complex_dtypes: + f_kwds = dict(allow_nan=False, allow_infinity=False) if finite else dict() + width = dh.dtype_nbits[dtype] # 64 or 128 + return draw(complex_numbers(width=width, **f_kwds, **kwds)) + else: raise ValueError(f"Unrecognized dtype {dtype}") From e473c80f28b63a8a4e6b46ae062e131668cf9e67 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 21 Mar 2026 10:39:58 +0100 Subject: [PATCH 2/7] MAINT: remove explicit complex128 from fft assertions, of real<->complex Reuse real_dtype_for/complex_dtype_for, which encapsulate the same mappings. --- array_api_tests/pytest_helpers.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index be601770..76d1c1b4 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -161,11 +161,8 @@ def assert_dtype( def assert_float_to_complex_dtype( func_name: str, *, in_dtype: DataType, out_dtype: DataType ): - if in_dtype == xp.float32: - expected = xp.complex64 - else: - assert in_dtype == xp.float64 # sanity check - expected = xp.complex128 + assert in_dtype in dh.real_float_dtypes # sanity check + expected = dh.complex_dtype_for(in_dtype) assert_dtype( func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected ) @@ -174,13 +171,8 @@ def assert_float_to_complex_dtype( def assert_complex_to_float_dtype( func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype" ): - if in_dtype == xp.complex64: - expected = xp.float32 - elif in_dtype == xp.complex128: - expected = xp.float64 - else: - assert in_dtype in (xp.float32, xp.float64) # sanity check - expected = in_dtype + assert in_dtype in dh.all_float_dtypes + expected = dh.real_dtype_for(in_dtype) assert_dtype( func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name ) From 5c50ff0f972da1a258a731e74ae460659492c1d4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 21 Mar 2026 10:52:10 +0100 Subject: [PATCH 3/7] MAINT: remove explicit dtypes from finfo tests --- array_api_tests/test_data_type_functions.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 2189ef88..84660771 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -215,12 +215,7 @@ def test_finfo_dtype(dtype): try: out = xp.finfo(dtype) - if dtype == xp.complex64: - assert out.dtype == xp.float32 - elif dtype == xp.complex128: - assert out.dtype == xp.float64 - else: - assert out.dtype == dtype + assert out.dtype == dh.real_dtype_for(dtype) # Guard vs. numpy.dtype.__eq__ lax comparison assert not isinstance(out.dtype, str) From 914124aa7358dd1dbcdfc2e4e1ca1757ff7f75d3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 21 Mar 2026 11:15:19 +0100 Subject: [PATCH 4/7] STY: rm whitespace-only lines --- array_api_tests/test_special_cases.py | 66 +++++++++++++-------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index a919fe08..4a94f922 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -496,7 +496,7 @@ def check_result(result: float) -> bool: def parse_complex_value(value_str: str) -> complex: """ Parses a complex value string to return a complex number, e.g. - + >>> parse_complex_value('+0 + 0j') 0j >>> parse_complex_value('NaN + NaN j') @@ -507,13 +507,13 @@ def parse_complex_value(value_str: str) -> complex: 1.5707963267948966j >>> parse_complex_value('+infinity + 3πj/4') (inf+2.356194490192345j) - + Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M" """ m = r_complex_value.match(value_str) if m is None: raise ParseError(value_str) - + # Parse real part with its sign # Normalize ± to + (we choose positive arbitrarily since sign is unspecified) real_sign = m.group(1) if m.group(1) else "+" @@ -521,7 +521,7 @@ def parse_complex_value(value_str: str) -> complex: real_sign = '+' real_val_str = m.group(2) real_val = parse_value(real_sign + real_val_str) - + # Parse imaginary part with its sign # Normalize ± to + for imaginary part as well imag_sign = m.group(3) @@ -536,9 +536,9 @@ def parse_complex_value(value_str: str) -> complex: imag_val_str_raw = m.group(5) # Strip trailing 'j' if present: "0j" -> "0" imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw - + imag_val = parse_value(imag_sign + imag_val_str) - + return complex(real_val, imag_val) @@ -548,10 +548,10 @@ def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]: """ real_check = make_strict_eq(v.real) imag_check = make_strict_eq(v.imag) - + def strict_eq_complex(z: complex) -> bool: return real_check(z.real) and imag_check(z.imag) - + return strict_eq_complex @@ -560,7 +560,7 @@ def parse_complex_cond( ) -> Tuple[Callable[[complex], bool], str, FromDtypeFunc]: """ Parses complex condition strings for real (a) and imaginary (b) parts. - + Returns: - cond: Function that checks if a complex number meets the condition - expr: String expression for the condition @@ -569,16 +569,16 @@ def parse_complex_cond( # Parse conditions for real and imaginary parts separately a_cond, a_expr_template, a_from_dtype = parse_cond(a_cond_str) b_cond, b_expr_template, b_from_dtype = parse_cond(b_cond_str) - + # Create compound condition def complex_cond(z: complex) -> bool: return a_cond(z.real) and b_cond(z.imag) - + # Create expression a_expr = a_expr_template.replace("{}", "real(x_i)") b_expr = b_expr_template.replace("{}", "imag(x_i)") expr = f"{a_expr} and {b_expr}" - + # Create strategy that generates complex numbers def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]: assert len(kw) == 0 # sanity check @@ -589,7 +589,7 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]: real_strat = a_from_dtype(float_dtype) imag_strat = b_from_dtype(float_dtype) return st.builds(complex, real_strat, imag_strat) - + return complex_cond, expr, complex_from_dtype @@ -609,7 +609,7 @@ def _check_component_with_tolerance(actual: float, expected: float, allow_any_si def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]: """ Parses a complex result string to return a checker and expression. - + Handles cases like: - "``+0 + 0j``" - exact complex value - "``0 + NaN j`` (sign of the real component is unspecified)" @@ -618,7 +618,7 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st # Check for unspecified sign notes (text-based detection) unspecified_real_sign = "sign of the real component is unspecified" in result_str unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str - + # Extract the complex value from backticks - need to handle spaces in complex values # Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j") m = re.search(r"``([^`]+)``", result_str) @@ -640,12 +640,12 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st # Check if the value contains π expressions (for approximate comparison) has_pi = 'π' in value_str - + try: expected = parse_complex_value(value_str) except ParseError: raise ParseError(result_str) - + # Create checker based on whether signs are unspecified and whether π is involved if has_pi: # Use approximate equality for both real and imaginary parts if they involve π @@ -670,7 +670,7 @@ def check_result(z: complex) -> bool: else: # Exact match including signs check_result = make_strict_eq_complex(expected) - + expr = value_str return check_result, expr else: @@ -884,35 +884,34 @@ def parse_unary_case_block(case_block: str, func_name: str, record_list: Optiona cases = [] # Check if the case block contains complex cases by looking for the marker in_complex_section = r_complex_marker.search(case_block) is not None - + for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - + # Record this special case if a record list is provided if record_list is not None: record_list.append(f"{func_name}: {case_str}.") - - + # Try to parse complex cases if we're in the complex section if in_complex_section and (m := r_complex_case.search(case_str)): try: a_cond_str = m.group(1) b_cond_str = m.group(2) result_str = m.group(3) - + # Skip cases with complex expressions like "cis(b)" if "cis" in result_str or "*" in result_str: warn(f"case for {func_name} not machine-readable: '{case_str}'") continue - + # Parse the complex condition and result complex_cond, cond_expr, complex_from_dtype = parse_complex_cond( a_cond_str, b_cond_str ) _check_result, result_expr = parse_complex_result(result_str) - + check_result = make_complex_unary_check_result(_check_result) - + case = UnaryCase( cond_expr=cond_expr, cond=complex_cond, @@ -926,7 +925,7 @@ def parse_unary_case_block(case_block: str, func_name: str, record_list: Optiona except ParseError as e: warn(f"case for {func_name} not machine-readable: '{e.value}'") continue - + # Parse regular (real-valued) cases if r_already_int_case.search(case_str): cases.append(already_int_case) @@ -1394,11 +1393,11 @@ def parse_binary_case_block(case_block: str, func_name: str, record_list: Option cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - + # Record this special case if a record list is provided if record_list is not None: record_list.append(f"{func_name}: {case_str}.") - + if r_redundant_case.search(case_str): continue if r_binary_case.match(case_str): @@ -1528,7 +1527,7 @@ def test_unary(func_name, func, case): # drawing multiple examples like a normal test, or just hard-coding a # single example test case without using hypothesis. filterwarnings('ignore', category=NonInteractiveExampleWarning) - + # Use the is_complex flag to determine the appropriate dtype if case.is_complex: dtype = xp.complex128 @@ -1536,16 +1535,16 @@ def test_unary(func_name, func, case): else: dtype = xp.float64 in_value = case.cond_from_dtype(dtype).example() - + # Create array and compute result based on dtype x = xp.asarray(in_value, dtype=dtype) out = func(x) - + if case.is_complex: out_value = complex(out) else: out_value = float(out) - + assert case.check_result(in_value, out_value), ( f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" ) @@ -1572,7 +1571,6 @@ def test_binary(func_name, func, case, data): ) - @pytest.mark.parametrize("iop_name, iop, case", iop_params) @settings(max_examples=1) @given(data=st.data()) From 3ab78bddce168fc59f3c5f88f442792991b8762e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 21 Mar 2026 11:26:30 +0100 Subject: [PATCH 5/7] ENH: add dtype_helpers.widest_{real,complex}_dtype, use in test_special_cases --- array_api_tests/dtype_helpers.py | 13 +++++++++++++ array_api_tests/test_special_cases.py | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index a479095f..7e66ca65 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -281,10 +281,23 @@ def __contains__(self, other): ) +# complex128 if available else complex64 +widest_complex_dtype = max( + [(dt, dtype_nbits[dt]) for dt in complex_dtypes], key=lambda x: x[1] +)[0] + + +# float64 if available else float32 +widest_real_dtype = max( + [(dt, dtype_nbits[dt]) for dt in real_float_dtypes], key=lambda x: x[1] +)[0] + + dtype_components = _make_dtype_mapping_from_names( {"complex64": xp.float32, "complex128": xp.float64} ) + def as_real_dtype(dtype): """ Return the corresponding real dtype for a given floating-point dtype. diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 4a94f922..deb79915 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1530,10 +1530,10 @@ def test_unary(func_name, func, case): # Use the is_complex flag to determine the appropriate dtype if case.is_complex: - dtype = xp.complex128 + dtype = dh.widest_complex_dtype in_value = case.cond_from_dtype(dtype).example() else: - dtype = xp.float64 + dtype = dh.widest_real_dtype in_value = case.cond_from_dtype(dtype).example() # Create array and compute result based on dtype From e9c8606ef7370b5df20725d5a91f9b9fc4735daa Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 21 Mar 2026 12:09:42 +0100 Subject: [PATCH 6/7] MAINT: avoid explicit xp.float64 in test_special_cases --- array_api_tests/test_creation_functions.py | 2 +- array_api_tests/test_special_cases.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 1723fdaa..77e5dbb3 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -198,7 +198,7 @@ def test_arange(dtype, data): ), f"out[0]={out[0]}, but should be {_start} {f_func}" except Exception as exc: ph.add_note(exc, repro_snippet) - raise + raise @given(shape=hh.shapes(min_side=1), data=st.data()) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index deb79915..978a64df 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1556,10 +1556,11 @@ def test_unary(func_name, func, case): def test_binary(func_name, func, case, data): # We don't use example() like in test_unary because the same internal shared # strategies used in both x1's and x2's don't "sync" with example() draws. - x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") - x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") - x1 = xp.asarray(x1_value, dtype=xp.float64) - x2 = xp.asarray(x2_value, dtype=xp.float64) + dtyp = dh.widest_real_dtype # float64 if available else float32 + x1_value = data.draw(case.x1_cond_from_dtype(dtyp), label="x1_value") + x2_value = data.draw(case.x2_cond_from_dtype(dtyp), label="x2_value") + x1 = xp.asarray(x1_value, dtype=dtyp) + x2 = xp.asarray(x2_value, dtype=dtyp) out = func(x1, x2) out_value = float(out) @@ -1576,10 +1577,11 @@ def test_binary(func_name, func, case, data): @given(data=st.data()) def test_iop(iop_name, iop, case, data): # See test_binary comment - x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") - x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") - x1 = xp.asarray(x1_value, dtype=xp.float64) - x2 = xp.asarray(x2_value, dtype=xp.float64) + dtyp = dh.widest_real_dtype + x1_value = data.draw(case.x1_cond_from_dtype(dtyp), label="x1_value") + x2_value = data.draw(case.x2_cond_from_dtype(dtyp), label="x2_value") + x1 = xp.asarray(x1_value, dtype=dtyp) + x2 = xp.asarray(x2_value, dtype=dtyp) res = iop(x1, x2) res_value = float(res) From 7b7713f8dee626c56d086469c96d64f939747ab4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 21 Mar 2026 13:09:32 +0100 Subject: [PATCH 7/7] MAINT: floa64->float32 in test_signatures --- array_api_tests/test_signatures.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 1c9a8ef6..ab1440cf 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -130,11 +130,11 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str: { "stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"}, "iinfo": {"type": "xp.int64"}, - "finfo": {"type": "xp.float64"}, - "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"}, - "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"}, + "finfo": {"type": "xp.float32"}, + "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float32)"}, + "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float32)"}, "solve": { - a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"] + a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float32)" for a in ["x1", "x2"] }, "outer": {"x1": "xp.ones((5,))", "x2": "xp.ones((5,))"}, },