From 00e575bc673172990539e5ad7015ae26c826ecee Mon Sep 17 00:00:00 2001 From: samueljwu <56311527+samueljwu@users.noreply.github.com> Date: Fri, 1 May 2026 22:02:31 +0800 Subject: [PATCH] Add focus ordering and ispta balancing OpenwaterHealth/openlifu-python#449 --- src/openlifu/bf/sequence.py | 15 +++- src/openlifu/plan/protocol.py | 23 +++++-- src/openlifu/plan/solution.py | 126 +++++++++++++++++++++++++++++++--- tests/test_protocol.py | 41 +++++++++++ tests/test_sequence.py | 18 ++++- tests/test_solution.py | 55 +++++++++++++++ 6 files changed, 264 insertions(+), 14 deletions(-) diff --git a/src/openlifu/bf/sequence.py b/src/openlifu/bf/sequence.py index a5601075..882c551f 100644 --- a/src/openlifu/bf/sequence.py +++ b/src/openlifu/bf/sequence.py @@ -27,6 +27,9 @@ class Sequence(DictMixin): pulse_train_count: Annotated[int, OpenLIFUFieldData("Pulse train count", "Number of pulse trains in the sequence")] = 1 """Number of pulse trains in the sequence""" + focus_order: Annotated[list[int] | None, OpenLIFUFieldData("Focus order", "Optional focus index order for each pulse")] = None + """Optional focus index order for each pulse""" + def __post_init__(self): if self.pulse_interval <= 0: raise ValueError("Pulse interval must be positive") @@ -38,6 +41,15 @@ def __post_init__(self): raise ValueError("Pulse train interval must be greater than or equal to the total pulse interval") if self.pulse_train_count <= 0: raise ValueError("Pulse train count must be positive") + if self.focus_order is not None: + if len(self.focus_order) == 0: + raise ValueError("Focus order must not be empty") + if len(self.focus_order) != self.pulse_count: + raise ValueError("Focus order length must match pulse count") + if any(not isinstance(focus_index, int) for focus_index in self.focus_order): + raise TypeError("Focus order entries must be integers") + if any(focus_index < 1 for focus_index in self.focus_order): + raise ValueError("Focus order entries must be positive") def to_table(self) -> pd.DataFrame: """ @@ -49,7 +61,8 @@ def to_table(self) -> pd.DataFrame: {"Name": "Pulse Interval", "Value": self.pulse_interval, "Unit": "s"}, {"Name": "Pulse Count", "Value": self.pulse_count, "Unit": ""}, {"Name": "Pulse Train Interval", "Value": self.pulse_train_interval, "Unit": "s"}, - {"Name": "Pulse Train Count", "Value": self.pulse_train_count, "Unit": ""} + {"Name": "Pulse Train Count", "Value": self.pulse_train_count, "Unit": ""}, + {"Name": "Focus Order", "Value": self.focus_order, "Unit": ""} ] return pd.DataFrame.from_records(records) diff --git a/src/openlifu/plan/protocol.py b/src/openlifu/plan/protocol.py index b1635e84..25bc9ea7 100644 --- a/src/openlifu/plan/protocol.py +++ b/src/openlifu/plan/protocol.py @@ -76,6 +76,9 @@ class Protocol: virtual_fit_options: Annotated[VirtualFitOptions, OpenLIFUFieldData("Virtual fit options", "Configuration of the virtual fit algorithm")] = field(default_factory=VirtualFitOptions) """Configuration of the virtual fit algorithm""" + scaling_options: Annotated[dict, OpenLIFUFieldData("Scaling options", "Options to adjust solution scaling. By default, no additional scaling options are applied")] = field(default_factory=dict) + """Options to adjust solution scaling. By default, no additional scaling options are applied""" + def __post_init__(self): self.logger = logging.getLogger(__name__) @@ -97,6 +100,7 @@ def from_dict(d : Dict[str,Any]) -> Protocol: if "virtual_fit_options" in d: d['virtual_fit_options'] = VirtualFitOptions.from_dict(d['virtual_fit_options']) d["analysis_options"] = SolutionAnalysisOptions.from_dict(d.get("analysis_options", {})) + d["scaling_options"] = d.get("scaling_options", {}) return Protocol(**d) def to_dict(self): @@ -116,6 +120,7 @@ def to_dict(self): "target_constraints": [tc.to_dict() for tc in self.target_constraints], "virtual_fit_options": self.virtual_fit_options.to_dict(), "analysis_options": self.analysis_options.to_dict(), + "scaling_options": self.scaling_options, } @staticmethod @@ -316,8 +321,11 @@ def calc_solution( simulation_result_aggregated: xa.Dataset = xa.Dataset() foci: List[Point] = self.focal_pattern.get_targets(target) + if self.sequence.focus_order is not None and max(self.sequence.focus_order) > len(foci): + raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(foci)})") + # updating solution sequence if pulse mismatch - if (self.sequence.pulse_count % len(foci)) != 0: + if self.sequence.focus_order is None and (self.sequence.pulse_count % len(foci)) != 0: self.fix_pulse_mismatch(on_pulse_mismatch, foci) # run simulation and aggregate the results for focus in foci: @@ -364,14 +372,21 @@ def calc_solution( raise ValueError(f"Cannot scale solution {solution.id} if simulation is not enabled!") self.logger.info(f"Scaling solution {solution.id}...") #TODO can analysis be an attribute of solution ? - solution.scale(self.focal_pattern, analysis_options=analysis_options) + solution.scale(self.focal_pattern, analysis_options=analysis_options, **self.scaling_options) if simulate: # Finally the resulting pressure is max-aggregated and intensity is mean-aggregated, over all focus points . pnp_aggregated = solution.simulation_result['p_min'].max(dim="focal_point_index", keep_attrs=True) ppp_aggregated = solution.simulation_result['p_max'].max(dim="focal_point_index", keep_attrs=True) - # TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times - intensity_aggregated = solution.simulation_result['intensity'].mean(dim="focal_point_index", keep_attrs=True) + focus_counts = solution.get_focus_counts() + focus_weights = xa.DataArray( + focus_counts / np.sum(focus_counts), + dims=("focal_point_index",), + coords={"focal_point_index": solution.simulation_result.coords["focal_point_index"]}, + ) + intensity = solution.simulation_result['intensity'] + intensity_aggregated = (intensity * focus_weights).sum(dim="focal_point_index", keep_attrs=True) + intensity_aggregated.attrs.update(intensity.attrs) simulation_result_aggregated = deepcopy(solution.simulation_result) simulation_result_aggregated = simulation_result_aggregated.drop_dims("focal_point_index") simulation_result_aggregated['p_min'] = pnp_aggregated diff --git a/src/openlifu/plan/solution.py b/src/openlifu/plan/solution.py index 0ab43b5c..2a662a17 100644 --- a/src/openlifu/plan/solution.py +++ b/src/openlifu/plan/solution.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import heapq import json import logging import tempfile @@ -123,6 +124,8 @@ def __post_init__(self): raise ValueError("Pulse train interval must be greater than or equal to the total pulse interval") if self.sequence.pulse_train_count <= 0: raise ValueError("Pulse train count must be positive") + if (self.sequence.focus_order is not None and len(self.foci) > 0 and max(self.sequence.focus_order) > len(self.foci)): + raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(self.foci)})") if len(self.foci)>0 and self.delays is not None and self.delays.shape[0] != len(self.foci): raise ValueError(f"Delays number of foci ({self.delays.shape[0]}) does not match number of foci ({len(self.foci)})") if len(self.foci)>0 and self.apodizations is not None and self.apodizations.shape[0] != len(self.foci): @@ -138,6 +141,83 @@ def num_foci(self) -> int: """Get the number of foci""" return len(self.foci) + def get_focus_order(self) -> np.ndarray: + """Get the focus index order for each pulse.""" + if self.sequence.focus_order is not None: + return np.array(self.sequence.focus_order) + return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1 + + def get_focus_counts(self) -> np.ndarray: + """Get the number of pulses assigned to each focus.""" + focus_order = self.get_focus_order() + return np.array([ + np.sum(focus_order == (focus_index + 1)) + for focus_index in range(self.num_foci()) + ]) + + def compute_balanced_focus_counts(self, balance_metric_values: np.ndarray, pulse_count: int) -> np.ndarray: + """Compute per-focus pulse counts that balance a positive per-focus metric.""" + balance_metric_values = np.asarray(balance_metric_values, dtype=float) + if balance_metric_values.shape != (self.num_foci(),): + raise ValueError(f"Balance metric must have one value per focus ({self.num_foci()})") + if pulse_count < self.num_foci(): + raise ValueError(f"Pulse count ({pulse_count}) must be greater than or equal to number of foci ({self.num_foci()})") + if np.any(~np.isfinite(balance_metric_values)) or np.any(balance_metric_values <= 0): + raise ValueError("Balance metric values must be finite and positive") + + remaining_pulses = pulse_count - self.num_foci() + counts = np.ones(self.num_foci(), dtype=int) + if remaining_pulses == 0: + return counts + + weights = 1 / balance_metric_values + ideal_extra_counts = weights / np.sum(weights) * remaining_pulses + extra_counts = np.floor(ideal_extra_counts).astype(int) + counts += extra_counts + + leftover_pulses = remaining_pulses - int(np.sum(extra_counts)) + remainders = ideal_extra_counts - extra_counts + for focus_index in np.argsort(remainders)[::-1][:leftover_pulses]: + counts[focus_index] += 1 + return counts + + def build_focus_order(self, focus_counts: np.ndarray, ordering: str = "minimize_repeats") -> list[int]: + """Build a focus order from per-focus pulse counts.""" + if ordering != "minimize_repeats": + raise ValueError(f"Unsupported focus ordering '{ordering}'") + focus_counts = np.asarray(focus_counts, dtype=int) + if focus_counts.shape != (self.num_foci(),): + raise ValueError(f"Focus counts must have one value per focus ({self.num_foci()})") + if np.any(focus_counts < 0): + raise ValueError("Focus counts must be non-negative") + + heap = [] + for focus_index, focus_count in enumerate(focus_counts): + if focus_count > 0: + heap.append((-focus_count, focus_index + 1)) + + heapq.heapify(heap) + focus_order = [] + previous_count = 0 + previous_focus_index = None + + while heap or previous_count < 0: + if not heap: + focus_order.append(previous_focus_index) + previous_count += 1 + continue + + focus_count, focus_index = heapq.heappop(heap) + focus_order.append(focus_index) + focus_count += 1 + + if previous_count < 0: + heapq.heappush(heap, (previous_count, previous_focus_index)) + + previous_count = focus_count + previous_focus_index = focus_index + return focus_order + def simulate(self, params: xa.Dataset, sim_options: SimSetup | None = None, @@ -195,7 +275,8 @@ def simulate(self, def analyze(self, simulation_result: xa.Dataset | None = None, options: SolutionAnalysisOptions = SolutionAnalysisOptions(), - param_constraints: Dict[str,ParameterConstraint] | None = None) -> SolutionAnalysis: + param_constraints: Dict[str,ParameterConstraint] | None = None, + focus_counts: np.ndarray | None = None) -> SolutionAnalysis: """Analyzes the treatment solution. Args: @@ -203,6 +284,7 @@ def analyze(self, options: A struct for solution analysis options. param_constraints: A dictionary of parameter constraints to apply to the analysis. The keys are the parameter names and the values are the ParameterConstraint objects. + focus_counts: Optional per-focus pulse counts to use for ITA calculations. Returns: A struct containing the results of the analysis. """ @@ -241,7 +323,7 @@ def analyze(self, solution_analysis.sequence_duration_s = float(self.sequence.pulse_interval * self.sequence.pulse_count * self.sequence.pulse_train_count) else: solution_analysis.sequence_duration_s = float(self.sequence.pulse_train_interval * self.sequence.pulse_train_count) - ita_mWcm2 = rescale_coords(self.get_ita(intensity=simulation_result['intensity'], units="mW/cm^2"), options.distance_units) + ita_mWcm2 = rescale_coords(self.get_ita(intensity=simulation_result['intensity'], units="mW/cm^2", focus_counts=focus_counts), options.distance_units) power_W = np.zeros(self.num_foci()) TIC = np.zeros(self.num_foci()) @@ -422,7 +504,10 @@ def compute_scaling_factors( def scale( self, focal_pattern: FocalPattern, - analysis_options: SolutionAnalysisOptions = SolutionAnalysisOptions() + analysis_options: SolutionAnalysisOptions = SolutionAnalysisOptions(), + balance_method: str | None = None, + balance_metric: str = "mainlobe_ispta_mWcm2", + ordering: str = "minimize_repeats", ) -> None: """ Scale the solution in-place to match the target pressure. @@ -430,6 +515,9 @@ def scale( Args: focal_pattern: FocalPattern analysis_options: plan.solution.SolutionAnalysisOptions + balance_method: Optional method for balancing scaled delivery. Supported: "ispta_repeats". + balance_metric: The per-focus analysis metric used for balancing. + ordering: How to order balanced repeats. Supported: "minimize_repeats". Returns: analysis_scaled: the resulting plan.solution.SolutionAnalysis from scaled solution @@ -446,6 +534,18 @@ def scale( self.apodizations[i] = self.apodizations[i]*apod_factors[i] self.voltage = v1 + if balance_method is None: + return + if balance_method != "ispta_repeats": + raise ValueError(f"Unsupported balance method '{balance_method}'") + baseline_focus_counts = np.ones(self.num_foci(), dtype=int) + scaled_analysis = self.analyze(options=analysis_options, focus_counts=baseline_focus_counts) + if not hasattr(scaled_analysis, balance_metric): + raise ValueError(f"Unknown balance metric '{balance_metric}'") + balance_metric_values = np.array(getattr(scaled_analysis, balance_metric)) + focus_counts = self.compute_balanced_focus_counts(balance_metric_values, self.sequence.pulse_count) + self.sequence.focus_order = self.build_focus_order(focus_counts, ordering=ordering) + def get_pulsetrain_dutycycle(self) -> float: """ Compute the pulse train dutycycle given a sequence. @@ -471,7 +571,12 @@ def get_sequence_dutycycle(self) -> float: sequence_duty_cycle = self.get_pulsetrain_dutycycle() * between_pulsetrain_duty_cycle return sequence_duty_cycle - def get_ita(self, intensity: xa.DataArray | None = None, units: str = "mW/cm^2") -> xa.DataArray: + def get_ita( + self, + intensity: xa.DataArray | None = None, + units: str = "mW/cm^2", + focus_counts: np.ndarray | None = None + ) -> xa.DataArray: """ Calculate the intensity-time-area product for a treatment solution. @@ -480,6 +585,7 @@ def get_ita(self, intensity: xa.DataArray | None = None, units: str = "mW/cm^2") If provided, use this intensity data array instead of the one from the simulation result. units: str Target units. Default "mW/cm^2". + focus_counts: Optional per-focus pulse counts. If not provided, use the sequence focus order. Returns: xa.DataArray @@ -491,10 +597,14 @@ def get_ita(self, intensity: xa.DataArray | None = None, units: str = "mW/cm^2") intensity_scaled = rescale_data_arr(self.simulation_result['intensity'], units) pulsetrain_dutycycle = self.get_pulsetrain_dutycycle() treatment_dutycycle = self.get_sequence_dutycycle() - pulse_seq = (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1 - counts = np.zeros((1, 1, 1, self.num_foci())) - for i in range(self.num_foci()): - counts[0, 0, 0, i] = np.sum(pulse_seq == (i+1)) + if focus_counts is None: + focus_counts = self.get_focus_counts() + focus_counts = np.asarray(focus_counts) + if focus_counts.shape != (self.num_foci(),): + raise ValueError(f"Focus counts must have one value per focus ({self.num_foci()})") + if np.any(focus_counts < 0): + raise ValueError("Focus counts must be non-negative") + counts = focus_counts.reshape((1, 1, 1, self.num_foci())) intensity = intensity_scaled.copy(deep=True) isppa_avg = np.sum(np.expand_dims(intensity.data, axis=-1) * counts, axis=-1) / np.sum(counts) intensity.data = isppa_avg * pulsetrain_dutycycle * treatment_dutycycle diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 3e0d37b1..d4918357 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -3,7 +3,9 @@ import logging from pathlib import Path +import numpy as np import pytest +import xarray as xa from openlifu import Protocol, Transducer from openlifu.bf.focal_patterns import Wheel @@ -29,6 +31,11 @@ def example_wheel_pattern() -> Wheel: return Wheel(num_spokes=6) def test_to_dict_from_dict(example_protocol: Protocol): + example_protocol.scaling_options = { + "balance_method": "ispta_repeats", + "balance_metric": "mainlobe_ispta_mWcm2", + "ordering": "minimize_repeats", + } proto_dict = example_protocol.to_dict() new_protocol = Protocol.from_dict(proto_dict) assert new_protocol == example_protocol @@ -106,3 +113,37 @@ def test_fix_pulse_mismatch( assert example_protocol.sequence.pulse_count == 2*num_foci elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDDOWN: assert example_protocol.sequence.pulse_count == num_foci + + +def test_calc_solution_skips_pulse_mismatch_when_focus_order_present( + example_protocol: Protocol, + example_transducer: Transducer, + example_session: Session, + mocker + ): + """Test explicit focus_order allows pulse counts that are not divisible by number of foci.""" + example_protocol.focal_pattern = Wheel(num_spokes=3) + num_foci = example_protocol.focal_pattern.num_foci() + example_protocol.sequence.pulse_count = 5 + example_protocol.sequence.focus_order = [1, 2, 3, 1, 2] + beamform_mock = mocker.patch.object( + example_protocol, + "beamform", + return_value=(np.zeros(len(example_transducer.elements)), np.ones(len(example_transducer.elements))), + ) + fix_pulse_mismatch_mock = mocker.patch.object(example_protocol, "fix_pulse_mismatch") + + solution, simulation_result_aggregated, solution_analysis = example_protocol.calc_solution( + target=example_session.targets[0], + transducer=example_transducer, + params=xa.Dataset(), + simulate=False, + scale=False, + ) + + assert solution.sequence.focus_order == [1, 2, 3, 1, 2] + assert solution.sequence.pulse_count == 5 + assert beamform_mock.call_count == num_foci + fix_pulse_mismatch_mock.assert_not_called() + assert simulation_result_aggregated is None + assert solution_analysis is None diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 6458ff62..e7ab1cd4 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pytest from helpers import dataclasses_are_equal from openlifu import Sequence @@ -7,6 +8,21 @@ def test_dict_undict_sequence(): """Test that conversion between Sequence and dict works""" - sequence = Sequence(pulse_interval=2, pulse_count=5, pulse_train_interval=11, pulse_train_count=3) + sequence = Sequence(pulse_interval=2, pulse_count=5, pulse_train_interval=11, pulse_train_count=3, focus_order=[1, 2, 1, 3, 2]) reconstructed_sequence = Sequence.from_dict(sequence.to_dict()) assert dataclasses_are_equal(sequence, reconstructed_sequence) + +@pytest.mark.parametrize( + ("focus_order", "error_type", "match"), + [ + ([], ValueError, "must not be empty"), + ([1, 2], ValueError, "length must match pulse count"), + ([1, 2, 1.5], TypeError, "entries must be integers"), + ([1, 2, 0], ValueError, "entries must be positive"), + ([1, 2, -1], ValueError, "entries must be positive"), + ], +) +def test_sequence_focus_order_validation(focus_order, error_type, match): + """Test validation of focus_order values.""" + with pytest.raises(error_type, match=match): + Sequence(pulse_count=3, pulse_train_interval=3, focus_order=focus_order) diff --git a/tests/test_solution.py b/tests/test_solution.py index 56f85a76..1e666ab4 100644 --- a/tests/test_solution.py +++ b/tests/test_solution.py @@ -141,6 +141,61 @@ def test_num_foci(example_solution:Solution): assert example_solution.delays.shape[0] == num_foci assert example_solution.apodizations.shape[0] == num_foci + +def test_solution_get_focus_counts_with_explicit_focus_order(example_solution: Solution): + """Test that explicit focus_order is counted per focus.""" + example_solution.foci = [ + Point(id="focus_1"), + Point(id="focus_2"), + Point(id="focus_3"), + ] + example_solution.sequence = Sequence( + pulse_count=6, + pulse_interval=1, + pulse_train_interval=6, + focus_order=[1, 3, 2, 1, 3, 1], + ) + + np.testing.assert_array_equal(example_solution.get_focus_counts(), np.array([3, 1, 2])) + + +def test_solution_compute_balanced_focus_counts_preserves_total_and_weights_inverse_metric(example_solution: Solution): + """Test ISPTA balancing count allocation preserves pulse count and favors lower metrics.""" + example_solution.foci = [ + Point(id="focus_1"), + Point(id="focus_2"), + Point(id="focus_3"), + ] + + focus_counts = example_solution.compute_balanced_focus_counts( + balance_metric_values=np.array([10, 20, 10]), + pulse_count=8, + ) + + assert np.sum(focus_counts) == 8 + assert np.all(focus_counts >= 1) + assert focus_counts[1] < focus_counts[0] + assert focus_counts[1] < focus_counts[2] + + +def test_solution_build_focus_order_minimizes_repeats(example_solution: Solution): + """Test focus order construction preserves counts and avoids repeats when possible.""" + example_solution.foci = [ + Point(id="focus_1"), + Point(id="focus_2"), + Point(id="focus_3"), + ] + focus_counts = np.array([3, 2, 3]) + + focus_order = example_solution.build_focus_order(focus_counts) + + assert len(focus_order) == np.sum(focus_counts) + np.testing.assert_array_equal( + np.bincount(focus_order, minlength=example_solution.num_foci() + 1)[1:], + focus_counts, + ) + assert all(focus_order[i] != focus_order[i + 1] for i in range(len(focus_order) - 1)) + @pytest.mark.parametrize("compact_representation", [True, False]) def test_json_serialize_deserialize_solution_analysis(compact_representation: bool): """Verify that turning a SolutionAnalysis into json and then re-constructing it gets back to the original"""