Skip to content

Commit 8093664

Browse files
committed
Add threshold calibration utilities and corresponding tests
- Implemented threshold calibration utilities in `threshold_calibration.py` to compute optimal thresholds and evaluate metrics at specified thresholds. - Added unit tests for the threshold calibration functions in `test_threshold_calibration.py` to ensure correct functionality. - Introduced tests for connectome-aware analysis utilities in `test_connectome_analysis.py`, covering orientation and alignment of connectome matrices. - Created extensive tests for behavior rate model functionalities in `test_behavior_rate_model.py`, validating model training, predictions, and ablation importance.
1 parent bd12a4b commit 8093664

11 files changed

Lines changed: 3697 additions & 4 deletions

docs/BEHAVIOR_RATE_MODEL.md

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Dataset-level PER Rate Model (Option 1)
2+
3+
This document describes the dataset-level (aggregated) PER rate model used to reproduce a behavioral matrix like `reaction_rates_summary_unordered.csv` (rows = datasets/conditions, columns = test odors, values in `[0,1]`).
4+
5+
## Key definitions
6+
7+
- **Dataset**: a training condition (e.g., `opto_hex`, `hex_control`, `opto_EB_6_training`).
8+
- **Trained odor**: the odor paired with reward for `opto_*` datasets (or the matched odor for controls).
9+
- **PER rate**: the aggregate probability of proboscis extension response for a (dataset, test odor) cell, computed as the mean of binary outcomes across flies in that condition. This is a **dataset-level summary statistic**, NOT per-fly dynamics or trial-by-trial probabilities.
10+
- **Observed cell**: a cell in the behavioral matrix that had measurements (non-NaN in input CSV).
11+
- **Extrapolated cell**: a cell that was NaN in the input (no data), but for which the model can predict based on DoOR profiles.
12+
13+
## Model
14+
15+
We fit a sparse generalized linear model (GLM) on DoOR receptor-response features:
16+
17+
```
18+
logit(p̂) = b + b_reward * reward_flag
19+
+ w_test · x_test
20+
+ w_train · x_train
21+
+ w_int · (x_test ⊙ x_train)
22+
+ w_diff · (x_test - x_train) (optional)
23+
```
24+
25+
Where:
26+
- `x_test` = DoOR receptor-response vector for the test odor (length = #receptors, e.g. 78).
27+
- `x_train` = DoOR receptor-response vector for the trained odor for that dataset.
28+
- `reward_flag` = `1` if dataset starts with `opto_`, else `0`.
29+
- `` = elementwise product.
30+
31+
## Feature Construction
32+
33+
For each (dataset, test_odor) cell:
34+
35+
1. **Test profile** (`x_test`): DoOR receptor-response vector for the test odor
36+
2. **Trained profile** (`x_train`): DoOR vector for the trained odor (inferred from dataset name)
37+
3. **Interaction** (`x_test ⊙ x_train`): Element-wise product (co-activation)
38+
4. **Difference** (`x_test - x_train`, optional): Directional contrast
39+
40+
### Receptor Masking
41+
42+
- By default, uses 55 adult-only ORNs from `data/mappings/training_receptor_set.json`
43+
- Excludes 23 larval/unmapped receptors
44+
- Manifest saved to `adult_only_mask_manifest.json` with details
45+
46+
### Odor Name Resolution
47+
48+
- Behavioral CSV labels (e.g., "Hexanol") → DoOR canonical names ("1-hexanol")
49+
- Uses InChIKey-based lookup via candidate mappings + fuzzy matching
50+
- All resolution decisions logged to `odor_name_resolution.json`
51+
- Control stimuli (air, paraffin oil) encoded as zero vectors (logged as control_stimulus)
52+
- Unknown odors encoded as zero vectors (logged as NOT_FOUND)
53+
54+
## Decision → Evidence → Implementation notes
55+
56+
### BCE on fractional labels
57+
58+
- **Decision**: use binary cross-entropy (BCE) with logits on PER rates `y ∈ [0,1]`.
59+
- **Evidence**: a PER rate is an aggregate mean of Bernoulli outcomes; BCE corresponds to the negative log-likelihood for Bernoulli targets, and fractional targets are a standard “soft label” relaxation.
60+
- **Implementation**: training uses `torch.nn.functional.binary_cross_entropy_with_logits(logits, y)`.
61+
62+
### L1 (lasso-style) regularization
63+
64+
- **Decision**: apply L1 regularization to ORN weight vectors to encourage sparsity and interpretability.
65+
- **Evidence**: sparsity yields simpler receptor circuits and reduces sensitivity to correlated features.
66+
- **Implementation**: `loss = BCE + λ * (|w_test|_1 + |w_train|_1 + |w_int|_1 + |w_diff|_1)`.
67+
68+
### Ablation-based importance (preferred)
69+
70+
- **Decision**: prioritize ablation-based ORN importance for experiment planning.
71+
- **Evidence**: correlated inputs can make weight magnitudes misleading; ablation measures sensitivity of model fit in the data regime.
72+
- **Implementation**: for each ORN channel `i`, set channel `i` to 0 in both `x_test` and `x_train` (and derived features), recompute BCE on the full observed training table, and record `ΔBCE = BCE_ablated - BCE_baseline`.
73+
74+
### Trained-odor resolution
75+
76+
- **Decision**: resolve trained odors via an explicit mapping first, then deterministic parsing.
77+
- **Evidence**: dataset naming conventions are not fully standardized (e.g., `opto_benz_1`).
78+
- **Implementation**: `door_toolkit.pathways.behavior_rate_model.resolve_trained_odor_for_dataset()` uses a default mapping and common token parsing; unknown datasets raise a clear error unless you provide an override mapping.
79+
80+
### Input validation
81+
82+
- **Decision**: Auto-detect wide vs long CSV format and validate rate ranges.
83+
- **Evidence**: User CSVs may use percentage [0,100] or fractional [0,1] rates; explicit validation prevents silent errors.
84+
- **Implementation**: `load_behavior_matrix_csv()` detects format, rescales if needed (with warning), and logs metadata to `input_format_metadata.json`.
85+
86+
### Learning-effect computation
87+
88+
- **Decision**: ΔPER = PER_opto - PER_control for matched opto/control pairs.
89+
- **Evidence**: Learning effect is the causal difference between optogenetic stimulation and matched control; requires exact dataset pairing.
90+
- **Implementation**: `compute_learning_effect_metrics()` merges predictions for matched pairs (e.g., `opto_hex` vs `hex_control`), computes delta for both y_true and y_pred, reports per-odor errors.
91+
92+
## Outputs
93+
94+
The training script writes artifacts under `--output_dir`:
95+
96+
**Core predictions:**
97+
- `predicted_rates.csv`: **Long format** with observed cells only. Columns: `[dataset, test_odor, y_true, y_pred, split, trained_odor, trained_odor_door, test_odor_door]`. `split` = "observed" for all rows. `y_true` = original measured PER rate (no NaN). `y_pred` = model prediction.
98+
- `extrapolated_predictions.csv`: **Long format** with extrapolated cells (originally NaN in input). Same columns, but `split` = "extrapolated", `y_true` = NaN.
99+
100+
**Metrics:**
101+
- `metrics.json`: Global fit metrics (BCE, R², Pearson r) on observed cells.
102+
- `metrics_per_dataset.csv`: Breakdown by dataset (one row per dataset): `[dataset, n_cells, bce, r2, pearson_r, mae, trained_odor]`.
103+
- `metrics_per_test_odor.csv`: Breakdown by test odor: `[test_odor, n_cells, bce, r2, pearson_r, mae, door_name]`.
104+
- `learning_effect_metrics.csv`: ΔPER = opto - control for matched pairs: `[test_odor, delta_per_true, delta_per_pred, error, opto_dataset, control_dataset]` (only if matched controls exist).
105+
- `predicted_learning_effect_opto_minus_control.csv`: Old wide-format learning effect table (for backward compatibility).
106+
107+
**Receptor importance:**
108+
- `orn_importance_global.csv`: Combined weight-based and ablation-based importance.
109+
- `orn_importance_by_dataset.csv`: Ablation ΔBCE per dataset.
110+
- `orn_importance_by_test_odor.csv`: Ablation ΔBCE per test odor.
111+
- `ablation_sweep.csv`: Top-K receptors with learning-effect deltas.
112+
113+
**Connectome-aware interpretation (optional):**
114+
- `connectome_analysis/connectome_inputs.json`: paths, hashes, shapes, orientation/alignment reports used for connectome propagation.
115+
- `connectome_analysis/orn_connectome_amplified_importance.csv`: ORN-level base importance + fanout metrics + connectome-amplified importance (ranked).
116+
- `connectome_analysis/pn_influence.csv`: Top PNs by propagated influence.
117+
- `connectome_analysis/kc_influence.csv`: Top KCs by propagated influence.
118+
- `connectome_analysis/connectome_summary.json`: Concentration metrics (top-K fractions, gini-like summary) and top ORNs/PNs/KCs.
119+
- `connectome_analysis/orn_to_pn_top_edges.csv`: (Optional) ORN→PN edge list for top ORNs by amplified importance.
120+
121+
**Provenance & logging:**
122+
- `run_config.json`: CLI args, git SHA, hyperparameters.
123+
- `input_format_metadata.json`: CSV format detection, rescaling log.
124+
- `odor_name_resolution.json`: All odor name mappings (csv_name → door_name or NOT_FOUND).
125+
- `adult_only_mask_manifest.json`: Receptor mask details (55 adult ORNs kept, 23 excluded).
126+
- `training_loss.csv`: Training curve for diagnostics.
127+
128+
## Connectome-aware interpretation (post-training)
129+
130+
This repository also contains FlyWire-derived connectivity artifacts for ORN→PN and PN→KC. The behavior-rate model training objective is unchanged; after training we can **optionally** propagate the learned ORN importance through these matrices to get a first-order downstream readout.
131+
132+
### Inputs
133+
134+
Expected files (from `scripts/extract_flywire_connectivity.py`):
135+
- `orn_pn_connectivity.pt`
136+
- `pn_kc_connectivity.pt`
137+
- `connectivity_metadata.json` (required for ORN ordering alignment; must include `receptor_names`)
138+
139+
By default, the training script auto-detects connectivity under:
140+
- `data/pgcn_features/connectivity/` (preferred)
141+
- `data/connectivity/`
142+
143+
Override with `--connectome_dir`, or skip with `--disable_connectome_analysis`.
144+
145+
### Propagation definition
146+
147+
Let:
148+
- `s_orn` be the non-negative ORN importance vector from the trained GLM weights, where `s_orn[i] = |w_test[i]| + |w_train[i]| + |w_int[i]| (+ |w_diff[i]| if enabled)`.
149+
- `A` be the ORN→PN matrix, oriented as `(n_pn, n_orn)` after detecting whether the stored file is ORN×PN vs PN×ORN.
150+
- `B` be the PN→KC matrix, oriented as `(n_kc, n_pn)` after detecting whether the stored file is PN×KC vs KC×PN.
151+
152+
We compute:
153+
154+
```
155+
s_pn = A @ s_orn
156+
s_kc = B @ s_pn
157+
```
158+
159+
These are not firing rates or dynamics; they are a linear "influence mass" readout under the assumption that larger ORN importance combined with stronger fanout yields larger downstream footprint.
160+
161+
### Connectome-amplified ORN importance
162+
163+
We define a simple amplification term based on **two-hop KC reach**:
164+
165+
```
166+
fanout_kc[orn] = sum_kc ( (B @ A)[kc, orn] )
167+
amp_factor = fanout_kc / mean(fanout_kc)
168+
connectome_amplified_importance = s_orn * amp_factor
169+
```
170+
171+
This answers: *do the ORNs that matter in the GLM also sit in wiring positions that project broadly into KCs?*
172+
173+
### Decision → Evidence → Implementation notes
174+
175+
- **Decision**: infer matrix orientation from shape compatibility (shared PN dimension), not hardcoded assumptions.
176+
- **Evidence**: exported connectivity artifacts may be stored as ORN×PN vs PN×ORN (and PN×KC vs KC×PN) depending on pipeline.
177+
- **Implementation**: `door_toolkit.pathways.connectome_analysis.orient_connectome()` orients matrices to `(PN, ORN)` and `(KC, PN)` by matching the shared PN dimension.
178+
179+
- **Decision**: require explicit receptor name mapping for ORN alignment.
180+
- **Evidence**: the behavior-rate model uses a 55-ORN adult-only subset in encoder order; connectome matrices often include additional receptors in a different order.
181+
- **Implementation**: `align_orn_connectome()` maps by name using `connectivity_metadata.json["receptor_names"]`; missing receptors raise a clear error (and auto-detected runs skip analysis rather than producing misaligned outputs).
182+
183+
## How to run
184+
185+
```
186+
python scripts/train_behavior_rate_model.py \
187+
--behavior_csv /path/to/reaction_rates_summary_unordered.csv \
188+
--output_dir outputs/behavior_rate_model \
189+
--epochs 500 --lr 1e-2 --l1_lambda 1e-3
190+
```
191+
192+
Default behavior:
193+
- trains on datasets starting with `opto_`
194+
- excludes `opto_AIR` from training
195+
196+
To train on all datasets in the CSV, set `--train_prefix ''`.

docs/CONNECTOME_RNN_MODEL.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,28 @@ Report AUROC, AUPRC (average precision), and balanced accuracy. Not just raw acc
194194

195195
**Location**: [train_static_door_rnn.py:108-123](../scripts/train_static_door_rnn.py)
196196

197+
### Threshold Calibration (Validation → Test)
198+
199+
#### Decision
200+
Choose a probability threshold on the **validation set only**, then apply it once on the held-out **test set**. Report test metrics at both:
201+
1) fixed `threshold=0.5` (status quo), and
202+
2) calibrated `threshold=thr_opt_from_val`.
203+
204+
#### Evidence
205+
- **Imbalanced labels** can push predicted probabilities below 0.5, yielding **all-negative predictions** and `balanced_acc=0.5` even when AUROC is strong.
206+
- Selecting a threshold on the test set is **data leakage** and inflates reported metrics.
207+
- Comparing against baselines (e.g., logistic regression) requires a **fair thresholding protocol** when balanced accuracy is reported.
208+
209+
#### Implementation
210+
- Collect `y_val_true/y_val_prob` after training, compute `thr_opt_from_val` by maximizing balanced accuracy using a deterministic midpoint grid (ties → lowest threshold).
211+
- Apply `thr_opt_from_val` to `y_test_prob` once (no re-optimization on test).
212+
- Save a self-contained evaluation artifact: `threshold_calibration_eval.json` in the run output directory.
213+
- Also embed the same summary block in `metrics.json` under `threshold_calibration_eval`.
214+
215+
**Location**:
216+
- Threshold search + metrics helpers: [`src/door_toolkit/threshold_calibration.py`](../src/door_toolkit/threshold_calibration.py)
217+
- Training integration + artifact write: [`scripts/train_static_door_rnn.py`](../scripts/train_static_door_rnn.py)
218+
197219
---
198220

199221
## 8. Provenance Tracking

0 commit comments

Comments
 (0)