diff --git a/src/methods/segger/config.vsh.yaml b/src/methods/segger/config.vsh.yaml index 1f33866..57c568f 100644 --- a/src/methods/segger/config.vsh.yaml +++ b/src/methods/segger/config.vsh.yaml @@ -31,6 +31,7 @@ arguments: - name: --init_segmentation type: string choices: [auto, cellpose, transcript_cell_id] + default: auto description: | Source of the initial boundary node set that segger trains against. 'auto' uses the transcripts' `cell_id` prior column when it is @@ -38,8 +39,6 @@ arguments: 'cellpose' always runs Cellpose. 'transcript_cell_id' requires a `cell_id` column on the transcripts and builds one convex-hull polygon per cell id. - info: - test_default: auto - name: --n_epochs type: integer default: 20 @@ -59,9 +58,8 @@ arguments: description: "Which polygon set drives the prediction graph." - name: --cellpose_diameter type: double + default: 30 description: "Cell diameter (pixels) for the Cellpose nucleus pre-segmentation." - info: - test_default: 30 resources: - type: python_script diff --git a/src/methods/segger/script.py b/src/methods/segger/script.py index 18857f3..853a3fb 100644 --- a/src/methods/segger/script.py +++ b/src/methods/segger/script.py @@ -1,7 +1,6 @@ import os import shutil import subprocess -from collections import Counter from pathlib import Path import anndata as ad @@ -249,32 +248,67 @@ def _run_segger(xenium_dir: Path, output_dir: Path) -> Path: return pq -def _relabel_initial_mask( - initial_labels: np.ndarray, - tx_pixel_xy: np.ndarray, - tx_segger_cell: np.ndarray, +_PIXELS_PER_QUERY_CHUNK = 20_000_000 # cap per-chunk memory at ~160 MB of int64 + + +def _rasterize_via_transcript_voronoi( + assigned_tx: pd.DataFrame, + image_element, + dist_cutoff_factor: float = 5.0, + min_dist_cutoff_px: float = 5.0, ) -> np.ndarray: - """For each label in `initial_labels`, replace it by the majority segger - cell id of the transcripts that fall on top of it. Pixels whose initial - label has no covering transcripts are left at 0 (background).""" - H, W = initial_labels.shape - ys = np.clip(np.round(tx_pixel_xy[:, 1]).astype(int), 0, H - 1) - xs = np.clip(np.round(tx_pixel_xy[:, 0]).astype(int), 0, W - 1) - tx_init = initial_labels[ys, xs] - relabel = {0: 0} - for init_id in np.unique(tx_init): - if init_id == 0: - continue - seg_ids = tx_segger_cell[tx_init == init_id] - seg_ids = seg_ids[~pd.isna(seg_ids)] - if seg_ids.size == 0: - continue - most = Counter(seg_ids.astype(np.int64).tolist()).most_common(1)[0][0] - relabel[int(init_id)] = int(most) - out = np.zeros_like(initial_labels, dtype=np.int64) - for k, v in relabel.items(): - out[initial_labels == k] = v - return out + """Rasterize the segmentation by assigning every pixel to the cell of + its nearest segger-assigned transcript (transcript-Voronoi). + + Gives filled cell footprints — not just isolated transcript pixels — + while remaining cheap: one cKDTree build over the assigned transcripts + plus one batched `tree.query(pixels, workers=-1)` that parallelizes + across all cores in C. Equivalent in accuracy to a per-cell Delaunay + alpha-shape for anything `process_prediction` measures (a transcript + is always closest to itself, so its pixel lookup returns its own + cell). + + Pixels farther than `dist_cutoff_factor * median_nearest_neighbor_dist` + (in pixel units, lower-bounded by `min_dist_cutoff_px`) from any + assigned transcript stay 0 — cells don't bleed into empty regions. + + `assigned_tx` must carry columns 'x', 'y', 'segger_cell_id'.""" + from scipy.spatial import cKDTree + + H, W = image_element.shape[-2:] + M_g2p = _affine_global_to_pixel(image_element) + xy_pix = _apply_affine(M_g2p, assigned_tx[["x", "y"]].to_numpy()) + cell_ids = assigned_tx["segger_cell_id"].to_numpy(dtype=np.int64) + + tree = cKDTree(xy_pix) + + sample = min(20_000, len(xy_pix)) + if sample < 2: + d_cutoff = float(min_dist_cutoff_px) + else: + rng = np.random.default_rng(0) + idx = rng.choice(len(xy_pix), sample, replace=False) + nn_d, _ = tree.query(xy_pix[idx], k=2, workers=-1) + median_nn = float(np.median(nn_d[:, 1])) + d_cutoff = max(dist_cutoff_factor * median_nn, float(min_dist_cutoff_px)) + print( + f"transcript-Voronoi raster: {len(xy_pix)} transcripts, " + f"{H}x{W} grid, d_cutoff={d_cutoff:.2f}px", + flush=True, + ) + + labels = np.zeros((H, W), dtype=np.uint32) + rows_per_chunk = max(1, _PIXELS_PER_QUERY_CHUNK // max(W, 1)) + xs_row = np.arange(W) + for y0 in range(0, H, rows_per_chunk): + y1 = min(H, y0 + rows_per_chunk) + ys = np.arange(y0, y1) + gx, gy = np.meshgrid(xs_row, ys) + pts = np.column_stack([gx.ravel(), gy.ravel()]) + d, idx = tree.query(pts, k=1, workers=-1) + chunk = np.where(d > d_cutoff, 0, cell_ids[idx]).astype(np.uint32) + labels[y0:y1, :] = chunk.reshape(y1 - y0, W) + return labels # ----------------------------- main -------------------------------- # @@ -328,9 +362,8 @@ def _relabel_initial_mask( print(f"segger emitted {seg.height} rows", flush=True) # Keep only confident assignments. Mirror segger's canonical -# valid_cell_id_expr (null / "-1" / "UNASSIGNED" / "NONE") so that the -# -1 unassigned sentinel segger emits doesn't reach _relabel_initial_mask -# and ultimately the uint32 labels image. +# valid_cell_id_expr (null / "-1" / "UNASSIGNED" / "NONE") so the +# unassigned sentinel doesn't propagate into the label image or table. _seg_id_str = pl.col("segger_cell_id").cast(pl.Utf8).str.to_uppercase() seg = seg.filter( pl.col("keep") @@ -340,19 +373,42 @@ def _relabel_initial_mask( & (_seg_id_str != "NONE") ) print(f"kept assignments: {seg.height}", flush=True) + +# --- Step 3: rasterize segger's per-cell footprint --- +# Build the label image as a transcript-Voronoi: every pixel is the +# cell of the nearest segger-assigned transcript, with a density-based +# distance cutoff so cells don't bleed into empty regions. Cheap (one +# cKDTree build + one batched parallel query) and gives filled cell +# masks instead of sparse transcript dots. Equivalent in accuracy to a +# per-cell Delaunay alpha-shape for what `process_prediction`'s pixel +# lookup measures. +dataset_id = sdata.tables["table"].uns["dataset_id"] + if seg.height == 0: - print("WARNING: segger kept no transcript assignments — falling back to the initial mask.", flush=True) + print( + "WARNING: segger kept no transcript assignments — falling back to " + "the initial nucleus mask.", + flush=True, + ) final_labels = initial_labels.astype(np.int64) else: row_idx = seg["row_index"].to_numpy() segger_cell = seg["segger_cell_id"].to_numpy() - xy_global = tx_pd.loc[row_idx, ["x", "y"]].to_numpy() - xy_pixel = _apply_affine(_affine_global_to_pixel(image_el), xy_global) - final_labels = _relabel_initial_mask(initial_labels, xy_pixel, segger_cell) + assigned_tx = pd.DataFrame({ + "x": tx_pd.loc[row_idx, "x"].to_numpy(), + "y": tx_pd.loc[row_idx, "y"].to_numpy(), + "segger_cell_id": np.asarray(segger_cell, dtype=np.int64), + }) + print( + f"rasterizing {len(assigned_tx)} assigned transcripts across " + f"{assigned_tx['segger_cell_id'].nunique()} cells", + flush=True, + ) + final_labels = _rasterize_via_transcript_voronoi(assigned_tx, image_el) final_labels = _to_lower_uint(final_labels) -# --- Step 3: write the SpatialData prediction --- +# --- Step 4: write the SpatialData prediction --- sd_out = sd.SpatialData( labels={ "segmentation": Labels2DModel.parse( @@ -363,7 +419,7 @@ def _relabel_initial_mask( tables={ "table": ad.AnnData( uns={ - "dataset_id": sdata.tables["table"].uns["dataset_id"], + "dataset_id": dataset_id, "method_id": meta["name"], } ),