Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 87 additions & 87 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,95 +1195,95 @@ def _calculate_size_containing_flux(image, thresh):
)


@jax.jit
def _inner_comp_find_maxk(arr, thresh, kx, ky):
msk = (arr * arr.conjugate()).real > thresh * thresh
max_kx = jnp.max(
jnp.where(
msk,
jnp.abs(kx),
-jnp.inf,
)
)
max_ky = jnp.max(
jnp.where(
msk,
jnp.abs(ky),
-jnp.inf,
)
)
# galsim adds one pixel at the end so that maxk is
# the k value where things do not pass the threshold,
# so we do that here too.
return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0]


# this version matches galsim's maxk operation exactly, but is
# more expensive to compute since it has a scan operation.
# I am leaving it here for posterity. - MRB
# this version doe snot match galsim's maxk operation exactly,
# but is faster to compute. I am leaving it here for
# posterity. - MRB
# @jax.jit
# def _inner_comp_find_maxk_scan(arr, thresh, kx, ky):
# val = (arr * arr.conjugate()).real
# msk_thresh = val > thresh * thresh
# akx = jnp.abs(kx)
# aky = jnp.abs(ky)
#
# def _func(carry, x):
# msk_kx = akx <= x
# msk_ky = aky <= x
# return carry, jnp.sum(msk_thresh & msk_kx & msk_ky)
#
# _, msk = jax.lax.scan(_func, None, xs=kx[0, :])
#
# # We are searching for the location of the first string of
# # five locations in a row in `msk` where the value stays the
# # same.
# # We do this by putting the array through jnp.diff, which
# # computes the difference of adjacent elements. Then we convolve
# # with a filter of ones of length five to sum groups of five
# # elements together. The first location where the result is
# # zero is the location we want. The tricky bit however is getting
# # the indexing right.
#
# # step 1. compute the diff of adjacent elements
# # The function jnp.diff returns an array of size one less than
# # the input. So we concatenate a zero at the front. This makes
# # sense since if the original array is all constant, then the
# # location of the first five zeros is at the start of the array.
# delta_msk = jnp.concatenate(
# [jnp.array([0], dtype=int), jnp.diff(msk)],
# axis=0,
# dtype=int,
# def _inner_comp_find_maxk(arr, thresh, kx, ky):
# msk = (arr * arr.conjugate()).real > thresh * thresh
# max_kx = jnp.max(
# jnp.where(
# msk,
# jnp.abs(kx),
# -jnp.inf,
# )
# )
#
# # step 2. convolve with the filter
# # In the discrete convolution, you have to deal with edge
# # behavior where the filter only partially overlaps the arrays.
# # We use the mode `full` which returns an array containing
# # every possible combination with missing elements set to zero.
# # We cut the first `length of filter - 1` elements so that
# # index i of the result is the sum of the filter starting
# # at index i of the input.
# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:]
#
# # step 3. find first location of zero in the convolution
# # Finally, we use jnp.argmin to find the location of the first
# # zero. Per the doc string, if there is more than one zero, this
# # function returns the first location (i.e., smallest index)
# # which is what we want.
# msk_zero = sums == 0
# sind, dk = jax.lax.cond(
# jnp.any(msk_zero),
# # if we find a set of zeros, the code computes the next pixel past
# # the pixels where |kval| > thresh. So we set dk = 0 since we don't
# # need to shift things.
# lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0),
# # if we get to the end of the array, we add one pixel spacing
# # so we match galsim
# lambda x: (-1, kx[0, -1] - kx[0, -2]),
# msk_zero,
# max_ky = jnp.max(
# jnp.where(
# msk,
# jnp.abs(ky),
# -jnp.inf,
# )
# )
# return kx[0, sind] + dk
# # galsim adds one pixel at the end so that maxk is
# # the k value where things do not pass the threshold,
# # so we do that here too.
# return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0]


@jax.jit
def _inner_comp_find_maxk_scan(arr, thresh, kx, ky):
val = (arr * arr.conjugate()).real
msk_thresh = val > thresh * thresh
akx = jnp.abs(kx)
aky = jnp.abs(ky)

def _func(carry, x):
msk_kx = akx <= x
msk_ky = aky <= x
return carry, jnp.sum(msk_thresh & msk_kx & msk_ky)

_, msk = jax.lax.scan(_func, None, xs=kx[0, :])

# We are searching for the location of the first string of
# five locations in a row in `msk` where the value stays the
# same.
# We do this by putting the array through jnp.diff, which
# computes the difference of adjacent elements. Then we convolve
# with a filter of ones of length five to sum groups of five
# elements together. The first location where the result is
# zero is the location we want. The tricky bit however is getting
# the indexing right.

# step 1. compute the diff of adjacent elements
# The function jnp.diff returns an array of size one less than
# the input. So we concatenate a zero at the front. This makes
# sense since if the original array is all constant, then the
# location of the first five zeros is at the start of the array.
delta_msk = jnp.concatenate(
[jnp.array([0], dtype=int), jnp.diff(msk)],
axis=0,
dtype=int,
)

# step 2. convolve with the filter
# In the discrete convolution, you have to deal with edge
# behavior where the filter only partially overlaps the arrays.
# We use the mode `full` which returns an array containing
# every possible combination with missing elements set to zero.
# We cut the first `length of filter - 1` elements so that
# index i of the result is the sum of the filter starting
# at index i of the input.
sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:]

# step 3. find first location of zero in the convolution
# Finally, we use jnp.argmin to find the location of the first
# zero. Per the doc string, if there is more than one zero, this
# function returns the first location (i.e., smallest index)
# which is what we want.
msk_zero = sums == 0
sind, dk = jax.lax.cond(
jnp.any(msk_zero),
# if we find a set of zeros, the code computes the next pixel past
# the pixels where |kval| > thresh. So we set dk = 0 since we don't
# need to shift things.
lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0),
# if we get to the end of the array, we add one pixel spacing
# so we match galsim
lambda x: (-1, kx[0, -1] - kx[0, -2]),
msk_zero,
)
return kx[0, sind] + dk


@jax.jit
Expand All @@ -1295,6 +1295,6 @@ def _find_maxk(kim, max_maxk, thresh):
# maxk from the image (computed by _inner_comp_find_maxk)
# by max_maxk from above
return jnp.minimum(
_inner_comp_find_maxk(kim.array, thresh, kx, ky),
_inner_comp_find_maxk_scan(kim.array, thresh, kx, ky),
max_maxk,
)
2 changes: 1 addition & 1 deletion tests/GalSim
Loading