diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 0d7c24c6..c1c1e4e9 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1195,95 +1195,95 @@ def _calculate_size_containing_flux(image, thresh): ) -@jax.jit -def _inner_comp_find_maxk(arr, thresh, kx, ky): - msk = (arr * arr.conjugate()).real > thresh * thresh - max_kx = jnp.max( - jnp.where( - msk, - jnp.abs(kx), - -jnp.inf, - ) - ) - max_ky = jnp.max( - jnp.where( - msk, - jnp.abs(ky), - -jnp.inf, - ) - ) - # galsim adds one pixel at the end so that maxk is - # the k value where things do not pass the threshold, - # so we do that here too. - return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] - - -# this version matches galsim's maxk operation exactly, but is -# more expensive to compute since it has a scan operation. -# I am leaving it here for posterity. - MRB +# this version doe snot match galsim's maxk operation exactly, +# but is faster to compute. I am leaving it here for +# posterity. - MRB # @jax.jit -# def _inner_comp_find_maxk_scan(arr, thresh, kx, ky): -# val = (arr * arr.conjugate()).real -# msk_thresh = val > thresh * thresh -# akx = jnp.abs(kx) -# aky = jnp.abs(ky) -# -# def _func(carry, x): -# msk_kx = akx <= x -# msk_ky = aky <= x -# return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) -# -# _, msk = jax.lax.scan(_func, None, xs=kx[0, :]) -# -# # We are searching for the location of the first string of -# # five locations in a row in `msk` where the value stays the -# # same. -# # We do this by putting the array through jnp.diff, which -# # computes the difference of adjacent elements. Then we convolve -# # with a filter of ones of length five to sum groups of five -# # elements together. The first location where the result is -# # zero is the location we want. The tricky bit however is getting -# # the indexing right. -# -# # step 1. compute the diff of adjacent elements -# # The function jnp.diff returns an array of size one less than -# # the input. So we concatenate a zero at the front. This makes -# # sense since if the original array is all constant, then the -# # location of the first five zeros is at the start of the array. -# delta_msk = jnp.concatenate( -# [jnp.array([0], dtype=int), jnp.diff(msk)], -# axis=0, -# dtype=int, +# def _inner_comp_find_maxk(arr, thresh, kx, ky): +# msk = (arr * arr.conjugate()).real > thresh * thresh +# max_kx = jnp.max( +# jnp.where( +# msk, +# jnp.abs(kx), +# -jnp.inf, +# ) # ) -# -# # step 2. convolve with the filter -# # In the discrete convolution, you have to deal with edge -# # behavior where the filter only partially overlaps the arrays. -# # We use the mode `full` which returns an array containing -# # every possible combination with missing elements set to zero. -# # We cut the first `length of filter - 1` elements so that -# # index i of the result is the sum of the filter starting -# # at index i of the input. -# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] -# -# # step 3. find first location of zero in the convolution -# # Finally, we use jnp.argmin to find the location of the first -# # zero. Per the doc string, if there is more than one zero, this -# # function returns the first location (i.e., smallest index) -# # which is what we want. -# msk_zero = sums == 0 -# sind, dk = jax.lax.cond( -# jnp.any(msk_zero), -# # if we find a set of zeros, the code computes the next pixel past -# # the pixels where |kval| > thresh. So we set dk = 0 since we don't -# # need to shift things. -# lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0), -# # if we get to the end of the array, we add one pixel spacing -# # so we match galsim -# lambda x: (-1, kx[0, -1] - kx[0, -2]), -# msk_zero, +# max_ky = jnp.max( +# jnp.where( +# msk, +# jnp.abs(ky), +# -jnp.inf, +# ) # ) -# return kx[0, sind] + dk +# # galsim adds one pixel at the end so that maxk is +# # the k value where things do not pass the threshold, +# # so we do that here too. +# return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0] + + +@jax.jit +def _inner_comp_find_maxk_scan(arr, thresh, kx, ky): + val = (arr * arr.conjugate()).real + msk_thresh = val > thresh * thresh + akx = jnp.abs(kx) + aky = jnp.abs(ky) + + def _func(carry, x): + msk_kx = akx <= x + msk_ky = aky <= x + return carry, jnp.sum(msk_thresh & msk_kx & msk_ky) + + _, msk = jax.lax.scan(_func, None, xs=kx[0, :]) + + # We are searching for the location of the first string of + # five locations in a row in `msk` where the value stays the + # same. + # We do this by putting the array through jnp.diff, which + # computes the difference of adjacent elements. Then we convolve + # with a filter of ones of length five to sum groups of five + # elements together. The first location where the result is + # zero is the location we want. The tricky bit however is getting + # the indexing right. + + # step 1. compute the diff of adjacent elements + # The function jnp.diff returns an array of size one less than + # the input. So we concatenate a zero at the front. This makes + # sense since if the original array is all constant, then the + # location of the first five zeros is at the start of the array. + delta_msk = jnp.concatenate( + [jnp.array([0], dtype=int), jnp.diff(msk)], + axis=0, + dtype=int, + ) + + # step 2. convolve with the filter + # In the discrete convolution, you have to deal with edge + # behavior where the filter only partially overlaps the arrays. + # We use the mode `full` which returns an array containing + # every possible combination with missing elements set to zero. + # We cut the first `length of filter - 1` elements so that + # index i of the result is the sum of the filter starting + # at index i of the input. + sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:] + + # step 3. find first location of zero in the convolution + # Finally, we use jnp.argmin to find the location of the first + # zero. Per the doc string, if there is more than one zero, this + # function returns the first location (i.e., smallest index) + # which is what we want. + msk_zero = sums == 0 + sind, dk = jax.lax.cond( + jnp.any(msk_zero), + # if we find a set of zeros, the code computes the next pixel past + # the pixels where |kval| > thresh. So we set dk = 0 since we don't + # need to shift things. + lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0), + # if we get to the end of the array, we add one pixel spacing + # so we match galsim + lambda x: (-1, kx[0, -1] - kx[0, -2]), + msk_zero, + ) + return kx[0, sind] + dk @jax.jit @@ -1295,6 +1295,6 @@ def _find_maxk(kim, max_maxk, thresh): # maxk from the image (computed by _inner_comp_find_maxk) # by max_maxk from above return jnp.minimum( - _inner_comp_find_maxk(kim.array, thresh, kx, ky), + _inner_comp_find_maxk_scan(kim.array, thresh, kx, ky), max_maxk, ) diff --git a/tests/GalSim b/tests/GalSim index 4cf07122..26f97325 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 4cf07122455fc2f73be5ebd7ad51ed569e8ebd05 +Subproject commit 26f9732554815ea8879c686c101fab07f3e2c2b0