Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def __init__(
max_margin_collisions = delta_collision_samples + margin_waveforms
self._margin = max_margin_collisions

# for some edge cases a template can be zero, leading to problems later
template_is_zero = [np.all(template == 0) for template in all_templates]

self._all_templates = all_templates
self._sparsity_mask = sparsity_mask
self._nbefore = nbefore
Expand All @@ -209,6 +212,7 @@ def __init__(
self._cut_out_after = cut_out_after
self._handle_collisions = handle_collisions
self._delta_collision_samples = delta_collision_samples
self._template_is_zero = template_is_zero

self._kwargs.update(
all_templates=all_templates,
Expand All @@ -220,6 +224,7 @@ def __init__(
return_in_uV=return_in_uV,
handle_collisions=handle_collisions,
delta_collision_samples=delta_collision_samples,
template_is_zero=template_is_zero,
)

def get_dtype(self):
Expand All @@ -239,6 +244,7 @@ def compute(self, traces, peaks):
cut_out_after = self._cut_out_after
handle_collisions = self._handle_collisions
delta_collision_samples = self._delta_collision_samples
template_is_zero = self._template_is_zero

# local_spikes_within_margin = peaks
# i0 = np.searchsorted(local_spikes_within_margin["sample_index"], left_margin)
Expand All @@ -265,7 +271,14 @@ def compute(self, traces, peaks):
if spike_index in collisions.keys():
# we deal with overlapping spikes later
continue

unit_index = spike["unit_index"]

if template_is_zero[unit_index]:
# if template is zero, linregress will fail so we intervene
scalings[spike_index] = 0
continue

sample_centered = spike["sample_index"]
(sparse_indices,) = np.nonzero(sparsity_mask[unit_index])
template = all_templates[unit_index][:, sparse_indices]
Expand Down
Loading