Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions kwave/kspaceFirstOrder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def _strip_pml(result, pml_size, ndim, suffixes=_FULL_GRID_SUFFIXES):
}


def _resolve_dtype(value):
"""Normalize a dtype-like input to ``np.float32`` or ``np.float64``.

Accepts numpy dtypes/types (``np.float32``, ``np.float64``), strings
(``"float32"`` etc., plus MATLAB aliases ``"single"`` / ``"double"``),
Python ``float``, ``None`` (default → float64), and the legacy MATLAB
``"off"`` alias for float64. Anything that resolves to a non-float32 /
non-float64 dtype raises ``ValueError`` — the solver isn't validated
for ``float16`` / complex dtypes.

Cupy dtypes (``cp.float32``, ``cp.float64``) work for free because cupy
re-exports numpy's scalar types. Torch / JAX dtypes are not accepted —
they live in different ecosystems and don't translate via ``np.dtype()``;
the error message points the user at the equivalent numpy dtype.
"""
if value is None or value == "off":
return np.float64
try:
resolved = np.dtype(value).type
except TypeError as e:
framework = getattr(type(value), "__module__", "").split(".")[0]
hint = ""
if framework in ("torch", "jax", "jaxlib", "tensorflow"):
hint = f" {framework}.dtype objects aren't supported; pass the equivalent numpy dtype (np.float32 / np.float64)."
raise ValueError(
f"dtype must be a numpy dtype, type, or string (e.g. 'float32', 'single'), got {value!r}.{hint}"
) from e
if resolved is np.float32:
return np.float32
if resolved is np.float64:
return np.float64
raise ValueError(f"dtype must resolve to float32 or float64; got {resolved.__name__} from {value!r}")


def kspaceFirstOrder(
kgrid: kWaveGrid,
medium: kWaveMedium,
Expand All @@ -96,6 +130,7 @@ def kspaceFirstOrder(
smooth_p0: bool = True,
backend: str = "python",
device: str = "cpu",
dtype=None,
save_only: bool = False,
data_path: Optional[str] = None,
quiet: bool = False,
Expand Down Expand Up @@ -142,6 +177,24 @@ def kspaceFirstOrder(
device: ``"cpu"`` or ``"gpu"``. For ``backend="python"`` this
selects NumPy (cpu) vs CuPy (gpu). For ``backend="cpp"`` it
selects the OMP vs CUDA binary. Default ``"cpu"``.
dtype: Numerical precision for state arrays in the Python backend.
Accepts dtype-like input — a numpy dtype (``np.float32``,
``np.float64``; cupy aliases like ``cp.float32`` work since cupy
re-exports numpy's scalar types), a string (``"float32"``,
``"float64"``, ``"single"``, ``"double"``), a Python type
(``float``), or ``None`` for the default (float64). The
MATLAB-style alias ``"off"`` is accepted as a synonym for
float64 to ease migration from the legacy
``SimulationOptions.data_cast``. Torch / JAX dtypes are not
accepted; pass the numpy equivalent (e.g. ``np.float32`` for
``torch.float32``).
``np.float32`` uses roughly half the memory and is faster on
most hardware, at the cost of reduced numerical accuracy.
Only ``float32`` and ``float64`` are supported; other dtypes
raise ``ValueError``. Has no effect on ``backend="cpp"`` (the
C++ binary uses fixed internal precision regardless); a warning
is emitted if ``dtype`` resolves to anything other than float64
with the C++ backend. Default ``None`` (float64).
save_only: When ``True`` (``backend="cpp"`` only), write the HDF5
input file and return without running the binary. Useful for
cluster submission. Default ``False``.
Expand All @@ -167,6 +220,7 @@ def kspaceFirstOrder(
raise ValueError(f"device must be 'cpu' or 'gpu', got {device!r}")
if backend not in ("python", "cpp"):
raise ValueError(f"Unknown backend: {backend!r}. Use 'python' or 'cpp'.")
dtype = _resolve_dtype(dtype)

if isinstance(pml_size, str) and pml_size.lower() == "auto":
pml_size = tuple(int(x) for x in get_optimal_pml_size(kgrid))
Expand Down Expand Up @@ -206,6 +260,7 @@ def kspaceFirstOrder(
pml_size=pml_size,
pml_alpha=pml_alpha,
quiet=quiet,
dtype=dtype,
).run()

elif backend == "cpp":
Expand All @@ -215,6 +270,14 @@ def kspaceFirstOrder(
check_alpha_mode_cpp_compatible(medium)
warn_alpha_power_near_unity_cpp(medium)

if dtype is not np.float64:
warnings.warn(
f"dtype={np.dtype(dtype).name!r} has no effect with backend='cpp'; the C++ binary "
"uses fixed internal precision regardless. Use backend='python' to control "
"computational precision.",
stacklevel=2,
)

if not use_kspace:
warnings.warn(
"use_kspace=False has no effect with backend='cpp'; the C++ binary always applies k-space correction.",
Expand Down
108 changes: 75 additions & 33 deletions kwave/solvers/kspace_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,55 @@ def _to_cpu(x):
return x.get() if hasattr(x, "get") else x


def _expand_to_grid(val, grid_shape, xp, name="parameter"):
def _array_sum(arrays):
"""Sum arrays, preserving dtype.

``sum(arrays)`` starts from Python ``int 0``; under numpy < 2 (NEP 50)
that promotes float32 inputs to float64. Starting from ``arrays[0]``
keeps the result's dtype equal to the elements'.
"""
out = arrays[0]
for a in arrays[1:]:
out = out + a
return out


def _expand_to_grid(val, grid_shape, xp, name="parameter", dtype=float):
if val is None:
raise ValueError(f"Missing required parameter: {name}")
arr = xp.array(val, dtype=float).ravel()
arr = xp.array(val, dtype=dtype).ravel()
grid_size = int(np.prod(grid_shape))
if arr.size == 1:
return xp.full(grid_shape, float(arr[0]), dtype=float)
return xp.full(grid_shape, arr[0], dtype=dtype)
if arr.size == grid_size:
return arr.reshape(grid_shape)
raise ValueError(f"{name} size {arr.size} incompatible with grid size {grid_size}")


def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_size, source_kappa, diff_fn):
def _build_source_op(mask_raw, signal_raw, mode, scale, *, xp, grid_shape, grid_size, source_kappa, diff_fn, dtype=float):
"""Build a source injection operator for one field variable.

Returns a callable (t, field) → field that injects scaled source values.
``dtype`` controls the precision of the source signal buffer (matches the
field precision set by ``data_cast``).
"""
mask = xp.array(mask_raw, dtype=bool).ravel()
if mask.size == 1:
mask = xp.full(grid_shape, bool(mask[0]), dtype=bool).ravel()
n_src = int(xp.sum(mask))

signal_arr = xp.array(signal_raw, dtype=float)
signal_arr = xp.array(signal_raw, dtype=dtype)
if signal_arr.ndim == 1:
signal = signal_arr.reshape(1, -1)
else:
signal = signal_arr.reshape(-1, signal_arr.shape[-1]) if signal_arr.ndim > 2 else signal_arr

scaled = signal * xp.atleast_1d(xp.asarray(scale))[:, None]
scaled = signal * xp.atleast_1d(xp.asarray(scale, dtype=dtype))[:, None]
signal_len = scaled.shape[1]

def get_val(t):
if scaled.shape[0] == 1:
return xp.full(n_src, float(scaled[0, t]))
return xp.full(n_src, scaled[0, t], dtype=dtype)
return scaled[:, t]

def dirichlet(t, field):
Expand All @@ -76,7 +91,7 @@ def dirichlet(t, field):
return flat.reshape(grid_shape)

# Pre-allocate buffer to avoid per-step allocation
_src_buf = xp.zeros(grid_size, dtype=float)
_src_buf = xp.zeros(grid_size, dtype=dtype)

def additive_kspace(t, field):
if t >= signal_len:
Expand Down Expand Up @@ -131,6 +146,7 @@ def __init__(
pml_size=None,
pml_alpha=None,
quiet=False,
dtype=None,
):
self.kgrid = kgrid
self.medium = medium
Expand All @@ -142,6 +158,25 @@ def __init__(
self.quiet = quiet
self._pml_size_override = pml_size
self._pml_alpha_override = pml_alpha
# Compute precision for state arrays. ``None`` defaults to float64 (matches
# MATLAB k-Wave). Only float32 / float64 are validated by the solver.
if dtype is None:
self._dtype = np.float64
elif dtype in (np.float32, np.float64):
self._dtype = dtype
else:
try:
resolved = np.dtype(dtype).type
except TypeError as e:
raise ValueError(f"dtype must be np.float32 or np.float64 (or string equivalent), got {dtype!r}") from e
if resolved not in (np.float32, np.float64):
raise ValueError(f"dtype must resolve to float32 or float64, got {resolved.__name__}")
self._dtype = resolved
# Companion complex dtype for FFT outputs. numpy<2 (np.fft) always upcasts
# to complex128 regardless of input precision; we cast back so the rest of
# the pipeline stays in self._dtype. Harmless on numpy 2+ and on cupy
# (which already respects input precision).
self._complex_dtype = np.complex64 if self._dtype is np.float32 else np.complex128
# kWaveGrid doesn't have pml_size_x attrs; warn if PML will silently be disabled
if pml_size is None:
from kwave.kgrid import kWaveGrid as _KWG
Expand Down Expand Up @@ -190,9 +225,9 @@ def setup(self):
self.Nt = int(self.kgrid.Nt)
self.dt = float(self.kgrid.dt)

self.c0 = _expand_to_grid(self.medium.sound_speed, self.grid_shape, xp, "sound_speed")
self.c0 = _expand_to_grid(self.medium.sound_speed, self.grid_shape, xp, "sound_speed", dtype=self._dtype)
density = getattr(self.medium, "density", None)
self.rho0 = _expand_to_grid(density if density is not None else 1000.0, self.grid_shape, xp, "density")
self.rho0 = _expand_to_grid(density if density is not None else 1000.0, self.grid_shape, xp, "density", dtype=self._dtype)
self.c_ref = float(xp.max(self.c0))

self._setup_sensor_mask()
Expand Down Expand Up @@ -356,26 +391,32 @@ def _setup_pml(self):
if pml_size == 0 or pml_alpha == 0:
shape = [1] * self.ndim
shape[axis] = N
self.pml_list.append(xp.ones(shape, dtype=float))
self.pml_sg_list.append(xp.ones(shape, dtype=float))
self.pml_list.append(xp.ones(shape, dtype=self._dtype))
self.pml_sg_list.append(xp.ones(shape, dtype=self._dtype))
else:
# dimension=2 gives shape (1, N) which we reshape for broadcasting
# dimension=2 gives shape (1, N) which we reshape for broadcasting.
# get_pml returns float64; cast so the per-step PML multiply doesn't
# upcast self.p / self.u (would silently break dtype='single').
pml = get_pml(N, dx, self.dt, self.c_ref, pml_size, pml_alpha, staggered=False, dimension=2, xp=xp)
pml_sg = get_pml(N, dx, self.dt, self.c_ref, pml_size, pml_alpha, staggered=True, dimension=2, xp=xp)

shape = [1] * self.ndim
shape[axis] = N
self.pml_list.append(pml.flatten().reshape(shape))
self.pml_sg_list.append(pml_sg.flatten().reshape(shape))
self.pml_list.append(pml.flatten().reshape(shape).astype(self._dtype))
self.pml_sg_list.append(pml_sg.flatten().reshape(shape).astype(self._dtype))

def _setup_kspace_operators(self):
"""Build k-space gradient/divergence operators for each dimension."""
xp = self.xp
self.k_list = []

# First pass: build k-vectors for each dimension
# First pass: build k-vectors for each dimension.
# ``fftfreq`` returns float64 by default; cast so the k-space operators
# (kappa, op_grad_list, op_div_list, _k_mag) match self._dtype. Without
# this cast, _diff's FFT round-trip with a float64 op upcasts the
# float32 field back to float64 -- silently breaking dtype='single'.
for axis, (N, dx) in enumerate(zip(self.grid_shape, self.spacing)):
k = 2 * np.pi * xp.fft.fftfreq(N, d=dx)
k = (2 * np.pi * xp.fft.fftfreq(N, d=dx)).astype(self._dtype)
shape = [1] * self.ndim
shape[axis] = N
self.k_list.append(k.reshape(shape))
Expand Down Expand Up @@ -424,7 +465,7 @@ def _alpha_neper_and_power(self):
if not _is_enabled(getattr(self.medium, "alpha_coeff", 0)):
return None, None
alpha_power = float(self.xp.array(getattr(self.medium, "alpha_power", 1.5)).flatten()[0])
alpha_coeff = _expand_to_grid(self.medium.alpha_coeff, self.grid_shape, self.xp, "alpha_coeff")
alpha_coeff = _expand_to_grid(self.medium.alpha_coeff, self.grid_shape, self.xp, "alpha_coeff", dtype=self._dtype)
return db2neper(alpha_coeff, alpha_power), alpha_power

def _init_absorption(self, alpha_np, alpha_power):
Expand Down Expand Up @@ -504,9 +545,9 @@ def _init_nonlinearity(self):
self._nonlinearity = lambda rho: 0
self._nl_factor = lambda rho_split: 1.0
else:
self.BonA = _expand_to_grid(BonA_raw, self.grid_shape, self.xp, "BonA")
self.BonA = _expand_to_grid(BonA_raw, self.grid_shape, self.xp, "BonA", dtype=self._dtype)
self._nonlinearity = lambda rho: self.BonA * rho**2 / (2 * self.rho0)
self._nl_factor = lambda rho_split: (2 * sum(rho_split) + self.rho0) / self.rho0
self._nl_factor = lambda rho_split: (2 * _array_sum(rho_split) + self.rho0) / self.rho0

def _setup_source_operators(self):
"""Build time-varying source injection operators.
Expand Down Expand Up @@ -560,6 +601,7 @@ def build_op(mask_raw, signal_raw, mode, scale):
grid_size=grid_size,
source_kappa=self.source_kappa,
diff_fn=self._diff,
dtype=self._dtype,
)

# --- Pressure source (per-axis spacing for non-isotropic grids) ---
Expand Down Expand Up @@ -606,10 +648,10 @@ def _setup_fields(self):
"""Initialize pressure, velocity, and density fields."""
xp = self.xp

self.p = xp.zeros(self.grid_shape, dtype=float)
self.u = [xp.zeros(self.grid_shape, dtype=float) for _ in range(self.ndim)]
self.p = xp.zeros(self.grid_shape, dtype=self._dtype)
self.u = [xp.zeros(self.grid_shape, dtype=self._dtype) for _ in range(self.ndim)]
# Split density per dimension enables independent PML absorption in each direction
self.rho_split = [xp.zeros(self.grid_shape, dtype=float) for _ in range(self.ndim)]
self.rho_split = [xp.zeros(self.grid_shape, dtype=self._dtype) for _ in range(self.ndim)]

if self.use_sg:
self.rho0_staggered = [self._stagger(self.rho0, axis) for axis in range(self.ndim)]
Expand All @@ -623,25 +665,25 @@ def _setup_fields(self):
# Sensor data storage (sized based on record_start_index)
self.sensor_data = {}
if "p" in self.record:
self.sensor_data["p"] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=float)
self.sensor_data["p"] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=self._dtype)
for a in "xyz"[: self.ndim]:
for suffix in ("", "_staggered"):
v = f"u{a}{suffix}"
if v in self.record:
self.sensor_data[v] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=float)
self.sensor_data[v] = xp.zeros((self.n_sensor_points, self.num_recorded_time_points), dtype=self._dtype)

# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)]
Comment on lines 675 to 676
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 unstagger_ops also built with 1j * literal → always complex128 on numpy < 2

The same Python-literal promotion issue as op_grad_list/op_div_list applies here: -1j * self.k_list[ax] is complex128 on numpy < 2, regardless of k_list dtype. Cast to self._complex_dtype after construction to keep the unstaggering FFT in the correct precision.

Suggested change
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)]
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)]

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified empirically on numpy 1.26.4: -1j * k_f32 * spacing / 2 evaluates to complex64 and np.exp(...) preserves complex64 — Python complex is a weak type just like Python float/int, so the multiplication chain stays in complex64 when k_list is float32. So unstagger_ops is already complex64 without an explicit cast.

Additionally, the prior commit (22ec78d) already added .astype(self._dtype, copy=False) to the staggered ifftn .real line in step(), so the recorded output stays float32 even if unstagger_ops were ever to change precision.

Added a regression test (d545952) that records ('p', 'ux_staggered', 'uy_staggered') with dtype=np.float32 and asserts each is float32 — locks the behavior in for both numpy 1.26.4 and 2.2.6.

@greptile-apps re-review

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empirical result doesn't match what NEP 50's own migration table documents. The NEP 50 spec explicitly shows the old (numpy < 2) vs new (numpy ≥ 2 / NEP 50) behaviour:

Expression numpy < 2 (old) numpy ≥ 2 (NEP 50)
(3j + array(3, complex64)).dtype complex128 complex64
(float32(1) + 1j).dtype complex128 complex64

Source: https://numpy.org/neps/nep-0050-scalar-promotion.html

Python complex literals are not weak types in numpy < 2 — they behave like np.complex128 in the type promotion system. The "weak scalar" treatment is exactly what NEP 50 introduced. So on numpy 1.26.4, -1j * k_f32 * spacing / 2complex128, and np.exp(...) of a complex128 array → complex128.

It's possible the empirical test ran against a numpy 2.x environment, or that k_list[ax] was already float64 at the point of measurement (defeating the test). The regression test added in d545952 checks recorded sensor_data["p"] and "ux_staggered" dtypes, but those go through the .astype(self._dtype, copy=False) guard added in step() — so they'll pass regardless of what unstagger_ops itself is. The dtype of unstagger_ops isn't directly asserted.

The fix remains: cast the list comprehension result to self._complex_dtype:

Suggested change
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)]
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)]

Tip: You can give me more context with .greptile/rules.md so I get this right next time.


# Initial pressure source (p0)
p0_raw = getattr(self.source, "p0", 0)
if _is_enabled(p0_raw):
p0 = _expand_to_grid(p0_raw, self.grid_shape, xp, "p0")
p0 = _expand_to_grid(p0_raw, self.grid_shape, xp, "p0", dtype=self._dtype)
if self.smooth_p0 and self.ndim >= 2:
from kwave.utils.filters import smooth

# smooth() is order-agnostic (uses FFT on shape)
p0 = xp.asarray(smooth(_to_cpu(p0), restore_max=True))
p0 = xp.asarray(smooth(_to_cpu(p0), restore_max=True), dtype=self._dtype)
self._p0_initial = p0
else:
self._p0_initial = None
Expand All @@ -657,16 +699,16 @@ def step(self):

# Momentum equation: du_i/dt = -grad_i(p)/rho, with PML
# Share forward FFT of p across all gradient axes
P = xp.fft.fftn(self.p)
P = xp.fft.fftn(self.p).astype(self._complex_dtype, copy=False)
for i in range(self.ndim):
pml_sg = self.pml_sg_list[i]
grad_p_i = xp.real(xp.fft.ifftn(self.op_grad_list[i] * P))
grad_p_i = xp.real(xp.fft.ifftn(self.op_grad_list[i] * P)).astype(self._dtype, copy=False)
self.u[i] = pml_sg * (pml_sg * self.u[i] - self.dt_over_rho0[i] * grad_p_i)
self.u[i] = self._source_u_ops[i](self.t, self.u[i])

# Mass conservation: drho_i/dt = -rho0 * div_i(u_i) * nl_factor, with PML
nl_factor = self._nl_factor(self.rho_split)
div_u_total = xp.zeros(self.grid_shape, dtype=float)
div_u_total = xp.zeros(self.grid_shape, dtype=self._dtype)
for i in range(self.ndim):
pml = self.pml_list[i]
div_u_i = self._diff(self.u[i], self.op_div_list[i])
Expand All @@ -675,7 +717,7 @@ def step(self):
self.rho_split[i] = self._source_p_ops[i](self.t, self.rho_split[i])

# Equation of state: p = c0^2 * (rho + absorption - dispersion + nonlinearity)
rho_total = sum(self.rho_split)
rho_total = _array_sum(self.rho_split)
self.p = self.c0_sq * (
rho_total + self._absorption(div_u_total) - self._dispersion(rho_total) + self._nonlinearity(rho_total)
)
Expand All @@ -697,7 +739,7 @@ def step(self):
self.sensor_data["p"][:, file_index] = self._extract(self.p)
for i, a in enumerate("xyz"[: self.ndim]):
if f"u{a}" in self.sensor_data: # non-staggered (collocated with pressure)
shifted = xp.real(xp.fft.ifftn(self.unstagger_ops[i] * xp.fft.fftn(self.u[i])))
shifted = xp.real(xp.fft.ifftn(self.unstagger_ops[i] * xp.fft.fftn(self.u[i]))).astype(self._dtype, copy=False)
self.sensor_data[f"u{a}"][:, file_index] = self._extract(shifted)
if f"u{a}_staggered" in self.sensor_data: # raw staggered grid
self.sensor_data[f"u{a}_staggered"][:, file_index] = self._extract(self.u[i])
Expand Down Expand Up @@ -751,7 +793,7 @@ def _diff(self, f, op):
if op is None:
return f
xp = self.xp
return xp.real(xp.fft.ifftn(op * xp.fft.fftn(f)))
return xp.real(xp.fft.ifftn(op * xp.fft.fftn(f))).astype(self._dtype, copy=False)

def _stagger(self, arr, axis):
"""Compute staggered grid values (average neighbors along axis)."""
Expand Down
Loading
Loading