diff --git a/TODO.md b/TODO.md index f6231b1..e33542c 100644 --- a/TODO.md +++ b/TODO.md @@ -52,9 +52,9 @@ Deferred items from PR reviews that were not addressed before merge. | ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) | | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | -| CallawaySantAnna survey: strata/PSU/FPC rejected at runtime. Full design-based SEs require routing the combined IF/WIF through `compute_survey_vcov()`. Currently weights-only. | `staggered.py` | #233 | Medium | +| CallawaySantAnna survey: strata/PSU/FPC — **Resolved**. Aggregated SEs now use `compute_survey_if_variance()`. Bootstrap uses PSU-level multiplier weights. | `staggered.py` | #233 | Resolved | | CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | -| SyntheticDiD/TROP survey: strata/PSU/FPC deferred. Full design-based bootstrap (Rao-Wu rescaled weights) needed for survey-aware resampling. Currently pweight-only. | `synthetic_did.py`, `trop.py` | — | Medium | +| SyntheticDiD/TROP survey: strata/PSU/FPC — **Resolved**. Rao-Wu rescaled bootstrap implemented for both. TROP uses cross-classified pseudo-strata. Rust TROP remains pweight-only (Python fallback for full design). | `synthetic_did.py`, `trop.py` | — | Resolved | | EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | EfficientDiD `control_group="last_cohort"` trims at `last_g - anticipation` but REGISTRY says `t >= last_g`. With `anticipation=0` (default) these are identical. With `anticipation>0`, code is arguably more conservative (excludes anticipation-contaminated periods). Either align REGISTRY with code or change code to `t < last_g` — needs design decision. | `efficient_did.py` | #230 | Low | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | diff --git a/diff_diff/_backend.py b/diff_diff/_backend.py index 2065e17..b3931b7 100644 --- a/diff_diff/_backend.py +++ b/diff_diff/_backend.py @@ -13,7 +13,7 @@ # Check for backend override via environment variable # DIFF_DIFF_BACKEND can be: 'auto' (default), 'python', or 'rust' -_backend_env = os.environ.get('DIFF_DIFF_BACKEND', 'auto').lower() +_backend_env = os.environ.get("DIFF_DIFF_BACKEND", "auto").lower() # Try to import Rust backend for accelerated operations try: @@ -38,6 +38,7 @@ # Diagnostics rust_backend_info as _rust_backend_info, ) + _rust_available = True except ImportError: _rust_available = False @@ -61,7 +62,7 @@ _rust_backend_info = None # Determine final backend based on environment variable and availability -if _backend_env == 'python': +if _backend_env == "python": # Force pure Python mode - disable Rust even if available HAS_RUST_BACKEND = False _rust_bootstrap_weights = None @@ -82,7 +83,7 @@ _rust_compute_noise_level = None _rust_sc_weight_fw = None _rust_backend_info = None -elif _backend_env == 'rust': +elif _backend_env == "rust": # Force Rust mode - fail if not available if not _rust_available: raise ImportError( @@ -111,23 +112,23 @@ def rust_backend_info(): __all__ = [ - 'HAS_RUST_BACKEND', - 'rust_backend_info', - '_rust_bootstrap_weights', - '_rust_synthetic_weights', - '_rust_project_simplex', - '_rust_solve_ols', - '_rust_compute_robust_vcov', + "HAS_RUST_BACKEND", + "rust_backend_info", + "_rust_bootstrap_weights", + "_rust_synthetic_weights", + "_rust_project_simplex", + "_rust_solve_ols", + "_rust_compute_robust_vcov", # TROP estimator acceleration (local method) - '_rust_unit_distance_matrix', - '_rust_loocv_grid_search', - '_rust_bootstrap_trop_variance', + "_rust_unit_distance_matrix", + "_rust_loocv_grid_search", + "_rust_bootstrap_trop_variance", # TROP estimator acceleration (global method) - '_rust_loocv_grid_search_global', - '_rust_bootstrap_trop_variance_global', + "_rust_loocv_grid_search_global", + "_rust_bootstrap_trop_variance_global", # SDID weights (Frank-Wolfe matching R's synthdid) - '_rust_sdid_unit_weights', - '_rust_compute_time_weights', - '_rust_compute_noise_level', - '_rust_sc_weight_fw', + "_rust_sdid_unit_weights", + "_rust_compute_time_weights", + "_rust_compute_noise_level", + "_rust_sc_weight_fw", ] diff --git a/diff_diff/bootstrap_utils.py b/diff_diff/bootstrap_utils.py index 7115692..1cb164e 100644 --- a/diff_diff/bootstrap_utils.py +++ b/diff_diff/bootstrap_utils.py @@ -16,6 +16,9 @@ "generate_bootstrap_weights", "generate_bootstrap_weights_batch", "generate_bootstrap_weights_batch_numpy", + "generate_survey_multiplier_weights_batch", + "generate_rao_wu_weights", + "generate_rao_wu_weights_batch", "compute_percentile_ci", "compute_bootstrap_pvalue", "compute_effect_bootstrap_stats", @@ -54,15 +57,20 @@ def generate_bootstrap_weights( p1 = (sqrt5 + 1) / (2 * sqrt5) return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1]) elif weight_type == "webb": - values = np.array([ - -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), - np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2) - ]) + values = np.array( + [ + -np.sqrt(3 / 2), + -np.sqrt(2 / 2), + -np.sqrt(1 / 2), + np.sqrt(1 / 2), + np.sqrt(2 / 2), + np.sqrt(3 / 2), + ] + ) return rng.choice(values, size=n_units) else: raise ValueError( - f"weight_type must be 'rademacher', 'mammen', or 'webb', " - f"got '{weight_type}'" + f"weight_type must be 'rademacher', 'mammen', or 'webb', " f"got '{weight_type}'" ) @@ -133,15 +141,20 @@ def generate_bootstrap_weights_batch_numpy( p1 = (sqrt5 + 1) / (2 * sqrt5) return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1]) elif weight_type == "webb": - values = np.array([ - -np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2), - np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2) - ]) + values = np.array( + [ + -np.sqrt(3 / 2), + -np.sqrt(2 / 2), + -np.sqrt(1 / 2), + np.sqrt(1 / 2), + np.sqrt(2 / 2), + np.sqrt(3 / 2), + ] + ) return rng.choice(values, size=(n_bootstrap, n_units)) else: raise ValueError( - f"weight_type must be 'rademacher', 'mammen', or 'webb', " - f"got '{weight_type}'" + f"weight_type must be 'rademacher', 'mammen', or 'webb', " f"got '{weight_type}'" ) @@ -274,9 +287,7 @@ def compute_effect_bootstrap_stats( return np.nan, (np.nan, np.nan), np.nan ci = compute_percentile_ci(valid_dist, alpha) - p_value = compute_bootstrap_pvalue( - original_effect, valid_dist, n_valid=len(valid_dist) - ) + p_value = compute_bootstrap_pvalue(original_effect, valid_dist, n_valid=len(valid_dist)) return se, ci, p_value @@ -392,8 +403,7 @@ def compute_effect_bootstrap_stats_batch( if np.any(partial_valid): for j in np.where(partial_valid)[0]: se, ci, pv = compute_effect_bootstrap_stats( - original_effects[j], bootstrap_matrix[:, j], alpha=alpha, - context=f"effect {j}" + original_effects[j], bootstrap_matrix[:, j], alpha=alpha, context=f"effect {j}" ) ses[j] = se ci_lowers[j] = ci[0] @@ -401,3 +411,248 @@ def compute_effect_bootstrap_stats_batch( p_values[j] = pv return ses, ci_lowers, ci_uppers, p_values + + +# --------------------------------------------------------------------------- +# Survey-aware bootstrap weight generators +# --------------------------------------------------------------------------- + + +def generate_survey_multiplier_weights_batch( + n_bootstrap: int, + resolved_survey: "ResolvedSurveyDesign", + weight_type: str, + rng: np.random.Generator, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate PSU-level multiplier weights for survey-aware bootstrap. + + Within each stratum, weights are generated independently. When FPC + is present, weights are scaled by ``sqrt(1 - f_h)`` per stratum so + the bootstrap variance matches the TSL variance. + + Parameters + ---------- + n_bootstrap : int + Number of bootstrap iterations. + resolved_survey : ResolvedSurveyDesign + Resolved survey design. + weight_type : str + Multiplier distribution: ``"rademacher"``, ``"mammen"``, or ``"webb"``. + rng : np.random.Generator + Random number generator. + + Returns + ------- + weights : np.ndarray + Multiplier weights, shape ``(n_bootstrap, n_psu)``. + psu_ids : np.ndarray + Unique PSU identifiers aligned to columns of *weights*. + """ + psu = resolved_survey.psu + strata = resolved_survey.strata + + if resolved_survey.lonely_psu == "adjust": + raise NotImplementedError( + "lonely_psu='adjust' is not yet supported for survey-aware bootstrap. " + "Use lonely_psu='remove' or 'certainty', or use analytical inference." + ) + + if psu is None: + # Each observation is its own PSU + n_psu = len(resolved_survey.weights) + psu_ids = np.arange(n_psu) + else: + psu_ids = np.unique(psu) + n_psu = len(psu_ids) + + if strata is None: + # No stratification — generate a single block of weights + if n_psu < 2: + # Single PSU — variance unidentified (matches compute_survey_vcov) + weights = np.zeros((n_bootstrap, n_psu), dtype=np.float64) + return weights, psu_ids + weights = generate_bootstrap_weights_batch(n_bootstrap, n_psu, weight_type, rng) + # FPC scaling (unstratified) + if resolved_survey.fpc is not None: + if psu is not None: + n_units_for_fpc = n_psu + else: + n_units_for_fpc = len(resolved_survey.weights) + if resolved_survey.fpc[0] < n_units_for_fpc: + raise ValueError( + f"FPC ({resolved_survey.fpc[0]}) is less than the number of PSUs " + f"({n_units_for_fpc}). FPC must be >= number of PSUs." + ) + f = n_units_for_fpc / resolved_survey.fpc[0] + if f < 1.0: + weights = weights * np.sqrt(1.0 - f) + else: + weights = np.zeros_like(weights) + else: + # Stratified — generate independently within strata + weights = np.empty((n_bootstrap, n_psu), dtype=np.float64) + + # Build PSU → column-index map + psu_to_col = {int(p): i for i, p in enumerate(psu_ids)} + + unique_strata = np.unique(strata) + for h in unique_strata: + mask_h = strata == h + + if psu is not None: + psus_in_h = np.unique(psu[mask_h]) + else: + psus_in_h = np.where(mask_h)[0] + + n_h = len(psus_in_h) + cols = np.array([psu_to_col[int(p)] for p in psus_in_h]) + + if n_h < 2: + # Lonely PSU — zero weight (matches remove/certainty behavior) + weights[:, cols] = 0.0 + continue + + # Generate weights for this stratum + stratum_weights = generate_bootstrap_weights_batch_numpy( + n_bootstrap, n_h, weight_type, rng + ) + + # FPC scaling + if resolved_survey.fpc is not None: + N_h = resolved_survey.fpc[mask_h][0] + if N_h < n_h: + raise ValueError( + f"FPC ({N_h}) is less than the number of PSUs " + f"({n_h}) in stratum {h}. FPC must be >= n_PSU." + ) + f_h = n_h / N_h + if f_h < 1.0: + stratum_weights = stratum_weights * np.sqrt(1.0 - f_h) + else: + stratum_weights = np.zeros_like(stratum_weights) + + weights[:, cols] = stratum_weights + + return weights, psu_ids + + +def generate_rao_wu_weights( + resolved_survey: "ResolvedSurveyDesign", + rng: np.random.Generator, +) -> np.ndarray: + """Generate one set of Rao-Wu (1988) rescaled observation weights. + + Within each stratum *h* with *n_h* PSUs, draw ``m_h`` PSUs with + replacement and rescale observation weights by ``(n_h / m_h) * r_hi`` + where ``r_hi`` is the count of PSU *i* being selected. + + Without FPC: ``m_h = n_h - 1``. + With FPC: ``m_h = max(1, round((1 - f_h) * (n_h - 1)))`` + (Rao, Wu & Yue 1992, Section 3). + + Parameters + ---------- + resolved_survey : ResolvedSurveyDesign + Resolved survey design. + rng : np.random.Generator + Random number generator. + + Returns + ------- + np.ndarray + Rescaled observation weights, shape ``(n_obs,)``. + """ + n_obs = len(resolved_survey.weights) + base_weights = resolved_survey.weights + psu = resolved_survey.psu + strata = resolved_survey.strata + + if resolved_survey.lonely_psu == "adjust": + raise NotImplementedError( + "lonely_psu='adjust' is not yet supported for survey-aware bootstrap. " + "Use lonely_psu='remove' or 'certainty', or use analytical inference." + ) + + rescaled = np.zeros(n_obs, dtype=np.float64) + + if psu is None: + obs_psu = np.arange(n_obs) + else: + obs_psu = psu + + if strata is None: + strata_masks = [np.ones(n_obs, dtype=bool)] + else: + unique_strata = np.unique(strata) + strata_masks = [strata == h for h in unique_strata] + + for mask_h in strata_masks: + psu_h = obs_psu[mask_h] + unique_psu_h = np.unique(psu_h) + n_h = len(unique_psu_h) + + if n_h < 2: + # Census / lonely PSU — keep original weights (zero variance) + rescaled[mask_h] = base_weights[mask_h] + continue + + # Compute resample size + if resolved_survey.fpc is not None: + N_h = resolved_survey.fpc[mask_h][0] + if N_h < n_h: + raise ValueError( + f"FPC ({N_h}) is less than the number of PSUs " + f"({n_h}). FPC must be >= number of PSUs." + ) + f_h = n_h / N_h + m_h = max(1, round((1.0 - f_h) * (n_h - 1))) + else: + m_h = n_h - 1 + + if m_h == 0: + # Full census — keep original weights + rescaled[mask_h] = base_weights[mask_h] + continue + + # Draw m_h PSUs with replacement + drawn_indices = rng.choice(n_h, size=m_h, replace=True) + counts = np.bincount(drawn_indices, minlength=n_h) + + # Rescale factor per PSU: (n_h / m_h) * r_hi + scale_per_psu = (n_h / m_h) * counts.astype(np.float64) + + # Map PSU → local index for vectorized application + psu_to_local = {int(p): i for i, p in enumerate(unique_psu_h)} + obs_in_h = np.where(mask_h)[0] + local_indices = np.array([psu_to_local[int(obs_psu[idx])] for idx in obs_in_h]) + rescaled[obs_in_h] = base_weights[obs_in_h] * scale_per_psu[local_indices] + + return rescaled + + +def generate_rao_wu_weights_batch( + n_bootstrap: int, + resolved_survey: "ResolvedSurveyDesign", + rng: np.random.Generator, +) -> np.ndarray: + """Generate multiple sets of Rao-Wu rescaled weights. + + Parameters + ---------- + n_bootstrap : int + Number of bootstrap iterations. + resolved_survey : ResolvedSurveyDesign + Resolved survey design. + rng : np.random.Generator + Random number generator. + + Returns + ------- + np.ndarray + Rescaled weights, shape ``(n_bootstrap, n_obs)``. + """ + n_obs = len(resolved_survey.weights) + result = np.empty((n_bootstrap, n_obs), dtype=np.float64) + for b in range(n_bootstrap): + result[b] = generate_rao_wu_weights(resolved_survey, rng) + return result diff --git a/diff_diff/continuous_did.py b/diff_diff/continuous_did.py index f4c242e..ef581ea 100644 --- a/diff_diff/continuous_did.py +++ b/diff_diff/continuous_did.py @@ -212,12 +212,7 @@ def fit( if resolved_survey is not None: _validate_unit_constant_survey(data, unit, survey_design) - # Guard: bootstrap + survey not yet supported - if self.n_bootstrap > 0 and resolved_survey is not None: - raise NotImplementedError( - "Multiplier bootstrap with survey weights is planned for Phase 5. " - "Use n_bootstrap=0 with survey_design for design-based standard errors." - ) + # Bootstrap + survey supported via PSU-level multiplier bootstrap. df = data.copy() for col in [outcome, unit, time, first_treat, dose]: @@ -465,6 +460,7 @@ def fit( agg_att_d, agg_acrt_d, event_study_effects, + resolved_survey=resolved_survey, ) att_d_se = boot_result["att_d_se"] att_d_ci_lower = boot_result["att_d_ci_lower"] @@ -630,11 +626,13 @@ def fit( mu_0 = b_info["mu_0"] delta_y_treated = b_info["delta_y_treated"] ee_control = b_info["ee_control"] + sw_treated = b_info.get("w_treated_arr") for k, uid in enumerate(treated_idx): - if_es[uid] += ( - w * (delta_y_treated[k] - att_glob_gt - mu_0) / p_1 / n_total_gt - ) + score_k = delta_y_treated[k] - att_glob_gt - mu_0 + if sw_treated is not None: + score_k = sw_treated[k] * score_k + if_es[uid] += w * score_k / p_1 / n_total_gt for k, uid in enumerate(control_idx): if_es[uid] -= w * ee_control[k] / p_0 / n_total_gt @@ -978,10 +976,19 @@ def _compute_dose_response_gt( bread = np.linalg.pinv(PtP / n_treated) # ee_treated: per-unit estimating equation vectors (K-vector per unit) - ee_treated = Psi * residuals[:, np.newaxis] # (n_treated, K) + # For WLS (survey weights), the score is w_i * X_i * u_i to match the + # weighted bread inv(X'WX / sum(w)). Without this factor the sandwich + # is inconsistent. For OLS (no survey weights), the score is X_i * u_i. + if w_treated is not None: + ee_treated = Psi * (w_treated * residuals)[:, np.newaxis] # (n_treated, K) + else: + ee_treated = Psi * residuals[:, np.newaxis] # (n_treated, K) - # ee_control: per-unit deviation from control mean - ee_control = delta_y_control - mu_0 # (n_control,) + # ee_control: per-unit deviation from control mean (weighted for WLS) + if w_control is not None: + ee_control = w_control * (delta_y_control - mu_0) # (n_control,) + else: + ee_control = delta_y_control - mu_0 # (n_control,) # psi_bar: mean basis vector for treated (weighted when survey) if w_treated is not None: @@ -1021,10 +1028,12 @@ def _compute_dose_response_gt( "acrt_glob": acrt_glob, } - # Store survey-weighted masses for IF linearization + # Store survey-weighted masses and per-unit arrays for IF linearization if w_treated is not None: bootstrap_info["w_treated"] = float(np.sum(w_treated)) bootstrap_info["w_control"] = float(np.sum(w_control)) + bootstrap_info["w_treated_arr"] = w_treated + bootstrap_info["w_control_arr"] = w_control return { "att_d": att_d, @@ -1127,14 +1136,23 @@ def _compute_analytical_se( att_glob_gt = info["att_glob"] mu_0 = info["mu_0"] delta_y_treated = info["delta_y_treated"] + # Per-unit survey weight array (None when no survey) + sw_treated = info.get("w_treated_arr") n_total = n_t + n_c p_1 = n_t / n_total p_0 = n_c / n_total # IF for ATT_glob (binarized DiD) + # When survey weights are present, each unit's score includes its + # survey weight w_k so the sandwich is consistent with the weighted + # estimand. ee_control already contains the w_k factor (set in + # _compute_dose_response_gt); delta_y_treated needs it here. for k, idx in enumerate(treated_idx): - if_att_glob[idx] += w * (delta_y_treated[k] - att_glob_gt - mu_0) / p_1 / n_total + score_k = delta_y_treated[k] - att_glob_gt - mu_0 + if sw_treated is not None: + score_k = sw_treated[k] * score_k + if_att_glob[idx] += w * score_k / p_1 / n_total for k, idx in enumerate(control_idx): if_att_glob[idx] -= w * ee_control[k] / p_0 / n_total @@ -1265,6 +1283,7 @@ def _run_bootstrap( original_att_d: np.ndarray, original_acrt_d: np.ndarray, event_study_effects: Optional[Dict[int, Dict]], + resolved_survey: object = None, ) -> Dict[str, Any]: """Run multiplier bootstrap inference.""" if self.n_bootstrap < 50: @@ -1279,11 +1298,62 @@ def _run_bootstrap( n_units = precomp["n_units"] n_grid = len(dvals) - # Generate all weights upfront - all_weights = generate_bootstrap_weights_batch( - self.n_bootstrap, n_units, self.bootstrap_weights, rng + # Build unit-level ResolvedSurveyDesign for survey-aware bootstrap + unit_resolved = None + if resolved_survey is not None: + from diff_diff.survey import ResolvedSurveyDesign + + row_idx = precomp["unit_first_panel_row"] + unit_weights = precomp.get("unit_survey_weights") + if unit_weights is None: + unit_weights = np.ones(n_units) + unit_strata = ( + resolved_survey.strata[row_idx] if resolved_survey.strata is not None else None + ) + unit_psu = resolved_survey.psu[row_idx] if resolved_survey.psu is not None else None + unit_fpc = resolved_survey.fpc[row_idx] if resolved_survey.fpc is not None else None + n_strata_u = len(np.unique(unit_strata)) if unit_strata is not None else 0 + n_psu_u = len(np.unique(unit_psu)) if unit_psu is not None else 0 + unit_resolved = ResolvedSurveyDesign( + weights=unit_weights, + weight_type=resolved_survey.weight_type, + strata=unit_strata, + psu=unit_psu, + fpc=unit_fpc, + n_strata=n_strata_u, + n_psu=n_psu_u, + lonely_psu=resolved_survey.lonely_psu, + ) + + # Generate bootstrap weights — PSU-level when survey design is present + _use_survey_bootstrap = unit_resolved is not None and ( + unit_resolved.strata is not None + or unit_resolved.psu is not None + or unit_resolved.fpc is not None ) + if _use_survey_bootstrap: + from diff_diff.bootstrap_utils import ( + generate_survey_multiplier_weights_batch, + ) + + psu_weights, psu_ids = generate_survey_multiplier_weights_batch( + self.n_bootstrap, unit_resolved, self.bootstrap_weights, rng + ) + # Build unit -> PSU column map + if unit_resolved.psu is not None: + psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)} + unit_to_psu_col = np.array( + [psu_id_to_col[int(unit_resolved.psu[i])] for i in range(n_units)] + ) + else: + unit_to_psu_col = np.arange(n_units) + all_weights = psu_weights[:, unit_to_psu_col] + else: + all_weights = generate_bootstrap_weights_batch( + self.n_bootstrap, n_units, self.bootstrap_weights, rng + ) + boot_att_glob = np.zeros(self.n_bootstrap) boot_acrt_glob = np.zeros(self.n_bootstrap) boot_att_d = np.zeros((self.n_bootstrap, n_grid)) @@ -1292,22 +1362,34 @@ def _run_bootstrap( # Event study bootstrap — compute weights per event-time bin es_keys = sorted(event_study_effects.keys()) if event_study_effects else [] boot_es = {e: np.zeros(self.n_bootstrap) for e in es_keys} - # Per-(g,t) weight within event-time bin + # Per-(g,t) weight within event-time bin — use survey-weighted cohort + # masses when available, matching _aggregate_event_study. + unit_sw = precomp.get("unit_survey_weights") + unit_cohorts = precomp["unit_cohorts"] es_cell_weights: Dict[Tuple, float] = {} if event_study_effects is not None: - # Build event-time bin weights from n_treated from collections import defaultdict es_bin_total: Dict[int, float] = defaultdict(float) for gt, r in gt_results.items(): g_val, t_val = gt e = t_val - g_val - es_bin_total[e] += float(r["n_treated"]) + if unit_sw is not None: + g_mask = unit_cohorts == g_val + cell_mass = float(np.sum(unit_sw[g_mask])) + else: + cell_mass = float(r["n_treated"]) + es_bin_total[e] += cell_mass for gt, r in gt_results.items(): g_val, t_val = gt e = t_val - g_val + if unit_sw is not None: + g_mask = unit_cohorts == g_val + cell_mass = float(np.sum(unit_sw[g_mask])) + else: + cell_mass = float(r["n_treated"]) if es_bin_total[e] > 0: - es_cell_weights[gt] = float(r["n_treated"]) / es_bin_total[e] + es_cell_weights[gt] = cell_mass / es_bin_total[e] # Helper to bootstrap a single (g,t) cell def _bootstrap_gt_cell(gt, info): @@ -1316,6 +1398,10 @@ def _bootstrap_gt_cell(gt, info): control_idx = info["control_indices"] n_t = info["n_treated"] n_c = info["n_control"] + # Use survey-weighted masses when available (matching analytical SE) + if "w_treated" in info: + n_t = info["w_treated"] + n_c = info["w_control"] bread = info["bread"] ee_treated = info["ee_treated"] ee_control = info["ee_control"] @@ -1327,6 +1413,7 @@ def _bootstrap_gt_cell(gt, info): delta_y_treated = info["delta_y_treated"] mu_0 = info["mu_0"] att_glob_gt = info["att_glob"] + sw_treated = info.get("w_treated_arr") w_treated = all_weights[:, treated_idx] w_control = all_weights[:, control_idx] @@ -1343,7 +1430,12 @@ def _bootstrap_gt_cell(gt, info): acrt_d_b = beta_b @ dPsi_eval.T mu_0_pert = (w_control @ ee_control) / n_c - mean_dy_treated_pert = (w_treated @ (delta_y_treated - att_glob_gt - mu_0)) / n_t + # ATT_glob perturbation: weight scores by survey weight w_k + # when present, matching the analytical IF path. + att_glob_score = delta_y_treated - att_glob_gt - mu_0 + if sw_treated is not None: + att_glob_score = sw_treated * att_glob_score + mean_dy_treated_pert = (w_treated @ att_glob_score) / n_t att_glob_b = att_glob_gt + mean_dy_treated_pert - mu_0_pert dpsi_mean = np.mean(dPsi_treated, axis=0) diff --git a/diff_diff/continuous_did_bspline.py b/diff_diff/continuous_did_bspline.py index c5a13b2..7042be6 100644 --- a/diff_diff/continuous_did_bspline.py +++ b/diff_diff/continuous_did_bspline.py @@ -51,11 +51,13 @@ def build_bspline_basis(dose, degree=3, num_knots=0): interior_knots = np.array([]) # Full knot vector: clamped at boundaries - knots = np.concatenate([ - np.repeat(d_L, degree + 1), - interior_knots, - np.repeat(d_U, degree + 1), - ]) + knots = np.concatenate( + [ + np.repeat(d_L, degree + 1), + interior_knots, + np.repeat(d_U, degree + 1), + ] + ) return knots, degree diff --git a/diff_diff/datasets.py b/diff_diff/datasets.py index f6170cb..d3fe556 100644 --- a/diff_diff/datasets.py +++ b/diff_diff/datasets.py @@ -135,9 +135,11 @@ def load_card_krueger(force_download: bool = False) -> pd.DataFrame: df = _construct_card_krueger_data() # Standardize column names and add convenience columns - df = df.rename(columns={ - "sheet": "store_id", - }) + df = df.rename( + columns={ + "sheet": "store_id", + } + ) # Ensure proper types if "state" not in df.columns and "nj" in df.columns: @@ -176,15 +178,17 @@ def _construct_card_krueger_data() -> pd.DataFrame: emp_pre = max(0, emp_pre) emp_post = max(0, emp_post) - stores.append({ - "store_id": store_id, - "state": "NJ", - "chain": chain, - "emp_pre": round(emp_pre, 1), - "emp_post": round(emp_post, 1), - "wage_pre": round(np.random.normal(4.61, 0.35), 2), - "wage_post": round(np.random.normal(5.08, 0.12), 2), - }) + stores.append( + { + "store_id": store_id, + "state": "NJ", + "chain": chain, + "emp_pre": round(emp_pre, 1), + "emp_post": round(emp_post, 1), + "wage_pre": round(np.random.normal(4.61, 0.35), 2), + "wage_post": round(np.random.normal(5.08, 0.12), 2), + } + ) store_id += 1 # Pennsylvania stores (control) - summary stats from paper @@ -198,15 +202,17 @@ def _construct_card_krueger_data() -> pd.DataFrame: emp_pre = max(0, emp_pre) emp_post = max(0, emp_post) - stores.append({ - "store_id": store_id, - "state": "PA", - "chain": chain, - "emp_pre": round(emp_pre, 1), - "emp_post": round(emp_post, 1), - "wage_pre": round(np.random.normal(4.63, 0.35), 2), - "wage_post": round(np.random.normal(4.62, 0.35), 2), - }) + stores.append( + { + "store_id": store_id, + "state": "PA", + "chain": chain, + "emp_pre": round(emp_pre, 1), + "emp_post": round(emp_post, 1), + "wage_pre": round(np.random.normal(4.63, 0.35), 2), + "wage_post": round(np.random.normal(4.62, 0.35), 2), + } + ) store_id += 1 df = pd.DataFrame(stores) @@ -310,16 +316,56 @@ def _construct_castle_doctrine_data() -> pd.DataFrame: # States and their Castle Doctrine adoption years # 0 = never adopted during the study period state_adoption = { - "AL": 2006, "AK": 2006, "AZ": 2006, "FL": 2005, "GA": 2006, - "IN": 2006, "KS": 2006, "KY": 2006, "LA": 2006, "MI": 2006, - "MS": 2006, "MO": 2007, "MT": 2009, "NH": 2011, "NC": 2011, - "ND": 2007, "OH": 2008, "OK": 2006, "PA": 2011, "SC": 2006, - "SD": 2006, "TN": 2007, "TX": 2007, "UT": 2010, "WV": 2008, + "AL": 2006, + "AK": 2006, + "AZ": 2006, + "FL": 2005, + "GA": 2006, + "IN": 2006, + "KS": 2006, + "KY": 2006, + "LA": 2006, + "MI": 2006, + "MS": 2006, + "MO": 2007, + "MT": 2009, + "NH": 2011, + "NC": 2011, + "ND": 2007, + "OH": 2008, + "OK": 2006, + "PA": 2011, + "SC": 2006, + "SD": 2006, + "TN": 2007, + "TX": 2007, + "UT": 2010, + "WV": 2008, # Control states (never adopted or adopted after 2010) - "CA": 0, "CO": 0, "CT": 0, "DE": 0, "HI": 0, "ID": 0, - "IL": 0, "IA": 0, "ME": 0, "MD": 0, "MA": 0, "MN": 0, - "NE": 0, "NV": 0, "NJ": 0, "NM": 0, "NY": 0, "OR": 0, - "RI": 0, "VT": 0, "VA": 0, "WA": 0, "WI": 0, "WY": 0, + "CA": 0, + "CO": 0, + "CT": 0, + "DE": 0, + "HI": 0, + "ID": 0, + "IL": 0, + "IA": 0, + "ME": 0, + "MD": 0, + "MA": 0, + "MN": 0, + "NE": 0, + "NV": 0, + "NJ": 0, + "NM": 0, + "NY": 0, + "OR": 0, + "RI": 0, + "VT": 0, + "VA": 0, + "WA": 0, + "WI": 0, + "WY": 0, } # Only include states that adopted before or during 2010, or never adopted @@ -342,17 +388,23 @@ def _construct_castle_doctrine_data() -> pd.DataFrame: else: treatment_effect = 0 - homicide = max(0, base_homicide + time_effect + treatment_effect + np.random.normal(0, 0.5)) - - data.append({ - "state": state, - "year": year, - "first_treat": first_treat, - "homicide_rate": round(homicide, 2), - "population": pop + year * 10000 + np.random.randint(-5000, 5000), - "income": round(base_income * (1 + 0.02 * (year - 2000)) + np.random.normal(0, 1000), 0), - "treated": int(first_treat > 0 and year >= first_treat), - }) + homicide = max( + 0, base_homicide + time_effect + treatment_effect + np.random.normal(0, 0.5) + ) + + data.append( + { + "state": state, + "year": year, + "first_treat": first_treat, + "homicide_rate": round(homicide, 2), + "population": pop + year * 10000 + np.random.randint(-5000, 5000), + "income": round( + base_income * (1 + 0.02 * (year - 2000)) + np.random.normal(0, 1000), 0 + ), + "treated": int(first_treat > 0 and year >= first_treat), + } + ) df = pd.DataFrame(data) df["cohort"] = df["first_treat"] @@ -443,7 +495,9 @@ def load_divorce_laws(force_download: bool = False) -> pd.DataFrame: if "unilateral" in df.columns: df["treated"] = df["unilateral"] elif "first_treat" in df.columns: - df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype(int) + df["treated"] = ((df["first_treat"] > 0) & (df["year"] >= df["first_treat"])).astype( + int + ) return df @@ -459,20 +513,53 @@ def _construct_divorce_laws_data() -> pd.DataFrame: # State adoption years for unilateral divorce (from Wolfers 2006) # 0 = never adopted or adopted before 1968 state_adoption = { - "AK": 1935, "AL": 1971, "AZ": 1973, "CA": 1970, "CO": 1972, - "CT": 1973, "DE": 1968, "FL": 1971, "GA": 1973, "HI": 1973, - "IA": 1970, "ID": 1971, "IN": 1973, "KS": 1969, "KY": 1972, - "MA": 1975, "ME": 1973, "MI": 1972, "MN": 1974, "MO": 0, - "MT": 1975, "NC": 0, "ND": 1971, "NE": 1972, "NH": 1971, - "NJ": 0, "NM": 1973, "NV": 1967, "NY": 0, "OH": 0, - "OK": 1975, "OR": 1971, "PA": 0, "RI": 1975, "SD": 1985, - "TN": 0, "TX": 1970, "UT": 1987, "VA": 0, "WA": 1973, - "WI": 1978, "WV": 1984, "WY": 1977, + "AK": 1935, + "AL": 1971, + "AZ": 1973, + "CA": 1970, + "CO": 1972, + "CT": 1973, + "DE": 1968, + "FL": 1971, + "GA": 1973, + "HI": 1973, + "IA": 1970, + "ID": 1971, + "IN": 1973, + "KS": 1969, + "KY": 1972, + "MA": 1975, + "ME": 1973, + "MI": 1972, + "MN": 1974, + "MO": 0, + "MT": 1975, + "NC": 0, + "ND": 1971, + "NE": 1972, + "NH": 1971, + "NJ": 0, + "NM": 1973, + "NV": 1967, + "NY": 0, + "OH": 0, + "OK": 1975, + "OR": 1971, + "PA": 0, + "RI": 1975, + "SD": 1985, + "TN": 0, + "TX": 1970, + "UT": 1987, + "VA": 0, + "WA": 1973, + "WI": 1978, + "WV": 1984, + "WY": 1977, } # Filter to states with adoption dates in our range or never adopted - state_adoption = {k: v for k, v in state_adoption.items() - if v == 0 or (1968 <= v <= 1990)} + state_adoption = {k: v for k, v in state_adoption.items() if v == 0 or (1968 <= v <= 1990)} data = [] for state, first_treat in state_adoption.items(): @@ -507,17 +594,35 @@ def _construct_divorce_laws_data() -> pd.DataFrame: lfp_effect = 0 suicide_effect = 0 - data.append({ - "state": state, - "year": year, - "first_treat": first_treat if first_treat >= 1968 else 0, - "divorce_rate": round(max(0, base_divorce + time_trend + divorce_effect + - np.random.normal(0, 0.3)), 2), - "female_lfp": round(min(1, max(0, base_lfp + 0.01 * (year - 1968) + - lfp_effect + np.random.normal(0, 0.02))), 3), - "suicide_rate": round(max(0, base_suicide + suicide_effect + - np.random.normal(0, 0.5)), 2), - }) + data.append( + { + "state": state, + "year": year, + "first_treat": first_treat if first_treat >= 1968 else 0, + "divorce_rate": round( + max( + 0, base_divorce + time_trend + divorce_effect + np.random.normal(0, 0.3) + ), + 2, + ), + "female_lfp": round( + min( + 1, + max( + 0, + base_lfp + + 0.01 * (year - 1968) + + lfp_effect + + np.random.normal(0, 0.02), + ), + ), + 3, + ), + "suicide_rate": round( + max(0, base_suicide + suicide_effect + np.random.normal(0, 0.5)), 2 + ), + } + ) df = pd.DataFrame(data) df["cohort"] = df["first_treat"] @@ -630,14 +735,16 @@ def _construct_mpdta_data() -> pd.DataFrame: else: te = 0 - data.append({ - "countyreal": county, - "year": year, - "lpop": round(base_lpop + np.random.normal(0, 0.05), 4), - "lemp": round(base_lemp + time_effect + te + np.random.normal(0, 0.02), 4), - "first_treat": first_treat, - "treat": int(first_treat > 0), - }) + data.append( + { + "countyreal": county, + "year": year, + "lpop": round(base_lpop + np.random.normal(0, 0.05), 4), + "lemp": round(base_lemp + time_effect + te + np.random.normal(0, 0.02), 4), + "first_treat": first_treat, + "treat": int(first_treat > 0), + } + ) df = pd.DataFrame(data) df["cohort"] = df["first_treat"] diff --git a/diff_diff/diagnostics.py b/diff_diff/diagnostics.py index e3d79c9..509634c 100644 --- a/diff_diff/diagnostics.py +++ b/diff_diff/diagnostics.py @@ -102,13 +102,15 @@ def summary(self) -> str: ] if self.original_effect is not None: - lines.extend([ - "", - "-" * 65, - "Comparison with Original Estimate".center(65), - "-" * 65, - f"{'Original ATT:':<25} {self.original_effect:>12.4f}", - ]) + lines.extend( + [ + "", + "-" * 65, + "Comparison with Original Estimate".center(65), + "-" * 65, + f"{'Original ATT:':<25} {self.original_effect:>12.4f}", + ] + ) if self.original_se is not None: lines.append(f"{'Original SE:':<25} {self.original_se:>12.4f}") @@ -121,40 +123,36 @@ def summary(self) -> str: if self.leave_one_out_effects is not None: n_units = len(self.leave_one_out_effects) effects = list(self.leave_one_out_effects.values()) - lines.extend([ + lines.extend( + [ + "", + "-" * 65, + "Leave-One-Out Summary".center(65), + "-" * 65, + f"{'Units analyzed:':<25} {n_units:>12}", + f"{'Mean effect:':<25} {np.mean(effects):>12.4f}", + f"{'Std. dev.:':<25} {np.std(effects, ddof=1):>12.4f}", + f"{'Min effect:':<25} {np.min(effects):>12.4f}", + f"{'Max effect:':<25} {np.max(effects):>12.4f}", + ] + ) + + # Interpretation + lines.extend( + [ "", "-" * 65, - "Leave-One-Out Summary".center(65), + "Interpretation".center(65), "-" * 65, - f"{'Units analyzed:':<25} {n_units:>12}", - f"{'Mean effect:':<25} {np.mean(effects):>12.4f}", - f"{'Std. dev.:':<25} {np.std(effects, ddof=1):>12.4f}", - f"{'Min effect:':<25} {np.min(effects):>12.4f}", - f"{'Max effect:':<25} {np.max(effects):>12.4f}", - ]) - - # Interpretation - lines.extend([ - "", - "-" * 65, - "Interpretation".center(65), - "-" * 65, - ]) + ] + ) if self.is_significant: - lines.append( - "WARNING: Significant placebo effect detected (p < 0.05)." - ) - lines.append( - "This suggests potential violations of the parallel trends assumption." - ) + lines.append("WARNING: Significant placebo effect detected (p < 0.05).") + lines.append("This suggests potential violations of the parallel trends assumption.") else: - lines.append( - "No significant placebo effect detected (p >= 0.05)." - ) - lines.append( - "This is consistent with the parallel trends assumption." - ) + lines.append("No significant placebo effect detected (p >= 0.05).") + lines.append("This is consistent with the parallel trends assumption.") lines.append("=" * 65) @@ -205,7 +203,7 @@ def run_placebo_test( n_permutations: int = 1000, alpha: float = 0.05, seed: Optional[int] = None, - **estimator_kwargs + **estimator_kwargs, ) -> PlaceboTestResults: """ Run a placebo test to validate DiD assumptions. @@ -288,9 +286,7 @@ def run_placebo_test( valid_types = ["fake_timing", "fake_group", "permutation", "leave_one_out"] if test_type not in valid_types: - raise ValueError( - f"test_type must be one of {valid_types}, got '{test_type}'" - ) + raise ValueError(f"test_type must be one of {valid_types}, got '{test_type}'") if test_type == "fake_timing": return placebo_timing_test( @@ -301,7 +297,7 @@ def run_placebo_test( fake_treatment_period=fake_treatment_period, post_periods=post_periods, alpha=alpha, - **estimator_kwargs + **estimator_kwargs, ) elif test_type == "fake_group": @@ -317,7 +313,7 @@ def run_placebo_test( fake_treated_units=fake_treatment_group, post_periods=post_periods, alpha=alpha, - **estimator_kwargs + **estimator_kwargs, ) elif test_type == "permutation": @@ -332,7 +328,7 @@ def run_placebo_test( n_permutations=n_permutations, alpha=alpha, seed=seed, - **estimator_kwargs + **estimator_kwargs, ) elif test_type == "leave_one_out": @@ -345,7 +341,7 @@ def run_placebo_test( time=time, unit=unit, alpha=alpha, - **estimator_kwargs + **estimator_kwargs, ) # This should never be reached due to validation above @@ -360,7 +356,7 @@ def placebo_timing_test( fake_treatment_period: Any, post_periods: Optional[List[Any]] = None, alpha: float = 0.05, - **estimator_kwargs + **estimator_kwargs, ) -> PlaceboTestResults: """ Test for pre-treatment effects by moving treatment timing earlier. @@ -417,23 +413,13 @@ def placebo_timing_test( # Fit DiD on pre-treatment data with fake post did = DifferenceInDifferences(**estimator_kwargs) - results = did.fit( - pre_data, - outcome=outcome, - treatment=treatment, - time="_fake_post" - ) + results = did.fit(pre_data, outcome=outcome, treatment=treatment, time="_fake_post") # Also fit on full data for comparison data_with_post = data.copy() data_with_post["_post"] = data_with_post[time].isin(post_periods).astype(int) did_full = DifferenceInDifferences(**estimator_kwargs) - results_full = did_full.fit( - data_with_post, - outcome=outcome, - treatment=treatment, - time="_post" - ) + results_full = did_full.fit(data_with_post, outcome=outcome, treatment=treatment, time="_post") return PlaceboTestResults( test_type="fake_timing", @@ -459,7 +445,7 @@ def placebo_group_test( fake_treated_units: List[Any], post_periods: Optional[List[Any]] = None, alpha: float = 0.05, - **estimator_kwargs + **estimator_kwargs, ) -> PlaceboTestResults: """ Test for differential trends among never-treated units. @@ -509,12 +495,7 @@ def placebo_group_test( # Fit DiD did = DifferenceInDifferences(**estimator_kwargs) - results = did.fit( - fake_data, - outcome=outcome, - treatment="_fake_treated", - time="_post" - ) + results = did.fit(fake_data, outcome=outcome, treatment="_fake_treated", time="_post") return PlaceboTestResults( test_type="fake_group", @@ -539,7 +520,7 @@ def permutation_test( n_permutations: int = 1000, alpha: float = 0.05, seed: Optional[int] = None, - **estimator_kwargs + **estimator_kwargs, ) -> PlaceboTestResults: """ Compute permutation-based p-value for DiD estimate. @@ -583,20 +564,11 @@ def permutation_test( # First, fit original model did = DifferenceInDifferences(**estimator_kwargs) - original_results = did.fit( - data, - outcome=outcome, - treatment=treatment, - time=time - ) + original_results = did.fit(data, outcome=outcome, treatment=treatment, time=time) original_att = original_results.att # Get unit-level treatment assignment - unit_treatment = ( - data.groupby(unit)[treatment] - .first() - .reset_index() - ) + unit_treatment = data.groupby(unit)[treatment].first().reset_index() units = unit_treatment[unit].values n_treated = int(unit_treatment[treatment].sum()) @@ -615,10 +587,7 @@ def permutation_test( try: perm_did = DifferenceInDifferences(**estimator_kwargs) perm_results = perm_did.fit( - perm_data, - outcome=outcome, - treatment="_perm_treatment", - time=time + perm_data, outcome=outcome, treatment="_perm_treatment", time=time ) permuted_effects[i] = perm_results.att except (ValueError, KeyError, np.linalg.LinAlgError): @@ -643,11 +612,12 @@ def permutation_test( failure_rate = n_failed / n_permutations if failure_rate > 0.1: import warnings + warnings.warn( f"{n_failed}/{n_permutations} permutations failed ({failure_rate:.1%}). " f"Results based on {len(valid_effects)} successful permutations.", UserWarning, - stacklevel=2 + stacklevel=2, ) # Compute p-value: proportion of |permuted| >= |original| @@ -688,7 +658,7 @@ def leave_one_out_test( time: str, unit: str, alpha: float = 0.05, - **estimator_kwargs + **estimator_kwargs, ) -> PlaceboTestResults: """ Assess sensitivity by dropping each treated unit in turn. @@ -720,12 +690,7 @@ def leave_one_out_test( """ # Fit original model did = DifferenceInDifferences(**estimator_kwargs) - original_results = did.fit( - data, - outcome=outcome, - treatment=treatment, - time=time - ) + original_results = did.fit(data, outcome=outcome, treatment=treatment, time=time) original_att = original_results.att # Get treated units @@ -744,12 +709,7 @@ def leave_one_out_test( try: loo_did = DifferenceInDifferences(**estimator_kwargs) - loo_results = loo_did.fit( - loo_data, - outcome=outcome, - treatment=treatment, - time=time - ) + loo_results = loo_did.fit(loo_data, outcome=outcome, treatment=treatment, time=time) loo_effects[u] = loo_results.att except (ValueError, KeyError, np.linalg.LinAlgError): # Skip units that cause fitting issues @@ -772,12 +732,13 @@ def leave_one_out_test( # Warn if significant number of LOO iterations failed if n_failed > 0: import warnings + failed_units = [u for u, v in loo_effects.items() if np.isnan(v)] warnings.warn( f"{n_failed}/{n_total} leave-one-out estimates failed for units: {failed_units}. " f"Results based on {len(valid_effects)} successful iterations.", UserWarning, - stacklevel=2 + stacklevel=2, ) # Statistics of LOO distribution @@ -813,7 +774,7 @@ def run_all_placebo_tests( n_permutations: int = 500, alpha: float = 0.05, seed: Optional[int] = None, - **estimator_kwargs + **estimator_kwargs, ) -> Dict[str, Union[PlaceboTestResults, Dict[str, str]]]: """ Run a comprehensive suite of placebo tests. @@ -866,7 +827,7 @@ def run_all_placebo_tests( fake_treatment_period=period, post_periods=post_periods, alpha=alpha, - **estimator_kwargs + **estimator_kwargs, ) results[f"fake_timing_{period}"] = test_result except Exception as e: @@ -875,7 +836,7 @@ def run_all_placebo_tests( "error": str(e), "error_type": type(e).__name__, "test_type": "fake_timing", - "period": period + "period": period, } # Permutation test @@ -889,14 +850,14 @@ def run_all_placebo_tests( n_permutations=n_permutations, alpha=alpha, seed=seed, - **estimator_kwargs + **estimator_kwargs, ) results["permutation"] = perm_result except Exception as e: results["permutation"] = { "error": str(e), "error_type": type(e).__name__, - "test_type": "permutation" + "test_type": "permutation", } # Leave-one-out test @@ -908,14 +869,14 @@ def run_all_placebo_tests( time=time, unit=unit, alpha=alpha, - **estimator_kwargs + **estimator_kwargs, ) results["leave_one_out"] = loo_result except Exception as e: results["leave_one_out"] = { "error": str(e), "error_type": type(e).__name__, - "test_type": "leave_one_out" + "test_type": "leave_one_out", } return results diff --git a/diff_diff/efficient_did.py b/diff_diff/efficient_did.py index 0c5c6d4..c3a341c 100644 --- a/diff_diff/efficient_did.py +++ b/diff_diff/efficient_did.py @@ -348,7 +348,7 @@ def fit( Missing columns, unbalanced panel, non-absorbing treatment, or PT-Post without a never-treated group. NotImplementedError - If ``n_bootstrap > 0`` with ``survey_design``. + If ``covariates`` and ``survey_design`` are both set. """ self._validate_params() @@ -374,13 +374,7 @@ def fit( # Store survey df for safe_inference calls (t-distribution with survey df) self._survey_df = survey_metadata.df_survey if survey_metadata is not None else None - # Guard bootstrap + survey - if self.n_bootstrap > 0 and resolved_survey is not None: - raise NotImplementedError( - "Multiplier bootstrap with survey weights is not yet supported " - "for EfficientDiD. Use analytical inference (n_bootstrap=0) with " - "survey_design for design-based standard errors." - ) + # Bootstrap + survey supported via PSU-level multiplier bootstrap. # Guard covariates + survey (DR path does not yet thread survey weights) if covariates is not None and len(covariates) > 0 and resolved_survey is not None: @@ -969,6 +963,7 @@ def fit( cohort_fractions=cohort_fractions, cluster_indices=unit_cluster_indices, n_clusters=n_clusters, + resolved_survey=self._unit_resolved_survey, ) # Update estimates with bootstrap inference overall_se = bootstrap_results.overall_att_se diff --git a/diff_diff/efficient_did_bootstrap.py b/diff_diff/efficient_did_bootstrap.py index 613f6f0..5375af4 100644 --- a/diff_diff/efficient_did_bootstrap.py +++ b/diff_diff/efficient_did_bootstrap.py @@ -62,6 +62,7 @@ def _run_multiplier_bootstrap( cohort_fractions: Dict[float, float], cluster_indices: Optional[np.ndarray] = None, n_clusters: Optional[int] = None, + resolved_survey: object = None, ) -> EDiDBootstrapResults: """Run multiplier bootstrap on stored EIF values. @@ -95,8 +96,32 @@ def _run_multiplier_bootstrap( gt_pairs = list(group_time_effects.keys()) n_gt = len(gt_pairs) - # Generate bootstrap weights — at cluster level if clustered - if cluster_indices is not None and n_clusters is not None: + # Generate bootstrap weights — PSU-level when survey design is present, + # cluster-level if clustered, unit-level otherwise. + _use_survey_bootstrap = resolved_survey is not None and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + + if _use_survey_bootstrap: + from diff_diff.bootstrap_utils import ( + generate_survey_multiplier_weights_batch as _gen_survey_weights, + ) + + psu_weights, psu_ids = _gen_survey_weights( + self.n_bootstrap, resolved_survey, self.bootstrap_weights, rng + ) + # Build unit -> PSU column map + if resolved_survey.psu is not None: + psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)} + unit_to_psu_col = np.array( + [psu_id_to_col[int(resolved_survey.psu[i])] for i in range(n_units)] + ) + else: + unit_to_psu_col = np.arange(n_units) + all_weights = psu_weights[:, unit_to_psu_col] + elif cluster_indices is not None and n_clusters is not None: cluster_weights = _generate_bootstrap_weights_batch( self.n_bootstrap, n_clusters, self.bootstrap_weights, rng ) diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 122efe8..88ae967 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -273,13 +273,10 @@ def fit( # FWL theorem: demean ALL regressors alongside outcome. # Regressors collinear with absorbed FE (e.g., treatment after # absorbing unit FE) will zero out and be handled by rank-deficiency. - working_data["_treat_time"] = ( - working_data[treatment].values.astype(float) - * working_data[time].values.astype(float) - ) - vars_to_demean = ( - [outcome, treatment, time, "_treat_time"] + (covariates or []) - ) + working_data["_treat_time"] = working_data[treatment].values.astype( + float + ) * working_data[time].values.astype(float) + vars_to_demean = [outcome, treatment, time, "_treat_time"] + (covariates or []) for ab_var in absorb: working_data, n_fe = demean_by_group( working_data, @@ -342,9 +339,14 @@ def fit( # Inject cluster as effective PSU for survey variance estimation if resolved_survey is not None and effective_cluster_ids is not None: from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) if resolved_survey.psu is not None and survey_metadata is not None: - raw_w = data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64) + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) survey_metadata = compute_survey_metadata(resolved_survey, raw_w) reg = LinearRegression( @@ -1045,12 +1047,8 @@ def fit( # type: ignore[override] t_raw = working_data[time].values working_data["_did_treatment"] = d_raw for period in non_ref_periods: - working_data[f"_did_period_{period}"] = ( - t_raw == period - ).astype(float) - working_data[f"_did_interact_{period}"] = ( - d_raw * (t_raw == period).astype(float) - ) + working_data[f"_did_period_{period}"] = (t_raw == period).astype(float) + working_data[f"_did_interact_{period}"] = d_raw * (t_raw == period).astype(float) vars_to_demean = ( [outcome, "_did_treatment"] + [f"_did_period_{p}" for p in non_ref_periods] @@ -1085,9 +1083,7 @@ def fit( # type: ignore[override] for period in non_ref_periods: if absorb: - period_dummy = working_data[ - f"_did_period_{period}" - ].values.astype(float) + period_dummy = working_data[f"_did_period_{period}"].values.astype(float) else: period_dummy = (t == period).astype(float) X = np.column_stack([X, period_dummy]) @@ -1101,9 +1097,7 @@ def fit( # type: ignore[override] for period in non_ref_periods: if absorb: - interaction = working_data[ - f"_did_interact_{period}" - ].values.astype(float) + interaction = working_data[f"_did_interact_{period}"].values.astype(float) else: interaction = d * (t == period).astype(float) X = np.column_stack([X, interaction]) @@ -1137,9 +1131,14 @@ def fit( # type: ignore[override] # Inject cluster as effective PSU for survey variance estimation if resolved_survey is not None and effective_cluster_ids is not None: from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) if resolved_survey.psu is not None and survey_metadata is not None: - raw_w = data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64) + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) survey_metadata = compute_survey_metadata(resolved_survey, raw_w) # Determine if survey vcov should be used @@ -1167,9 +1166,7 @@ def fit( # type: ignore[override] if np.any(nan_mask): kept_cols = np.where(~nan_mask)[0] if len(kept_cols) > 0: - vcov_reduced = compute_survey_vcov( - X[:, kept_cols], residuals, resolved_survey - ) + vcov_reduced = compute_survey_vcov(X[:, kept_cols], residuals, resolved_survey) vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols) else: vcov = np.full((X.shape[1], X.shape[1]), np.nan) diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index 7f0b5fa..b79ba2d 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -260,12 +260,7 @@ def fit( "and PSU (for cluster-robust variance) are supported." ) - # Guard bootstrap + survey - if self.n_bootstrap > 0 and resolved_survey is not None: - raise NotImplementedError( - "Bootstrap inference with survey weights is not yet supported " - "for ImputationDiD. Use analytical inference (n_bootstrap=0)." - ) + # Bootstrap + survey supported via PSU-level multiplier bootstrap. # Ensure numeric types df[time] = pd.to_numeric(df[time]) @@ -593,6 +588,10 @@ def fit( psi_data = None if self.n_bootstrap > 0 and n_valid > 0: try: + # Extract survey weights for untreated obs (same as analytical path) + _sw_0 = survey_weights[omega_0_mask.values] if survey_weights is not None else None + # Extract survey weights for treated obs (event-study/group bootstrap paths) + _sw_1 = survey_weights[omega_1_mask.values] if survey_weights is not None else None psi_data = self._precompute_bootstrap_psi( df=df, outcome=outcome, @@ -614,6 +613,8 @@ def fit( treatment_groups=treatment_groups, tau_hat=tau_hat, balance_e=balance_e, + survey_weights_0=_sw_0, + survey_weights_1=_sw_1, ) except Exception as e: warnings.warn( @@ -631,6 +632,7 @@ def fit( original_event_study=event_study_effects, original_group=group_effects, psi_data=psi_data, + resolved_survey=resolved_survey, ) # Update inference with bootstrap results diff --git a/diff_diff/imputation_bootstrap.py b/diff_diff/imputation_bootstrap.py index 3c14316..f3f1a04 100644 --- a/diff_diff/imputation_bootstrap.py +++ b/diff_diff/imputation_bootstrap.py @@ -17,6 +17,9 @@ from diff_diff.bootstrap_utils import ( generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch, ) +from diff_diff.bootstrap_utils import ( + generate_survey_multiplier_weights_batch as _generate_survey_multiplier_weights_batch, +) from diff_diff.imputation_results import ImputationBootstrapResults __all__ = [ @@ -87,6 +90,7 @@ def _compute_cluster_psi_sums( weights: np.ndarray, cluster_var: str, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights_0: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: ... @staticmethod @@ -126,6 +130,8 @@ def _precompute_bootstrap_psi( treatment_groups: List[Any], tau_hat: np.ndarray, balance_e: Optional[int], + survey_weights_0: Optional[np.ndarray] = None, + survey_weights_1: Optional[np.ndarray] = None, ) -> Dict[str, Any]: """ Pre-compute cluster-level influence function sums for each bootstrap target. @@ -155,6 +161,7 @@ def _precompute_bootstrap_psi( delta_hat=delta_hat, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights_0=survey_weights_0, ) # Overall ATT @@ -189,9 +196,28 @@ def _precompute_bootstrap_psi( h_mask = rel_times == h if balanced_mask is not None: h_mask = h_mask & balanced_mask - weights_h, n_valid_h = _compute_target_weights(tau_hat, h_mask) - if n_valid_h == 0: - continue + + # When survey weights are provided, build weights proportional + # to treated-observation survey weights (matching the analytical + # path in _aggregate_event_study). Otherwise use equal weights. + if survey_weights_1 is not None: + finite_target = np.isfinite(tau_hat) & h_mask + n_valid_h = int(finite_target.sum()) + if n_valid_h == 0: + continue + treated_sw = survey_weights_1 + sw_h = treated_sw[h_mask] + finite_in_h = np.isfinite(tau_hat[h_mask]) + sw_finite = sw_h[finite_in_h] + weights_h = np.zeros(len(tau_hat)) + if sw_finite.sum() > 0: + h_indices = np.where(h_mask)[0] + finite_indices = h_indices[finite_in_h] + weights_h[finite_indices] = sw_finite / sw_finite.sum() + else: + weights_h, n_valid_h = _compute_target_weights(tau_hat, h_mask) + if n_valid_h == 0: + continue psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h) result["event_study"][h] = psi_h @@ -208,9 +234,28 @@ def _precompute_bootstrap_psi( if not np.isfinite(group_effects[g].get("effect", np.nan)): continue g_mask = cohorts == g - weights_g, n_valid_g = _compute_target_weights(tau_hat, g_mask) - if n_valid_g == 0: - continue + + # When survey weights are provided, build weights proportional + # to treated-observation survey weights (matching the analytical + # path in _aggregate_group). Otherwise use equal weights. + if survey_weights_1 is not None: + finite_target = np.isfinite(tau_hat) & g_mask + n_valid_g = int(finite_target.sum()) + if n_valid_g == 0: + continue + treated_sw = survey_weights_1 + sw_g = treated_sw[g_mask] + finite_in_g = np.isfinite(tau_hat[g_mask]) + sw_finite = sw_g[finite_in_g] + weights_g = np.zeros(len(tau_hat)) + if sw_finite.sum() > 0: + g_indices = np.where(g_mask)[0] + finite_indices = g_indices[finite_in_g] + weights_g[finite_indices] = sw_finite / sw_finite.sum() + else: + weights_g, n_valid_g = _compute_target_weights(tau_hat, g_mask) + if n_valid_g == 0: + continue psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g) result["group"][g] = psi_g @@ -223,6 +268,7 @@ def _run_bootstrap( original_event_study: Optional[Dict[int, Dict[str, Any]]], original_group: Optional[Dict[Any, Dict[str, Any]]], psi_data: Dict[str, Any], + resolved_survey: Optional[Any] = None, ) -> ImputationBootstrapResults: """ Run multiplier bootstrap on pre-computed influence function sums. @@ -231,6 +277,11 @@ def _run_bootstrap( (rademacher/mammen/webb; configurable via ``bootstrap_weights``) and psi_i are cluster-level influence function sums from Theorem 3. SE = std(T_b, ddof=1). + + When ``resolved_survey`` carries PSU/strata/FPC structure, weights are + generated via ``generate_survey_multiplier_weights_batch`` so the + bootstrap variance respects the survey design (stratification and FPC + scaling). """ if self.n_bootstrap < 50: warnings.warn( @@ -245,11 +296,30 @@ def _run_bootstrap( overall_psi, cluster_ids = psi_data["overall"] n_clusters = len(cluster_ids) - # Generate ALL weights upfront: shape (n_bootstrap, n_clusters) - all_weights = _generate_bootstrap_weights_batch( - self.n_bootstrap, n_clusters, self.bootstrap_weights, rng + # Determine whether to use survey-aware bootstrap weights + _use_survey_bootstrap = resolved_survey is not None and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None ) + # Generate ALL weights upfront: shape (n_bootstrap, n_clusters) + if _use_survey_bootstrap: + psu_weights, psu_ids = _generate_survey_multiplier_weights_batch( + self.n_bootstrap, resolved_survey, self.bootstrap_weights, rng + ) + # Reindex PSU weights to match cluster_ids ordering. + # cluster_ids are unique PSU values from _compute_cluster_psi_sums; + # psu_ids are unique PSU values from the survey weight generator. + # Build a map from psu_id -> column index in psu_weights. + psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)} + cluster_to_psu_col = np.array([psu_id_to_col[int(cid)] for cid in cluster_ids]) + all_weights = psu_weights[:, cluster_to_psu_col] + else: + all_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_clusters, self.bootstrap_weights, rng + ) + # Overall ATT bootstrap draws boot_overall = np.dot(all_weights, overall_psi) # (n_bootstrap,) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 7ff1eb8..43d47b5 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -38,7 +38,7 @@ def make_treatment_indicator( treated_values: Optional[Union[Any, List[Any]]] = None, threshold: Optional[float] = None, above_threshold: bool = True, - new_column: str = "treated" + new_column: str = "treated", ) -> pd.DataFrame: """ Create a binary treatment indicator column from various input types. @@ -119,7 +119,7 @@ def make_post_indicator( time_column: str, post_periods: Optional[Union[Any, List[Any]]] = None, treatment_start: Optional[Any] = None, - new_column: str = "post" + new_column: str = "post", ) -> pd.DataFrame: """ Create a binary post-treatment indicator column. @@ -196,7 +196,7 @@ def wide_to_long( id_column: str, time_name: str = "period", value_name: str = "value", - time_values: Optional[List[Any]] = None + time_values: Optional[List[Any]] = None, ) -> pd.DataFrame: """ Convert wide-format panel data to long format for DiD analysis. @@ -274,14 +274,14 @@ def wide_to_long( data, id_vars=[id_column] + other_cols, value_vars=value_columns, - var_name='_temp_var', - value_name=value_name + var_name="_temp_var", + value_name=value_name, ) # Map column names to time values col_to_time = dict(zip(value_columns, time_values)) - long_df[time_name] = long_df['_temp_var'].map(col_to_time) - long_df = long_df.drop('_temp_var', axis=1) + long_df[time_name] = long_df["_temp_var"].map(col_to_time) + long_df = long_df.drop("_temp_var", axis=1) # Reorder columns and sort cols = [id_column, time_name, value_name] + other_cols @@ -293,7 +293,7 @@ def balance_panel( unit_column: str, time_column: str, method: str = "inner", - fill_value: Optional[float] = None + fill_value: Optional[float] = None, ) -> pd.DataFrame: """ Balance a panel dataset to ensure all units have all time periods. @@ -360,8 +360,7 @@ def balance_panel( elif method in ["outer", "fill"]: # Create full grid of unit-period combinations full_index = pd.MultiIndex.from_product( - [all_units, all_periods], - names=[unit_column, time_column] + [all_units, all_periods], names=[unit_column, time_column] ) full_df = pd.DataFrame(index=full_index).reset_index() @@ -396,7 +395,7 @@ def validate_did_data( treatment: str, time: str, unit: Optional[str] = None, - raise_on_error: bool = True + raise_on_error: bool = True, ) -> Dict[str, Any]: """ Validate that data is properly formatted for DiD analysis. @@ -459,8 +458,7 @@ def validate_did_data( # Check outcome is numeric if not pd.api.types.is_numeric_dtype(data[outcome]): errors.append( - f"Outcome column '{outcome}' must be numeric. " - f"Got type: {data[outcome].dtype}" + f"Outcome column '{outcome}' must be numeric. " f"Got type: {data[outcome].dtype}" ) # Check treatment is binary @@ -553,20 +551,11 @@ def validate_did_data( if raise_on_error and not valid: raise ValueError("Data validation failed:\n" + "\n".join(errors)) - return { - "valid": valid, - "errors": errors, - "warnings": warnings, - "summary": summary - } + return {"valid": valid, "errors": errors, "warnings": warnings, "summary": summary} def summarize_did_data( - data: pd.DataFrame, - outcome: str, - treatment: str, - time: str, - unit: Optional[str] = None + data: pd.DataFrame, outcome: str, treatment: str, time: str, unit: Optional[str] = None ) -> pd.DataFrame: """ Generate summary statistics by treatment group and time period. @@ -600,13 +589,11 @@ def summarize_did_data( >>> print(summary) """ # Group by treatment and time - summary = data.groupby([treatment, time])[outcome].agg([ - ("n", "count"), - ("mean", "mean"), - ("std", "std"), - ("min", "min"), - ("max", "max") - ]).round(4) + summary = ( + data.groupby([treatment, time])[outcome] + .agg([("n", "count"), ("mean", "mean"), ("std", "std"), ("min", "min"), ("max", "max")]) + .round(4) + ) # Calculate time values for labeling time_vals = sorted(data[time].unique()) @@ -616,8 +603,8 @@ def summarize_did_data( pre_val, post_val = time_vals[0], time_vals[1] def format_label(x: tuple) -> str: - treatment_label = 'Treated' if x[0] == 1 else 'Control' - time_label = 'Post' if x[1] == post_val else 'Pre' + treatment_label = "Treated" if x[0] == 1 else "Control" + time_label = "Post" if x[1] == post_val else "Pre" return f"{treatment_label} - {time_label}" summary.index = summary.index.map(format_label) @@ -635,14 +622,8 @@ def format_label(x: tuple) -> str: # Add to summary as a new row did_row = pd.DataFrame( - { - "n": ["-"], - "mean": [did_estimate], - "std": ["-"], - "min": ["-"], - "max": ["-"] - }, - index=["DiD Estimate"] + {"n": ["-"], "mean": [did_estimate], "std": ["-"], "min": ["-"], "max": ["-"]}, + index=["DiD Estimate"], ) summary = pd.concat([summary, did_row]) else: @@ -654,10 +635,7 @@ def format_label(x: tuple) -> str: def create_event_time( - data: pd.DataFrame, - time_column: str, - treatment_time_column: str, - new_column: str = "event_time" + data: pd.DataFrame, time_column: str, treatment_time_column: str, new_column: str = "event_time" ) -> pd.DataFrame: """ Create an event-time column relative to treatment timing. @@ -724,7 +702,7 @@ def aggregate_to_cohorts( time_column: str, treatment_column: str, outcome: str, - covariates: Optional[List[str]] = None + covariates: Optional[List[str]] = None, ) -> pd.DataFrame: """ Aggregate unit-level data to treatment cohort means. @@ -772,10 +750,7 @@ def aggregate_to_cohorts( cohort_data = data.groupby([treatment_column, time_column]).agg(agg_cols).reset_index() # Rename columns - cohort_data = cohort_data.rename(columns={ - unit_column: "n_units", - outcome: f"mean_{outcome}" - }) + cohort_data = cohort_data.rename(columns={unit_column: "n_units", outcome: f"mean_{outcome}"}) return cohort_data @@ -960,8 +935,7 @@ def rank_control_units( # ------------------------------------------------------------------------- if suggest_treatment_candidates and len(treated_set) == 0: return _suggest_treatment_candidates( - data, unit_column, time_column, outcome_column, - pre_periods, n_treatment_candidates + data, unit_column, time_column, outcome_column, pre_periods, n_treatment_candidates ) if len(treated_set) == 0: @@ -1005,9 +979,7 @@ def rank_control_units( # Compute outcome trend scores # ------------------------------------------------------------------------- # Synthetic weights (higher = better match) - synthetic_weights = compute_synthetic_weights( - Y_control, Y_treated_mean, lambda_reg=lambda_reg - ) + synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg) # RMSE for each control vs treated mean (use nanmean to handle missing data) rmse_scores = [] @@ -1095,31 +1067,33 @@ def rank_control_units( # ------------------------------------------------------------------------- require_set = set(require_units) if require_units else set() - result = pd.DataFrame({ - 'unit': control_candidates, - 'quality_score': quality_scores, - 'outcome_trend_score': outcome_trend_scores, - 'covariate_score': covariate_scores, - 'synthetic_weight': synthetic_weights, - 'pre_trend_rmse': rmse_scores, - 'is_required': [u in require_set for u in control_candidates] - }) + result = pd.DataFrame( + { + "unit": control_candidates, + "quality_score": quality_scores, + "outcome_trend_score": outcome_trend_scores, + "covariate_score": covariate_scores, + "synthetic_weight": synthetic_weights, + "pre_trend_rmse": rmse_scores, + "is_required": [u in require_set for u in control_candidates], + } + ) # Sort by quality score (descending) - result = result.sort_values('quality_score', ascending=False) + result = result.sort_values("quality_score", ascending=False) # Apply n_top limit if specified if n_top is not None and n_top < len(result): # Always include required units - required_df = result[result['is_required']] - non_required_df = result[~result['is_required']] + required_df = result[result["is_required"]] + non_required_df = result[~result["is_required"]] # Take top from non-required to fill remaining slots remaining_slots = max(0, n_top - len(required_df)) top_non_required = non_required_df.head(remaining_slots) result = pd.concat([required_df, top_non_required]) - result = result.sort_values('quality_score', ascending=False) + result = result.sort_values("quality_score", ascending=False) return result.reset_index(drop=True) @@ -1130,7 +1104,7 @@ def _suggest_treatment_candidates( time_column: str, outcome_column: str, pre_periods: List[Any], - n_candidates: int + n_candidates: int, ) -> pd.DataFrame: """ Identify units that would make good treatment candidates. @@ -1188,55 +1162,64 @@ def _suggest_treatment_candidates( # Count similar potential controls other_units = [u for u in all_units if u != unit] - other_means = pre_data[ - pre_data[unit_column].isin(other_units) - ].groupby(unit_column)[outcome_column].mean() + other_means = ( + pre_data[pre_data[unit_column].isin(other_units)] + .groupby(unit_column)[outcome_column] + .mean() + ) if len(other_means) > 0: sd = other_means.std() if sd > 0: - n_similar = int(np.sum( - np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd - )) + n_similar = int( + np.sum(np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd) + ) else: n_similar = len(other_means) else: n_similar = 0 - candidate_info.append({ - 'unit': unit, - 'avg_outcome_level': avg_outcome, - 'outcome_trend': slope, - 'n_similar_controls': n_similar - }) + candidate_info.append( + { + "unit": unit, + "avg_outcome_level": avg_outcome, + "outcome_trend": slope, + "n_similar_controls": n_similar, + } + ) if len(candidate_info) == 0: - return pd.DataFrame(columns=[ - 'unit', 'treatment_candidate_score', 'avg_outcome_level', - 'outcome_trend', 'n_similar_controls' - ]) + return pd.DataFrame( + columns=[ + "unit", + "treatment_candidate_score", + "avg_outcome_level", + "outcome_trend", + "n_similar_controls", + ] + ) result = pd.DataFrame(candidate_info) # Score: prefer units with many similar controls and moderate outcome levels - max_similar = result['n_similar_controls'].max() + max_similar = result["n_similar_controls"].max() if max_similar > 0: - similarity_score = result['n_similar_controls'] / max_similar + similarity_score = result["n_similar_controls"] / max_similar else: similarity_score = pd.Series([0.0] * len(result)) # Penalty for outliers in outcome level - outcome_mean = result['avg_outcome_level'].mean() - outcome_std = result['avg_outcome_level'].std() + outcome_mean = result["avg_outcome_level"].mean() + outcome_std = result["avg_outcome_level"].std() if outcome_std > 0: - outcome_z = np.abs((result['avg_outcome_level'] - outcome_mean) / outcome_std) + outcome_z = np.abs((result["avg_outcome_level"] - outcome_mean) / outcome_std) else: outcome_z = pd.Series([0.0] * len(result)) - result['treatment_candidate_score'] = ( + result["treatment_candidate_score"] = ( similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z ).clip(0, 1) # Return top candidates - result = result.nlargest(n_candidates, 'treatment_candidate_score') + result = result.nlargest(n_candidates, "treatment_candidate_score") return result.reset_index(drop=True) diff --git a/diff_diff/prep_dgp.py b/diff_diff/prep_dgp.py index ed00d9b..2aab32c 100644 --- a/diff_diff/prep_dgp.py +++ b/diff_diff/prep_dgp.py @@ -21,7 +21,7 @@ def generate_did_data( unit_fe_sd: float = 2.0, time_trend: float = 0.5, noise_sd: float = 1.0, - seed: Optional[int] = None + seed: Optional[int] = None, ) -> pd.DataFrame: """ Generate synthetic data for DiD analysis with known treatment effect. @@ -110,14 +110,16 @@ def generate_did_data( # Add noise y += rng.normal(0, noise_sd) - records.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": int(is_post), - "outcome": y, - "true_effect": effect - }) + records.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": int(is_post), + "outcome": y, + "true_effect": effect, + } + ) return pd.DataFrame(records) @@ -211,9 +213,7 @@ def generate_staggered_data( # Validate cohort periods for cp in cohort_periods: if cp < 1 or cp >= n_periods: - raise ValueError( - f"Cohort period {cp} must be between 1 and {n_periods - 1}" - ) + raise ValueError(f"Cohort period {cp} must be between 1 and {n_periods - 1}") # Determine number of never-treated and treated units n_never = int(n_units * never_treated_frac) @@ -254,15 +254,17 @@ def generate_staggered_data( # Add noise y += rng.normal(0, noise_sd) - records.append({ - "unit": unit, - "period": period, - "outcome": y, - "first_treat": unit_first_treat, - "treated": int(is_treated), - "treat": int(is_ever_treated), - "true_effect": effect, - }) + records.append( + { + "unit": unit, + "period": period, + "outcome": y, + "first_treat": unit_first_treat, + "treated": int(is_treated), + "treat": int(is_ever_treated), + "true_effect": effect, + } + ) return pd.DataFrame(records) @@ -395,14 +397,16 @@ def generate_factor_data( # Add noise y += rng.normal(0, noise_sd) - records.append({ - "unit": i, - "period": t, - "outcome": y, - "treated": int(is_ever_treated and post), - "treat": int(is_ever_treated), - "true_effect": effect, - }) + records.append( + { + "unit": i, + "period": t, + "outcome": y, + "treated": int(is_ever_treated and post), + "treat": int(is_ever_treated), + "true_effect": effect, + } + ) return pd.DataFrame(records) @@ -500,9 +504,9 @@ def generate_ddd_data( y = 50 + group_effect * g + partition_effect * p + time_effect * t # Second-order interactions (non-treatment) - y += 1.5 * g * p # group-partition interaction - y += 1.0 * g * t # group-time interaction (diff trends) - y += 0.5 * p * t # partition-time interaction + y += 1.5 * g * p # group-partition interaction + y += 1.0 * g * t # group-time interaction (diff trends) + y += 0.5 * p * t # partition-time interaction # Treatment effect: ONLY for G=1, P=1, T=1 effect = 0.0 @@ -653,14 +657,16 @@ def generate_panel_data( # Add noise y += rng.normal(0, noise_sd) - records.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": int(post), - "outcome": y, - "true_effect": effect, - }) + records.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": int(post), + "outcome": y, + "true_effect": effect, + } + ) return pd.DataFrame(records) @@ -764,15 +770,17 @@ def generate_event_study_data( # Add noise y += rng.normal(0, noise_sd) - records.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": int(post), - "outcome": y, - "event_time": event_time, - "true_effect": effect, - }) + records.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": int(post), + "outcome": y, + "event_time": event_time, + "true_effect": effect, + } + ) return pd.DataFrame(records) @@ -850,7 +858,7 @@ def generate_continuous_did_data( idx = 0 for i, g in enumerate(cohort_periods): n_this = n_per_cohort if i < len(cohort_periods) - 1 else n_treated_total - idx - cohort_assignments[n_never + idx: n_never + idx + n_this] = g + cohort_assignments[n_never + idx : n_never + idx + n_this] = g idx += n_this # Generate doses @@ -898,8 +906,7 @@ def _att_func(d): return att_intercept + att_slope * np.log1p(d) else: raise ValueError( - f"att_function must be 'linear', 'quadratic', or 'log', " - f"got '{att_function}'" + f"att_function must be 'linear', 'quadratic', or 'log', " f"got '{att_function}'" ) # Unit fixed effects @@ -920,13 +927,15 @@ def _att_func(d): else: att_d = 0.0 - records.append({ - "unit": i, - "period": int(t), - "outcome": y0 + att_d, - "first_treat": int(g_i) if g_i > 0 else 0, - "dose": d_i, - "true_att": att_d, - }) + records.append( + { + "unit": i, + "period": int(t), + "outcome": y0 + att_d, + "first_treat": int(g_i) if g_i > 0 else 0, + "dose": d_i, + "true_att": att_d, + } + ) return pd.DataFrame(records) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 251ac9b..8d89a63 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -328,6 +328,68 @@ def __init__( self.is_fitted_ = False self.results_: Optional[CallawaySantAnnaResults] = None + @staticmethod + def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units): + """Create unit-level ResolvedSurveyDesign for panel IF-based variance. + + Survey design columns are constant within units (validated upstream). + This extracts one row per unit, aligned to ``all_units`` ordering. + """ + from diff_diff.survey import ResolvedSurveyDesign + + n_units = len(all_units) + # Use groupby().first() to get one value per unit, then reindex + unit_groups = df.groupby(unit_col) + + weights_unit = ( + pd.Series(resolved_survey.weights, index=df.index) + .groupby(df[unit_col]) + .first() + .reindex(all_units) + .values + ) + + strata_unit = None + if resolved_survey.strata is not None: + strata_unit = ( + pd.Series(resolved_survey.strata, index=df.index) + .groupby(df[unit_col]) + .first() + .reindex(all_units) + .values + ) + + psu_unit = None + if resolved_survey.psu is not None: + psu_unit = ( + pd.Series(resolved_survey.psu, index=df.index) + .groupby(df[unit_col]) + .first() + .reindex(all_units) + .values + ) + + fpc_unit = None + if resolved_survey.fpc is not None: + fpc_unit = ( + pd.Series(resolved_survey.fpc, index=df.index) + .groupby(df[unit_col]) + .first() + .reindex(all_units) + .values + ) + + return ResolvedSurveyDesign( + weights=weights_unit.astype(np.float64), + weight_type=resolved_survey.weight_type, + strata=strata_unit, + psu=psu_unit, + fpc=fpc_unit, + n_strata=resolved_survey.n_strata, + n_psu=resolved_survey.n_psu, + lonely_psu=resolved_survey.lonely_psu, + ) + def _precompute_structures( self, df: pd.DataFrame, @@ -399,6 +461,12 @@ def _precompute_structures( else: survey_weights_arr = None + resolved_survey_unit = ( + self._collapse_survey_to_unit_level(resolved_survey, df, unit, all_units) + if resolved_survey is not None + else None + ) + return { "all_units": all_units, "unit_to_idx": unit_to_idx, @@ -411,7 +479,11 @@ def _precompute_structures( "time_periods": time_periods, "is_balanced": is_balanced, "survey_weights": survey_weights_arr, - "df_survey": (len(all_units) - 1) if resolved_survey is not None else None, + "resolved_survey": resolved_survey, + "resolved_survey_unit": resolved_survey_unit, + "df_survey": ( + resolved_survey_unit.df_survey if resolved_survey_unit is not None else None + ), } def _compute_att_gt_fast( @@ -1155,10 +1227,10 @@ def fit( For event study, balance the panel at relative time e. Ensures all groups contribute to each relative period. survey_design : SurveyDesign, optional - Survey design specification. Only weights-only designs are supported - (strata/PSU/FPC raise NotImplementedError). Supports pweight only. - Covariates + IPW/DR + survey also raises NotImplementedError. - Use analytical inference (n_bootstrap=0) with survey_design. + Survey design specification. Supports pweight with strata/PSU/FPC. + Aggregated SEs (overall, event study, group) use design-based + variance via compute_survey_if_variance(). + Covariates + IPW/DR + survey raises NotImplementedError. Returns ------- @@ -1197,25 +1269,10 @@ def fit( f"got '{resolved_survey.weight_type}'. The survey variance math " f"assumes probability weights (pweight)." ) - if ( - resolved_survey.strata is not None - or resolved_survey.psu is not None - or resolved_survey.fpc is not None - ): - raise NotImplementedError( - "CallawaySantAnna does not yet support strata/PSU/FPC in " - "SurveyDesign. Per-cell and aggregation SEs use IF-based " - "variance which does not incorporate the full survey design " - "structure. Use SurveyDesign(weights=...) only. Full " - "design-based SEs via compute_survey_vcov() are planned." - ) + # Note: strata/PSU/FPC are now supported — aggregated SEs use + # compute_survey_if_variance() for design-based inference. - # Guard bootstrap + survey - if self.n_bootstrap > 0 and resolved_survey is not None: - raise NotImplementedError( - "Bootstrap inference with survey weights is not yet supported " - "for CallawaySantAnna. Use analytical inference (n_bootstrap=0)." - ) + # Bootstrap + survey is now supported via PSU-level multiplier bootstrap. # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet # implemented to match DRDID panel formula) @@ -1288,26 +1345,6 @@ def fit( # rejected above). We do NOT inject cluster-as-PSU here because CS # per-cell SEs use IF-based variance, not TSL. The user's cluster= # parameter is handled by the existing non-survey clustering path. - if resolved_survey is not None and survey_metadata is not None: - # Just recompute metadata with the resolved design (no PSU injection) - if survey_design.weights: - from diff_diff.survey import compute_survey_metadata - - raw_w = data[survey_design.weights].values.astype(np.float64) - survey_metadata = compute_survey_metadata(resolved_survey, raw_w) - # Override df_survey with unit-level df (CS is unit-level, not obs-level) - n_units_for_df = len(data[unit].unique()) - survey_metadata = type(survey_metadata)( - weight_type=survey_metadata.weight_type, - effective_n=survey_metadata.effective_n, - design_effect=survey_metadata.design_effect, - sum_weights=survey_metadata.sum_weights, - n_strata=survey_metadata.n_strata, - n_psu=n_units_for_df, # unit-level for CS - weight_range=survey_metadata.weight_range, - df_survey=n_units_for_df - 1, - ) - # Pre-compute data structures for efficient ATT(g,t) computation precomputed = self._precompute_structures( df, @@ -1321,15 +1358,20 @@ def fit( resolved_survey=resolved_survey, ) - # Survey df for safe_inference calls. - # CS operates at unit level, so use n_units - 1 (not n_obs - 1 from - # the long panel). For weights-only designs (no strata/PSU), this is - # the correct unit-level degrees of freedom. - if resolved_survey is not None: - n_all_units = len(precomputed["all_units"]) - df_survey = n_all_units - 1 - else: - df_survey = None + # Recompute survey metadata from the unit-level resolved survey so + # that n_psu and df_survey reflect the actual survey design (explicit + # PSU/strata) rather than hard-coding n_units. + if resolved_survey is not None and survey_metadata is not None: + resolved_survey_unit = precomputed.get("resolved_survey_unit") + if resolved_survey_unit is not None: + from diff_diff.survey import compute_survey_metadata + + unit_w = resolved_survey_unit.weights + survey_metadata = compute_survey_metadata(resolved_survey_unit, unit_w) + + # Survey df for safe_inference calls — use the unit-level resolved + # survey df computed in _precompute_structures for consistency. + df_survey = precomputed.get("df_survey") # Compute ATT(g,t) for each group-time combination min_period = min(time_periods) diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 9365ff4..246b9e7 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -427,6 +427,11 @@ def _compute_aggregated_se_with_wif( This matches R's `did` package approach for aggregation, which accounts for uncertainty in estimating group-size weights. + When a full survey design (strata/PSU/FPC) is available in + ``precomputed['resolved_survey']``, the design-based variance + :func:`compute_survey_if_variance` is used instead of the simple + ``sum(psi^2)`` formula. + Formula (matching R's did::aggte): agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k se = sqrt(mean(agg_inf^2) / n) @@ -463,6 +468,23 @@ def _compute_aggregated_se_with_wif( if not np.all(np.isfinite(psi_total)): return np.nan + # Use design-based variance when full survey design is available + # Use unit-level resolved survey (panel IF is indexed by unit, not obs) + resolved_survey = ( + precomputed.get("resolved_survey_unit") if precomputed is not None else None + ) + if resolved_survey is not None and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ): + from diff_diff.survey import compute_survey_if_variance + + variance = compute_survey_if_variance(psi_total, resolved_survey) + if np.isnan(variance): + return np.nan + return np.sqrt(max(variance, 0.0)) + variance = np.sum(psi_total**2) return np.sqrt(variance) diff --git a/diff_diff/staggered_bootstrap.py b/diff_diff/staggered_bootstrap.py index 4a22047..83266be 100644 --- a/diff_diff/staggered_bootstrap.py +++ b/diff_diff/staggered_bootstrap.py @@ -27,6 +27,9 @@ from diff_diff.bootstrap_utils import ( generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch, ) +from diff_diff.bootstrap_utils import ( + generate_survey_multiplier_weights_batch as _generate_survey_multiplier_weights_batch, +) if TYPE_CHECKING: import pandas as pd @@ -79,6 +82,7 @@ class CSBootstrapResults: bootstrap_distribution : Optional[np.ndarray] Full bootstrap distribution of overall ATT (if requested). """ + n_bootstrap: int weight_type: str alpha: float @@ -191,18 +195,20 @@ def _run_multiplier_bootstrap( # Without this, pg is overestimated in unbalanced panels where some # units don't appear in any influence function. if precomputed is not None: - all_units = precomputed['all_units'] + all_units = precomputed["all_units"] n_units = len(all_units) - unit_to_idx = precomputed['unit_to_idx'] + unit_to_idx = precomputed["unit_to_idx"] else: # Fallback: collect units from influence functions all_units_set = set() for (g, t), info in influence_func_info.items(): - all_units_set.update(info['treated_units']) - all_units_set.update(info['control_units']) + all_units_set.update(info["treated_units"]) + all_units_set.update(info["control_units"]) all_units = sorted(all_units_set) # Use global N from dataframe when available - n_units = df[unit].nunique() if (df is not None and unit is not None) else len(all_units) + n_units = ( + df[unit].nunique() if (df is not None and unit is not None) else len(all_units) + ) unit_to_idx = {u: i for i, u in enumerate(all_units)} # Get list of (g,t) pairs @@ -211,21 +217,37 @@ def _run_multiplier_bootstrap( # Identify post-treatment (g,t) pairs for overall ATT # Pre-treatment effects are for parallel trends assessment, not aggregated - post_treatment_mask = np.array([ - t >= g - self.anticipation for (g, t) in gt_pairs - ]) + post_treatment_mask = np.array([t >= g - self.anticipation for (g, t) in gt_pairs]) post_treatment_indices = np.where(post_treatment_mask)[0] # Compute aggregation weights for overall ATT (post-treatment only) - all_n_treated = np.array([ - group_time_effects[gt]['n_treated'] for gt in gt_pairs - ], dtype=float) + # When survey weights are present, use fixed cohort survey masses + # (from precomputed survey_weights × unit_cohorts), matching the + # analytical _aggregate_simple() path in staggered_aggregation.py. + # Do NOT use per-cell survey_weight_sum (which varies by cell on + # unbalanced panels). + survey_w = precomputed.get("survey_weights") if precomputed is not None else None + if survey_w is not None: + unit_cohorts = precomputed["unit_cohorts"] + # Precompute fixed cohort masses (same formula as _aggregate_simple) + _cohort_mass_cache: dict = {} + for gt in gt_pairs: + g = gt[0] + if g not in _cohort_mass_cache: + _cohort_mass_cache[g] = float(np.sum(survey_w[unit_cohorts == g])) + all_n_treated = np.array( + [_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float + ) + else: + all_n_treated = np.array( + [group_time_effects[gt]["n_treated"] for gt in gt_pairs], dtype=float + ) post_n_treated = all_n_treated[post_treatment_mask] # Filter out NaN ATT(g,t) cells from overall aggregation (matches analytical path) - post_effects_raw = np.array([ - group_time_effects[gt_pairs[i]]['effect'] for i in post_treatment_indices - ]) + post_effects_raw = np.array( + [group_time_effects[gt_pairs[i]]["effect"] for i in post_treatment_indices] + ) finite_post = np.isfinite(post_effects_raw) if not np.all(finite_post): post_treatment_indices = post_treatment_indices[finite_post] @@ -239,7 +261,7 @@ def _run_multiplier_bootstrap( "No post-treatment effects for bootstrap aggregation. " "Overall ATT statistics will be NaN, but per-effect SEs will be computed.", UserWarning, - stacklevel=2 + stacklevel=2, ) skip_overall_aggregation = True overall_weights_post = np.array([]) @@ -247,7 +269,7 @@ def _run_multiplier_bootstrap( overall_weights_post = post_n_treated / np.sum(post_n_treated) # Original point estimates - original_atts = np.array([group_time_effects[gt]['effect'] for gt in gt_pairs]) + original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs]) if skip_overall_aggregation: original_overall = np.nan else: @@ -259,10 +281,15 @@ def _run_multiplier_bootstrap( if aggregate in ["event_study", "all"]: event_study_info = self._prepare_event_study_aggregation( - gt_pairs, group_time_effects, balance_e, + gt_pairs, + group_time_effects, + balance_e, influence_func_info=influence_func_info, - df=df, unit=unit, precomputed=precomputed, - global_unit_to_idx=unit_to_idx, n_global_units=n_units, + df=df, + unit=unit, + precomputed=precomputed, + global_unit_to_idx=unit_to_idx, + n_global_units=n_units, ) if aggregate in ["group", "all"]: @@ -278,17 +305,47 @@ def _run_multiplier_bootstrap( for j, gt in enumerate(gt_pairs): info = influence_func_info[gt] - gt_treated_indices.append(info['treated_idx']) - gt_control_indices.append(info['control_idx']) - gt_treated_inf.append(np.asarray(info['treated_inf'])) - gt_control_inf.append(np.asarray(info['control_inf'])) - - # Generate ALL bootstrap weights upfront: shape (n_bootstrap, n_units) - # This is much faster than generating one at a time - all_bootstrap_weights = _generate_bootstrap_weights_batch( - self.n_bootstrap, n_units, self.bootstrap_weight_type, rng + gt_treated_indices.append(info["treated_idx"]) + gt_control_indices.append(info["control_idx"]) + gt_treated_inf.append(np.asarray(info["treated_inf"])) + gt_control_inf.append(np.asarray(info["control_inf"])) + + # Generate bootstrap weights — PSU-level when survey design is present, + # unit-level otherwise. + resolved_survey_unit = ( + precomputed.get("resolved_survey_unit") if precomputed is not None else None + ) + _use_survey_bootstrap = resolved_survey_unit is not None and ( + resolved_survey_unit.strata is not None + or resolved_survey_unit.psu is not None + or resolved_survey_unit.fpc is not None ) + if _use_survey_bootstrap: + # PSU-level multiplier weights + psu_weights, psu_ids = _generate_survey_multiplier_weights_batch( + self.n_bootstrap, resolved_survey_unit, self.bootstrap_weight_type, rng + ) + # Build unit → PSU column map + if resolved_survey_unit.psu is not None: + unit_psu = resolved_survey_unit.psu + psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)} + unit_to_psu_col = np.array( + [psu_id_to_col[int(unit_psu[i])] for i in range(n_units)] + ) + else: + # Each unit is its own PSU — identity mapping + unit_to_psu_col = np.arange(n_units) + + # Expand PSU weights to unit level for per-(g,t) perturbation + # Shape: (n_bootstrap, n_units) + all_bootstrap_weights = psu_weights[:, unit_to_psu_col] + else: + # Standard unit-level weights (no survey or weights-only) + all_bootstrap_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_units, self.bootstrap_weight_type, rng + ) + # Vectorized bootstrap ATT(g,t) computation # Compute all bootstrap ATTs for all (g,t) pairs using matrix operations bootstrap_atts_gt = np.zeros((self.n_bootstrap, n_gt)) @@ -307,11 +364,8 @@ def _run_multiplier_bootstrap( # Vectorized perturbation: matrix-vector multiply # Shape: (n_bootstrap,) # Suppress RuntimeWarnings for edge cases (small samples, extreme weights) - with np.errstate(divide='ignore', invalid='ignore', over='ignore'): - perturbations = ( - treated_weights @ treated_inf + - control_weights @ control_inf - ) + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + perturbations = treated_weights @ treated_inf + control_weights @ control_inf # Let non-finite values propagate - they will be handled at statistics computation bootstrap_atts_gt[:, j] = original_atts[j] + perturbations @@ -326,11 +380,18 @@ def _run_multiplier_bootstrap( post_groups = np.array([gt_pairs[i][0] for i in post_treatment_indices]) post_effects = original_atts[post_treatment_indices] overall_combined_if, _ = self._compute_combined_influence_function( - post_gt_pairs, overall_weights_post, post_effects, post_groups, - influence_func_info, df, unit, precomputed, - global_unit_to_idx=unit_to_idx, n_global_units=n_units, + post_gt_pairs, + overall_weights_post, + post_effects, + post_groups, + influence_func_info, + df, + unit, + precomputed, + global_unit_to_idx=unit_to_idx, + n_global_units=n_units, ) - with np.errstate(divide='ignore', invalid='ignore', over='ignore'): + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): bootstrap_overall = original_overall + all_bootstrap_weights @ overall_combined_if # Vectorized event study aggregation using combined IFs @@ -343,9 +404,9 @@ def _run_multiplier_bootstrap( for e in rel_periods: agg_info = event_study_info[e] # Use combined IF (standard IF + WIF) for proper bootstrap - with np.errstate(divide='ignore', invalid='ignore', over='ignore'): + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): bootstrap_event_study[e] = ( - agg_info['effect'] + all_bootstrap_weights @ agg_info['combined_if'] + agg_info["effect"] + all_bootstrap_weights @ agg_info["combined_if"] ) # Vectorized group aggregation @@ -357,17 +418,15 @@ def _run_multiplier_bootstrap( bootstrap_group = {} for g in group_list: agg_info = group_agg_info[g] - gt_indices = agg_info['gt_indices'] - weights = agg_info['weights'] + gt_indices = agg_info["gt_indices"] + weights = agg_info["weights"] # Suppress RuntimeWarnings for edge cases - with np.errstate(divide='ignore', invalid='ignore', over='ignore'): + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): bootstrap_group[g] = bootstrap_atts_gt[:, gt_indices] @ weights # Batch compute bootstrap statistics for ATT(g,t) - batch_ses, batch_ci_lo, batch_ci_hi, batch_pv = ( - _compute_effect_bootstrap_stats_batch_func( - original_atts, bootstrap_atts_gt, alpha=self.alpha - ) + batch_ses, batch_ci_lo, batch_ci_hi, batch_pv = _compute_effect_bootstrap_stats_batch_func( + original_atts, bootstrap_atts_gt, alpha=self.alpha ) gt_ses = {} gt_cis = {} @@ -384,8 +443,7 @@ def _run_multiplier_bootstrap( overall_p_value = np.nan else: overall_se, overall_ci, overall_p_value = self._compute_effect_bootstrap_stats( - original_overall, bootstrap_overall, - context="overall ATT" + original_overall, bootstrap_overall, context="overall ATT" ) # Batch compute bootstrap statistics for event study effects @@ -394,15 +452,15 @@ def _run_multiplier_bootstrap( event_study_p_values = None if bootstrap_event_study is not None and event_study_info is not None: - es_effects = np.array([event_study_info[e]['effect'] for e in rel_periods]) + es_effects = np.array([event_study_info[e]["effect"] for e in rel_periods]) es_boot_matrix = np.column_stack([bootstrap_event_study[e] for e in rel_periods]) - es_ses, es_ci_lo, es_ci_hi, es_pv = ( - _compute_effect_bootstrap_stats_batch_func( - es_effects, es_boot_matrix, alpha=self.alpha - ) + es_ses, es_ci_lo, es_ci_hi, es_pv = _compute_effect_bootstrap_stats_batch_func( + es_effects, es_boot_matrix, alpha=self.alpha ) event_study_ses = {e: float(es_ses[i]) for i, e in enumerate(rel_periods)} - event_study_cis = {e: (float(es_ci_lo[i]), float(es_ci_hi[i])) for i, e in enumerate(rel_periods)} + event_study_cis = { + e: (float(es_ci_lo[i]), float(es_ci_hi[i])) for i, e in enumerate(rel_periods) + } event_study_p_values = {e: float(es_pv[i]) for i, e in enumerate(rel_periods)} # Batch compute bootstrap statistics for group effects @@ -411,23 +469,28 @@ def _run_multiplier_bootstrap( group_effect_p_values = None if bootstrap_group is not None and group_agg_info is not None: - grp_effects = np.array([group_agg_info[g]['effect'] for g in group_list]) + grp_effects = np.array([group_agg_info[g]["effect"] for g in group_list]) grp_boot_matrix = np.column_stack([bootstrap_group[g] for g in group_list]) - grp_ses, grp_ci_lo, grp_ci_hi, grp_pv = ( - _compute_effect_bootstrap_stats_batch_func( - grp_effects, grp_boot_matrix, alpha=self.alpha - ) + grp_ses, grp_ci_lo, grp_ci_hi, grp_pv = _compute_effect_bootstrap_stats_batch_func( + grp_effects, grp_boot_matrix, alpha=self.alpha ) group_effect_ses = {g: float(grp_ses[i]) for i, g in enumerate(group_list)} - group_effect_cis = {g: (float(grp_ci_lo[i]), float(grp_ci_hi[i])) for i, g in enumerate(group_list)} + group_effect_cis = { + g: (float(grp_ci_lo[i]), float(grp_ci_hi[i])) for i, g in enumerate(group_list) + } group_effect_p_values = {g: float(grp_pv[i]) for i, g in enumerate(group_list)} # Compute simultaneous confidence band critical value (sup-t) cband_crit_value = None - if (cband and bootstrap_event_study is not None - and event_study_ses is not None and event_study_info is not None): + if ( + cband + and bootstrap_event_study is not None + and event_study_ses is not None + and event_study_info is not None + ): valid_es = [ - e for e in rel_periods + e + for e in rel_periods if e in event_study_ses and np.isfinite(event_study_ses[e]) and event_study_ses[e] > 0 @@ -435,9 +498,9 @@ def _run_multiplier_bootstrap( if valid_es: # Vectorized sup_t: max_e |(boot_att_e[b] - att_e) / se_e| boot_matrix = np.array([bootstrap_event_study[e] for e in valid_es]) - effects_vec = np.array([event_study_info[e]['effect'] for e in valid_es]) + effects_vec = np.array([event_study_info[e]["effect"] for e in valid_es]) ses_vec = np.array([event_study_ses[e] for e in valid_es]) - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): sup_t_dist = np.max( np.abs((boot_matrix - effects_vec[:, None]) / ses_vec[:, None]), axis=0, @@ -453,9 +516,7 @@ def _run_multiplier_bootstrap( stacklevel=2, ) elif n_valid > 0: - cband_crit_value = float( - np.quantile(sup_t_dist[finite_mask], 1 - self.alpha) - ) + cband_crit_value = float(np.quantile(sup_t_dist[finite_mask], 1 - self.alpha)) return CSBootstrapResults( n_bootstrap=self.n_bootstrap, @@ -490,6 +551,22 @@ def _prepare_event_study_aggregation( n_global_units: Optional[int] = None, ) -> Dict[int, Dict[str, Any]]: """Prepare aggregation info for event study bootstrap.""" + # Use fixed cohort survey masses (not per-cell survey_weight_sum) when + # survey weights are present, matching the analytical + # _aggregate_event_study() path. + survey_w = precomputed.get("survey_weights") if precomputed is not None else None + _cohort_mass: Optional[dict] = None + if survey_w is not None: + unit_cohorts = precomputed["unit_cohorts"] + _cohort_mass = {} + + def _agg_weight(g: Any, t: Any) -> float: + if _cohort_mass is not None: + if g not in _cohort_mass: + _cohort_mass[g] = float(np.sum(survey_w[unit_cohorts == g])) + return _cohort_mass[g] + return group_time_effects[(g, t)]["n_treated"] + # Organize by relative time effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {} @@ -497,17 +574,19 @@ def _prepare_event_study_aggregation( e = t - g if e not in effects_by_e: effects_by_e[e] = [] - effects_by_e[e].append(( - j, # index in gt_pairs - group_time_effects[(g, t)]['effect'], - group_time_effects[(g, t)]['n_treated'] - )) + effects_by_e[e].append( + ( + j, # index in gt_pairs + group_time_effects[(g, t)]["effect"], + _agg_weight(g, t), + ) + ) # Balance if requested if balance_e is not None: groups_at_e = set() for j, (g, t) in enumerate(gt_pairs): - if t - g == balance_e and np.isfinite(group_time_effects[(g, t)]['effect']): + if t - g == balance_e and np.isfinite(group_time_effects[(g, t)]["effect"]): groups_at_e.add(g) balanced_effects: Dict[int, List[Tuple[int, float, float]]] = {} @@ -516,11 +595,13 @@ def _prepare_event_study_aggregation( e = t - g if e not in balanced_effects: balanced_effects[e] = [] - balanced_effects[e].append(( - j, - group_time_effects[(g, t)]['effect'], - group_time_effects[(g, t)]['n_treated'] - )) + balanced_effects[e].append( + ( + j, + group_time_effects[(g, t)]["effect"], + _agg_weight(g, t), + ) + ) effects_by_e = balanced_effects # Compute aggregation weights @@ -543,9 +624,9 @@ def _prepare_event_study_aggregation( agg_effect = np.sum(weights * effects) entry: Dict[str, Any] = { - 'gt_indices': indices, - 'weights': weights, - 'effect': agg_effect, + "gt_indices": indices, + "weights": weights, + "effect": agg_effect, } # Compute combined IF for this event time if args available @@ -553,12 +634,18 @@ def _prepare_event_study_aggregation( gt_pairs_for_e = [gt_pairs[i] for i in indices] groups_for_gt = np.array([gt_pairs[i][0] for i in indices]) combined_if, _ = self._compute_combined_influence_function( - gt_pairs_for_e, weights, effects, groups_for_gt, - influence_func_info, df, unit, precomputed, + gt_pairs_for_e, + weights, + effects, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, global_unit_to_idx=global_unit_to_idx, n_global_units=n_global_units, ) - entry['combined_if'] = combined_if + entry["combined_if"] = combined_if result[e] = entry @@ -578,10 +665,12 @@ def _prepare_group_aggregation( group_data = [] for j, (gg, t) in enumerate(gt_pairs): if gg == g and t >= g - self.anticipation: - group_data.append(( - j, - group_time_effects[(gg, t)]['effect'], - )) + group_data.append( + ( + j, + group_time_effects[(gg, t)]["effect"], + ) + ) if not group_data: continue @@ -602,9 +691,9 @@ def _prepare_group_aggregation( agg_effect = np.sum(weights * effects) result[g] = { - 'gt_indices': indices, - 'weights': weights, - 'effect': agg_effect, + "gt_indices": indices, + "weights": weights, + "effect": agg_effect, } return result diff --git a/diff_diff/sun_abraham.py b/diff_diff/sun_abraham.py index 107df47..87db4a7 100644 --- a/diff_diff/sun_abraham.py +++ b/diff_diff/sun_abraham.py @@ -508,13 +508,16 @@ def fit( _resolve_survey_for_fit(survey_design, data, "analytical") ) - # Reject bootstrap + survey (pairs bootstrap with survey weights needs Phase 5) - if self.n_bootstrap > 0 and resolved_survey is not None: - raise NotImplementedError( - "Bootstrap inference with survey weights is not yet supported " - "for SunAbraham. Use analytical inference (n_bootstrap=0) with " - "survey_design for design-based standard errors." - ) + # Bootstrap + survey supported via Rao-Wu rescaled bootstrap. + # Determine Rao-Wu eligibility from the *original* survey_design + # (before cluster-as-PSU injection which adds PSU to weights-only designs). + _use_rao_wu = False + if survey_design is not None and resolved_survey is not None: + _has_explicit_strata = getattr(survey_design, "strata", None) is not None + _has_explicit_psu = getattr(survey_design, "psu", None) is not None + _has_explicit_fpc = getattr(survey_design, "fpc", None) is not None + if _has_explicit_strata or _has_explicit_psu or _has_explicit_fpc: + _use_rao_wu = True # Create working copy df = data.copy() @@ -687,6 +690,11 @@ def fit( cluster_var=cluster_var, original_event_study=event_study_effects, original_overall_att=overall_att, + resolved_survey=resolved_survey, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, + survey_weight_col=survey_weight_col, + use_rao_wu=_use_rao_wu, ) # Update results with bootstrap inference @@ -1073,11 +1081,18 @@ def _run_bootstrap( cluster_var: str, original_event_study: Dict[int, Dict[str, Any]], original_overall_att: float, + resolved_survey: object = None, + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", + survey_weight_col: Optional[str] = None, + use_rao_wu: bool = False, ) -> SABootstrapResults: """ - Run pairs bootstrap for inference. + Run bootstrap for inference. - Resamples units with replacement and re-estimates the full model. + When use_rao_wu is True (survey design with explicit strata/PSU/FPC), + uses Rao-Wu rescaled bootstrap (weight perturbation). Otherwise, uses + pairs bootstrap (resampling units with replacement). """ if self.n_bootstrap < 50: warnings.warn( @@ -1089,6 +1104,27 @@ def _run_bootstrap( rng = np.random.default_rng(self.seed) + if use_rao_wu: + return self._run_rao_wu_bootstrap( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + treatment_groups=treatment_groups, + rel_periods_to_estimate=rel_periods_to_estimate, + covariates=covariates, + cluster_var=cluster_var, + original_event_study=original_event_study, + original_overall_att=original_overall_att, + resolved_survey=resolved_survey, + survey_weight_type=survey_weight_type, + survey_weight_col=survey_weight_col, + rng=rng, + ) + + # --- Pairs bootstrap (non-survey or weights-only survey) --- + # Get unique units all_units = df[unit].unique() n_units = len(all_units) @@ -1122,6 +1158,11 @@ def _run_bootstrap( df_b["_never_treated"] = (df_b[first_treat] == 0) | (df_b[first_treat] == np.inf) try: + # Extract survey weights from resampled data if present + boot_survey_weights = None + if survey_weight_col is not None and survey_weight_col in df_b.columns: + boot_survey_weights = df_b[survey_weight_col].values + # Re-estimate saturated regression ( cohort_effects_b, @@ -1138,6 +1179,9 @@ def _run_bootstrap( rel_periods_to_estimate, covariates, cluster_var, + survey_weights=boot_survey_weights, + survey_weight_type=survey_weight_type, + resolved_survey=None, # Use explicit weights, not stale design ) # Compute IW effects for this bootstrap sample @@ -1151,6 +1195,7 @@ def _run_bootstrap( cohort_ses_b, vcov_b, coef_map_b, + survey_weight_col=survey_weight_col, ) # Store bootstrap estimates @@ -1169,6 +1214,7 @@ def _run_bootstrap( cohort_weights_b, vcov_b, coef_map_b, + survey_weight_col=survey_weight_col, ) bootstrap_overall[b] = overall_b @@ -1222,6 +1268,240 @@ def _run_bootstrap( bootstrap_distribution=bootstrap_overall, ) + def _run_rao_wu_bootstrap( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + treatment_groups: List[Any], + rel_periods_to_estimate: List[int], + covariates: Optional[List[str]], + cluster_var: str, + original_event_study: Dict[int, Dict[str, Any]], + original_overall_att: float, + resolved_survey: object, + survey_weight_type: str, + survey_weight_col: Optional[str], + rng: np.random.Generator, + ) -> SABootstrapResults: + """ + Run Rao-Wu rescaled bootstrap for survey-aware inference. + + Instead of physically resampling units, each iteration generates + rescaled observation weights via Rao-Wu (1988) weight perturbation. + The rescaled weights feed into the existing WLS regression path. + """ + from diff_diff.bootstrap_utils import generate_rao_wu_weights + from diff_diff.survey import ResolvedSurveyDesign + + # Column name for rescaled weights in the bootstrap DataFrame + _rw_col = "__rw_boot_weight" + + # Collapse survey design to unit level so Rao-Wu respects panel + # structure: each unit gets one set of weights regardless of how + # many time periods it has. Without this, when there is no + # explicit PSU, generate_rao_wu_weights treats each observation as + # its own PSU and different obs of the same unit can get different + # weights, breaking panel semantics. + all_units = df[unit].unique() + + weights_unit = ( + pd.Series(resolved_survey.weights, index=df.index) + .groupby(df[unit]) + .first() + .reindex(all_units) + .values + .astype(np.float64) + ) + + strata_unit = None + if resolved_survey.strata is not None: + strata_unit = ( + pd.Series(resolved_survey.strata, index=df.index) + .groupby(df[unit]) + .first() + .reindex(all_units) + .values + ) + + psu_unit = None + if resolved_survey.psu is not None: + psu_unit = ( + pd.Series(resolved_survey.psu, index=df.index) + .groupby(df[unit]) + .first() + .reindex(all_units) + .values + ) + + fpc_unit = None + if resolved_survey.fpc is not None: + fpc_unit = ( + pd.Series(resolved_survey.fpc, index=df.index) + .groupby(df[unit]) + .first() + .reindex(all_units) + .values + ) + + unit_resolved = ResolvedSurveyDesign( + weights=weights_unit, + weight_type=resolved_survey.weight_type, + strata=strata_unit, + psu=psu_unit, + fpc=fpc_unit, + n_strata=resolved_survey.n_strata, + n_psu=resolved_survey.n_psu, + lonely_psu=resolved_survey.lonely_psu, + ) + + # Build unit -> row indices mapping for expanding unit-level weights + unit_to_rows = {u: df.index[df[unit] == u].values for u in all_units} + unit_order = {u: i for i, u in enumerate(all_units)} + + # Store bootstrap samples + rel_periods = sorted(original_event_study.keys()) + bootstrap_effects = {e: np.full(self.n_bootstrap, np.nan) for e in rel_periods} + bootstrap_overall = np.full(self.n_bootstrap, np.nan) + + for b in range(self.n_bootstrap): + try: + # Generate Rao-Wu rescaled weights at unit level + unit_boot_weights = generate_rao_wu_weights(unit_resolved, rng) + + # Expand unit-level weights to observation level + boot_weights = np.empty(len(df), dtype=np.float64) + for u, idx in unit_to_rows.items(): + boot_weights[idx] = unit_boot_weights[unit_order[u]] + + # Drop observations with zero weight (PSUs not drawn in this + # iteration) to avoid NaN/Inf in within-transformation. + positive_mask = boot_weights > 0 + if positive_mask.sum() < 2: + # Too few observations with positive weight + raise ValueError("Rao-Wu iteration produced < 2 positive weights") + + df_b = df[positive_mask].reset_index(drop=True) + boot_weights_b = boot_weights[positive_mask] + df_b[_rw_col] = boot_weights_b + + # Verify we still have both treated and control observations + has_treated = (df_b[first_treat] > 0).any() + has_control = ((df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)).any() + if not has_treated or not has_control: + raise ValueError("Rao-Wu iteration dropped all treated or control units") + + # Re-estimate saturated regression with rescaled weights. + # Pass resolved_survey=None since inference comes from the + # bootstrap distribution, not from within-iteration vcov. + ( + cohort_effects_b, + cohort_ses_b, + vcov_b, + coef_map_b, + ) = self._fit_saturated_regression( + df_b, + outcome, + unit, + time, + first_treat, + treatment_groups, + rel_periods_to_estimate, + covariates, + cluster_var, + survey_weights=boot_weights_b, + survey_weight_type=survey_weight_type, + resolved_survey=None, + ) + + # Compute IW effects using rescaled weights for cohort shares + event_study_b, cohort_weights_b = self._compute_iw_effects( + df_b, + unit, + first_treat, + treatment_groups, + rel_periods_to_estimate, + cohort_effects_b, + cohort_ses_b, + vcov_b, + coef_map_b, + survey_weight_col=_rw_col, + ) + + # Store bootstrap estimates + for e in rel_periods: + if e in event_study_b: + bootstrap_effects[e][b] = event_study_b[e]["effect"] + else: + bootstrap_effects[e][b] = original_event_study[e]["effect"] + + # Compute overall ATT using rescaled weights + overall_b, _ = self._compute_overall_att( + df_b, + first_treat, + event_study_b, + cohort_effects_b, + cohort_weights_b, + vcov_b, + coef_map_b, + survey_weight_col=_rw_col, + ) + bootstrap_overall[b] = overall_b + + except (ValueError, np.linalg.LinAlgError) as exc: + # Failed draws stored as NaN (not original estimate) to avoid + # shrinking bootstrap dispersion. compute_effect_bootstrap_stats + # handles NaN draws via nanstd. + warnings.warn( + f"Bootstrap iteration {b} failed: {exc}. Storing NaN.", + UserWarning, + stacklevel=2, + ) + for e in rel_periods: + bootstrap_effects[e][b] = np.nan + bootstrap_overall[b] = np.nan + + # Compute bootstrap statistics + event_study_ses = {} + event_study_cis = {} + event_study_p_values = {} + + for e in rel_periods: + boot_dist = bootstrap_effects[e] + original_effect = original_event_study[e]["effect"] + se, ci, p_value = compute_effect_bootstrap_stats( + original_effect, + boot_dist, + alpha=self.alpha, + context=f"event study e={e}", + ) + event_study_ses[e] = se + event_study_cis[e] = ci + event_study_p_values[e] = p_value + + # Overall ATT statistics + overall_se, overall_ci, overall_p = compute_effect_bootstrap_stats( + original_overall_att, + bootstrap_overall, + alpha=self.alpha, + context="overall ATT", + ) + + return SABootstrapResults( + n_bootstrap=self.n_bootstrap, + weight_type="rao_wu", + alpha=self.alpha, + overall_att_se=overall_se, + overall_att_ci=overall_ci, + overall_att_p_value=overall_p, + event_study_ses=event_study_ses, + event_study_cis=event_study_cis, + event_study_p_values=event_study_p_values, + bootstrap_distribution=bootstrap_overall, + ) + def get_params(self) -> Dict[str, Any]: """Get estimator parameters (sklearn-compatible).""" return { diff --git a/diff_diff/survey.py b/diff_diff/survey.py index efa1626..76d21b2 100644 --- a/diff_diff/survey.py +++ b/diff_diff/survey.py @@ -577,60 +577,52 @@ def _inject_cluster_as_psu(resolved, cluster_ids): return replace(resolved, psu=codes, n_psu=n_clusters) -def compute_survey_vcov( - X: np.ndarray, - residuals: np.ndarray, +def _compute_stratified_psu_meat( + scores: np.ndarray, resolved: "ResolvedSurveyDesign", -) -> np.ndarray: - """ - Compute Taylor Series Linearization (TSL) variance-covariance matrix. - - Implements the stratified cluster sandwich estimator with optional - finite population correction (FPC). +) -> tuple: + """Compute the stratified PSU-level meat matrix for TSL variance. - V_TSL = (X'WX)^{-1} [sum_h V_h] (X'WX)^{-1} + This is the core computation shared by :func:`compute_survey_vcov` + (which wraps it in a sandwich with the bread matrix) and + :func:`compute_survey_if_variance` (which uses it directly for + influence-function-based estimators). Parameters ---------- - X : np.ndarray - Design matrix of shape (n, k). - residuals : np.ndarray - Residuals from WLS fit (y - X @ beta, on ORIGINAL scale). + scores : np.ndarray + Score matrix of shape (n, k). For OLS-based estimators these are + the weighted score contributions X_i * w_i * u_i. For IF-based + estimators these are the per-unit influence function values + (reshaped to (n, 1) for scalar estimators). resolved : ResolvedSurveyDesign Resolved survey design with weights, strata, PSU arrays. Returns ------- - vcov : np.ndarray - Variance-covariance matrix of shape (k, k). + meat : np.ndarray + Meat matrix of shape (k, k). + variance_computed : bool + Whether any actual variance computation happened. + legitimate_zero_count : int + Number of strata/sources that legitimately contribute zero variance. """ - n, k = X.shape - weights = resolved.weights - - # Bread: (X'WX)^{-1} - XtWX = X.T @ (X * weights[:, np.newaxis]) - - # Compute weighted scores per observation: w_i * X_i * u_i - if resolved.weight_type == "aweight": - # aweights: no weight in meat (errors already homoskedastic after WLS) - scores = X * residuals[:, np.newaxis] - else: - scores = X * (weights * residuals)[:, np.newaxis] + n = scores.shape[0] + k = scores.shape[1] if scores.ndim > 1 else 1 + if scores.ndim == 1: + scores = scores[:, np.newaxis] - # Determine strata and PSU structure strata = resolved.strata psu = resolved.psu legitimate_zero_count = 0 - _variance_computed = False # Did any actual variance computation happen? + _variance_computed = False if strata is None and psu is None: - # No survey structure beyond weights — use implicit per-observation PSUs - # so the TSL construction is consistent across all branches. - # Each observation is its own PSU; scores are already per-obs. + # No survey structure beyond weights — implicit per-observation PSUs psu_mean = scores.mean(axis=0, keepdims=True) centered = scores - psu_mean - f_h = 0.0 # No FPC by default + f_h = 0.0 if resolved.fpc is not None: N_h = resolved.fpc[0] if N_h < n: @@ -650,15 +642,11 @@ def compute_survey_vcov( n_psu = psu_scores.shape[0] if n_psu < 2: - # With only 1 PSU and no strata, variance estimation is impossible - # regardless of lonely_psu mode. The "adjust" mode cannot help - # because there is no global-vs-stratum distinction to exploit. meat = np.zeros((k, k)) else: - # Center around grand mean psu_mean = psu_scores.mean(axis=0, keepdims=True) centered = psu_scores - psu_mean - f_h = 0.0 # No FPC + f_h = 0.0 if resolved.fpc is not None: N_h = resolved.fpc[0] if N_h < n_psu: @@ -677,8 +665,6 @@ def compute_survey_vcov( unique_strata = np.unique(strata) meat = np.zeros((k, k)) - # Pre-compute global PSU scores for lonely_psu="adjust" (avoids - # recomputing O(n) groupby inside the per-stratum loop) _global_psu_mean = None if resolved.lonely_psu == "adjust": if psu is not None: @@ -692,25 +678,22 @@ def compute_survey_vcov( mask_h = strata == h if psu is not None: - # Get PSU-level score totals within stratum h psu_h = psu[mask_h] scores_h = scores[mask_h] psu_scores_h = pd.DataFrame(scores_h).groupby(psu_h).sum().values n_psu_h = psu_scores_h.shape[0] else: - # Each observation is its own PSU psu_scores_h = scores[mask_h] n_psu_h = psu_scores_h.shape[0] # Handle singleton strata if n_psu_h < 2: if resolved.lonely_psu == "remove": - continue # Skip this stratum + continue elif resolved.lonely_psu == "certainty": legitimate_zero_count += 1 - continue # f_h = 1, so (1-f_h) = 0, zero contribution + continue elif resolved.lonely_psu == "adjust": - # Center around overall mean instead of stratum mean centered = psu_scores_h - _global_psu_mean V_h = centered.T @ centered meat += V_h @@ -730,25 +713,62 @@ def compute_survey_vcov( if f_h >= 1.0: legitimate_zero_count += 1 - # Stratum mean of PSU scores psu_mean_h = psu_scores_h.mean(axis=0, keepdims=True) centered = psu_scores_h - psu_mean_h - # V_h = (1 - f_h) * (n_h / (n_h - 1)) * sum (e_hi - e_bar_h)(...)^T adjustment = (1.0 - f_h) * (n_psu_h / (n_psu_h - 1)) V_h = adjustment * (centered.T @ centered) meat += V_h _variance_computed = True + return meat, _variance_computed, legitimate_zero_count + + +def compute_survey_vcov( + X: np.ndarray, + residuals: np.ndarray, + resolved: "ResolvedSurveyDesign", +) -> np.ndarray: + """ + Compute Taylor Series Linearization (TSL) variance-covariance matrix. + + Implements the stratified cluster sandwich estimator with optional + finite population correction (FPC). + + V_TSL = (X'WX)^{-1} [sum_h V_h] (X'WX)^{-1} + + Parameters + ---------- + X : np.ndarray + Design matrix of shape (n, k). + residuals : np.ndarray + Residuals from WLS fit (y - X @ beta, on ORIGINAL scale). + resolved : ResolvedSurveyDesign + Resolved survey design with weights, strata, PSU arrays. + + Returns + ------- + vcov : np.ndarray + Variance-covariance matrix of shape (k, k). + """ + n, k = X.shape + weights = resolved.weights + + # Bread: (X'WX)^{-1} + XtWX = X.T @ (X * weights[:, np.newaxis]) + + # Compute weighted scores per observation: w_i * X_i * u_i + if resolved.weight_type == "aweight": + scores = X * residuals[:, np.newaxis] + else: + scores = X * (weights * residuals)[:, np.newaxis] + + meat, _variance_computed, legitimate_zero_count = _compute_stratified_psu_meat(scores, resolved) + # Guard: if meat is zero, distinguish legitimate zero from unidentified variance if not np.any(meat != 0): if _variance_computed or legitimate_zero_count > 0: - # Zero meat from actual computation (e.g., identical PSU scores, - # perfect fit) or from legitimate zero-variance sources (certainty - # PSUs, full-census FPC). Zero vcov is the correct result. return np.zeros((k, k)) - # No variance computation happened (e.g., all strata removed, single - # unstratified PSU). Variance is genuinely unidentified. return np.full((k, k), np.nan) # Sandwich: (X'WX)^{-1} meat (X'WX)^{-1} @@ -764,3 +784,83 @@ def compute_survey_vcov( raise return vcov + + +def compute_survey_if_variance( + psi: np.ndarray, + resolved: "ResolvedSurveyDesign", +) -> float: + """Compute design-based variance of a scalar estimator from IF values. + + For influence-function-based estimators (e.g., CallawaySantAnna), + the per-unit influence function values ``psi_i`` capture each unit's + contribution to the estimating equation. Under simple random sampling + the variance is ``sum(psi_i^2)``. This function computes the + design-based analogue accounting for PSU clustering, stratification, + and finite population correction. + + V_design = sum_h (1-f_h) * (n_h/(n_h-1)) * sum_j (psi_hj - psi_h_bar)^2 + + where psi_hj = sum_{i in PSU j, stratum h} psi_i. + + Parameters + ---------- + psi : np.ndarray + Per-unit influence function values, shape (n,). + resolved : ResolvedSurveyDesign + Resolved survey design. + + Returns + ------- + float + Design-based variance. Returns ``np.nan`` when variance is + unidentified (e.g., all strata removed by lonely_psu='remove'). + """ + psi = np.asarray(psi, dtype=np.float64).ravel() + + meat, _variance_computed, legitimate_zero_count = _compute_stratified_psu_meat( + psi[:, np.newaxis], resolved + ) + + # meat is (1, 1) — extract scalar + meat_scalar = float(meat[0, 0]) + + if meat_scalar == 0.0: + if _variance_computed or legitimate_zero_count > 0: + return 0.0 + return np.nan + + return meat_scalar + + +def aggregate_to_psu( + values: np.ndarray, + resolved: "ResolvedSurveyDesign", +) -> tuple: + """Sum values within PSUs for PSU-level bootstrap perturbation. + + Parameters + ---------- + values : np.ndarray + Per-observation values, shape (n,) or (n, k). + resolved : ResolvedSurveyDesign + Resolved survey design. + + Returns + ------- + psu_sums : np.ndarray + Aggregated values, shape (n_psu,) or (n_psu, k). + psu_ids : np.ndarray + Unique PSU identifiers in the same order as ``psu_sums``. + """ + if resolved.psu is None: + # Each observation is its own PSU — return as-is + return values.copy(), np.arange(len(values)) + + psu = resolved.psu + unique_psu = np.unique(psu) + if values.ndim == 1: + psu_sums = np.array([values[psu == p].sum() for p in unique_psu]) + else: + psu_sums = np.array([values[psu == p].sum(axis=0) for p in unique_psu]) + return psu_sums, unique_psu diff --git a/diff_diff/synthetic_did.py b/diff_diff/synthetic_did.py index 2ff982c..f60a9dd 100644 --- a/diff_diff/synthetic_did.py +++ b/diff_diff/synthetic_did.py @@ -216,8 +216,10 @@ def fit( # type: ignore[override] List of covariate column names. Covariates are residualized out before computing the SDID estimator. survey_design : SurveyDesign, optional - Survey design specification. Only pweight designs are supported - (strata/PSU/FPC raise NotImplementedError). + Survey design specification. Only pweight weight_type is supported. + Strata/PSU/FPC are supported via Rao-Wu rescaled bootstrap when + variance_method='bootstrap'. Placebo variance does not support + strata/PSU/FPC; use variance_method='bootstrap' for full designs. Returns ------- @@ -230,8 +232,6 @@ def fit( # type: ignore[override] ValueError If required parameters are missing, data validation fails, or a non-pweight survey design is provided. - NotImplementedError - If survey_design includes strata, PSU, or FPC. """ # Validate inputs if outcome is None or treatment is None or unit is None or time is None: @@ -249,7 +249,6 @@ def fit( # type: ignore[override] # Resolve survey design from diff_diff.survey import ( _extract_unit_survey_weights, - _resolve_pweight_only, _resolve_survey_for_fit, _validate_unit_constant_survey, ) @@ -257,7 +256,27 @@ def fit( # type: ignore[override] resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( _resolve_survey_for_fit(survey_design, data, "analytical") ) - _resolve_pweight_only(resolved_survey, "SyntheticDiD") + # Validate pweight only (strata/PSU/FPC are allowed for Rao-Wu bootstrap) + if resolved_survey is not None and resolved_survey.weight_type != "pweight": + raise ValueError( + "SyntheticDiD survey support requires weight_type='pweight'. " + f"Got '{resolved_survey.weight_type}'." + ) + + # Reject placebo + full survey design (strata/PSU/FPC are silently ignored) + if ( + resolved_survey is not None + and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + and self.variance_method == "placebo" + ): + raise NotImplementedError( + "SyntheticDiD with variance_method='placebo' does not support strata/PSU/FPC. " + "Use variance_method='bootstrap' for full survey design support." + ) # Validate treatment is binary validate_binary(data[treatment].values, "treatment") @@ -330,10 +349,28 @@ def fit( # type: ignore[override] ) # Validate and extract survey weights + # Build unit-level ResolvedSurveyDesign for Rao-Wu bootstrap when + # strata/PSU/FPC are present (survey columns are unit-constant). + _unit_resolved_survey = None if resolved_survey is not None: _validate_unit_constant_survey(data, unit, survey_design) w_treated = _extract_unit_survey_weights(data, unit, survey_design, treated_units) w_control = _extract_unit_survey_weights(data, unit, survey_design, control_units) + + # Build unit-level resolved survey for Rao-Wu bootstrap + _has_design = ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + if _has_design: + _unit_resolved_survey = self._build_unit_resolved_survey( + data, + unit, + survey_design, + control_units, + treated_units, + ) else: w_treated = None w_control = None @@ -463,6 +500,7 @@ def fit( # type: ignore[override] time_weights, w_treated=w_treated, w_control=w_control, + resolved_survey=_unit_resolved_survey, ) placebo_effects = bootstrap_estimates inference_method = "bootstrap" @@ -613,6 +651,63 @@ def _residualize_covariates( return data + @staticmethod + def _build_unit_resolved_survey(data, unit_col, survey_design, control_units, treated_units): + """Build a unit-level ResolvedSurveyDesign for Rao-Wu bootstrap. + + Extracts one row per unit (survey columns are unit-constant) in + control-then-treated order matching the panel matrix columns. + """ + from diff_diff.linalg import _factorize_cluster_ids + from diff_diff.survey import ResolvedSurveyDesign + + all_units = list(control_units) + list(treated_units) + # Take first row per unit in the specified order + first_rows = data.groupby(unit_col).first().loc[all_units] + n_units = len(all_units) + + # Weights (normalized pweights, mean=1) + if survey_design.weights is not None: + raw_w = first_rows[survey_design.weights].values.astype(np.float64) + weights = raw_w * (n_units / np.sum(raw_w)) + else: + weights = np.ones(n_units, dtype=np.float64) + + # Strata + strata_arr = None + n_strata = 0 + if survey_design.strata is not None: + strata_arr = _factorize_cluster_ids(first_rows[survey_design.strata].values) + n_strata = len(np.unique(strata_arr)) + + # PSU + psu_arr = None + n_psu = 0 + if survey_design.psu is not None: + psu_raw = first_rows[survey_design.psu].values + if survey_design.nest and strata_arr is not None: + combined = np.array([f"{s}_{p}" for s, p in zip(strata_arr, psu_raw)]) + psu_arr = _factorize_cluster_ids(combined) + else: + psu_arr = _factorize_cluster_ids(psu_raw) + n_psu = len(np.unique(psu_arr)) + + # FPC + fpc_arr = None + if survey_design.fpc is not None: + fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64) + + return ResolvedSurveyDesign( + weights=weights, + weight_type=survey_design.weight_type, + strata=strata_arr, + psu=psu_arr, + fpc=fpc_arr, + n_strata=n_strata, + n_psu=n_psu, + lonely_psu=survey_design.lonely_psu, + ) + def _bootstrap_se( self, Y_pre_control: np.ndarray, @@ -623,6 +718,7 @@ def _bootstrap_se( time_weights: np.ndarray, w_treated=None, w_control=None, + resolved_survey=None, ) -> Tuple[float, np.ndarray]: """Compute bootstrap standard error matching R's synthdid bootstrap_sample. @@ -630,8 +726,16 @@ def _bootstrap_se( original unit weights for the resampled controls, and computes the SDID estimator with **fixed** weights (no re-estimation). + When ``resolved_survey`` is provided (unit-level ResolvedSurveyDesign + with strata/PSU/FPC), uses Rao-Wu rescaled bootstrap instead of the + simple pairs bootstrap. The Rao-Wu weights are per-unit rescaled + survey weights; they composite with SDID unit weights the same way + pweights do in the weights-only path. + This matches R's ``synthdid::vcov(method="bootstrap")``. """ + from diff_diff.bootstrap_utils import generate_rao_wu_weights + rng = np.random.default_rng(self.seed) n_control = Y_pre_control.shape[1] n_treated = Y_pre_treated.shape[1] @@ -641,65 +745,125 @@ def _bootstrap_se( Y_full = np.block([[Y_pre_control, Y_pre_treated], [Y_post_control, Y_post_treated]]) n_pre = Y_pre_control.shape[0] + # Determine whether to use Rao-Wu (full design) or pairs bootstrap + _use_rao_wu = resolved_survey is not None + bootstrap_estimates = [] for _ in range(self.n_bootstrap): - # Resample ALL units with replacement - boot_idx = rng.choice(n_total, size=n_total, replace=True) - - # Identify which resampled units are control vs treated - boot_is_control = boot_idx < n_control - boot_control_idx = boot_idx[boot_is_control] - boot_treated_idx = boot_idx[~boot_is_control] - - # Skip if no control or no treated units in bootstrap sample - if len(boot_control_idx) == 0 or len(boot_treated_idx) == 0: - continue - - try: - # Renormalize original unit weights for the resampled controls - boot_omega = _sum_normalize(unit_weights[boot_control_idx]) + if _use_rao_wu: + # --- Rao-Wu rescaled bootstrap path --- + # generate_rao_wu_weights returns per-unit rescaled survey + # weights (shape n_total). Units whose PSU was not drawn + # get weight 0, effectively dropping them. + try: + boot_rw = generate_rao_wu_weights(resolved_survey, rng) + + rw_control = boot_rw[:n_control] + rw_treated = boot_rw[n_control:] + + # Skip if all control or all treated weights are zero + if rw_control.sum() == 0 or rw_treated.sum() == 0: + continue + + # Composite SDID unit weights with Rao-Wu rescaled weights + boot_omega_eff = unit_weights * rw_control + if boot_omega_eff.sum() > 0: + boot_omega_eff = boot_omega_eff / boot_omega_eff.sum() + else: + continue + + # Treated mean weighted by Rao-Wu weights + Y_boot_pre_t_mean = np.average( + Y_pre_treated, + axis=1, + weights=rw_treated, + ) + Y_boot_post_t_mean = np.average( + Y_post_treated, + axis=1, + weights=rw_treated, + ) - # Compose with control survey weights if present - if w_control is not None: - boot_w_c = w_control[boot_idx[boot_is_control]] - boot_omega_eff = boot_omega * boot_w_c - boot_omega_eff = boot_omega_eff / boot_omega_eff.sum() - else: - boot_omega_eff = boot_omega - - # Extract resampled outcome matrices - Y_boot = Y_full[:, boot_idx] - Y_boot_pre_c = Y_boot[:n_pre, boot_is_control] - Y_boot_post_c = Y_boot[n_pre:, boot_is_control] - Y_boot_pre_t = Y_boot[:n_pre, ~boot_is_control] - Y_boot_post_t = Y_boot[n_pre:, ~boot_is_control] - - # Compute ATT with FIXED weights (do NOT re-estimate). - # boot_idx[~boot_is_control] maps to original index space; - # subtract n_control to index into w_treated. Duplicate draws - # carry identical weights → alignment is safe. - if w_treated is not None: - boot_w_t = w_treated[boot_idx[~boot_is_control] - n_control] - Y_boot_pre_t_mean = np.average(Y_boot_pre_t, axis=1, weights=boot_w_t) - Y_boot_post_t_mean = np.average(Y_boot_post_t, axis=1, weights=boot_w_t) - else: - Y_boot_pre_t_mean = np.mean(Y_boot_pre_t, axis=1) - Y_boot_post_t_mean = np.mean(Y_boot_post_t, axis=1) + tau = compute_sdid_estimator( + Y_pre_control, + Y_post_control, + Y_boot_pre_t_mean, + Y_boot_post_t_mean, + boot_omega_eff, + time_weights, + ) + if np.isfinite(tau): + bootstrap_estimates.append(tau) - tau = compute_sdid_estimator( - Y_boot_pre_c, - Y_boot_post_c, - Y_boot_pre_t_mean, - Y_boot_post_t_mean, - boot_omega_eff, - time_weights, - ) - if np.isfinite(tau): - bootstrap_estimates.append(tau) + except (ValueError, LinAlgError): + continue + else: + # --- Standard pairs bootstrap path (weights-only or no survey) --- + # Resample ALL units with replacement + boot_idx = rng.choice(n_total, size=n_total, replace=True) + + # Identify which resampled units are control vs treated + boot_is_control = boot_idx < n_control + boot_control_idx = boot_idx[boot_is_control] + boot_treated_idx = boot_idx[~boot_is_control] + + # Skip if no control or no treated units in bootstrap sample + if len(boot_control_idx) == 0 or len(boot_treated_idx) == 0: + continue + + try: + # Renormalize original unit weights for the resampled controls + boot_omega = _sum_normalize(unit_weights[boot_control_idx]) + + # Compose with control survey weights if present + if w_control is not None: + boot_w_c = w_control[boot_idx[boot_is_control]] + boot_omega_eff = boot_omega * boot_w_c + boot_omega_eff = boot_omega_eff / boot_omega_eff.sum() + else: + boot_omega_eff = boot_omega + + # Extract resampled outcome matrices + Y_boot = Y_full[:, boot_idx] + Y_boot_pre_c = Y_boot[:n_pre, boot_is_control] + Y_boot_post_c = Y_boot[n_pre:, boot_is_control] + Y_boot_pre_t = Y_boot[:n_pre, ~boot_is_control] + Y_boot_post_t = Y_boot[n_pre:, ~boot_is_control] + + # Compute ATT with FIXED weights (do NOT re-estimate). + # boot_idx[~boot_is_control] maps to original index space; + # subtract n_control to index into w_treated. Duplicate draws + # carry identical weights -> alignment is safe. + if w_treated is not None: + boot_w_t = w_treated[boot_idx[~boot_is_control] - n_control] + Y_boot_pre_t_mean = np.average( + Y_boot_pre_t, + axis=1, + weights=boot_w_t, + ) + Y_boot_post_t_mean = np.average( + Y_boot_post_t, + axis=1, + weights=boot_w_t, + ) + else: + Y_boot_pre_t_mean = np.mean(Y_boot_pre_t, axis=1) + Y_boot_post_t_mean = np.mean(Y_boot_post_t, axis=1) + + tau = compute_sdid_estimator( + Y_boot_pre_c, + Y_boot_post_c, + Y_boot_pre_t_mean, + Y_boot_post_t_mean, + boot_omega_eff, + time_weights, + ) + if np.isfinite(tau): + bootstrap_estimates.append(tau) - except (ValueError, LinAlgError): - continue + except (ValueError, LinAlgError): + continue bootstrap_estimates = np.array(bootstrap_estimates) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index b12b121..c9ab428 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -426,9 +426,10 @@ def fit( time : str Name of the time period column. survey_design : SurveyDesign, optional - Survey design specification. Only pweight designs are supported - (strata/PSU/FPC raise NotImplementedError). Survey weights enter - ATT aggregation only. + Survey design specification. Supports pweight, strata, PSU, and + FPC. Full-design surveys (strata/PSU/FPC) use Rao-Wu rescaled + bootstrap; Rust backend is pweight-only (Python fallback for + full design). Survey weights enter ATT aggregation only. Returns ------- @@ -443,8 +444,6 @@ def fit( ------ ValueError If required columns are missing or non-pweight survey design. - NotImplementedError - If survey_design includes strata, PSU, or FPC. """ # Validate inputs required_cols = [outcome, treatment, unit, time] @@ -455,7 +454,6 @@ def fit( # Resolve survey design from diff_diff.survey import ( _extract_unit_survey_weights, - _resolve_pweight_only, _resolve_survey_for_fit, _validate_unit_constant_survey, ) @@ -463,7 +461,13 @@ def fit( resolved_survey, _survey_weights, _survey_wt, survey_metadata = _resolve_survey_for_fit( survey_design, data, "analytical" ) - _resolve_pweight_only(resolved_survey, "TROP") + # Validate weight_type is pweight (keep restriction), but allow + # strata/PSU/FPC — those are handled via Rao-Wu rescaled bootstrap. + if resolved_survey is not None and resolved_survey.weight_type != "pweight": + raise ValueError( + "TROP requires pweight survey weights. " + f"Got weight_type='{resolved_survey.weight_type}'." + ) if resolved_survey is not None: _validate_unit_constant_survey(data, unit, survey_design) @@ -836,6 +840,7 @@ def fit( control_unit_idx=control_unit_idx, survey_design=survey_design, unit_weight_arr=unit_weight_arr, + resolved_survey=resolved_survey, ) # Compute test statistics @@ -951,7 +956,7 @@ def trop( time : str Time period column name. survey_design : SurveyDesign, optional - Survey design specification. Only pweight designs are supported. + Survey design specification. Supports pweight, strata, PSU, and FPC. **kwargs Additional arguments passed to TROP constructor. diff --git a/diff_diff/trop_global.py b/diff_diff/trop_global.py index c8a9097..5ce8134 100644 --- a/diff_diff/trop_global.py +++ b/diff_diff/trop_global.py @@ -794,6 +794,7 @@ def _fit_global( treated_periods, survey_design=survey_design, unit_weight_arr=unit_weight_arr, + resolved_survey=resolved_survey, ) # Compute test statistics @@ -845,11 +846,14 @@ def _bootstrap_variance_global( treated_periods: int, survey_design=None, unit_weight_arr: Optional[np.ndarray] = None, + resolved_survey=None, ) -> Tuple[float, np.ndarray]: """ Compute bootstrap standard error for global method. Uses Rust backend when available for parallel bootstrap (5-15x speedup). + When a full survey design (strata/PSU/FPC) is present, uses Rao-Wu + rescaled bootstrap instead, which skips the Rust path. Parameters ---------- @@ -867,6 +871,12 @@ def _bootstrap_variance_global( Optimal tuning parameters. treated_periods : int Number of post-treatment periods. + survey_design : SurveyDesign, optional + Survey design specification. + unit_weight_arr : np.ndarray, optional + Unit-level survey weights. + resolved_survey : ResolvedSurveyDesign, optional + Resolved survey design (observation-level). Returns ------- @@ -875,7 +885,29 @@ def _bootstrap_variance_global( """ lambda_time, lambda_unit, lambda_nn = optimal_lambda + # Check for full survey design (strata/PSU/FPC present) + _has_full_design = resolved_survey is not None and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + + # Full survey design: use Python Rao-Wu rescaled bootstrap + if _has_full_design: + return self._bootstrap_rao_wu_global( + data, + outcome, + treatment, + unit, + time, + optimal_lambda, + treated_periods, + resolved_survey, + survey_design, + ) + # Try Rust backend for parallel bootstrap (5-15x speedup) + # Only used for pweight-only designs (no strata/PSU/FPC) if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_global is not None: try: # Create matrices for Rust function @@ -985,6 +1017,174 @@ def _bootstrap_variance_global( se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates + def _bootstrap_rao_wu_global( + self, + data: pd.DataFrame, + outcome: str, + treatment: str, + unit: str, + time: str, + optimal_lambda: Tuple[float, float, float], + treated_periods: int, + resolved_survey, + survey_design, + ) -> Tuple[float, np.ndarray]: + """ + Rao-Wu rescaled bootstrap for global method with full survey design. + + Instead of physically resampling units, each iteration generates + rescaled observation weights via Rao-Wu (1988) weight perturbation. + Cross-classifies survey strata with treatment group to preserve + the stratified resampling structure. + + Parameters + ---------- + data : pd.DataFrame + Original data. + outcome, treatment, unit, time : str + Column names. + optimal_lambda : tuple + Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn). + treated_periods : int + Number of post-treatment periods. + resolved_survey : ResolvedSurveyDesign + Resolved survey design (observation-level). + survey_design : SurveyDesign + Original survey design specification. + + Returns + ------- + Tuple[float, np.ndarray] + (se, bootstrap_estimates). + """ + from diff_diff.bootstrap_utils import generate_rao_wu_weights + from diff_diff.survey import ResolvedSurveyDesign + + lambda_time, lambda_unit, lambda_nn = optimal_lambda + rng = np.random.default_rng(self.seed) + + # Build unit-level resolved survey with cross-classified strata + all_units = sorted(data[unit].unique()) + n_units = len(all_units) + + # Determine treatment status per unit + unit_ever_treated = data.groupby(unit)[treatment].max() + treatment_group = np.array([int(unit_ever_treated[u]) for u in all_units], dtype=np.int64) + + # Extract unit-level survey design fields + first_rows = data.groupby(unit).first().loc[all_units] + + # Weights (unit-level) + if survey_design.weights is not None: + unit_weights = first_rows[survey_design.weights].values.astype(np.float64) + else: + unit_weights = np.ones(n_units, dtype=np.float64) + + # Strata: cross-classify survey strata x treatment group + from diff_diff.linalg import _factorize_cluster_ids + + if survey_design.strata is not None: + survey_strata = first_rows[survey_design.strata].values + cross_labels = np.array([f"{s}_{g}" for s, g in zip(survey_strata, treatment_group)]) + cross_strata = _factorize_cluster_ids(cross_labels) + else: + # No survey strata: use treatment group as strata + cross_strata = treatment_group.copy() + n_strata = len(np.unique(cross_strata)) + + # PSU (unit-level) + psu_arr = None + n_psu = 0 + if survey_design.psu is not None: + psu_raw = first_rows[survey_design.psu].values + if survey_design.nest and survey_design.strata is not None: + combined = np.array([f"{s}_{p}" for s, p in zip(cross_strata, psu_raw)]) + psu_arr = _factorize_cluster_ids(combined) + else: + psu_arr = _factorize_cluster_ids(psu_raw) + n_psu = len(np.unique(psu_arr)) + else: + # Implicit PSU: each unit is its own PSU + psu_arr = np.arange(n_units, dtype=np.int64) + n_psu = n_units + + # FPC (unit-level) + fpc_arr = None + if survey_design.fpc is not None: + fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64) + + unit_resolved = ResolvedSurveyDesign( + weights=unit_weights, + weight_type=resolved_survey.weight_type, + strata=cross_strata, + psu=psu_arr, + fpc=fpc_arr, + n_strata=n_strata, + n_psu=n_psu, + lonely_psu=resolved_survey.lonely_psu, + ) + + # Bootstrap loop with Rao-Wu rescaled weights + all_periods = sorted(data[time].unique()) + n_periods = len(all_periods) + + Y = ( + data.pivot(index=time, columns=unit, values=outcome) + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + data.pivot(index=time, columns=unit, values=treatment) + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + bootstrap_estimates_list: List[float] = [] + + for _ in range(self.n_bootstrap): + try: + # Generate Rao-Wu rescaled weights (unit-level) + boot_weights = generate_rao_wu_weights(unit_resolved, rng) + + # Skip if all control or all treated weights are zero + control_mask_units = treatment_group == 0 + treated_mask_units = treatment_group == 1 + if boot_weights[control_mask_units].sum() == 0: + continue + if boot_weights[treated_mask_units].sum() == 0: + continue + + # Compute global weights and fit model + delta = self._compute_global_weights( + Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods + ) + mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) + + # Extract weighted ATT using Rao-Wu rescaled weights + att, _, _ = self._extract_posthoc_tau( + Y, D, mu, alpha, beta, L, unit_weights=boot_weights + ) + + if np.isfinite(att): + bootstrap_estimates_list.append(att) + except (ValueError, np.linalg.LinAlgError, KeyError): + continue + + bootstrap_estimates = np.array(bootstrap_estimates_list) + + if len(bootstrap_estimates) < 10: + warnings.warn( + f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", + UserWarning, + ) + if len(bootstrap_estimates) == 0: + return np.nan, np.array([]) + + se = np.std(bootstrap_estimates, ddof=1) + return float(se), bootstrap_estimates + def _fit_global_with_fixed_lambda( self, data: pd.DataFrame, diff --git a/diff_diff/trop_local.py b/diff_diff/trop_local.py index dd1398c..37e35ff 100644 --- a/diff_diff/trop_local.py +++ b/diff_diff/trop_local.py @@ -810,6 +810,7 @@ def _bootstrap_variance( control_unit_idx: Optional[np.ndarray] = None, survey_design=None, unit_weight_arr: Optional[np.ndarray] = None, + resolved_survey=None, ) -> Tuple[float, np.ndarray]: """ Compute bootstrap standard error using unit-level block bootstrap. @@ -819,6 +820,9 @@ def _bootstrap_variance( implementation for 5-15x speedup. Falls back to Python implementation if Rust is unavailable or if matrix parameters are not provided. + When a full survey design (strata/PSU/FPC) is present, uses Rao-Wu + rescaled bootstrap instead, which skips the Rust path. + Parameters ---------- data : pd.DataFrame @@ -844,6 +848,12 @@ def _bootstrap_variance( control_unit_idx : np.ndarray, optional Array of indices for control units (never-treated). Required for Rust backend acceleration. + survey_design : SurveyDesign, optional + Survey design specification. + unit_weight_arr : np.ndarray, optional + Unit-level survey weights. + resolved_survey : ResolvedSurveyDesign, optional + Resolved survey design (observation-level). Returns ------- @@ -861,7 +871,28 @@ def _bootstrap_variance( """ lambda_time, lambda_unit, lambda_nn = optimal_lambda + # Check for full survey design (strata/PSU/FPC present) + _has_full_design = resolved_survey is not None and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + + # Full survey design: use Python Rao-Wu rescaled bootstrap + if _has_full_design: + return self._bootstrap_rao_wu_local( + data, + outcome, + treatment, + unit, + time, + optimal_lambda, + resolved_survey, + survey_design, + ) + # Try Rust backend for parallel bootstrap (5-15x speedup) + # Only used for pweight-only designs (no strata/PSU/FPC) if ( HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None @@ -968,6 +999,197 @@ def _bootstrap_variance( se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates + def _bootstrap_rao_wu_local( + self, + data: pd.DataFrame, + outcome: str, + treatment: str, + unit: str, + time: str, + optimal_lambda: Tuple[float, float, float], + resolved_survey, + survey_design, + ) -> Tuple[float, np.ndarray]: + """ + Rao-Wu rescaled bootstrap for local method with full survey design. + + Instead of physically resampling units, each iteration generates + rescaled observation weights via Rao-Wu (1988) weight perturbation. + Cross-classifies survey strata with treatment group to preserve + the stratified resampling structure. + + Parameters + ---------- + data : pd.DataFrame + Original data. + outcome, treatment, unit, time : str + Column names. + optimal_lambda : tuple + Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn). + resolved_survey : ResolvedSurveyDesign + Resolved survey design (observation-level). + survey_design : SurveyDesign + Original survey design specification. + + Returns + ------- + Tuple[float, np.ndarray] + (se, bootstrap_estimates). + """ + import warnings + + from diff_diff.bootstrap_utils import generate_rao_wu_weights + from diff_diff.linalg import _factorize_cluster_ids + from diff_diff.survey import ResolvedSurveyDesign + + lambda_time, lambda_unit, lambda_nn = optimal_lambda + rng = np.random.default_rng(self.seed) + + # Build unit-level resolved survey with cross-classified strata + all_units = sorted(data[unit].unique()) + all_periods = sorted(data[time].unique()) + n_units = len(all_units) + n_periods = len(all_periods) + + # Determine treatment status per unit + unit_ever_treated = data.groupby(unit)[treatment].max() + treatment_group = np.array([int(unit_ever_treated[u]) for u in all_units], dtype=np.int64) + + # Extract unit-level survey design fields + first_rows = data.groupby(unit).first().loc[all_units] + + # Weights (unit-level) + if survey_design.weights is not None: + unit_weights = first_rows[survey_design.weights].values.astype(np.float64) + else: + unit_weights = np.ones(n_units, dtype=np.float64) + + # Strata: cross-classify survey strata x treatment group + if survey_design.strata is not None: + survey_strata = first_rows[survey_design.strata].values + cross_labels = np.array([f"{s}_{g}" for s, g in zip(survey_strata, treatment_group)]) + cross_strata = _factorize_cluster_ids(cross_labels) + else: + # No survey strata: use treatment group as strata + cross_strata = treatment_group.copy() + n_strata = len(np.unique(cross_strata)) + + # PSU (unit-level) + psu_arr = None + n_psu = 0 + if survey_design.psu is not None: + psu_raw = first_rows[survey_design.psu].values + if survey_design.nest and survey_design.strata is not None: + combined = np.array([f"{s}_{p}" for s, p in zip(cross_strata, psu_raw)]) + psu_arr = _factorize_cluster_ids(combined) + else: + psu_arr = _factorize_cluster_ids(psu_raw) + n_psu = len(np.unique(psu_arr)) + else: + # Implicit PSU: each unit is its own PSU + psu_arr = np.arange(n_units, dtype=np.int64) + n_psu = n_units + + # FPC (unit-level) + fpc_arr = None + if survey_design.fpc is not None: + fpc_arr = first_rows[survey_design.fpc].values.astype(np.float64) + + unit_resolved = ResolvedSurveyDesign( + weights=unit_weights, + weight_type=resolved_survey.weight_type, + strata=cross_strata, + psu=psu_arr, + fpc=fpc_arr, + n_strata=n_strata, + n_psu=n_psu, + lonely_psu=resolved_survey.lonely_psu, + ) + + # Setup matrices (same as _fit_with_fixed_lambda) + Y = ( + data.pivot(index=time, columns=unit, values=outcome) + .reindex(index=all_periods, columns=all_units) + .values + ) + D = ( + data.pivot(index=time, columns=unit, values=treatment) + .reindex(index=all_periods, columns=all_units) + .fillna(0) + .astype(int) + .values + ) + + control_mask = D == 0 + unit_ever_treated_arr = np.any(D == 1, axis=0) + control_unit_idx = np.where(~unit_ever_treated_arr)[0] + + # Get list of treated observations + treated_observations = [ + (t, i) for t in range(n_periods) for i in range(n_units) if D[t, i] == 1 + ] + + if not treated_observations: + return np.nan, np.array([]) + + # Pre-compute per-observation tau values (fixed across bootstrap) + # The model fit is deterministic; only the ATT aggregation weights vary. + tau_per_obs = [] # (tau_value, unit_idx) pairs + for t, i in treated_observations: + if not np.isfinite(Y[t, i]): + continue + + weight_matrix = self._compute_observation_weights( + Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods + ) + alpha, beta, L = self._estimate_model( + Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods + ) + tau = Y[t, i] - alpha[i] - beta[t] - L[t, i] + tau_per_obs.append((tau, i)) + + if not tau_per_obs: + return np.nan, np.array([]) + + tau_values = np.array([tp[0] for tp in tau_per_obs]) + tau_unit_indices = np.array([tp[1] for tp in tau_per_obs]) + + # Bootstrap loop with Rao-Wu rescaled weights + bootstrap_estimates_list = [] + + for _ in range(self.n_bootstrap): + try: + # Generate Rao-Wu rescaled weights (unit-level) + boot_weights = generate_rao_wu_weights(unit_resolved, rng) + + # Map unit-level weights to per-observation weights + obs_weights = boot_weights[tau_unit_indices] + + # Skip if all weights are zero + if obs_weights.sum() == 0: + continue + + att = float(np.average(tau_values, weights=obs_weights)) + + if np.isfinite(att): + bootstrap_estimates_list.append(att) + except (ValueError, np.linalg.LinAlgError, KeyError): + continue + + bootstrap_estimates = np.array(bootstrap_estimates_list) + + if len(bootstrap_estimates) < 10: + warnings.warn( + f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. " + "Standard errors may be unreliable.", + UserWarning, + ) + if len(bootstrap_estimates) == 0: + return np.nan, np.array([]) + + se = np.std(bootstrap_estimates, ddof=1) + return float(se), bootstrap_estimates + def _fit_with_fixed_lambda( self, data: pd.DataFrame, diff --git a/diff_diff/twfe.py b/diff_diff/twfe.py index 220a267..0b49266 100644 --- a/diff_diff/twfe.py +++ b/diff_diff/twfe.py @@ -186,9 +186,14 @@ def fit( # type: ignore[override] # Inject cluster as effective PSU for survey variance estimation if resolved_survey is not None and survey_cluster_ids is not None: from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata + resolved_survey = _inject_cluster_as_psu(resolved_survey, survey_cluster_ids) if resolved_survey.psu is not None and survey_metadata is not None: - raw_w = data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64) + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) survey_metadata = compute_survey_metadata(resolved_survey, raw_w) # Pass rank_deficient_action to LinearRegression diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 20d33e2..1ae1d40 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -256,12 +256,7 @@ def fit( "and PSU (for cluster-robust variance) are supported." ) - # Guard bootstrap + survey - if self.n_bootstrap > 0 and resolved_survey is not None: - raise NotImplementedError( - "Bootstrap inference with survey weights is not yet supported " - "for TwoStageDiD. Use analytical inference (n_bootstrap=0)." - ) + # Bootstrap + survey supported via PSU-level multiplier bootstrap. df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) @@ -588,6 +583,7 @@ def fit( original_event_study=event_study_effects, original_group=group_effects, aggregate=aggregate, + resolved_survey=resolved_survey, ) except Exception as e: warnings.warn( diff --git a/diff_diff/two_stage_bootstrap.py b/diff_diff/two_stage_bootstrap.py index 5689962..f4d6981 100644 --- a/diff_diff/two_stage_bootstrap.py +++ b/diff_diff/two_stage_bootstrap.py @@ -19,6 +19,9 @@ from diff_diff.bootstrap_utils import ( generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch, ) +from diff_diff.bootstrap_utils import ( + generate_survey_multiplier_weights_batch as _generate_survey_multiplier_weights_batch, +) from diff_diff.linalg import solve_ols from diff_diff.two_stage_results import TwoStageBootstrapResults @@ -51,9 +54,7 @@ def _build_fe_design( time: str, covariates: Optional[List[str]], omega_0_mask: pd.Series, - ) -> Tuple[ - "sparse.csr_matrix", "sparse.csr_matrix", Dict[Any, int], Dict[Any, int] - ]: ... + ) -> Tuple["sparse.csr_matrix", "sparse.csr_matrix", Dict[Any, int], Dict[Any, int]]: ... @staticmethod def _compute_gmm_scores( @@ -76,6 +77,7 @@ def _compute_cluster_S_scores( X_2: np.ndarray, eps_2: np.ndarray, cluster_ids: np.ndarray, + survey_weights: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute per-cluster S_g scores for bootstrap. @@ -121,9 +123,13 @@ def _compute_cluster_S_scores( eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] eps_10[~omega_0] = y_vals[~omega_0] - # gamma_hat - XtX_10 = X_10_sparse.T @ X_10_sparse - Xt1_X2 = X_1_sparse.T @ X_2 + # gamma_hat — with survey weights, both cross-products need W + if survey_weights is not None: + XtX_10 = X_10_sparse.T @ X_10_sparse.multiply(survey_weights[:, None]) + Xt1_X2 = X_1_sparse.T @ (X_2 * survey_weights[:, None]) + else: + XtX_10 = X_10_sparse.T @ X_10_sparse + Xt1_X2 = X_1_sparse.T @ X_2 try: solve_XtX = sparse_factorized(XtX_10.tocsc()) @@ -138,8 +144,12 @@ def _compute_cluster_S_scores( if gamma_hat.ndim == 1: gamma_hat = gamma_hat.reshape(-1, 1) - # Per-cluster aggregation - weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) + # Per-cluster aggregation — survey weights multiply eps_10 before sparse multiply + if survey_weights is not None: + weighted_eps_10 = survey_weights * eps_10 + else: + weighted_eps_10 = eps_10 + weighted_X10 = X_10_sparse.multiply(weighted_eps_10[:, None]) unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) G = len(unique_clusters) @@ -157,15 +167,23 @@ def _compute_cluster_S_scores( for j_col in range(p): np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col]) - weighted_X2 = X_2 * eps_2[:, None] + if survey_weights is not None: + weighted_eps_2 = survey_weights * eps_2 + else: + weighted_eps_2 = eps_2 + weighted_X2 = X_2 * weighted_eps_2[:, None] s2_by_cluster = np.zeros((G, k)) for j_col in range(k): np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) S = self._compute_gmm_scores(c_by_cluster, gamma_hat, s2_by_cluster) - # Bread - XtX_2 = np.dot(X_2.T, X_2) + # Bread — (X'_2 W X_2)^{-1} with survey weights + with np.errstate(invalid="ignore", over="ignore", divide="ignore"): + if survey_weights is not None: + XtX_2 = X_2.T @ (X_2 * survey_weights[:, None]) + else: + XtX_2 = np.dot(X_2.T, X_2) try: bread = np.linalg.solve(XtX_2, np.eye(k)) except np.linalg.LinAlgError: @@ -195,6 +213,7 @@ def _run_bootstrap( original_event_study: Optional[Dict[int, Dict[str, Any]]], original_group: Optional[Dict[Any, Dict[str, Any]]], aggregate: Optional[str], + resolved_survey: Optional[Any] = None, ) -> Optional[TwoStageBootstrapResults]: """Run multiplier bootstrap on GMM influence function.""" if self.n_bootstrap < 50: @@ -211,6 +230,13 @@ def _run_bootstrap( n = len(df) cluster_ids = df[cluster_var].values + # Extract survey weights for S-score computation and Stage-2 WLS + survey_weights: Optional[np.ndarray] = None + survey_weight_type: str = "pweight" + if resolved_survey is not None: + survey_weights = resolved_survey.weights + survey_weight_type = resolved_survey.weight_type + # Handle NaN y_tilde (from unidentified FEs) — matches _stage2_static logic nan_mask = ~np.isfinite(y_tilde) if nan_mask.any(): @@ -225,7 +251,10 @@ def _run_bootstrap( return None X_2_static = D.reshape(-1, 1) - coef_static = solve_ols(X_2_static, y_tilde, return_vcov=False)[0] + coef_static = solve_ols( + X_2_static, y_tilde, return_vcov=False, + weights=survey_weights, weight_type=survey_weight_type, + )[0] eps_2_static = y_tilde - np.dot(X_2_static, coef_static) S_static, bread_static, unique_clusters = self._compute_cluster_S_scores( @@ -241,13 +270,33 @@ def _run_bootstrap( X_2=X_2_static, eps_2=eps_2_static, cluster_ids=cluster_ids, + survey_weights=survey_weights, ) n_clusters = len(unique_clusters) - all_weights = _generate_bootstrap_weights_batch( - self.n_bootstrap, n_clusters, self.bootstrap_weights, rng + + # Generate bootstrap weights — PSU-level when survey design is present + _use_survey_bootstrap = resolved_survey is not None and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None ) + if _use_survey_bootstrap: + psu_weights, psu_ids = _generate_survey_multiplier_weights_batch( + self.n_bootstrap, resolved_survey, self.bootstrap_weights, rng + ) + # Map unique_clusters (PSU values) to PSU weight columns. + # When survey+PSU is active, cluster_var == "_survey_cluster" so + # unique_clusters are the PSU ids used in S-score aggregation. + psu_id_to_col = {int(p): c for c, p in enumerate(psu_ids)} + cluster_to_psu_col = np.array([psu_id_to_col[int(cl)] for cl in unique_clusters]) + all_weights = psu_weights[:, cluster_to_psu_col] + else: + all_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_clusters, self.bootstrap_weights, rng + ) + # T_b = bread @ (sum_g w_bg * S_g) = bread @ (W @ S)' per boot # IF_b = bread @ S_g for each cluster, then perturb # boot_coef = all_weights @ S_static @ bread_static.T -> (B, k) @@ -318,7 +367,10 @@ def _run_bootstrap( if h_int in horizon_to_col: X_2_es[i, horizon_to_col[h_int]] = 1.0 - coef_es = solve_ols(X_2_es, y_tilde, return_vcov=False)[0] + coef_es = solve_ols( + X_2_es, y_tilde, return_vcov=False, + weights=survey_weights, weight_type=survey_weight_type, + )[0] eps_2_es = y_tilde - np.dot(X_2_es, coef_es) S_es, bread_es, _ = self._compute_cluster_S_scores( @@ -334,6 +386,7 @@ def _run_bootstrap( X_2=X_2_es, eps_2=eps_2_es, cluster_ids=cluster_ids, + survey_weights=survey_weights, ) # boot_coef_es: (B, k_es) @@ -382,7 +435,10 @@ def _run_bootstrap( if g in group_to_col: X_2_grp[i, group_to_col[g]] = 1.0 - coef_grp = solve_ols(X_2_grp, y_tilde, return_vcov=False)[0] + coef_grp = solve_ols( + X_2_grp, y_tilde, return_vcov=False, + weights=survey_weights, weight_type=survey_weight_type, + )[0] eps_2_grp = y_tilde - np.dot(X_2_grp, coef_grp) S_grp, bread_grp, _ = self._compute_cluster_S_scores( @@ -398,6 +454,7 @@ def _run_bootstrap( X_2=X_2_grp, eps_2=eps_2_grp, cluster_ids=cluster_ids, + survey_weights=survey_weights, ) boot_coef_grp = np.dot(np.dot(all_weights, S_grp), bread_grp.T) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 24b17f7..3a8e3aa 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,9 +416,9 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) -- **Note:** CallawaySantAnna survey support: weights-only (strata/PSU/FPC raise NotImplementedError — full design-based SEs via compute_survey_vcov not yet implemented). Regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. Bootstrap + survey deferred. +- **Note:** CallawaySantAnna survey support: weights, strata, PSU, and FPC are all supported. Analytical SEs use influence-function-based variance; when strata/PSU/FPC are present, aggregated SEs use design-based variance via `compute_survey_if_variance()`. Bootstrap SEs use PSU-level multiplier weights. Regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. - **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SE uses a conservative plug-in IF based on WLS residuals. The treated IF is `inf_treated_i = (sw_i/sum(sw_treated)) * (resid_i - ATT)` (normalized by treated weight sum, matching unweighted `(resid-ATT)/n_t`). The control IF is `inf_control_i = -(sw_i/sum(sw_control)) * wls_resid_i` (normalized by control weight sum, matching unweighted `-resid/n_c`). SE is computed as `sqrt(sum(sw_t_norm * (resid_t - ATT)^2) + sum(sw_c_norm * resid_c^2))`, the weighted analogue of the unweighted `sqrt(var_t/n_t + var_c/n_c)`. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel` — WLS residuals are orthogonal to the weighted design matrix by construction, so the first-order IF term is asymptotically valid but may be conservative. SEs pass weight-scale-invariance tests. The efficient DRDID correction is deferred to future work. -- **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization. Strata/PSU/FPC are rejected at runtime until full design-based SEs via `compute_survey_vcov()` on the combined IF/WIF are implemented. +- **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization. When strata/PSU/FPC are present, aggregated SEs are computed via PSU-level multiplier bootstrap (see bootstrap + survey note above) rather than analytical Taylor-series linearization on the combined IF/WIF. **Reference implementation(s):** - R: `did::att_gt()` (Callaway & Sant'Anna's official package) @@ -497,8 +497,8 @@ See `docs/methodology/continuous-did.md` Section 4 for full details. - [ ] Covariate support (deferred, matching R v0.1.0) - [ ] Discrete treatment saturated regression - [ ] Lowest-dose-as-control (Remark 3.1) -- [x] Survey design support (Phase 3): weighted B-spline OLS, TSL on influence functions; bootstrap+survey deferred -- **Note:** ContinuousDiD bootstrap with survey weights deferred to Phase 5 +- [x] Survey design support (Phase 3): weighted B-spline OLS, TSL on influence functions; bootstrap+survey supported (Phase 6) +- **Note:** ContinuousDiD bootstrap with survey weights supported (Phase 6) via PSU-level multiplier weights --- @@ -675,12 +675,12 @@ where `q_{g,e} = pi_g / sum_{g' in G_{trt,e}} pi_{g'}`. - [x] Each ATT(g,t) can be estimated independently (parallelizable) - [x] Absorbing treatment validation - [x] Overlap diagnostics for propensity score ratios -- [x] Survey design support (Phase 3): survey-weighted means/covariances in Omega*, TSL on EIF scores; bootstrap+survey deferred +- [x] Survey design support (Phase 3): survey-weighted means/covariances in Omega*, TSL on EIF scores; bootstrap+survey supported (Phase 6) - **Note:** Sieve ratio estimation uses polynomial basis functions (total degree up to K) with AIC/BIC model selection. The paper describes sieve estimators generally without specifying a particular basis family; polynomial sieves are a standard choice (Section 4, Eq 4.2). Negative sieve ratio predictions are clipped to a small positive value since the population ratio p_g(X)/p_{g'}(X) is non-negative. - **Note:** Kernel-smoothed conditional covariance Omega*(X) uses Gaussian kernel with Silverman's rule-of-thumb bandwidth by default. The paper specifies kernel smoothing (step 5, Section 4) without mandating a particular kernel or bandwidth selection method. - **Note:** Conditional covariance Omega*(X) scales each term by per-unit sieve-estimated inverse propensities s_hat_{g'}(X) = 1/p_{g'}(X) (algorithm step 4), matching Eq 3.12. The inverse propensity estimation uses the same polynomial sieve convex minimization as the ratio estimator. Estimated s_hat values are clipped to [1, n] with a UserWarning when clipping binds, mirroring the ratio path's overlap diagnostics. - **Note:** Outcome regressions m_hat_{g',t,tpre}(X) use linear OLS working models. The paper's Section 4 describes flexible nonparametric nuisance estimation (sieve regression, kernel smoothing, or ML methods). The DR property ensures consistency if either the OLS outcome model or the sieve propensity ratio is correctly specified, but the linear OLS specification does not generically guarantee attainment of the semiparametric efficiency bound unless the conditional mean is linear in the covariates. -- **Note:** EfficientDiD bootstrap with survey weights deferred to Phase 5 +- **Note:** EfficientDiD bootstrap with survey weights supported (Phase 6) via PSU-level multiplier weights - **Note:** EfficientDiD covariates (DR path) with survey weights deferred — the doubly robust nuisance estimation does not yet thread survey weights through sieve/kernel steps - **Note:** Cluster-robust SEs use the standard Liang-Zeger clustered sandwich estimator applied to EIF values: aggregate EIF within clusters, center, and compute variance with G/(G-1) small-sample correction. Cluster bootstrap generates multiplier weights at the cluster level (all units in a cluster share the same weight). Analytical clustered SEs are the default when `cluster` is set; cluster bootstrap is opt-in via `n_bootstrap > 0`. - **Note:** Hausman pretest operates on the post-treatment event-study vector ES(e) per Theorem A.1. Both PT-All and PT-Post fits are aggregated to ES(e) using cohort-size weights before computing the test statistic H = delta' V^{-1} delta where delta = ES_post - ES_all and V = Cov(ES_post) - Cov(ES_all). Covariance is computed from aggregated ES(e)-level EIF values. The variance-difference matrix V is inverted via Moore-Penrose pseudoinverse to handle finite-sample non-positive-definiteness. Effective rank of V (number of positive eigenvalues) is used as degrees of freedom. @@ -755,7 +755,7 @@ where weights ŵ_{g,e} = n_{g,e} / Σ_g n_{g,e} (sample share of cohort g at eve - [x] R comparison: ATT matches within machine precision (<1e-11) - [x] R comparison: SE matches within 0.3% (well within 1% threshold) - [x] R comparison: Event study effects match perfectly (correlation 1.0) -- [x] Survey design support (Phase 3): weighted within-transform, survey weights in LinearRegression with TSL vcov; bootstrap+survey deferred +- [x] Survey design support (Phase 3): weighted within-transform, survey weights in LinearRegression with TSL vcov; bootstrap+survey supported (Phase 6) via Rao-Wu rescaled bootstrap --- @@ -843,7 +843,7 @@ Y_it = alpha_i + beta_t [+ X'_it * delta] + W'_it * gamma + epsilon_it - **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand (unweighted), or normalized survey weights `sw_i/sum(sw)` when `survey_design` is active. - **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection. - **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails. -- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and survey-weighted conservative variance (Theorem 3). PSU is used as the cluster variable for Theorem 3 variance. Strata enters survey df (n_PSU - n_strata) for t-distribution inference. FPC is not supported (raises NotImplementedError). Strata does NOT enter the variance formula itself (no stratified sandwich) — this is conservative relative to stratified variance. Bootstrap + survey deferred. +- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and survey-weighted conservative variance (Theorem 3). PSU is used as the cluster variable for Theorem 3 variance. Strata enters survey df (n_PSU - n_strata) for t-distribution inference. FPC is not supported (raises NotImplementedError). Strata does NOT enter the variance formula itself (no stratified sandwich) — this is conservative relative to stratified variance. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights. - **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns. - **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean. @@ -921,7 +921,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **No never-treated units (Proposition 5):** When there are no never-treated units and multiple treatment cohorts, horizons h >= h_bar (where h_bar = max(groups) - min(groups)) are unidentified per Proposition 5 of Borusyak et al. (2024). These produce NaN inference with n_obs > 0 (treated observations exist but counterfactual is unidentified) and a warning listing affected horizons. Matches ImputationDiD behavior. Proposition 5 applies to event study horizons only, not cohort aggregation — a cohort whose treated obs all fall at Prop 5 horizons naturally gets n_obs=0 in group effects because all its y_tilde values are NaN. - **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. - **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0. -- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU is used as the cluster variable for GMM variance. Strata enters survey df (n_PSU - n_strata) for t-distribution inference. FPC is not supported (raises NotImplementedError). Strata does NOT enter the variance formula itself (no stratified sandwich) — this is conservative. Bootstrap + survey deferred. +- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU is used as the cluster variable for GMM variance. Strata enters survey df (n_PSU - n_strata) for t-distribution inference. FPC is not supported (raises NotImplementedError). Strata does NOT enter the variance formula itself (no stratified sandwich) — this is conservative. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights. **Reference implementation(s):** - R: `did2s::did2s()` (Kyle Butts & John Gardner) @@ -1143,7 +1143,7 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi - **Varying treatment within unit**: Raises `ValueError`. SDID requires block treatment (constant within each unit). Suggests CallawaySantAnna or ImputationDiD for staggered adoption. - **Unbalanced panel**: Raises `ValueError`. SDID requires all units observed in all periods. Suggests `balance_panel()`. - **Poor pre-treatment fit**: Warns (`UserWarning`) when `pre_fit_rmse > std(treated_pre_outcomes, ddof=1)`. Diagnostic only; estimation proceeds. -- **Note:** Survey support: pweight only (strata/PSU/FPC raise NotImplementedError). Both sides weighted per WLS regression interpretation: treated-side means are survey-weighted (Frank-Wolfe target and ATT formula); control-side synthetic weights are composed with survey weights post-optimization (ω_eff = ω * w_co, renormalized). Frank-Wolfe optimization itself is unweighted — survey importance enters after trajectory-matching. Covariate residualization uses WLS with survey weights. Placebo and bootstrap SE preserve survey weights on both sides. +- **Note:** Survey support: weights, strata, PSU, and FPC are all supported. Full-design surveys use Rao-Wu rescaled bootstrap (Phase 6); `variance_method="placebo"` requires weights-only (strata/PSU/FPC require bootstrap). Both sides weighted per WLS regression interpretation: treated-side means are survey-weighted (Frank-Wolfe target and ATT formula); control-side synthetic weights are composed with survey weights post-optimization (ω_eff = ω * w_co, renormalized). Frank-Wolfe optimization itself is unweighted — survey importance enters after trajectory-matching. Covariate residualization uses WLS with survey weights. Placebo and bootstrap SE preserve survey weights on both sides. **Reference implementation(s):** - R: `synthdid::synthdid_estimate()` (Arkhangelsky et al.'s official package) @@ -1397,7 +1397,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - [x] No post_periods parameter (D matrix determines treatment timing) - [x] D matrix semantics documented (absorbing state, not event indicator) - [x] Unbalanced panels supported (missing observations don't trigger false violations) -- **Note:** Survey support: pweight only (strata/PSU/FPC raise NotImplementedError). Survey weights enter ATT aggregation only — population-weighted average of per-observation treatment effects. Model fitting (kernel weights, LOOCV, nuclear norm regularization) stays unchanged. Rust and Python bootstrap paths both support survey-weighted ATT in each iteration. +- **Note:** Survey support: weights, strata, PSU, and FPC are all supported via Rao-Wu rescaled bootstrap with cross-classified pseudo-strata (Phase 6). Rust backend remains pweight-only; full-design surveys fall back to the Python bootstrap path. Survey weights enter ATT aggregation only — population-weighted average of per-observation treatment effects. Model fitting (kernel weights, LOOCV, nuclear norm regularization) stays unchanged. Rust and Python bootstrap paths both support survey-weighted ATT in each iteration. ### TROP Global Estimation Method @@ -1938,9 +1938,51 @@ unequal selection probabilities). Linearization variance estimation, matching the R `survey` package convention that clusters are the primary sampling units. +### Survey-Aware Bootstrap (Phase 6) + +Two strategies for bootstrap variance under complex survey designs: + +**Multiplier Bootstrap at PSU Level** (CallawaySantAnna, ImputationDiD, TwoStageDiD, +ContinuousDiD, EfficientDiD): + +- **Reference**: Standard Taylor linearization bootstrap (Shao 2003, "Impact of the + Bootstrap on Sample Surveys", Statistical Science 18(2)) +- **Formula**: Generate multiplier weights independently within strata at the PSU level. + Scale by `sqrt(1 - f_h)` for FPC. Perturbation: + `ATT_boot[b] = ATT + w_b^T @ psi_psu` where `psi_psu` are PSU-aggregated IF sums. +- **Note:** When no strata/PSU/FPC, degenerates to standard unit-level multiplier bootstrap. + +**Rao-Wu Rescaled Bootstrap** (SunAbraham, SyntheticDiD, TROP): + +- **Reference**: Rao & Wu (1988) "Resampling Inference with Complex Survey Data", + JASA 83(401); Rao, Wu & Yue (1992) "Some Recent Work on Resampling Methods for + Complex Surveys", Survey Methodology 18(2), Section 3. +- **Formula**: Within each stratum *h* with *n_h* PSUs, draw `m_h` PSUs with replacement. + Without FPC: `m_h = n_h - 1`. With FPC: `m_h = max(1, round((1 - f_h) * (n_h - 1)))`. + Rescaled weight: `w*_i = w_i * (n_h / m_h) * r_hi` where `r_hi` = count of PSU *i* drawn. +- **Note:** FPC enters through the resample size `m_h`, not as a post-hoc scaling factor. +- **Deviation from R:** R `survey::as.svrepdesign(type="subbootstrap")` uses the same + formula. Our implementation matches. + +**CallawaySantAnna Design-Based Aggregated SEs**: + +- **Formula**: `V_design = sum_h (1-f_h) * (n_h/(n_h-1)) * sum_j (psi_hj - psi_h_bar)^2` + where `psi_hj = sum_{i in PSU j} psi_i` and `psi_i` is the combined IF (standard + WIF). +- **Note:** Per-(g,t) cell SEs use the simpler IF-based formula `sqrt(sum(psi^2))` which + already incorporates survey weights. Only aggregated SEs (overall, event study, group) + use the full design-based variance. + +**TROP Cross-Classified Strata**: + +- **Note (deviation from R):** When survey strata and treatment groups both exist, TROP + creates pseudo-strata as `(survey_stratum x treatment_group)` for Rao-Wu resampling. + This preserves both survey variance structure and treatment ratio. Survey df computed + from pseudo-strata structure. + --- # Version History +- **v1.2** (2026-03-24): Added Survey-Aware Bootstrap section (Phase 6) - **v1.1** (2026-03-20): Added Survey Data Support section - **v1.0** (2025-01-19): Initial registry with 12 estimators diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index 289120e..ab95fac 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -79,10 +79,16 @@ TripleDifference IPW/DR from Phase 3 deferred work. ## Phase 6: Advanced Features -### Bootstrap + Survey Interaction -Unblock bootstrap + survey for all estimators that currently defer it -(ImputationDiD, TwoStageDiD, CallawaySantAnna, SunAbraham, ContinuousDiD, -EfficientDiD). Requires survey-aware resampling schemes. +### Bootstrap + Survey Interaction ✅ (2026-03-24) +Survey-aware bootstrap for all 8 bootstrap-using estimators. Two strategies: +- **Multiplier at PSU level** (CS, ImputationDiD, TwoStageDiD, ContinuousDiD, + EfficientDiD): generate multiplier weights at PSU level within strata, FPC-scaled. +- **Rao-Wu rescaled** (SunAbraham, SyntheticDiD, TROP): draw PSUs within strata, + rescale observation weights per Rao, Wu & Yue (1992). +- **CS analytical expansion**: strata/PSU/FPC now supported for aggregated SEs via + `compute_survey_if_variance()`. +- **TROP**: cross-classified pseudo-strata (survey_stratum × treatment_group). +- **Rust TROP**: pweight-only in Rust; full design falls to Python path. ### Replicate Weight Variance Re-run WLS for each replicate weight column, compute variance from distribution diff --git a/tests/test_survey.py b/tests/test_survey.py index 83b7379..f109306 100644 --- a/tests/test_survey.py +++ b/tests/test_survey.py @@ -1358,9 +1358,7 @@ def test_fweight_se_matches_expanded_oracle(self): se_exp = np.sqrt(np.diag(vcov_exp)) # Fweight path: compressed data with integer weights - coef_fw, _, vcov_fw = solve_ols( - X, y, weights=fw.astype(float), weight_type="fweight" - ) + coef_fw, _, vcov_fw = solve_ols(X, y, weights=fw.astype(float), weight_type="fweight") se_fw = np.sqrt(np.diag(vcov_fw)) np.testing.assert_allclose(coef_fw, coef_exp, atol=1e-10) @@ -1471,9 +1469,7 @@ def test_fweight_survey_oracle(self): freq = np.random.choice([1, 2, 3], n).astype(float) # WLS with fweights via survey - coef_fw, resid_fw, _ = solve_ols( - X_base, y_base, weights=freq, weight_type="fweight" - ) + coef_fw, resid_fw, _ = solve_ols(X_base, y_base, weights=freq, weight_type="fweight") resolved = ResolvedSurveyDesign( weights=freq, weight_type="fweight", @@ -1810,9 +1806,7 @@ def test_linear_regression_auto_derives_weights_from_survey(self): ) # Vcov should match - np.testing.assert_allclose( - model_auto.vcov_, model_explicit.vcov_, rtol=1e-10 - ) + np.testing.assert_allclose(model_auto.vcov_, model_explicit.vcov_, rtol=1e-10) def test_resolve_warns_single_psu_unstratified(self): """P1-B: SurveyDesign.resolve() warns for single PSU unstratified.""" @@ -1898,9 +1892,7 @@ def test_conflicting_weights_warns_and_uses_survey(self): ) # Fit with conflicting explicit weights — should warn - reg_conflict = LinearRegression( - weights=different_weights, survey_design=resolved - ) + reg_conflict = LinearRegression(weights=different_weights, survey_design=resolved) with pytest.warns(UserWarning, match="differ from survey_design"): reg_conflict.fit(X, y) @@ -1908,12 +1900,8 @@ def test_conflicting_weights_warns_and_uses_survey(self): reg_survey = LinearRegression(survey_design=resolved) reg_survey.fit(X, y) - np.testing.assert_allclose( - reg_conflict.coefficients_, reg_survey.coefficients_, atol=1e-14 - ) - np.testing.assert_allclose( - reg_conflict.vcov_, reg_survey.vcov_, atol=1e-14 - ) + np.testing.assert_allclose(reg_conflict.coefficients_, reg_survey.coefficients_, atol=1e-14) + np.testing.assert_allclose(reg_conflict.vcov_, reg_survey.vcov_, atol=1e-14) def test_matching_weights_no_warning(self): """Same array object passed as weights and in survey_design: no warning.""" @@ -1958,14 +1946,16 @@ def _make_cluster_data(seed=700): y = 10.0 + c * 0.3 + np.random.randn() * 0.5 if period == 1 and is_treated: y += 3.0 - rows.append({ - "unit": c * obs_per_cluster + i, - "period": period, - "treated": int(is_treated), - "y": y, - "cluster_id": c, - "w": 1.0 + 0.2 * c, - }) + rows.append( + { + "unit": c * obs_per_cluster + i, + "period": period, + "treated": int(is_treated), + "y": y, + "cluster_id": c, + "w": 1.0 + 0.2 * c, + } + ) return pd.DataFrame(rows) def test_cluster_injected_as_psu_did(self): @@ -1974,13 +1964,19 @@ def test_cluster_injected_as_psu_did(self): # Fit with cluster= and weights-only survey (no PSU) result_inject = DifferenceInDifferences(cluster="cluster_id").fit( - data, "y", "treated", "period", + data, + "y", + "treated", + "period", survey_design=SurveyDesign(weights="w"), ) # Fit with explicit PSU in survey design result_explicit = DifferenceInDifferences(cluster="cluster_id").fit( - data, "y", "treated", "period", + data, + "y", + "treated", + "period", survey_design=SurveyDesign(weights="w", psu="cluster_id"), ) @@ -1993,12 +1989,20 @@ def test_cluster_injected_as_psu_twfe(self): data = self._make_cluster_data() result_inject = TwoWayFixedEffects(cluster="cluster_id").fit( - data, "y", "treated", "period", unit="unit", + data, + "y", + "treated", + "period", + unit="unit", survey_design=SurveyDesign(weights="w"), ) result_explicit = TwoWayFixedEffects(cluster="cluster_id").fit( - data, "y", "treated", "period", unit="unit", + data, + "y", + "treated", + "period", + unit="unit", survey_design=SurveyDesign(weights="w", psu="cluster_id"), ) @@ -2017,12 +2021,18 @@ def test_cluster_injected_as_psu_linear_regression(self): # No PSU in resolved design resolved_no_psu = ResolvedSurveyDesign( - weights=weights, weight_type="pweight", - strata=None, psu=None, fpc=None, - n_strata=0, n_psu=0, lonely_psu="remove", + weights=weights, + weight_type="pweight", + strata=None, + psu=None, + fpc=None, + n_strata=0, + n_psu=0, + lonely_psu="remove", ) reg_inject = LinearRegression( - include_intercept=False, cluster_ids=cluster_ids, + include_intercept=False, + cluster_ids=cluster_ids, survey_design=resolved_no_psu, ) reg_inject.fit(X, y) @@ -2030,12 +2040,18 @@ def test_cluster_injected_as_psu_linear_regression(self): # Explicit PSU codes, uniques = pd.factorize(cluster_ids) resolved_psu = ResolvedSurveyDesign( - weights=weights, weight_type="pweight", - strata=None, psu=codes, fpc=None, - n_strata=0, n_psu=len(uniques), lonely_psu="remove", + weights=weights, + weight_type="pweight", + strata=None, + psu=codes, + fpc=None, + n_strata=0, + n_psu=len(uniques), + lonely_psu="remove", ) reg_explicit = LinearRegression( - include_intercept=False, cluster_ids=cluster_ids, + include_intercept=False, + cluster_ids=cluster_ids, survey_design=resolved_psu, ) reg_explicit.fit(X, y) @@ -2048,9 +2064,14 @@ def test_cluster_injection_no_effect_when_psu_present(self): existing_psu = np.array([0, 0, 1, 1, 2, 2]) resolved = ResolvedSurveyDesign( - weights=np.ones(6), weight_type="pweight", - strata=None, psu=existing_psu, fpc=None, - n_strata=0, n_psu=3, lonely_psu="remove", + weights=np.ones(6), + weight_type="pweight", + strata=None, + psu=existing_psu, + fpc=None, + n_strata=0, + n_psu=3, + lonely_psu="remove", ) result = _inject_cluster_as_psu(resolved, np.array([10, 10, 20, 20, 30, 30])) assert result is resolved # Same object — no replacement @@ -2196,9 +2217,7 @@ def test_fweight_df_rounding(self): # Near-integer weights that would truncate incorrectly w = np.full(n, 2.0 - 1e-14) - reg = LinearRegression( - weights=w, weight_type="fweight", include_intercept=False - ) + reg = LinearRegression(weights=w, weight_type="fweight", include_intercept=False) reg.fit(X, y) # sum(w) ≈ 20 - 1e-13; round → 20, truncate → 19 assert reg.df_ == 20 - reg.n_params_effective_ @@ -2331,9 +2350,7 @@ def test_multiperiod_absorb_matches_explicit_wls_dummies(self): survey_design=sd, ) - np.testing.assert_allclose( - result_absorb.avg_att, result_explicit.avg_att, atol=1e-6 - ) + np.testing.assert_allclose(result_absorb.avg_att, result_explicit.avg_att, atol=1e-6) def test_fractional_fweight_rejected_solve_ols(self): """Fractional fweights raise ValueError via solve_ols.""" @@ -2735,9 +2752,9 @@ def test_multiperiod_nonpositive_df_fallback(self): # (normal distribution fallback, not NaN from t(df=0)) for period, pe in result.period_effects.items(): if np.isfinite(pe.se) and pe.se > 0: - assert np.isfinite(pe.p_value), ( - f"Period {period}: finite SE={pe.se} but p_value={pe.p_value}" - ) + assert np.isfinite( + pe.p_value + ), f"Period {period}: finite SE={pe.se} but p_value={pe.p_value}" class TestRound13Fixes: @@ -2908,9 +2925,7 @@ def test_fpc_with_strata_no_psu_accepted(self): "pop": [10.0, 10.0, 10.0, 20.0, 20.0, 20.0], } ) - sd = SurveyDesign( - weights="w", weight_type="pweight", strata="strat", fpc="pop" - ) + sd = SurveyDesign(weights="w", weight_type="pweight", strata="strat", fpc="pop") # Should not raise at resolve time — FPC >= n_PSU validated at vcov time resolved = sd.resolve(df) assert resolved.fpc is not None @@ -2969,26 +2984,34 @@ def test_weights_only_fpc_reduces_variance(self): # Without FPC resolved_no_fpc = ResolvedSurveyDesign( - weights=weights, weight_type="pweight", - strata=None, psu=None, fpc=None, - n_strata=0, n_psu=0, lonely_psu="remove", + weights=weights, + weight_type="pweight", + strata=None, + psu=None, + fpc=None, + n_strata=0, + n_psu=0, + lonely_psu="remove", ) vcov_no_fpc = compute_survey_vcov(X, residuals, resolved=resolved_no_fpc) # With FPC = 100 (sampling 20 from 100) fpc = np.full(n, 100.0) resolved_fpc = ResolvedSurveyDesign( - weights=weights, weight_type="pweight", - strata=None, psu=None, fpc=fpc, - n_strata=0, n_psu=0, lonely_psu="remove", + weights=weights, + weight_type="pweight", + strata=None, + psu=None, + fpc=fpc, + n_strata=0, + n_psu=0, + lonely_psu="remove", ) vcov_fpc = compute_survey_vcov(X, residuals, resolved=resolved_fpc) # FPC should reduce variance: (1 - 20/100) = 0.8 multiplier assert np.all(np.diag(vcov_fpc) < np.diag(vcov_no_fpc)) - np.testing.assert_allclose( - np.diag(vcov_fpc), np.diag(vcov_no_fpc) * 0.8, rtol=1e-10 - ) + np.testing.assert_allclose(np.diag(vcov_fpc), np.diag(vcov_no_fpc) * 0.8, rtol=1e-10) def test_weights_only_fpc_full_census_zero_vcov(self): """Weights-only FPC == n_obs (full census) produces zero vcov.""" @@ -3000,9 +3023,14 @@ def test_weights_only_fpc_full_census_zero_vcov(self): fpc = np.full(n, float(n)) # Full census: FPC == n_obs resolved = ResolvedSurveyDesign( - weights=weights, weight_type="pweight", - strata=None, psu=None, fpc=fpc, - n_strata=0, n_psu=0, lonely_psu="remove", + weights=weights, + weight_type="pweight", + strata=None, + psu=None, + fpc=fpc, + n_strata=0, + n_psu=0, + lonely_psu="remove", ) vcov = compute_survey_vcov(X, residuals, resolved=resolved) np.testing.assert_array_equal(vcov, np.zeros((2, 2))) @@ -3017,9 +3045,14 @@ def test_weights_only_fpc_lt_nobs_rejected(self): fpc = np.full(n, 10.0) # FPC < n_obs → invalid resolved = ResolvedSurveyDesign( - weights=weights, weight_type="pweight", - strata=None, psu=None, fpc=fpc, - n_strata=0, n_psu=0, lonely_psu="remove", + weights=weights, + weight_type="pweight", + strata=None, + psu=None, + fpc=fpc, + n_strata=0, + n_psu=0, + lonely_psu="remove", ) with pytest.raises(ValueError, match="FPC.*less than.*observations"): compute_survey_vcov(X, residuals, resolved=resolved) @@ -3123,14 +3156,10 @@ def _make_panel(self): { "unit": np.repeat(range(n_units), n_periods), "time": np.tile(range(n_periods), n_units), - "treated": np.repeat( - [1] * (n_units // 2) + [0] * (n_units // 2), n_periods - ), + "treated": np.repeat([1] * (n_units // 2) + [0] * (n_units // 2), n_periods), "post": np.tile([0, 0, 1, 1], n_units), "outcome": np.random.randn(n_units * n_periods), - "region": np.repeat( - ["A", "B"] * (n_units // 2), n_periods - ), + "region": np.repeat(["A", "B"] * (n_units // 2), n_periods), "sw": np.random.uniform(0.5, 2.0, n_units * n_periods), } ) diff --git a/tests/test_survey_phase3.py b/tests/test_survey_phase3.py index 3cddfbc..830561a 100644 --- a/tests/test_survey_phase3.py +++ b/tests/test_survey_phase3.py @@ -200,20 +200,46 @@ def test_se_differs_with_design(self, staggered_survey_data): # SEs should differ due to different variance estimators assert r_w.overall_se != r_full.overall_se - def test_bootstrap_survey_raises(self, staggered_survey_data): - """Bootstrap + survey should raise NotImplementedError.""" + def test_bootstrap_weights_only_uses_pairs(self, staggered_survey_data): + """Bootstrap + weights-only survey uses pairs bootstrap (no Rao-Wu).""" from diff_diff import SunAbraham sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="Bootstrap"): - SunAbraham(n_bootstrap=99).fit( - staggered_survey_data, - "outcome", - "unit", - "time", - "first_treat", - survey_design=sd, - ) + result = SunAbraham(n_bootstrap=99, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "time", + "first_treat", + survey_design=sd, + ) + assert result.bootstrap_results is not None + assert result.bootstrap_results.weight_type == "pairs" + assert result.bootstrap_results.n_bootstrap == 99 + assert np.isfinite(result.overall_se) + assert np.isfinite(result.overall_att) + + def test_bootstrap_survey_strata_uses_rao_wu(self, staggered_survey_data): + """Bootstrap + survey with strata/PSU uses Rao-Wu rescaled bootstrap.""" + from diff_diff import SunAbraham + + sd = SurveyDesign(weights="weight", strata="stratum", psu="psu", nest=True) + result = SunAbraham(n_bootstrap=99, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "time", + "first_treat", + survey_design=sd, + ) + assert result.bootstrap_results is not None + assert result.bootstrap_results.weight_type == "rao_wu" + assert result.bootstrap_results.n_bootstrap == 99 + assert np.isfinite(result.overall_se) + assert np.isfinite(result.overall_att) + # Event study effects should also have finite bootstrap SEs + for e, eff in result.event_study_effects.items(): + assert np.isfinite(eff["se"]), f"Event study e={e} has non-finite SE" def test_summary_includes_survey(self, staggered_survey_data): """Summary output should include survey design section.""" @@ -644,21 +670,22 @@ def test_smoke_weights_only(self, continuous_survey_data): assert np.isfinite(result.overall_att) assert result.survey_metadata is not None - def test_bootstrap_survey_raises(self, continuous_survey_data): - """Bootstrap + survey should raise NotImplementedError.""" + def test_bootstrap_survey_supported(self, continuous_survey_data): + """Bootstrap + survey now works via PSU-level multiplier bootstrap.""" from diff_diff import ContinuousDiD sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="bootstrap"): - ContinuousDiD(n_bootstrap=99).fit( - continuous_survey_data, - "outcome", - "unit", - "time", - "first_treat", - "dose", - survey_design=sd, - ) + result = ContinuousDiD(n_bootstrap=30, seed=42).fit( + continuous_survey_data, + "outcome", + "unit", + "time", + "first_treat", + "dose", + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_att_se) def test_summary_includes_survey(self, continuous_survey_data): """Summary includes survey design section.""" @@ -702,20 +729,21 @@ def test_smoke_weights_only(self, staggered_survey_data): assert np.isfinite(result.overall_se) assert result.survey_metadata is not None - def test_bootstrap_survey_raises(self, staggered_survey_data): - """Bootstrap + survey should raise NotImplementedError.""" + def test_bootstrap_survey_supported(self, staggered_survey_data): + """Bootstrap + survey now works via PSU-level multiplier bootstrap.""" from diff_diff import EfficientDiD sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="bootstrap"): - EfficientDiD(n_bootstrap=99).fit( - staggered_survey_data, - "outcome", - "unit", - "time", - "first_treat", - survey_design=sd, - ) + result = EfficientDiD(n_bootstrap=30, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "time", + "first_treat", + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) def test_covariates_survey_raises(self, staggered_survey_data): """Covariates + survey should raise NotImplementedError.""" diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index eee72fe..879345c 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -337,17 +337,34 @@ def test_event_study_with_survey(self, staggered_survey_data, survey_design_weig assert np.isfinite(eff["effect"]) assert np.isfinite(eff["se"]) - def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """Bootstrap + survey should raise NotImplementedError.""" - with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): - ImputationDiD(n_bootstrap=99).fit( - staggered_survey_data, - "outcome", - "unit", - "period", - "first_treat", - survey_design=survey_design_weights_only, - ) + def test_bootstrap_survey_supported(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should produce finite SE via PSU-level multiplier bootstrap.""" + result = ImputationDiD(n_bootstrap=99, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert result.bootstrap_results is not None + assert result.bootstrap_results.n_bootstrap == 99 + assert np.isfinite(result.overall_se) + assert np.isfinite(result.overall_att) + + def test_bootstrap_survey_full_design(self, staggered_survey_data, survey_design_full): + """Bootstrap + full survey (strata+PSU) uses survey-aware multiplier weights.""" + result = ImputationDiD(n_bootstrap=99, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + assert result.bootstrap_results is not None + assert np.isfinite(result.overall_se) + assert np.isfinite(result.overall_att) def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): """Summary output should include survey design section.""" @@ -603,17 +620,20 @@ def test_weighted_gmm_variance(self, staggered_survey_data, survey_design_weight # SE magnitude should differ (not just sign) assert abs(r_unw.overall_se - r_w.overall_se) > 1e-6 - def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """Bootstrap + survey should raise NotImplementedError.""" - with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): - TwoStageDiD(n_bootstrap=99).fit( - staggered_survey_data, - "outcome", - "unit", - "period", - "first_treat", - survey_design=survey_design_weights_only, - ) + def test_bootstrap_survey_works(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should succeed via PSU-level multiplier bootstrap.""" + result = TwoStageDiD(n_bootstrap=99, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert result is not None + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): """Summary output should include survey design section.""" @@ -831,18 +851,20 @@ def test_survey_metadata_fields(self, staggered_survey_data, survey_design_weigh assert sm.effective_n > 0 assert sm.design_effect > 0 - def test_strata_psu_fpc_raises(self, staggered_survey_data): - """Strata/PSU/FPC should raise NotImplementedError.""" + def test_strata_psu_fpc_supported(self, staggered_survey_data): + """Strata/PSU/FPC now works with design-based aggregated SEs.""" sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") - with pytest.raises(NotImplementedError, match="strata/PSU/FPC"): - CallawaySantAnna(estimation_method="reg").fit( - staggered_survey_data, - "outcome", - "unit", - "period", - "first_treat", - survey_design=sd_full, - ) + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None def test_aggregate_group_with_survey(self, staggered_survey_data, survey_design_weights_only): """aggregate='group' works with weights-only survey design.""" @@ -873,17 +895,18 @@ def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_we assert result.event_study_effects is not None assert result.group_effects is not None - def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """Bootstrap + survey should raise NotImplementedError.""" - with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): - CallawaySantAnna(estimation_method="reg", n_bootstrap=99).fit( - staggered_survey_data, - "outcome", - "unit", - "period", - "first_treat", - survey_design=survey_design_weights_only, - ) + def test_bootstrap_survey_supported(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey now works via PSU-level multiplier bootstrap.""" + result = CallawaySantAnna(estimation_method="reg", n_bootstrap=30, seed=42).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) def test_ipw_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): """IPW + covariates + survey should raise NotImplementedError.""" @@ -1367,19 +1390,20 @@ def test_survey_weights_change_per_cell_att(self, staggered_survey_data): effects_no, effects_sv, atol=1e-6 ), f"{method}: survey weights should change per-cell ATT" - def test_strata_psu_fpc_raises_inference(self, staggered_survey_data): - """Strata/PSU/FPC raises NotImplementedError in inference context.""" + def test_strata_psu_fpc_works_inference(self, staggered_survey_data): + """Strata/PSU/FPC now works with design-based aggregated SEs.""" sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") - with pytest.raises(NotImplementedError, match="strata/PSU/FPC"): - CallawaySantAnna(estimation_method="reg").fit( - staggered_survey_data, - "outcome", - "unit", - "period", - "first_treat", - aggregate="simple", - survey_design=sd_full, - ) + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="simple", + survey_design=sd_full, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) # ============================================================================= diff --git a/tests/test_survey_phase5.py b/tests/test_survey_phase5.py index 79bd00e..fcf7e6e 100644 --- a/tests/test_survey_phase5.py +++ b/tests/test_survey_phase5.py @@ -176,10 +176,29 @@ def test_survey_metadata_fields(self, sdid_survey_data, survey_design_weights): assert sm.effective_n > 0 assert sm.design_effect > 0 - def test_strata_psu_fpc_raises(self, sdid_survey_data, survey_design_full): - """Full design raises NotImplementedError.""" - est = SyntheticDiD(n_bootstrap=50, seed=42) - with pytest.raises(NotImplementedError, match="strata/PSU/FPC"): + def test_full_design_bootstrap_smoke(self, sdid_survey_data, survey_design_full): + """Full survey design (strata/PSU) works with bootstrap variance.""" + est = SyntheticDiD(variance_method="bootstrap", n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_full, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.se > 0 + assert result.survey_metadata is not None + assert result.survey_metadata.n_strata is not None + assert result.survey_metadata.n_psu is not None + + def test_full_design_placebo_raises(self, sdid_survey_data, survey_design_full): + """Placebo variance with full design raises NotImplementedError.""" + est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42) + with pytest.raises(NotImplementedError, match="placebo.*does not support strata/PSU/FPC"): est.fit( sdid_survey_data, outcome="outcome", @@ -190,6 +209,35 @@ def test_strata_psu_fpc_raises(self, sdid_survey_data, survey_design_full): survey_design=survey_design_full, ) + def test_full_design_se_differs_from_weights_only(self, sdid_survey_data): + """Rao-Wu bootstrap SE differs from pweight-only bootstrap SE.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + est = SyntheticDiD(variance_method="bootstrap", n_bootstrap=100, seed=42) + + result_w = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=sd_w, + ) + result_full = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=sd_full, + ) + # ATT point estimates should be the same (same weights) + assert result_full.att == pytest.approx(result_w.att, abs=1e-10) + # SEs should differ (different bootstrap scheme) + assert result_full.se != pytest.approx(result_w.se, abs=1e-6) + def test_fweight_aweight_raises(self, sdid_survey_data): """Non-pweight raises ValueError.""" est = SyntheticDiD(n_bootstrap=50, seed=42) @@ -481,18 +529,35 @@ def test_survey_metadata_fields(self, trop_survey_data, survey_design_weights): assert sm.weight_type == "pweight" assert sm.effective_n > 0 - def test_strata_psu_fpc_raises(self, trop_survey_data, survey_design_full): - """Full design raises NotImplementedError.""" - est = TROP(method="local", n_bootstrap=10, seed=42) - with pytest.raises(NotImplementedError, match="strata/PSU/FPC"): - est.fit( - trop_survey_data, - outcome="outcome", - treatment="D", - unit="unit", - time="time", - survey_design=survey_design_full, - ) + def test_full_design_local_rao_wu(self, trop_survey_data, survey_design_full): + """Full design (strata/PSU/FPC) uses Rao-Wu bootstrap and succeeds.""" + est = TROP(method="local", n_bootstrap=20, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_full, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.survey_metadata is not None + + def test_full_design_global_rao_wu(self, trop_survey_data, survey_design_full): + """Full design (strata/PSU/FPC) with global method uses Rao-Wu bootstrap.""" + est = TROP(method="global", n_bootstrap=20, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_full, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.survey_metadata is not None def test_fweight_aweight_raises(self, trop_survey_data): """Non-pweight raises ValueError."""