Skip to content
Merged
80 changes: 73 additions & 7 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,78 @@ def _inner_comp_find_maxk(arr, thresh, kx, ky):
-jnp.inf,
)
)
return jnp.maximum(max_kx, max_ky)
# galsim adds one pixel at the end so that maxk is
# the k value where things do not pass the threshold,
# so we do that here too.
return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0]


# this version matches galsim's maxk operation exactly, but is
# more expensive to compute since it has a scan operation.
# I am leaving it here for posterity. - MRB
# @jax.jit
# def _inner_comp_find_maxk_scan(arr, thresh, kx, ky):
# val = (arr * arr.conjugate()).real
# msk_thresh = val > thresh * thresh
# akx = jnp.abs(kx)
# aky = jnp.abs(ky)
#
# def _func(carry, x):
# msk_kx = akx <= x
# msk_ky = aky <= x
# return carry, jnp.sum(msk_thresh & msk_kx & msk_ky)
#
# _, msk = jax.lax.scan(_func, None, xs=kx[0, :])
#
# # We are searching for the location of the first string of
# # five locations in a row in `msk` where the value stays the
# # same.
# # We do this by putting the array through jnp.diff, which
# # computes the difference of adjacent elements. Then we convolve
# # with a filter of ones of length five to sum groups of five
# # elements together. The first location where the result is
# # zero is the location we want. The tricky bit however is getting
# # the indexing right.
#
# # step 1. compute the diff of adjacent elements
# # The function jnp.diff returns an array of size one less than
# # the input. So we concatenate a zero at the front. This makes
# # sense since if the original array is all constant, then the
# # location of the first five zeros is at the start of the array.
# delta_msk = jnp.concatenate(
# [jnp.array([0], dtype=int), jnp.diff(msk)],
# axis=0,
# dtype=int,
# )
#
# # step 2. convolve with the filter
# # In the discrete convolution, you have to deal with edge
# # behavior where the filter only partially overlaps the arrays.
# # We use the mode `full` which returns an array containing
# # every possible combination with missing elements set to zero.
# # We cut the first `length of filter - 1` elements so that
# # index i of the result is the sum of the filter starting
# # at index i of the input.
# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:]
#
# # step 3. find first location of zero in the convolution
# # Finally, we use jnp.argmin to find the location of the first
# # zero. Per the doc string, if there is more than one zero, this
# # function returns the first location (i.e., smallest index)
# # which is what we want.
# msk_zero = sums == 0
# sind, dk = jax.lax.cond(
# jnp.any(msk_zero),
# # if we find a set of zeros, the code computes the next pixel past
# # the pixels where |kval| > thresh. So we set dk = 0 since we don't
# # need to shift things.
# lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0),
# # if we get to the end of the array, we add one pixel spacing
# # so we match galsim
# lambda x: (-1, kx[0, -1] - kx[0, -2]),
# msk_zero,
# )
# return kx[0, sind] + dk


@jax.jit
Expand All @@ -1226,11 +1297,6 @@ def _find_maxk(kim, max_maxk, thresh):
# maxk from the image (computed by _inner_comp_find_maxk)
# by max_maxk from above
return jnp.minimum(
# jax-galsim tends to be less conservative for maxk
# since compared to galsim, it does NOT require 5 rows
# of pixels in a row below the threshold.
# thus we add pixels here to ensure the galsim tests pass.
# it turns out one worked ok so that is what we did. - MRB
_inner_comp_find_maxk(kim.array, thresh, kx, ky) + 1 * kim.scale,
_inner_comp_find_maxk(kim.array, thresh, kx, ky),
max_maxk,
)
2 changes: 1 addition & 1 deletion tests/GalSim
Submodule GalSim updated 1 files
+122 −94 tests/test_moffat.py
Loading
Loading