From c9346d84750b1190bf80a647b1557a6442842f0c Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Apr 2026 12:09:51 -0500 Subject: [PATCH 01/10] test: remove setting maxk since this now works too? --- tests/jax/test_interpolatedimage_utils.py | 6 +- tests/jax/test_metacal_jax.py | 73 ----------------------- 2 files changed, 2 insertions(+), 77 deletions(-) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 97f07663..02e5dac1 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -234,8 +234,7 @@ def test_interpolatedimage_utils_comp_to_galsim( ) np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) - # FIXME: match maxk - np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0.5, atol=0) + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) assert jgii.maxk >= gii.maxk kxvals = [ (0, 0), @@ -485,8 +484,7 @@ def test_interpolatedimage_utils_comp_stepk_maxk_to_galsim( np.testing.assert_allclose(gR, lgR, rtol=0, atol=1e-6) np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) - # FIXME: make maxk match - np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0.5, atol=0) + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) # this is a copy of the galsim C++ algorithm in a pure python diff --git a/tests/jax/test_metacal_jax.py b/tests/jax/test_metacal_jax.py index 9e2aeb25..a7f0041a 100644 --- a/tests/jax/test_metacal_jax.py +++ b/tests/jax/test_metacal_jax.py @@ -17,28 +17,22 @@ def _metacal_galsim( scale, target_fwhm, g1, - iim_kwargs, - ipsf_kwargs, - inse_kwargs, nk, ): iim = _galsim.InterpolatedImage( _galsim.ImageD(im), scale=scale, x_interpolant="lanczos15", - **iim_kwargs, ) ipsf = _galsim.InterpolatedImage( _galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15", - **ipsf_kwargs, ) inse = _galsim.InterpolatedImage( _galsim.ImageD(np.rot90(nse_im, 1)), scale=scale, x_interpolant="lanczos15", - **inse_kwargs, ) ppsf_iim = _galsim.Convolve(iim, _galsim.Deconvolve(ipsf)) @@ -153,34 +147,6 @@ def test_metacal_comp_to_galsim(nse): nse_im = rng.normal(size=im.shape, scale=nse) im += rng.normal(size=im.shape, scale=nse) - # jax galsim and galsim set stepk and maxk differently due to slight - # algorithmic differences. We force them to be the same here for this - # test so it passes. - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - iim_kwargs = { - "_force_maxk": iim.maxk.item(), - } - inse = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(jnp.rot90(nse_im, 1)), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - inse_kwargs = { - "_force_maxk": inse.maxk.item(), - } - ipsf = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" - ) - ipsf_kwargs = { - "_force_maxk": ipsf.maxk.item(), - } - gt0 = time.time() gres = _metacal_galsim( im.copy(), @@ -189,9 +155,6 @@ def test_metacal_comp_to_galsim(nse): scale, target_fwhm, g1, - iim_kwargs, - ipsf_kwargs, - inse_kwargs, 128, ) gt0 = time.time() - gt0 @@ -252,7 +215,6 @@ def test_metacal_vmap(ntest): ims = [] nse_ims = [] psfs = [] - init_done = False for _seed in range(ntest): seed = _seed + start_seed rng = np.random.RandomState(seed) @@ -287,37 +249,6 @@ def test_metacal_vmap(ntest): psfs.append(psf) nse_ims.append(nse_im) - if not init_done: - init_done = True - - # jax galsim and galsim set stepk and maxk differently due to slight - # algorithmic differences. We force them to be the same here for this - # test so it passes. - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - iim_kwargs = { - "_force_maxk": iim.maxk.item(), - } - inse = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(jnp.rot90(nse_im, 1)), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - inse_kwargs = { - "_force_maxk": inse.maxk.item(), - } - ipsf = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" - ) - ipsf_kwargs = { - "_force_maxk": ipsf.maxk.item(), - } - ims = np.stack(ims) psfs = np.stack(psfs) nse_ims = np.stack(nse_ims) @@ -331,9 +262,6 @@ def test_metacal_vmap(ntest): scale, target_fwhm, g1, - iim_kwargs, - ipsf_kwargs, - inse_kwargs, 128, ) gt0 = time.time() - gt0 @@ -417,7 +345,6 @@ def test_metacal_iimage_with_noise(nse, draw_method): scale=scale, x_interpolant="lanczos15", gsparams=_galsim.GSParams(minimum_fft_size=nk), - _force_maxk=jgiim.maxk.item(), ) def _plot_real(gim, jgim): From 1b8f8f6f2db150fc422d176bd45d8a95781c87ee Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Apr 2026 12:14:36 -0500 Subject: [PATCH 02/10] test: rename test to make purpose clearer --- tests/jax/test_interpolatedimage_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 02e5dac1..67a8e7b5 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -182,7 +182,7 @@ def test_interpolatedimage_utils_stepk_maxk(): ], ) @pytest.mark.parametrize("method", ["kValue", "xValue"]) -def test_interpolatedimage_utils_comp_to_galsim( +def test_interpolatedimage_utils_comp_to_galsim_xk_value( method, ref_array, offset_x, @@ -235,7 +235,6 @@ def test_interpolatedimage_utils_comp_to_galsim( np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) - assert jgii.maxk >= gii.maxk kxvals = [ (0, 0), (-5, -5), From 86d669bf38de552919eedd427c5539092ca91aff Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Apr 2026 15:43:20 -0500 Subject: [PATCH 03/10] test: run new test suite --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 1a490c3b..23368789 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 1a490c3b558fddf2cab1fc0e6d449b73fa3b4eda +Subproject commit 2336878979462fc726fabc7fbc5a89fba4ef2648 From 5a7b1c35e18a509fad4f234fcd5b1cf32c2e11c7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Apr 2026 15:53:30 -0500 Subject: [PATCH 04/10] test: tighter tolerance --- tests/jax/test_moffat_comp_galsim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 20d026e8..a6a44910 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -46,7 +46,7 @@ def test_moffat_comp_galsim_maxk(psf, thresh): f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}", flush=True, ) - np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=1e-4, atol=0) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=5e-5, atol=0) np.testing.assert_allclose( psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-5 ) From 8902ea63d0e5b80a4b06b0803f8eab948275a1bb Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Apr 2026 05:53:36 -0500 Subject: [PATCH 05/10] feat: add galsim maxk version --- jax_galsim/interpolatedimage.py | 88 ++++++++++++++++++++++++++++++--- tests/jax/test_metacal_jax.py | 18 ++++--- 2 files changed, 91 insertions(+), 15 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index d652e8a9..71cf02c2 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1214,7 +1214,86 @@ def _inner_comp_find_maxk(arr, thresh, kx, ky): -jnp.inf, ) ) - return jnp.maximum(max_kx, max_ky) + # galsim adds one pixel at the end so that maxk is + # the k value where things do not pass the threshold, + # so we do that here too. + return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] + + +# I am keeping this bit of dead code for posterity. +# It uses a memory-intensive algorithm that matches galsim's +# approach for maxk. In practice, the direct search +# above has MUCH better performance. The code below won't even run +# on my laptop for parts of the galsim test suite due to memory +# use. - MRB +# @jax.jit +# def _inner_comp_find_maxk(a, thresh, kx, ky): +# # Galsim searches the boundary pixels of increasing squares +# # at a given distance. It finds the location of the start of +# # the first five boundaries where none of the pixels are above +# # threshold. To do this operation in JAX, we first compute +# # a cumulative version which is the number of pixels within a +# # distance that are above threshold. Then we search for the +# # location of the first string of five adjacent distances where the +# # value stays the same. +# +# # step 0) compute # of pixels within a distance above threshold +# # This bit of code builds an array of shape (ny, nx, # of distances) +# # that holds the number of pixels with |kx| <= distance, +# # |ky| <= distance, and value |a|^2 > thresh^2. +# # We build the array by broadcasting. +# val = (a * a.conjugate()).real +# val = jnp.reshape(val, (a.shape[0], a.shape[1], 1)) +# +# # we add one pixel scale here so that we have the distances +# # starting at one pixel scale. GalSim's version adds a pixel +# # scale at the end, so that matches. +# scale = kx[0, -1] - kx[0, -2] +# dk = jnp.reshape(kx[0, :] + scale, (1, 1, -1)) +# +# msk_kx = jnp.reshape(jnp.abs(kx), (a.shape[0], a.shape[1], 1)) <= dk +# msk_ky = jnp.reshape(jnp.abs(ky), (a.shape[0], a.shape[1], 1)) <= dk +# +# msk = (val > thresh * thresh) & msk_kx & msk_ky +# msk = jnp.sum(msk, axis=(0, 1)) +# +# # Our approach to finding the location of the five adjacent +# # values that are the same is as follows. +# # We put the array through jnp.diff, which +# # computes the difference of adjacent elements. Then we convolve +# # with a filter of ones of length five to sum groups of five +# # elements together. The first location where the result is +# # zero is the location we want. The tricky bit however is getting +# # the indexing right. +# +# # step 1) compute the diff of adjacent elements +# # The function jnp.diff returns an array of size one less than +# # the input. So we concatenate a zero at the front. This makes +# # sense since if the original array is all constant, then the +# # location of the first five zeros is at the start of the array. +# delta_msk = jnp.concatenate( +# [jnp.array([0], dtype=int), jnp.diff(msk)], +# axis=0, +# dtype=int, +# ) +# +# # step 2) convolve with the filter +# # In the discrete convolution, you have to deal with edge +# # behavior where the filter only partially overlaps the arrays. +# # We use the mode `full` which returns an array containing +# # every possible combination with missing elements set to zero. +# # We cut the first `length of filter - 1` elements so that +# # index i of the result is the sum of the filter starting +# # at index i of the input. +# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] +# +# # step 3) find first location of zero in the convolution +# # Finally, we use jnp.argmin to find the location of the first +# # zero. Per the doc string, if there is more than one zero, this +# # function returns the first location (i.e., smallest index) +# # which is what we want. +# sind = jnp.argmin(jnp.where(sums == 0, 0, 1)) +# return dk[0, 0, sind] @jax.jit @@ -1226,11 +1305,6 @@ def _find_maxk(kim, max_maxk, thresh): # maxk from the image (computed by _inner_comp_find_maxk) # by max_maxk from above return jnp.minimum( - # jax-galsim tends to be less conservative for maxk - # since compared to galsim, it does NOT require 5 rows - # of pixels in a row below the threshold. - # thus we add pixels here to ensure the galsim tests pass. - # it turns out one worked ok so that is what we did. - MRB - _inner_comp_find_maxk(kim.array, thresh, kx, ky) + 1 * kim.scale, + _inner_comp_find_maxk(kim.array, thresh, kx, ky), max_maxk, ) diff --git a/tests/jax/test_metacal_jax.py b/tests/jax/test_metacal_jax.py index a7f0041a..eaa6f947 100644 --- a/tests/jax/test_metacal_jax.py +++ b/tests/jax/test_metacal_jax.py @@ -282,14 +282,16 @@ def test_metacal_vmap(ntest): msg = "jit" jgt0 = time.time() - vmap_mcal( - ims, - psfs, - nse_ims, - scale, - target_fwhm, - g1, - 128, + jax.block_until_ready( + vmap_mcal( + ims, + psfs, + nse_ims, + scale, + target_fwhm, + g1, + 128, + ) ) jgt0 = time.time() - jgt0 print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") From d6f791129c2192f977c9fde85e1cc9cdcd81481d Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Apr 2026 09:59:34 -0500 Subject: [PATCH 06/10] feat: use a scan op to match galsim exactly --- jax_galsim/interpolatedimage.py | 179 +++++++++++++++----------------- 1 file changed, 84 insertions(+), 95 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 71cf02c2..f2ca0153 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1197,103 +1197,92 @@ def _calculate_size_containing_flux(image, thresh): ) +# MRB - This version usually matches galsim, but not always. +# The version below that uses the scan op is equivalent to +# galsim's approach. +# @jax.jit +# def _inner_comp_find_maxk(arr, thresh, kx, ky): +# msk = (arr * arr.conjugate()).real > thresh * thresh +# max_kx = jnp.max( +# jnp.where( +# msk, +# jnp.abs(kx), +# -jnp.inf, +# ) +# ) +# max_ky = jnp.max( +# jnp.where( +# msk, +# jnp.abs(ky), +# -jnp.inf, +# ) +# ) +# # galsim adds one pixel at the end so that maxk is +# # the k value where things do not pass the threshold, +# # so we do that here too. +# return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] + + @jax.jit -def _inner_comp_find_maxk(arr, thresh, kx, ky): - msk = (arr * arr.conjugate()).real > thresh * thresh - max_kx = jnp.max( - jnp.where( - msk, - jnp.abs(kx), - -jnp.inf, - ) +def _inner_comp_find_maxk_scan(arr, thresh, kx, ky): + val = (arr * arr.conjugate()).real + msk_thresh = val > thresh * thresh + akx = jnp.abs(kx) + aky = jnp.abs(ky) + + def _func(carry, x): + msk_kx = akx <= x + msk_ky = aky <= x + return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) + + scale = kx[0, -1] - kx[0, -2] + dk = kx[0, :] + scale + _, msk = jax.lax.scan(_func, None, xs=dk) + + # We are searching for the location of the first string of + # five locations in a row in `msk` where the value stays the + # same. + # We do this by putting the array through jnp.diff, which + # computes the difference of adjacent elements. Then we convolve + # with a filter of ones of length five to sum groups of five + # elements together. The first location where the result is + # zero is the location we want. The tricky bit however is getting + # the indexing right. + + # step 1. compute the diff of adjacent elements + # The function jnp.diff returns an array of size one less than + # the input. So we concatenate a zero at the front. This makes + # sense since if the original array is all constant, then the + # location of the first five zeros is at the start of the array. + delta_msk = jnp.concatenate( + [jnp.array([0], dtype=int), jnp.diff(msk)], + axis=0, + dtype=int, ) - max_ky = jnp.max( - jnp.where( - msk, - jnp.abs(ky), - -jnp.inf, - ) + + # step 2. convolve with the filter + # In the discrete convolution, you have to deal with edge + # behavior where the filter only partially overlaps the arrays. + # We use the mode `full` which returns an array containing + # every possible combination with missing elements set to zero. + # We cut the first `length of filter - 1` elements so that + # index i of the result is the sum of the filter starting + # at index i of the input. + sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] + + # step 3. find first location of zero in the convolution + # Finally, we use jnp.argmin to find the location of the first + # zero. Per the doc string, if there is more than one zero, this + # function returns the first location (i.e., smallest index) + # which is what we want. + msk_zero = sums == 0 + sind = jax.lax.cond( + jnp.any(msk_zero), + lambda x: jnp.argmin(jnp.where(x == 0, 0, 1)), + lambda x: -1, + sums, ) - # galsim adds one pixel at the end so that maxk is - # the k value where things do not pass the threshold, - # so we do that here too. - return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] - - -# I am keeping this bit of dead code for posterity. -# It uses a memory-intensive algorithm that matches galsim's -# approach for maxk. In practice, the direct search -# above has MUCH better performance. The code below won't even run -# on my laptop for parts of the galsim test suite due to memory -# use. - MRB -# @jax.jit -# def _inner_comp_find_maxk(a, thresh, kx, ky): -# # Galsim searches the boundary pixels of increasing squares -# # at a given distance. It finds the location of the start of -# # the first five boundaries where none of the pixels are above -# # threshold. To do this operation in JAX, we first compute -# # a cumulative version which is the number of pixels within a -# # distance that are above threshold. Then we search for the -# # location of the first string of five adjacent distances where the -# # value stays the same. -# -# # step 0) compute # of pixels within a distance above threshold -# # This bit of code builds an array of shape (ny, nx, # of distances) -# # that holds the number of pixels with |kx| <= distance, -# # |ky| <= distance, and value |a|^2 > thresh^2. -# # We build the array by broadcasting. -# val = (a * a.conjugate()).real -# val = jnp.reshape(val, (a.shape[0], a.shape[1], 1)) -# -# # we add one pixel scale here so that we have the distances -# # starting at one pixel scale. GalSim's version adds a pixel -# # scale at the end, so that matches. -# scale = kx[0, -1] - kx[0, -2] -# dk = jnp.reshape(kx[0, :] + scale, (1, 1, -1)) -# -# msk_kx = jnp.reshape(jnp.abs(kx), (a.shape[0], a.shape[1], 1)) <= dk -# msk_ky = jnp.reshape(jnp.abs(ky), (a.shape[0], a.shape[1], 1)) <= dk -# -# msk = (val > thresh * thresh) & msk_kx & msk_ky -# msk = jnp.sum(msk, axis=(0, 1)) -# -# # Our approach to finding the location of the five adjacent -# # values that are the same is as follows. -# # We put the array through jnp.diff, which -# # computes the difference of adjacent elements. Then we convolve -# # with a filter of ones of length five to sum groups of five -# # elements together. The first location where the result is -# # zero is the location we want. The tricky bit however is getting -# # the indexing right. -# -# # step 1) compute the diff of adjacent elements -# # The function jnp.diff returns an array of size one less than -# # the input. So we concatenate a zero at the front. This makes -# # sense since if the original array is all constant, then the -# # location of the first five zeros is at the start of the array. -# delta_msk = jnp.concatenate( -# [jnp.array([0], dtype=int), jnp.diff(msk)], -# axis=0, -# dtype=int, -# ) -# -# # step 2) convolve with the filter -# # In the discrete convolution, you have to deal with edge -# # behavior where the filter only partially overlaps the arrays. -# # We use the mode `full` which returns an array containing -# # every possible combination with missing elements set to zero. -# # We cut the first `length of filter - 1` elements so that -# # index i of the result is the sum of the filter starting -# # at index i of the input. -# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] -# -# # step 3) find first location of zero in the convolution -# # Finally, we use jnp.argmin to find the location of the first -# # zero. Per the doc string, if there is more than one zero, this -# # function returns the first location (i.e., smallest index) -# # which is what we want. -# sind = jnp.argmin(jnp.where(sums == 0, 0, 1)) -# return dk[0, 0, sind] + return dk[sind] @jax.jit @@ -1305,6 +1294,6 @@ def _find_maxk(kim, max_maxk, thresh): # maxk from the image (computed by _inner_comp_find_maxk) # by max_maxk from above return jnp.minimum( - _inner_comp_find_maxk(kim.array, thresh, kx, ky), + _inner_comp_find_maxk_scan(kim.array, thresh, kx, ky), max_maxk, ) From 9973dc27bbfc66aa85ceb0c3173a2cb30e380791 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Apr 2026 12:43:06 -0500 Subject: [PATCH 07/10] fix: adjust indexing a bit --- jax_galsim/interpolatedimage.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index f2ca0153..f306b030 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1235,9 +1235,7 @@ def _func(carry, x): msk_ky = aky <= x return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) - scale = kx[0, -1] - kx[0, -2] - dk = kx[0, :] + scale - _, msk = jax.lax.scan(_func, None, xs=dk) + _, msk = jax.lax.scan(_func, None, xs=kx[0, :]) # We are searching for the location of the first string of # five locations in a row in `msk` where the value stays the @@ -1278,11 +1276,12 @@ def _func(carry, x): msk_zero = sums == 0 sind = jax.lax.cond( jnp.any(msk_zero), - lambda x: jnp.argmin(jnp.where(x == 0, 0, 1)), + lambda x: jnp.argmin(jnp.where(x, 0, 1)), lambda x: -1, - sums, + msk_zero, ) - return dk[sind] + # add pixel so we have the bound + return kx[0, sind] + kx[0, -1] - kx[0, -2] @jax.jit From fd7770f81766ae19feea05077d1913e3ba71c1f8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Apr 2026 12:37:17 -0500 Subject: [PATCH 08/10] fix: off by one only if we reach the end of the array; galsim is fine --- jax_galsim/interpolatedimage.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index f306b030..5eda8ac9 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1274,14 +1274,18 @@ def _func(carry, x): # function returns the first location (i.e., smallest index) # which is what we want. msk_zero = sums == 0 - sind = jax.lax.cond( + sind, dk = jax.lax.cond( jnp.any(msk_zero), - lambda x: jnp.argmin(jnp.where(x, 0, 1)), - lambda x: -1, + # if we find a set of zeros, the code computes the next pixel past + # the pixels where |kval| > thresh. So we set dk = 0 since we don't + # need to shift things. + lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0), + # if we get to the end of the array, we add one pixel spacing + # so we match galsim + lambda x: (-1, kx[0, -1] - kx[0, -2]), msk_zero, ) - # add pixel so we have the bound - return kx[0, sind] + kx[0, -1] - kx[0, -2] + return kx[0, sind] + dk @jax.jit From 1758be557e36b7ce1a4854fda810e874cc577dd6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Apr 2026 12:57:41 -0500 Subject: [PATCH 09/10] test: simplify tests --- tests/jax/test_interpolatedimage_utils.py | 304 +++++----------------- 1 file changed, 62 insertions(+), 242 deletions(-) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 67a8e7b5..b45ffbd1 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -116,7 +116,7 @@ def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): ) -def test_interpolatedimage_utils_stepk_maxk(): +def test_interpolatedimage_utils_stepk_maxk_simple(): hlr = 0.5 fwhm = 0.9 scale = 0.2 @@ -182,7 +182,7 @@ def test_interpolatedimage_utils_stepk_maxk(): ], ) @pytest.mark.parametrize("method", ["kValue", "xValue"]) -def test_interpolatedimage_utils_comp_to_galsim_xk_value( +def test_interpolatedimage_utils_comp_to_galsim_xkvalue_stepk_maxk( method, ref_array, offset_x, @@ -213,6 +213,11 @@ def test_interpolatedimage_utils_comp_to_galsim_xk_value( "Skipping `test_interpolatedimage_utils_comp_to_galsim` case at random to save time." ) + if rng.uniform() < 0.5: + ref_array = ref_array + rng.normal( + size=ref_array.shape, scale=0.1 * ref_array.max() + ) + gimage_in = _galsim.Image(ref_array, scale=0.2) jgimage_in = jax_galsim.Image(ref_array, scale=0.2) @@ -233,8 +238,6 @@ def test_interpolatedimage_utils_comp_to_galsim_xk_value( x_interpolant=x_interp, ) - np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) - np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) kxvals = [ (0, 0), (-5, -5), @@ -267,6 +270,61 @@ def test_interpolatedimage_utils_comp_to_galsim_xk_value( err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", ) + gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux + gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh) + + from jax_galsim.interpolatedimage import _calculate_size_containing_flux + + jgthresh = ( + 1.0 - jgii._original.gsparams.folding_threshold + ) * jgii._original._image_flux + jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh) + + lgR = _galsim_stepk_loop(gii._image, gthresh) + ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh) + + np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x) + np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y) + np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0)) + np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum()) + np.testing.assert_allclose(jgthresh, gthresh, rtol=0, atol=1e-6) + np.testing.assert_allclose(jgR, gR, rtol=0, atol=1e-6) + np.testing.assert_allclose(ljgR, gR, rtol=0, atol=1e-6) + np.testing.assert_allclose(gR, lgR, rtol=0, atol=1e-6) + + np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) + + # test forcing stepk/maxk + maxk = gii.maxk * 1.04 + stepk = gii.stepk / 1.04 + + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + _force_maxk=maxk, + _force_stepk=stepk, + ) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, + wcs=jax_galsim.BaseWCS.from_galsim(wcs), + offset=jax_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + _force_maxk=maxk, + _force_stepk=stepk, + ) + + np.testing.assert_allclose(gii.maxk, maxk) + np.testing.assert_allclose(gii.stepk, stepk) + np.testing.assert_allclose(jgii.maxk, maxk) + np.testing.assert_allclose(jgii.stepk, stepk) + def _compute_fft_with_numpy_jax_galsim(im): import numpy as np @@ -359,133 +417,6 @@ def test_interpolatedimage_interpolant_sample(interp): np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") -@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) -@pytest.mark.parametrize("normalization", ["sb", "flux"]) -@pytest.mark.parametrize("use_true_center", [True, False]) -@pytest.mark.parametrize( - "wcs", - [ - _galsim.PixelScale(0.2), - _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), - _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), - ], -) -@pytest.mark.parametrize( - "offset_x", - [ - -4.35, - -0.45, - 0.0, - 0.67, - 3.78, - ], -) -@pytest.mark.parametrize( - "offset_y", - [ - -2.12, - -0.33, - 0.0, - 0.12, - 1.45, - ], -) -@pytest.mark.parametrize( - "ref_array", - [ - _galsim.Gaussian(fwhm=0.9) - .shear(g1=0.3, g2=-0.2) - .drawImage(nx=33, ny=33, scale=0.2) - .array, - _galsim.Gaussian(fwhm=0.9) - .shear(g1=-0.03, g2=0.1) - .drawImage(nx=32, ny=32, scale=0.2) - .array, - ], -) -def test_interpolatedimage_utils_comp_stepk_maxk_to_galsim( - ref_array, - offset_x, - offset_y, - wcs, - use_true_center, - normalization, - x_interp, -): - seed = max( - abs( - int( - hashlib.sha1( - f"{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( - "utf-8" - ) - ).hexdigest(), - 16, - ) - ) - % (10**7), - 1, - ) - - rng = np.random.RandomState(seed=seed) - if rng.uniform() < FRAC_TEST_TO_KEEP: - pytest.skip( - "Skipping `test_interpolatedimage_utils_comp_stepk_maxk_to_galsim` case at random to save time." - ) - - nse = rng.uniform(size=ref_array.shape) * ref_array.max() * 0.05 - - gimage_in = _galsim.Image(ref_array + nse, scale=0.2) - jgimage_in = jax_galsim.Image(ref_array + nse, scale=0.2) - - np.testing.assert_allclose(gimage_in.center.x, jgimage_in.center.x) - np.testing.assert_allclose(gimage_in.center.y, jgimage_in.center.y) - - gii = _galsim.InterpolatedImage( - gimage_in, - wcs=wcs, - offset=_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - flux=20, - ) - jgii = jax_galsim.InterpolatedImage( - jgimage_in, - wcs=jax_galsim.BaseWCS.from_galsim(wcs), - offset=jax_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - flux=20, - ) - - gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux - gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh) - - from jax_galsim.interpolatedimage import _calculate_size_containing_flux - - jgthresh = ( - 1.0 - jgii._original.gsparams.folding_threshold - ) * jgii._original._image_flux - jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh) - - lgR = _galsim_stepk_loop(gii._image, gthresh) - ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh) - - np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x) - np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y) - np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0)) - np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum()) - np.testing.assert_allclose(jgthresh, gthresh, rtol=0, atol=1e-6) - np.testing.assert_allclose(jgR, gR, rtol=0, atol=1e-6) - np.testing.assert_allclose(ljgR, gR, rtol=0, atol=1e-6) - np.testing.assert_allclose(gR, lgR, rtol=0, atol=1e-6) - - np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) - np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) - - # this is a copy of the galsim C++ algorithm in a pure python # loop to help with debugging and testing def _galsim_stepk_loop(im, target_flux): @@ -515,114 +446,3 @@ def _galsim_stepk_loop(im, target_flux): d += 1 return d + 0.5 - - -@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) -@pytest.mark.parametrize("normalization", ["sb", "flux"]) -@pytest.mark.parametrize("use_true_center", [True, False]) -@pytest.mark.parametrize( - "wcs", - [ - _galsim.PixelScale(0.2), - _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), - _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), - ], -) -@pytest.mark.parametrize( - "offset_x", - [ - -4.35, - -0.45, - 0.0, - 0.67, - 3.78, - ], -) -@pytest.mark.parametrize( - "offset_y", - [ - -2.12, - -0.33, - 0.0, - 0.12, - 1.45, - ], -) -@pytest.mark.parametrize( - "ref_array", - [ - _galsim.Gaussian(fwhm=0.9).drawImage(nx=33, ny=33, scale=0.2).array, - _galsim.Gaussian(fwhm=0.9).drawImage(nx=32, ny=32, scale=0.2).array, - ], -) -@pytest.mark.parametrize("method", ["kValue", "xValue"]) -def test_interpolatedimage_utils_force_stepk_maxk( - method, - ref_array, - offset_x, - offset_y, - wcs, - use_true_center, - normalization, - x_interp, -): - seed = max( - abs( - int( - hashlib.sha1( - f"{method}{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( - "utf-8" - ) - ).hexdigest(), - 16, - ) - ) - % (10**7), - 1, - ) - - rng = np.random.RandomState(seed=seed) - if rng.uniform() < FRAC_TEST_TO_KEEP: - pytest.skip( - "Skipping `test_interpolatedimage_utils_force_stepk_maxk` case at random to save time." - ) - - gimage_in = _galsim.Image(ref_array, scale=0.2) - jgimage_in = jax_galsim.Image(ref_array, scale=0.2) - - gii = _galsim.InterpolatedImage( - gimage_in, - wcs=wcs, - offset=_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - ) - maxk = gii.maxk * 1.04 - stepk = gii.stepk / 1.04 - - gii = _galsim.InterpolatedImage( - gimage_in, - wcs=wcs, - offset=_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - _force_maxk=maxk, - _force_stepk=stepk, - ) - jgii = jax_galsim.InterpolatedImage( - jgimage_in, - wcs=jax_galsim.BaseWCS.from_galsim(wcs), - offset=jax_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - _force_maxk=maxk, - _force_stepk=stepk, - ) - - np.testing.assert_allclose(gii.maxk, maxk) - np.testing.assert_allclose(gii.stepk, stepk) - np.testing.assert_allclose(jgii.maxk, maxk) - np.testing.assert_allclose(jgii.stepk, stepk) From ffcef75b5a14d53bc3a1a253f29515bc8345625d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Apr 2026 13:07:43 -0500 Subject: [PATCH 10/10] perf: use faster algorithm --- jax_galsim/interpolatedimage.py | 174 ++++++++++++++++---------------- 1 file changed, 87 insertions(+), 87 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 5eda8ac9..8561c159 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1197,95 +1197,95 @@ def _calculate_size_containing_flux(image, thresh): ) -# MRB - This version usually matches galsim, but not always. -# The version below that uses the scan op is equivalent to -# galsim's approach. -# @jax.jit -# def _inner_comp_find_maxk(arr, thresh, kx, ky): -# msk = (arr * arr.conjugate()).real > thresh * thresh -# max_kx = jnp.max( -# jnp.where( -# msk, -# jnp.abs(kx), -# -jnp.inf, -# ) -# ) -# max_ky = jnp.max( -# jnp.where( -# msk, -# jnp.abs(ky), -# -jnp.inf, -# ) -# ) -# # galsim adds one pixel at the end so that maxk is -# # the k value where things do not pass the threshold, -# # so we do that here too. -# return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] - - @jax.jit -def _inner_comp_find_maxk_scan(arr, thresh, kx, ky): - val = (arr * arr.conjugate()).real - msk_thresh = val > thresh * thresh - akx = jnp.abs(kx) - aky = jnp.abs(ky) - - def _func(carry, x): - msk_kx = akx <= x - msk_ky = aky <= x - return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) - - _, msk = jax.lax.scan(_func, None, xs=kx[0, :]) - - # We are searching for the location of the first string of - # five locations in a row in `msk` where the value stays the - # same. - # We do this by putting the array through jnp.diff, which - # computes the difference of adjacent elements. Then we convolve - # with a filter of ones of length five to sum groups of five - # elements together. The first location where the result is - # zero is the location we want. The tricky bit however is getting - # the indexing right. - - # step 1. compute the diff of adjacent elements - # The function jnp.diff returns an array of size one less than - # the input. So we concatenate a zero at the front. This makes - # sense since if the original array is all constant, then the - # location of the first five zeros is at the start of the array. - delta_msk = jnp.concatenate( - [jnp.array([0], dtype=int), jnp.diff(msk)], - axis=0, - dtype=int, +def _inner_comp_find_maxk(arr, thresh, kx, ky): + msk = (arr * arr.conjugate()).real > thresh * thresh + max_kx = jnp.max( + jnp.where( + msk, + jnp.abs(kx), + -jnp.inf, + ) ) - - # step 2. convolve with the filter - # In the discrete convolution, you have to deal with edge - # behavior where the filter only partially overlaps the arrays. - # We use the mode `full` which returns an array containing - # every possible combination with missing elements set to zero. - # We cut the first `length of filter - 1` elements so that - # index i of the result is the sum of the filter starting - # at index i of the input. - sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] - - # step 3. find first location of zero in the convolution - # Finally, we use jnp.argmin to find the location of the first - # zero. Per the doc string, if there is more than one zero, this - # function returns the first location (i.e., smallest index) - # which is what we want. - msk_zero = sums == 0 - sind, dk = jax.lax.cond( - jnp.any(msk_zero), - # if we find a set of zeros, the code computes the next pixel past - # the pixels where |kval| > thresh. So we set dk = 0 since we don't - # need to shift things. - lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0), - # if we get to the end of the array, we add one pixel spacing - # so we match galsim - lambda x: (-1, kx[0, -1] - kx[0, -2]), - msk_zero, + max_ky = jnp.max( + jnp.where( + msk, + jnp.abs(ky), + -jnp.inf, + ) ) - return kx[0, sind] + dk + # galsim adds one pixel at the end so that maxk is + # the k value where things do not pass the threshold, + # so we do that here too. + return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] + + +# this version matches galsim's maxk operation exactly, but is +# more expensive to compute since it has a scan operation. +# I am leaving it here for posterity. - MRB +# @jax.jit +# def _inner_comp_find_maxk_scan(arr, thresh, kx, ky): +# val = (arr * arr.conjugate()).real +# msk_thresh = val > thresh * thresh +# akx = jnp.abs(kx) +# aky = jnp.abs(ky) +# +# def _func(carry, x): +# msk_kx = akx <= x +# msk_ky = aky <= x +# return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) +# +# _, msk = jax.lax.scan(_func, None, xs=kx[0, :]) +# +# # We are searching for the location of the first string of +# # five locations in a row in `msk` where the value stays the +# # same. +# # We do this by putting the array through jnp.diff, which +# # computes the difference of adjacent elements. Then we convolve +# # with a filter of ones of length five to sum groups of five +# # elements together. The first location where the result is +# # zero is the location we want. The tricky bit however is getting +# # the indexing right. +# +# # step 1. compute the diff of adjacent elements +# # The function jnp.diff returns an array of size one less than +# # the input. So we concatenate a zero at the front. This makes +# # sense since if the original array is all constant, then the +# # location of the first five zeros is at the start of the array. +# delta_msk = jnp.concatenate( +# [jnp.array([0], dtype=int), jnp.diff(msk)], +# axis=0, +# dtype=int, +# ) +# +# # step 2. convolve with the filter +# # In the discrete convolution, you have to deal with edge +# # behavior where the filter only partially overlaps the arrays. +# # We use the mode `full` which returns an array containing +# # every possible combination with missing elements set to zero. +# # We cut the first `length of filter - 1` elements so that +# # index i of the result is the sum of the filter starting +# # at index i of the input. +# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] +# +# # step 3. find first location of zero in the convolution +# # Finally, we use jnp.argmin to find the location of the first +# # zero. Per the doc string, if there is more than one zero, this +# # function returns the first location (i.e., smallest index) +# # which is what we want. +# msk_zero = sums == 0 +# sind, dk = jax.lax.cond( +# jnp.any(msk_zero), +# # if we find a set of zeros, the code computes the next pixel past +# # the pixels where |kval| > thresh. So we set dk = 0 since we don't +# # need to shift things. +# lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0), +# # if we get to the end of the array, we add one pixel spacing +# # so we match galsim +# lambda x: (-1, kx[0, -1] - kx[0, -2]), +# msk_zero, +# ) +# return kx[0, sind] + dk @jax.jit @@ -1297,6 +1297,6 @@ def _find_maxk(kim, max_maxk, thresh): # maxk from the image (computed by _inner_comp_find_maxk) # by max_maxk from above return jnp.minimum( - _inner_comp_find_maxk_scan(kim.array, thresh, kx, ky), + _inner_comp_find_maxk(kim.array, thresh, kx, ky), max_maxk, )