|
| 1 | +# Dataset-level PER Rate Model (Option 1) |
| 2 | + |
| 3 | +This document describes the dataset-level (aggregated) PER rate model used to reproduce a behavioral matrix like `reaction_rates_summary_unordered.csv` (rows = datasets/conditions, columns = test odors, values in `[0,1]`). |
| 4 | + |
| 5 | +## Key definitions |
| 6 | + |
| 7 | +- **Dataset**: a training condition (e.g., `opto_hex`, `hex_control`, `opto_EB_6_training`). |
| 8 | +- **Trained odor**: the odor paired with reward for `opto_*` datasets (or the matched odor for controls). |
| 9 | +- **PER rate**: the aggregate probability of proboscis extension response for a (dataset, test odor) cell, computed as the mean of binary outcomes across flies in that condition. This is a **dataset-level summary statistic**, NOT per-fly dynamics or trial-by-trial probabilities. |
| 10 | +- **Observed cell**: a cell in the behavioral matrix that had measurements (non-NaN in input CSV). |
| 11 | +- **Extrapolated cell**: a cell that was NaN in the input (no data), but for which the model can predict based on DoOR profiles. |
| 12 | + |
| 13 | +## Model |
| 14 | + |
| 15 | +We fit a sparse generalized linear model (GLM) on DoOR receptor-response features: |
| 16 | + |
| 17 | +``` |
| 18 | +logit(p̂) = b + b_reward * reward_flag |
| 19 | + + w_test · x_test |
| 20 | + + w_train · x_train |
| 21 | + + w_int · (x_test ⊙ x_train) |
| 22 | + + w_diff · (x_test - x_train) (optional) |
| 23 | +``` |
| 24 | + |
| 25 | +Where: |
| 26 | +- `x_test` = DoOR receptor-response vector for the test odor (length = #receptors, e.g. 78). |
| 27 | +- `x_train` = DoOR receptor-response vector for the trained odor for that dataset. |
| 28 | +- `reward_flag` = `1` if dataset starts with `opto_`, else `0`. |
| 29 | +- `⊙` = elementwise product. |
| 30 | + |
| 31 | +## Feature Construction |
| 32 | + |
| 33 | +For each (dataset, test_odor) cell: |
| 34 | + |
| 35 | +1. **Test profile** (`x_test`): DoOR receptor-response vector for the test odor |
| 36 | +2. **Trained profile** (`x_train`): DoOR vector for the trained odor (inferred from dataset name) |
| 37 | +3. **Interaction** (`x_test ⊙ x_train`): Element-wise product (co-activation) |
| 38 | +4. **Difference** (`x_test - x_train`, optional): Directional contrast |
| 39 | + |
| 40 | +### Receptor Masking |
| 41 | + |
| 42 | +- By default, uses 55 adult-only ORNs from `data/mappings/training_receptor_set.json` |
| 43 | +- Excludes 23 larval/unmapped receptors |
| 44 | +- Manifest saved to `adult_only_mask_manifest.json` with details |
| 45 | + |
| 46 | +### Odor Name Resolution |
| 47 | + |
| 48 | +- Behavioral CSV labels (e.g., "Hexanol") → DoOR canonical names ("1-hexanol") |
| 49 | +- Uses InChIKey-based lookup via candidate mappings + fuzzy matching |
| 50 | +- All resolution decisions logged to `odor_name_resolution.json` |
| 51 | +- Control stimuli (air, paraffin oil) encoded as zero vectors (logged as control_stimulus) |
| 52 | +- Unknown odors encoded as zero vectors (logged as NOT_FOUND) |
| 53 | + |
| 54 | +## Decision → Evidence → Implementation notes |
| 55 | + |
| 56 | +### BCE on fractional labels |
| 57 | + |
| 58 | +- **Decision**: use binary cross-entropy (BCE) with logits on PER rates `y ∈ [0,1]`. |
| 59 | +- **Evidence**: a PER rate is an aggregate mean of Bernoulli outcomes; BCE corresponds to the negative log-likelihood for Bernoulli targets, and fractional targets are a standard “soft label” relaxation. |
| 60 | +- **Implementation**: training uses `torch.nn.functional.binary_cross_entropy_with_logits(logits, y)`. |
| 61 | + |
| 62 | +### L1 (lasso-style) regularization |
| 63 | + |
| 64 | +- **Decision**: apply L1 regularization to ORN weight vectors to encourage sparsity and interpretability. |
| 65 | +- **Evidence**: sparsity yields simpler receptor circuits and reduces sensitivity to correlated features. |
| 66 | +- **Implementation**: `loss = BCE + λ * (|w_test|_1 + |w_train|_1 + |w_int|_1 + |w_diff|_1)`. |
| 67 | + |
| 68 | +### Ablation-based importance (preferred) |
| 69 | + |
| 70 | +- **Decision**: prioritize ablation-based ORN importance for experiment planning. |
| 71 | +- **Evidence**: correlated inputs can make weight magnitudes misleading; ablation measures sensitivity of model fit in the data regime. |
| 72 | +- **Implementation**: for each ORN channel `i`, set channel `i` to 0 in both `x_test` and `x_train` (and derived features), recompute BCE on the full observed training table, and record `ΔBCE = BCE_ablated - BCE_baseline`. |
| 73 | + |
| 74 | +### Trained-odor resolution |
| 75 | + |
| 76 | +- **Decision**: resolve trained odors via an explicit mapping first, then deterministic parsing. |
| 77 | +- **Evidence**: dataset naming conventions are not fully standardized (e.g., `opto_benz_1`). |
| 78 | +- **Implementation**: `door_toolkit.pathways.behavior_rate_model.resolve_trained_odor_for_dataset()` uses a default mapping and common token parsing; unknown datasets raise a clear error unless you provide an override mapping. |
| 79 | + |
| 80 | +### Input validation |
| 81 | + |
| 82 | +- **Decision**: Auto-detect wide vs long CSV format and validate rate ranges. |
| 83 | +- **Evidence**: User CSVs may use percentage [0,100] or fractional [0,1] rates; explicit validation prevents silent errors. |
| 84 | +- **Implementation**: `load_behavior_matrix_csv()` detects format, rescales if needed (with warning), and logs metadata to `input_format_metadata.json`. |
| 85 | + |
| 86 | +### Learning-effect computation |
| 87 | + |
| 88 | +- **Decision**: ΔPER = PER_opto - PER_control for matched opto/control pairs. |
| 89 | +- **Evidence**: Learning effect is the causal difference between optogenetic stimulation and matched control; requires exact dataset pairing. |
| 90 | +- **Implementation**: `compute_learning_effect_metrics()` merges predictions for matched pairs (e.g., `opto_hex` vs `hex_control`), computes delta for both y_true and y_pred, reports per-odor errors. |
| 91 | + |
| 92 | +## Outputs |
| 93 | + |
| 94 | +The training script writes artifacts under `--output_dir`: |
| 95 | + |
| 96 | +**Core predictions:** |
| 97 | +- `predicted_rates.csv`: **Long format** with observed cells only. Columns: `[dataset, test_odor, y_true, y_pred, split, trained_odor, trained_odor_door, test_odor_door]`. `split` = "observed" for all rows. `y_true` = original measured PER rate (no NaN). `y_pred` = model prediction. |
| 98 | +- `extrapolated_predictions.csv`: **Long format** with extrapolated cells (originally NaN in input). Same columns, but `split` = "extrapolated", `y_true` = NaN. |
| 99 | + |
| 100 | +**Metrics:** |
| 101 | +- `metrics.json`: Global fit metrics (BCE, R², Pearson r) on observed cells. |
| 102 | +- `metrics_per_dataset.csv`: Breakdown by dataset (one row per dataset): `[dataset, n_cells, bce, r2, pearson_r, mae, trained_odor]`. |
| 103 | +- `metrics_per_test_odor.csv`: Breakdown by test odor: `[test_odor, n_cells, bce, r2, pearson_r, mae, door_name]`. |
| 104 | +- `learning_effect_metrics.csv`: ΔPER = opto - control for matched pairs: `[test_odor, delta_per_true, delta_per_pred, error, opto_dataset, control_dataset]` (only if matched controls exist). |
| 105 | +- `predicted_learning_effect_opto_minus_control.csv`: Old wide-format learning effect table (for backward compatibility). |
| 106 | + |
| 107 | +**Receptor importance:** |
| 108 | +- `orn_importance_global.csv`: Combined weight-based and ablation-based importance. |
| 109 | +- `orn_importance_by_dataset.csv`: Ablation ΔBCE per dataset. |
| 110 | +- `orn_importance_by_test_odor.csv`: Ablation ΔBCE per test odor. |
| 111 | +- `ablation_sweep.csv`: Top-K receptors with learning-effect deltas. |
| 112 | + |
| 113 | +**Connectome-aware interpretation (optional):** |
| 114 | +- `connectome_analysis/connectome_inputs.json`: paths, hashes, shapes, orientation/alignment reports used for connectome propagation. |
| 115 | +- `connectome_analysis/orn_connectome_amplified_importance.csv`: ORN-level base importance + fanout metrics + connectome-amplified importance (ranked). |
| 116 | +- `connectome_analysis/pn_influence.csv`: Top PNs by propagated influence. |
| 117 | +- `connectome_analysis/kc_influence.csv`: Top KCs by propagated influence. |
| 118 | +- `connectome_analysis/connectome_summary.json`: Concentration metrics (top-K fractions, gini-like summary) and top ORNs/PNs/KCs. |
| 119 | +- `connectome_analysis/orn_to_pn_top_edges.csv`: (Optional) ORN→PN edge list for top ORNs by amplified importance. |
| 120 | + |
| 121 | +**Provenance & logging:** |
| 122 | +- `run_config.json`: CLI args, git SHA, hyperparameters. |
| 123 | +- `input_format_metadata.json`: CSV format detection, rescaling log. |
| 124 | +- `odor_name_resolution.json`: All odor name mappings (csv_name → door_name or NOT_FOUND). |
| 125 | +- `adult_only_mask_manifest.json`: Receptor mask details (55 adult ORNs kept, 23 excluded). |
| 126 | +- `training_loss.csv`: Training curve for diagnostics. |
| 127 | + |
| 128 | +## Connectome-aware interpretation (post-training) |
| 129 | + |
| 130 | +This repository also contains FlyWire-derived connectivity artifacts for ORN→PN and PN→KC. The behavior-rate model training objective is unchanged; after training we can **optionally** propagate the learned ORN importance through these matrices to get a first-order downstream readout. |
| 131 | + |
| 132 | +### Inputs |
| 133 | + |
| 134 | +Expected files (from `scripts/extract_flywire_connectivity.py`): |
| 135 | +- `orn_pn_connectivity.pt` |
| 136 | +- `pn_kc_connectivity.pt` |
| 137 | +- `connectivity_metadata.json` (required for ORN ordering alignment; must include `receptor_names`) |
| 138 | + |
| 139 | +By default, the training script auto-detects connectivity under: |
| 140 | +- `data/pgcn_features/connectivity/` (preferred) |
| 141 | +- `data/connectivity/` |
| 142 | + |
| 143 | +Override with `--connectome_dir`, or skip with `--disable_connectome_analysis`. |
| 144 | + |
| 145 | +### Propagation definition |
| 146 | + |
| 147 | +Let: |
| 148 | +- `s_orn` be the non-negative ORN importance vector from the trained GLM weights, where `s_orn[i] = |w_test[i]| + |w_train[i]| + |w_int[i]| (+ |w_diff[i]| if enabled)`. |
| 149 | +- `A` be the ORN→PN matrix, oriented as `(n_pn, n_orn)` after detecting whether the stored file is ORN×PN vs PN×ORN. |
| 150 | +- `B` be the PN→KC matrix, oriented as `(n_kc, n_pn)` after detecting whether the stored file is PN×KC vs KC×PN. |
| 151 | + |
| 152 | +We compute: |
| 153 | + |
| 154 | +``` |
| 155 | +s_pn = A @ s_orn |
| 156 | +s_kc = B @ s_pn |
| 157 | +``` |
| 158 | + |
| 159 | +These are not firing rates or dynamics; they are a linear "influence mass" readout under the assumption that larger ORN importance combined with stronger fanout yields larger downstream footprint. |
| 160 | + |
| 161 | +### Connectome-amplified ORN importance |
| 162 | + |
| 163 | +We define a simple amplification term based on **two-hop KC reach**: |
| 164 | + |
| 165 | +``` |
| 166 | +fanout_kc[orn] = sum_kc ( (B @ A)[kc, orn] ) |
| 167 | +amp_factor = fanout_kc / mean(fanout_kc) |
| 168 | +connectome_amplified_importance = s_orn * amp_factor |
| 169 | +``` |
| 170 | + |
| 171 | +This answers: *do the ORNs that matter in the GLM also sit in wiring positions that project broadly into KCs?* |
| 172 | + |
| 173 | +### Decision → Evidence → Implementation notes |
| 174 | + |
| 175 | +- **Decision**: infer matrix orientation from shape compatibility (shared PN dimension), not hardcoded assumptions. |
| 176 | + - **Evidence**: exported connectivity artifacts may be stored as ORN×PN vs PN×ORN (and PN×KC vs KC×PN) depending on pipeline. |
| 177 | + - **Implementation**: `door_toolkit.pathways.connectome_analysis.orient_connectome()` orients matrices to `(PN, ORN)` and `(KC, PN)` by matching the shared PN dimension. |
| 178 | + |
| 179 | +- **Decision**: require explicit receptor name mapping for ORN alignment. |
| 180 | + - **Evidence**: the behavior-rate model uses a 55-ORN adult-only subset in encoder order; connectome matrices often include additional receptors in a different order. |
| 181 | + - **Implementation**: `align_orn_connectome()` maps by name using `connectivity_metadata.json["receptor_names"]`; missing receptors raise a clear error (and auto-detected runs skip analysis rather than producing misaligned outputs). |
| 182 | + |
| 183 | +## How to run |
| 184 | + |
| 185 | +``` |
| 186 | +python scripts/train_behavior_rate_model.py \ |
| 187 | + --behavior_csv /path/to/reaction_rates_summary_unordered.csv \ |
| 188 | + --output_dir outputs/behavior_rate_model \ |
| 189 | + --epochs 500 --lr 1e-2 --l1_lambda 1e-3 |
| 190 | +``` |
| 191 | + |
| 192 | +Default behavior: |
| 193 | +- trains on datasets starting with `opto_` |
| 194 | +- excludes `opto_AIR` from training |
| 195 | + |
| 196 | +To train on all datasets in the CSV, set `--train_prefix ''`. |
0 commit comments