diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index d652e8a9..8561c159 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1214,7 +1214,78 @@ def _inner_comp_find_maxk(arr, thresh, kx, ky): -jnp.inf, ) ) - return jnp.maximum(max_kx, max_ky) + # galsim adds one pixel at the end so that maxk is + # the k value where things do not pass the threshold, + # so we do that here too. + return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] + + +# this version matches galsim's maxk operation exactly, but is +# more expensive to compute since it has a scan operation. +# I am leaving it here for posterity. - MRB +# @jax.jit +# def _inner_comp_find_maxk_scan(arr, thresh, kx, ky): +# val = (arr * arr.conjugate()).real +# msk_thresh = val > thresh * thresh +# akx = jnp.abs(kx) +# aky = jnp.abs(ky) +# +# def _func(carry, x): +# msk_kx = akx <= x +# msk_ky = aky <= x +# return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) +# +# _, msk = jax.lax.scan(_func, None, xs=kx[0, :]) +# +# # We are searching for the location of the first string of +# # five locations in a row in `msk` where the value stays the +# # same. +# # We do this by putting the array through jnp.diff, which +# # computes the difference of adjacent elements. Then we convolve +# # with a filter of ones of length five to sum groups of five +# # elements together. The first location where the result is +# # zero is the location we want. The tricky bit however is getting +# # the indexing right. +# +# # step 1. compute the diff of adjacent elements +# # The function jnp.diff returns an array of size one less than +# # the input. So we concatenate a zero at the front. This makes +# # sense since if the original array is all constant, then the +# # location of the first five zeros is at the start of the array. +# delta_msk = jnp.concatenate( +# [jnp.array([0], dtype=int), jnp.diff(msk)], +# axis=0, +# dtype=int, +# ) +# +# # step 2. convolve with the filter +# # In the discrete convolution, you have to deal with edge +# # behavior where the filter only partially overlaps the arrays. +# # We use the mode `full` which returns an array containing +# # every possible combination with missing elements set to zero. +# # We cut the first `length of filter - 1` elements so that +# # index i of the result is the sum of the filter starting +# # at index i of the input. +# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] +# +# # step 3. find first location of zero in the convolution +# # Finally, we use jnp.argmin to find the location of the first +# # zero. Per the doc string, if there is more than one zero, this +# # function returns the first location (i.e., smallest index) +# # which is what we want. +# msk_zero = sums == 0 +# sind, dk = jax.lax.cond( +# jnp.any(msk_zero), +# # if we find a set of zeros, the code computes the next pixel past +# # the pixels where |kval| > thresh. So we set dk = 0 since we don't +# # need to shift things. +# lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0), +# # if we get to the end of the array, we add one pixel spacing +# # so we match galsim +# lambda x: (-1, kx[0, -1] - kx[0, -2]), +# msk_zero, +# ) +# return kx[0, sind] + dk @jax.jit @@ -1226,11 +1297,6 @@ def _find_maxk(kim, max_maxk, thresh): # maxk from the image (computed by _inner_comp_find_maxk) # by max_maxk from above return jnp.minimum( - # jax-galsim tends to be less conservative for maxk - # since compared to galsim, it does NOT require 5 rows - # of pixels in a row below the threshold. - # thus we add pixels here to ensure the galsim tests pass. - # it turns out one worked ok so that is what we did. - MRB - _inner_comp_find_maxk(kim.array, thresh, kx, ky) + 1 * kim.scale, + _inner_comp_find_maxk(kim.array, thresh, kx, ky), max_maxk, ) diff --git a/tests/GalSim b/tests/GalSim index 1a490c3b..23368789 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 1a490c3b558fddf2cab1fc0e6d449b73fa3b4eda +Subproject commit 2336878979462fc726fabc7fbc5a89fba4ef2648 diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 97f07663..b45ffbd1 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -116,7 +116,7 @@ def test_interpolatedimage_utils_draw_with_interpolant_kval(interp): ) -def test_interpolatedimage_utils_stepk_maxk(): +def test_interpolatedimage_utils_stepk_maxk_simple(): hlr = 0.5 fwhm = 0.9 scale = 0.2 @@ -182,7 +182,7 @@ def test_interpolatedimage_utils_stepk_maxk(): ], ) @pytest.mark.parametrize("method", ["kValue", "xValue"]) -def test_interpolatedimage_utils_comp_to_galsim( +def test_interpolatedimage_utils_comp_to_galsim_xkvalue_stepk_maxk( method, ref_array, offset_x, @@ -213,6 +213,11 @@ def test_interpolatedimage_utils_comp_to_galsim( "Skipping `test_interpolatedimage_utils_comp_to_galsim` case at random to save time." ) + if rng.uniform() < 0.5: + ref_array = ref_array + rng.normal( + size=ref_array.shape, scale=0.1 * ref_array.max() + ) + gimage_in = _galsim.Image(ref_array, scale=0.2) jgimage_in = jax_galsim.Image(ref_array, scale=0.2) @@ -233,10 +238,6 @@ def test_interpolatedimage_utils_comp_to_galsim( x_interpolant=x_interp, ) - np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) - # FIXME: match maxk - np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0.5, atol=0) - assert jgii.maxk >= gii.maxk kxvals = [ (0, 0), (-5, -5), @@ -269,6 +270,61 @@ def test_interpolatedimage_utils_comp_to_galsim( err_msg=f"xValue mismatch: wcs={wcs}, x={x}, y={y}", ) + gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux + gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh) + + from jax_galsim.interpolatedimage import _calculate_size_containing_flux + + jgthresh = ( + 1.0 - jgii._original.gsparams.folding_threshold + ) * jgii._original._image_flux + jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh) + + lgR = _galsim_stepk_loop(gii._image, gthresh) + ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh) + + np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x) + np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y) + np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0)) + np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum()) + np.testing.assert_allclose(jgthresh, gthresh, rtol=0, atol=1e-6) + np.testing.assert_allclose(jgR, gR, rtol=0, atol=1e-6) + np.testing.assert_allclose(ljgR, gR, rtol=0, atol=1e-6) + np.testing.assert_allclose(gR, lgR, rtol=0, atol=1e-6) + + np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) + np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0, atol=1e-6) + + # test forcing stepk/maxk + maxk = gii.maxk * 1.04 + stepk = gii.stepk / 1.04 + + gii = _galsim.InterpolatedImage( + gimage_in, + wcs=wcs, + offset=_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + _force_maxk=maxk, + _force_stepk=stepk, + ) + jgii = jax_galsim.InterpolatedImage( + jgimage_in, + wcs=jax_galsim.BaseWCS.from_galsim(wcs), + offset=jax_galsim.PositionD(offset_x, offset_y), + use_true_center=use_true_center, + normalization=normalization, + x_interpolant=x_interp, + _force_maxk=maxk, + _force_stepk=stepk, + ) + + np.testing.assert_allclose(gii.maxk, maxk) + np.testing.assert_allclose(gii.stepk, stepk) + np.testing.assert_allclose(jgii.maxk, maxk) + np.testing.assert_allclose(jgii.stepk, stepk) + def _compute_fft_with_numpy_jax_galsim(im): import numpy as np @@ -361,134 +417,6 @@ def test_interpolatedimage_interpolant_sample(interp): np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") -@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) -@pytest.mark.parametrize("normalization", ["sb", "flux"]) -@pytest.mark.parametrize("use_true_center", [True, False]) -@pytest.mark.parametrize( - "wcs", - [ - _galsim.PixelScale(0.2), - _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), - _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), - ], -) -@pytest.mark.parametrize( - "offset_x", - [ - -4.35, - -0.45, - 0.0, - 0.67, - 3.78, - ], -) -@pytest.mark.parametrize( - "offset_y", - [ - -2.12, - -0.33, - 0.0, - 0.12, - 1.45, - ], -) -@pytest.mark.parametrize( - "ref_array", - [ - _galsim.Gaussian(fwhm=0.9) - .shear(g1=0.3, g2=-0.2) - .drawImage(nx=33, ny=33, scale=0.2) - .array, - _galsim.Gaussian(fwhm=0.9) - .shear(g1=-0.03, g2=0.1) - .drawImage(nx=32, ny=32, scale=0.2) - .array, - ], -) -def test_interpolatedimage_utils_comp_stepk_maxk_to_galsim( - ref_array, - offset_x, - offset_y, - wcs, - use_true_center, - normalization, - x_interp, -): - seed = max( - abs( - int( - hashlib.sha1( - f"{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( - "utf-8" - ) - ).hexdigest(), - 16, - ) - ) - % (10**7), - 1, - ) - - rng = np.random.RandomState(seed=seed) - if rng.uniform() < FRAC_TEST_TO_KEEP: - pytest.skip( - "Skipping `test_interpolatedimage_utils_comp_stepk_maxk_to_galsim` case at random to save time." - ) - - nse = rng.uniform(size=ref_array.shape) * ref_array.max() * 0.05 - - gimage_in = _galsim.Image(ref_array + nse, scale=0.2) - jgimage_in = jax_galsim.Image(ref_array + nse, scale=0.2) - - np.testing.assert_allclose(gimage_in.center.x, jgimage_in.center.x) - np.testing.assert_allclose(gimage_in.center.y, jgimage_in.center.y) - - gii = _galsim.InterpolatedImage( - gimage_in, - wcs=wcs, - offset=_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - flux=20, - ) - jgii = jax_galsim.InterpolatedImage( - jgimage_in, - wcs=jax_galsim.BaseWCS.from_galsim(wcs), - offset=jax_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - flux=20, - ) - - gthresh = (1.0 - gii.gsparams.folding_threshold) * gii._image_flux - gR = _galsim._galsim.CalculateSizeContainingFlux(gii._image._image, gthresh) - - from jax_galsim.interpolatedimage import _calculate_size_containing_flux - - jgthresh = ( - 1.0 - jgii._original.gsparams.folding_threshold - ) * jgii._original._image_flux - jgR = _calculate_size_containing_flux(jgii._original.image, jgthresh) - - lgR = _galsim_stepk_loop(gii._image, gthresh) - ljgR = _galsim_stepk_loop(jgii._original.image, jgthresh) - - np.testing.assert_allclose(jgii._original.image.center.x, gii._image.center.x) - np.testing.assert_allclose(jgii._original.image.center.y, gii._image.center.y) - np.testing.assert_allclose(jgii._original.image(0, 0), gii._image(0, 0)) - np.testing.assert_allclose(jgii._original.image.array.sum(), gii._image.array.sum()) - np.testing.assert_allclose(jgthresh, gthresh, rtol=0, atol=1e-6) - np.testing.assert_allclose(jgR, gR, rtol=0, atol=1e-6) - np.testing.assert_allclose(ljgR, gR, rtol=0, atol=1e-6) - np.testing.assert_allclose(gR, lgR, rtol=0, atol=1e-6) - - np.testing.assert_allclose(jgii.stepk, gii.stepk, rtol=0, atol=1e-6) - # FIXME: make maxk match - np.testing.assert_allclose(jgii.maxk, gii.maxk, rtol=0.5, atol=0) - - # this is a copy of the galsim C++ algorithm in a pure python # loop to help with debugging and testing def _galsim_stepk_loop(im, target_flux): @@ -518,114 +446,3 @@ def _galsim_stepk_loop(im, target_flux): d += 1 return d + 0.5 - - -@pytest.mark.parametrize("x_interp", ["lanczos15", "quintic"]) -@pytest.mark.parametrize("normalization", ["sb", "flux"]) -@pytest.mark.parametrize("use_true_center", [True, False]) -@pytest.mark.parametrize( - "wcs", - [ - _galsim.PixelScale(0.2), - _galsim.JacobianWCS(0.21, 0.03, -0.04, 0.23), - _galsim.AffineTransform(-0.03, 0.21, 0.18, 0.01, _galsim.PositionD(0.3, -0.4)), - ], -) -@pytest.mark.parametrize( - "offset_x", - [ - -4.35, - -0.45, - 0.0, - 0.67, - 3.78, - ], -) -@pytest.mark.parametrize( - "offset_y", - [ - -2.12, - -0.33, - 0.0, - 0.12, - 1.45, - ], -) -@pytest.mark.parametrize( - "ref_array", - [ - _galsim.Gaussian(fwhm=0.9).drawImage(nx=33, ny=33, scale=0.2).array, - _galsim.Gaussian(fwhm=0.9).drawImage(nx=32, ny=32, scale=0.2).array, - ], -) -@pytest.mark.parametrize("method", ["kValue", "xValue"]) -def test_interpolatedimage_utils_force_stepk_maxk( - method, - ref_array, - offset_x, - offset_y, - wcs, - use_true_center, - normalization, - x_interp, -): - seed = max( - abs( - int( - hashlib.sha1( - f"{method}{ref_array}{offset_x}{offset_y}{wcs}{use_true_center}{normalization}{x_interp}".encode( - "utf-8" - ) - ).hexdigest(), - 16, - ) - ) - % (10**7), - 1, - ) - - rng = np.random.RandomState(seed=seed) - if rng.uniform() < FRAC_TEST_TO_KEEP: - pytest.skip( - "Skipping `test_interpolatedimage_utils_force_stepk_maxk` case at random to save time." - ) - - gimage_in = _galsim.Image(ref_array, scale=0.2) - jgimage_in = jax_galsim.Image(ref_array, scale=0.2) - - gii = _galsim.InterpolatedImage( - gimage_in, - wcs=wcs, - offset=_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - ) - maxk = gii.maxk * 1.04 - stepk = gii.stepk / 1.04 - - gii = _galsim.InterpolatedImage( - gimage_in, - wcs=wcs, - offset=_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - _force_maxk=maxk, - _force_stepk=stepk, - ) - jgii = jax_galsim.InterpolatedImage( - jgimage_in, - wcs=jax_galsim.BaseWCS.from_galsim(wcs), - offset=jax_galsim.PositionD(offset_x, offset_y), - use_true_center=use_true_center, - normalization=normalization, - x_interpolant=x_interp, - _force_maxk=maxk, - _force_stepk=stepk, - ) - - np.testing.assert_allclose(gii.maxk, maxk) - np.testing.assert_allclose(gii.stepk, stepk) - np.testing.assert_allclose(jgii.maxk, maxk) - np.testing.assert_allclose(jgii.stepk, stepk) diff --git a/tests/jax/test_metacal_jax.py b/tests/jax/test_metacal_jax.py index 9e2aeb25..eaa6f947 100644 --- a/tests/jax/test_metacal_jax.py +++ b/tests/jax/test_metacal_jax.py @@ -17,28 +17,22 @@ def _metacal_galsim( scale, target_fwhm, g1, - iim_kwargs, - ipsf_kwargs, - inse_kwargs, nk, ): iim = _galsim.InterpolatedImage( _galsim.ImageD(im), scale=scale, x_interpolant="lanczos15", - **iim_kwargs, ) ipsf = _galsim.InterpolatedImage( _galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15", - **ipsf_kwargs, ) inse = _galsim.InterpolatedImage( _galsim.ImageD(np.rot90(nse_im, 1)), scale=scale, x_interpolant="lanczos15", - **inse_kwargs, ) ppsf_iim = _galsim.Convolve(iim, _galsim.Deconvolve(ipsf)) @@ -153,34 +147,6 @@ def test_metacal_comp_to_galsim(nse): nse_im = rng.normal(size=im.shape, scale=nse) im += rng.normal(size=im.shape, scale=nse) - # jax galsim and galsim set stepk and maxk differently due to slight - # algorithmic differences. We force them to be the same here for this - # test so it passes. - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - iim_kwargs = { - "_force_maxk": iim.maxk.item(), - } - inse = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(jnp.rot90(nse_im, 1)), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - inse_kwargs = { - "_force_maxk": inse.maxk.item(), - } - ipsf = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" - ) - ipsf_kwargs = { - "_force_maxk": ipsf.maxk.item(), - } - gt0 = time.time() gres = _metacal_galsim( im.copy(), @@ -189,9 +155,6 @@ def test_metacal_comp_to_galsim(nse): scale, target_fwhm, g1, - iim_kwargs, - ipsf_kwargs, - inse_kwargs, 128, ) gt0 = time.time() - gt0 @@ -252,7 +215,6 @@ def test_metacal_vmap(ntest): ims = [] nse_ims = [] psfs = [] - init_done = False for _seed in range(ntest): seed = _seed + start_seed rng = np.random.RandomState(seed) @@ -287,37 +249,6 @@ def test_metacal_vmap(ntest): psfs.append(psf) nse_ims.append(nse_im) - if not init_done: - init_done = True - - # jax galsim and galsim set stepk and maxk differently due to slight - # algorithmic differences. We force them to be the same here for this - # test so it passes. - iim = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(im), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - iim_kwargs = { - "_force_maxk": iim.maxk.item(), - } - inse = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(jnp.rot90(nse_im, 1)), - scale=scale, - x_interpolant="lanczos15", - gsparams=jax_galsim.GSParams(minimum_fft_size=128), - ) - inse_kwargs = { - "_force_maxk": inse.maxk.item(), - } - ipsf = jax_galsim.InterpolatedImage( - jax_galsim.ImageD(psf), scale=scale, x_interpolant="lanczos15" - ) - ipsf_kwargs = { - "_force_maxk": ipsf.maxk.item(), - } - ims = np.stack(ims) psfs = np.stack(psfs) nse_ims = np.stack(nse_ims) @@ -331,9 +262,6 @@ def test_metacal_vmap(ntest): scale, target_fwhm, g1, - iim_kwargs, - ipsf_kwargs, - inse_kwargs, 128, ) gt0 = time.time() - gt0 @@ -354,14 +282,16 @@ def test_metacal_vmap(ntest): msg = "jit" jgt0 = time.time() - vmap_mcal( - ims, - psfs, - nse_ims, - scale, - target_fwhm, - g1, - 128, + jax.block_until_ready( + vmap_mcal( + ims, + psfs, + nse_ims, + scale, + target_fwhm, + g1, + 128, + ) ) jgt0 = time.time() - jgt0 print("Jax-Galsim time (%s): " % msg, jgt0 * 1e3, " [ms]") @@ -417,7 +347,6 @@ def test_metacal_iimage_with_noise(nse, draw_method): scale=scale, x_interpolant="lanczos15", gsparams=_galsim.GSParams(minimum_fft_size=nk), - _force_maxk=jgiim.maxk.item(), ) def _plot_real(gim, jgim): diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 20d026e8..a6a44910 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -46,7 +46,7 @@ def test_moffat_comp_galsim_maxk(psf, thresh): f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}", flush=True, ) - np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=1e-4, atol=0) + np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=5e-5, atol=0) np.testing.assert_allclose( psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-5 )