diff --git a/TODO.md b/TODO.md index bf0b49c2..f6231b16 100644 --- a/TODO.md +++ b/TODO.md @@ -54,6 +54,7 @@ Deferred items from PR reviews that were not addressed before merge. | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | | CallawaySantAnna survey: strata/PSU/FPC rejected at runtime. Full design-based SEs require routing the combined IF/WIF through `compute_survey_vcov()`. Currently weights-only. | `staggered.py` | #233 | Medium | | CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | +| SyntheticDiD/TROP survey: strata/PSU/FPC deferred. Full design-based bootstrap (Rao-Wu rescaled weights) needed for survey-aware resampling. Currently pweight-only. | `synthetic_did.py`, `trop.py` | — | Medium | | EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | EfficientDiD `control_group="last_cohort"` trims at `last_g - anticipation` but REGISTRY says `t >= last_g`. With `anticipation=0` (default) these are identical. With `anticipation>0`, code is arguably more conservative (excludes anticipation-contaminated periods). Either align REGISTRY with code or change code to `t < last_g` — needs design decision. | `efficient_did.py` | #230 | Low | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | diff --git a/diff_diff/results.py b/diff_diff/results.py index 444c3e09..1c62b8b7 100644 --- a/diff_diff/results.py +++ b/diff_diff/results.py @@ -680,6 +680,8 @@ class SyntheticDiDResults: pre_treatment_fit: Optional[float] = field(default=None) placebo_effects: Optional[np.ndarray] = field(default=None) n_bootstrap: Optional[int] = field(default=None) + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None) def __repr__(self) -> str: """Concise string representation.""" @@ -735,6 +737,28 @@ def summary(self, alpha: Optional[float] = None) -> str: if self.variance_method == "bootstrap" and self.n_bootstrap is not None: lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") + # Add survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "", + "-" * 75, + "Survey Design".center(75), + "-" * 75, + f"{'Weight type:':<25} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<25} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<25} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<25} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<25} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<25} {sm.df_survey:>10}") + lines.append("-" * 75) + lines.extend( [ "", @@ -812,6 +836,15 @@ def to_dict(self) -> Dict[str, Any]: } if self.n_bootstrap is not None: result["n_bootstrap"] = self.n_bootstrap + if self.survey_metadata is not None: + sm = self.survey_metadata + result["weight_type"] = sm.weight_type + result["effective_n"] = sm.effective_n + result["design_effect"] = sm.design_effect + result["sum_weights"] = sm.sum_weights + result["n_strata"] = sm.n_strata + result["n_psu"] = sm.n_psu + result["df_survey"] = sm.df_survey return result def to_dataframe(self) -> pd.DataFrame: diff --git a/diff_diff/survey.py b/diff_diff/survey.py index d47d86ee..efa16262 100644 --- a/diff_diff/survey.py +++ b/diff_diff/survey.py @@ -430,6 +430,66 @@ def _validate_unit_constant_survey(data, unit_col, survey_design): ) +def _resolve_pweight_only(resolved_survey, estimator_name): + """Guard: reject non-pweight and strata/PSU/FPC for pweight-only estimators. + + Parameters + ---------- + resolved_survey : ResolvedSurveyDesign or None + Resolved survey design. If None, returns immediately. + estimator_name : str + Estimator name for error messages. + + Raises + ------ + ValueError + If weight_type is not 'pweight'. + NotImplementedError + If strata, PSU, or FPC are present. + """ + if resolved_survey is None: + return + if resolved_survey.weight_type != "pweight": + raise ValueError( + f"{estimator_name} survey support requires weight_type='pweight'. " + f"Got '{resolved_survey.weight_type}'." + ) + if ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ): + raise NotImplementedError( + f"{estimator_name} does not yet support strata/PSU/FPC in " + "SurveyDesign. Use SurveyDesign(weights=...) only. Full " + "design-based bootstrap is planned for the Bootstrap + " + "Survey Interaction phase." + ) + + +def _extract_unit_survey_weights(data, unit_col, survey_design, unit_order): + """Extract unit-level survey weights aligned to a given unit ordering. + + Parameters + ---------- + data : pd.DataFrame + Panel data with survey weight column. + unit_col : str + Unit identifier column name. + survey_design : SurveyDesign + Survey design (uses ``weights`` column name). + unit_order : array-like + Ordered sequence of unit identifiers to align weights to. + + Returns + ------- + np.ndarray + Float64 array of unit-level weights, one per unit in ``unit_order``. + """ + unit_w = data.groupby(unit_col)[survey_design.weights].first() + return np.array([unit_w[u] for u in unit_order], dtype=np.float64) + + def _resolve_survey_for_fit(survey_design, data, inference_mode="analytical"): """ Shared helper: validate and resolve a SurveyDesign for an estimator fit() call. diff --git a/diff_diff/synthetic_did.py b/diff_diff/synthetic_did.py index 59258452..2ff982ca 100644 --- a/diff_diff/synthetic_did.py +++ b/diff_diff/synthetic_did.py @@ -174,8 +174,7 @@ def __init__( valid_methods = ("bootstrap", "placebo") if variance_method not in valid_methods: raise ValueError( - f"variance_method must be one of {valid_methods}, " - f"got '{variance_method}'" + f"variance_method must be one of {valid_methods}, " f"got '{variance_method}'" ) self._unit_weights = None @@ -189,7 +188,8 @@ def fit( # type: ignore[override] unit: str, time: str, post_periods: Optional[List[Any]] = None, - covariates: Optional[List[str]] = None + covariates: Optional[List[str]] = None, + survey_design=None, ) -> SyntheticDiDResults: """ Fit the Synthetic Difference-in-Differences model. @@ -215,6 +215,9 @@ def fit( # type: ignore[override] covariates : list, optional List of covariate column names. Covariates are residualized out before computing the SDID estimator. + survey_design : SurveyDesign, optional + Survey design specification. Only pweight designs are supported + (strata/PSU/FPC raise NotImplementedError). Returns ------- @@ -225,13 +228,14 @@ def fit( # type: ignore[override] Raises ------ ValueError - If required parameters are missing or data validation fails. + If required parameters are missing, data validation fails, + or a non-pweight survey design is provided. + NotImplementedError + If survey_design includes strata, PSU, or FPC. """ # Validate inputs if outcome is None or treatment is None or unit is None or time is None: - raise ValueError( - "Must provide 'outcome', 'treatment', 'unit', and 'time'" - ) + raise ValueError("Must provide 'outcome', 'treatment', 'unit', and 'time'") # Check columns exist required_cols = [outcome, treatment, unit, time] @@ -242,6 +246,19 @@ def fit( # type: ignore[override] if missing: raise ValueError(f"Missing columns: {missing}") + # Resolve survey design + from diff_diff.survey import ( + _extract_unit_survey_weights, + _resolve_pweight_only, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( + _resolve_survey_for_fit(survey_design, data, "analytical") + ) + _resolve_pweight_only(resolved_survey, "SyntheticDiD") + # Validate treatment is binary validate_binary(data[treatment].values, "treatment") @@ -279,9 +296,7 @@ def fit( # type: ignore[override] varying_units = treatment_nunique[treatment_nunique > 1] if len(varying_units) > 0: example_unit = varying_units.index[0] - example_vals = sorted( - data.loc[data[unit] == example_unit, treatment].unique() - ) + example_vals = sorted(data.loc[data[unit] == example_unit, treatment].unique()) raise ValueError( f"Treatment indicator varies within {len(varying_units)} unit(s) " f"(e.g., unit '{example_unit}' has values {example_vals}). " @@ -314,20 +329,42 @@ def fit( # type: ignore[override] f"diff_diff.prep.balance_panel() to balance the panel first." ) + # Validate and extract survey weights + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + w_treated = _extract_unit_survey_weights(data, unit, survey_design, treated_units) + w_control = _extract_unit_survey_weights(data, unit, survey_design, control_units) + else: + w_treated = None + w_control = None + # Residualize covariates if provided working_data = data.copy() if covariates: working_data = self._residualize_covariates( - working_data, outcome, covariates, unit, time + working_data, + outcome, + covariates, + unit, + time, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, ) # Create outcome matrices # Shape: (n_periods, n_units) - Y_pre_control, Y_post_control, Y_pre_treated, Y_post_treated = \ + Y_pre_control, Y_post_control, Y_pre_treated, Y_post_treated = ( self._create_outcome_matrices( - working_data, outcome, unit, time, - pre_periods, post_periods, treated_units, control_units + working_data, + outcome, + unit, + time, + pre_periods, + post_periods, + treated_units, + control_units, ) + ) # Compute auto-regularization (or use user overrides) auto_zeta_omega, auto_zeta_lambda = _compute_regularization( @@ -338,6 +375,7 @@ def fit( # type: ignore[override] # Store noise level for diagnostics from diff_diff.utils import _compute_noise_level + noise_level = _compute_noise_level(Y_pre_control) # Data-dependent convergence threshold (matches R's 1e-5 * noise.level). @@ -347,7 +385,11 @@ def fit( # type: ignore[override] min_decrease = 1e-5 * noise_level if noise_level > 0 else 1e-5 # Compute unit weights (Frank-Wolfe with sparsification) - Y_pre_treated_mean = np.mean(Y_pre_treated, axis=1) + # Survey weights enter via the treated mean target + if w_treated is not None: + Y_pre_treated_mean = np.average(Y_pre_treated, axis=1, weights=w_treated) + else: + Y_pre_treated_mean = np.mean(Y_pre_treated, axis=1) unit_weights = compute_sdid_unit_weights( Y_pre_control, @@ -364,27 +406,41 @@ def fit( # type: ignore[override] min_decrease=min_decrease, ) + # Compose ω with control survey weights (WLS regression interpretation). + # Frank-Wolfe finds best trajectory match; survey weights reweight by + # population importance post-optimization. + if w_control is not None: + omega_eff = unit_weights * w_control + omega_eff = omega_eff / omega_eff.sum() + else: + omega_eff = unit_weights + # Compute SDID estimate - Y_post_treated_mean = np.mean(Y_post_treated, axis=1) + if w_treated is not None: + Y_post_treated_mean = np.average(Y_post_treated, axis=1, weights=w_treated) + else: + Y_post_treated_mean = np.mean(Y_post_treated, axis=1) att = compute_sdid_estimator( Y_pre_control, Y_post_control, Y_pre_treated_mean, Y_post_treated_mean, - unit_weights, - time_weights + omega_eff, + time_weights, ) - # Compute pre-treatment fit (RMSE) - synthetic_pre = Y_pre_control @ unit_weights + # Compute pre-treatment fit (RMSE) using composed weights + synthetic_pre = Y_pre_control @ omega_eff pre_fit_rmse = np.sqrt(np.mean((Y_pre_treated_mean - synthetic_pre) ** 2)) # Warn if pre-treatment fit is poor (Registry requirement). # Threshold: 1× SD of treated pre-treatment outcomes — a natural baseline # since RMSE exceeding natural variation indicates the synthetic control # fails to reproduce the treated series' level or trend. - pre_treatment_sd = np.std(Y_pre_treated_mean, ddof=1) if len(Y_pre_treated_mean) > 1 else 0.0 + pre_treatment_sd = ( + np.std(Y_pre_treated_mean, ddof=1) if len(Y_pre_treated_mean) > 1 else 0.0 + ) if pre_treatment_sd > 0 and pre_fit_rmse > pre_treatment_sd: warnings.warn( f"Pre-treatment fit is poor: RMSE ({pre_fit_rmse:.4f}) exceeds " @@ -399,9 +455,14 @@ def fit( # type: ignore[override] # Compute standard errors based on variance_method if self.variance_method == "bootstrap": se, bootstrap_estimates = self._bootstrap_se( - Y_pre_control, Y_post_control, - Y_pre_treated, Y_post_treated, - unit_weights, time_weights, + Y_pre_control, + Y_post_control, + Y_pre_treated, + Y_post_treated, + unit_weights, + time_weights, + w_treated=w_treated, + w_control=w_control, ) placebo_effects = bootstrap_estimates inference_method = "bootstrap" @@ -416,7 +477,8 @@ def fit( # type: ignore[override] zeta_omega=zeta_omega, zeta_lambda=zeta_lambda, min_decrease=min_decrease, - replications=self.n_bootstrap # Reuse n_bootstrap for replications + replications=self.n_bootstrap, + w_control=w_control, ) inference_method = "placebo" @@ -430,13 +492,11 @@ def fit( # type: ignore[override] else: p_value = p_value_analytical - # Create weight dictionaries - unit_weights_dict = { - unit_id: w for unit_id, w in zip(control_units, unit_weights) - } - time_weights_dict = { - period: w for period, w in zip(pre_periods, time_weights) - } + # Create weight dictionaries. When survey weights are active, store + # the effective (composed) weights that were actually used for the ATT + # so that results.unit_weights matches the estimator. + unit_weights_dict = {unit_id: w for unit_id, w in zip(control_units, omega_eff)} + time_weights_dict = {period: w for period, w in zip(pre_periods, time_weights)} # Store results self.results_ = SyntheticDiDResults( @@ -459,7 +519,8 @@ def fit( # type: ignore[override] zeta_lambda=zeta_lambda, pre_treatment_fit=pre_fit_rmse, placebo_effects=placebo_effects if len(placebo_effects) > 0 else None, - n_bootstrap=self.n_bootstrap if inference_method == "bootstrap" else None + n_bootstrap=self.n_bootstrap if inference_method == "bootstrap" else None, + survey_metadata=survey_metadata, ) self._unit_weights = unit_weights @@ -477,7 +538,7 @@ def _create_outcome_matrices( pre_periods: List[Any], post_periods: List[Any], treated_units: List[Any], - control_units: List[Any] + control_units: List[Any], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Create outcome matrices for SDID estimation. @@ -501,7 +562,7 @@ def _create_outcome_matrices( Y_pre_control.astype(float), Y_post_control.astype(float), Y_pre_treated.astype(float), - Y_post_treated.astype(float) + Y_post_treated.astype(float), ) def _residualize_covariates( @@ -510,12 +571,16 @@ def _residualize_covariates( outcome: str, covariates: List[str], unit: str, - time: str + time: str, + survey_weights=None, + survey_weight_type=None, ) -> pd.DataFrame: """ Residualize outcome by regressing out covariates. - Uses two-way fixed effects to partial out covariates. + Uses two-way fixed effects to partial out covariates. When survey + weights are provided, uses WLS for population-representative + covariate removal. """ data = data.copy() @@ -523,23 +588,28 @@ def _residualize_covariates( X = data[covariates].values.astype(float) # Add unit and time dummies - unit_dummies = pd.get_dummies(data[unit], prefix='u', drop_first=True) - time_dummies = pd.get_dummies(data[time], prefix='t', drop_first=True) + unit_dummies = pd.get_dummies(data[unit], prefix="u", drop_first=True) + time_dummies = pd.get_dummies(data[time], prefix="t", drop_first=True) - X_full = np.column_stack([ - np.ones(len(data)), - X, - unit_dummies.values, - time_dummies.values - ]) + X_full = np.column_stack([np.ones(len(data)), X, unit_dummies.values, time_dummies.values]) y = data[outcome].values.astype(float) # Fit and get residuals using unified backend - coeffs, residuals, _ = solve_ols(X_full, y, return_vcov=False) + coeffs, residuals, _ = solve_ols( + X_full, + y, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) # Add back the mean for interpretability - data[outcome] = residuals + np.mean(y) + if survey_weights is not None: + y_center = np.average(y, weights=survey_weights) + else: + y_center = np.mean(y) + data[outcome] = residuals + y_center return data @@ -551,6 +621,8 @@ def _bootstrap_se( Y_post_treated: np.ndarray, unit_weights: np.ndarray, time_weights: np.ndarray, + w_treated=None, + w_control=None, ) -> Tuple[float, np.ndarray]: """Compute bootstrap standard error matching R's synthdid bootstrap_sample. @@ -566,10 +638,7 @@ def _bootstrap_se( n_total = n_control + n_treated # Build full panel matrix: (n_pre+n_post, n_control+n_treated) - Y_full = np.block([ - [Y_pre_control, Y_pre_treated], - [Y_post_control, Y_post_treated] - ]) + Y_full = np.block([[Y_pre_control, Y_pre_treated], [Y_post_control, Y_post_treated]]) n_pre = Y_pre_control.shape[0] bootstrap_estimates = [] @@ -591,6 +660,14 @@ def _bootstrap_se( # Renormalize original unit weights for the resampled controls boot_omega = _sum_normalize(unit_weights[boot_control_idx]) + # Compose with control survey weights if present + if w_control is not None: + boot_w_c = w_control[boot_idx[boot_is_control]] + boot_omega_eff = boot_omega * boot_w_c + boot_omega_eff = boot_omega_eff / boot_omega_eff.sum() + else: + boot_omega_eff = boot_omega + # Extract resampled outcome matrices Y_boot = Y_full[:, boot_idx] Y_boot_pre_c = Y_boot[:n_pre, boot_is_control] @@ -598,14 +675,25 @@ def _bootstrap_se( Y_boot_pre_t = Y_boot[:n_pre, ~boot_is_control] Y_boot_post_t = Y_boot[n_pre:, ~boot_is_control] - # Compute ATT with FIXED weights (do NOT re-estimate) - Y_boot_pre_t_mean = np.mean(Y_boot_pre_t, axis=1) - Y_boot_post_t_mean = np.mean(Y_boot_post_t, axis=1) + # Compute ATT with FIXED weights (do NOT re-estimate). + # boot_idx[~boot_is_control] maps to original index space; + # subtract n_control to index into w_treated. Duplicate draws + # carry identical weights → alignment is safe. + if w_treated is not None: + boot_w_t = w_treated[boot_idx[~boot_is_control] - n_control] + Y_boot_pre_t_mean = np.average(Y_boot_pre_t, axis=1, weights=boot_w_t) + Y_boot_post_t_mean = np.average(Y_boot_post_t, axis=1, weights=boot_w_t) + else: + Y_boot_pre_t_mean = np.mean(Y_boot_pre_t, axis=1) + Y_boot_post_t_mean = np.mean(Y_boot_post_t, axis=1) tau = compute_sdid_estimator( - Y_boot_pre_c, Y_boot_post_c, - Y_boot_pre_t_mean, Y_boot_post_t_mean, - boot_omega, time_weights # time_weights = original lambda + Y_boot_pre_c, + Y_boot_post_c, + Y_boot_pre_t_mean, + Y_boot_post_t_mean, + boot_omega_eff, + time_weights, ) if np.isfinite(tau): bootstrap_estimates.append(tau) @@ -664,7 +752,8 @@ def _placebo_variance_se( zeta_omega: float = 0.0, zeta_lambda: float = 0.0, min_decrease: float = 1e-5, - replications: int = 200 + replications: int = 200, + w_control=None, ) -> Tuple[float, np.ndarray]: """ Compute placebo-based variance matching R's synthdid methodology. @@ -743,12 +832,27 @@ def _placebo_variance_se( # Get pseudo-control and pseudo-treated outcomes Y_pre_pseudo_control = Y_pre_control[:, pseudo_control_idx] Y_post_pseudo_control = Y_post_control[:, pseudo_control_idx] - Y_pre_pseudo_treated_mean = np.mean( - Y_pre_control[:, pseudo_treated_idx], axis=1 - ) - Y_post_pseudo_treated_mean = np.mean( - Y_post_control[:, pseudo_treated_idx], axis=1 - ) + + # Pseudo-treated means: survey-weighted when available + if w_control is not None: + pseudo_w_tr = w_control[pseudo_treated_idx] + Y_pre_pseudo_treated_mean = np.average( + Y_pre_control[:, pseudo_treated_idx], + axis=1, + weights=pseudo_w_tr, + ) + Y_post_pseudo_treated_mean = np.average( + Y_post_control[:, pseudo_treated_idx], + axis=1, + weights=pseudo_w_tr, + ) + else: + Y_pre_pseudo_treated_mean = np.mean( + Y_pre_control[:, pseudo_treated_idx], axis=1 + ) + Y_post_pseudo_treated_mean = np.mean( + Y_post_control[:, pseudo_treated_idx], axis=1 + ) # Re-estimate weights on permuted data (matching R's behavior) # R passes update.omega=TRUE, update.lambda=TRUE via opts, @@ -761,6 +865,14 @@ def _placebo_variance_se( min_decrease=min_decrease, ) + # Compose pseudo_omega with control survey weights + if w_control is not None: + pseudo_w_co = w_control[pseudo_control_idx] + pseudo_omega_eff = pseudo_omega * pseudo_w_co + pseudo_omega_eff = pseudo_omega_eff / pseudo_omega_eff.sum() + else: + pseudo_omega_eff = pseudo_omega + # Time weights: re-estimate on pseudo-control data pseudo_lambda = compute_time_weights( Y_pre_pseudo_control, @@ -775,8 +887,8 @@ def _placebo_variance_se( Y_post_pseudo_control, Y_pre_pseudo_treated_mean, Y_post_pseudo_treated_mean, - pseudo_omega, - pseudo_lambda + pseudo_omega_eff, + pseudo_lambda, ) if np.isfinite(tau): placebo_estimates.append(tau) @@ -811,9 +923,7 @@ def _placebo_variance_se( # Compute SE using R's formula: sqrt((r-1)/r) * sd(estimates) # This matches synthdid::vcov.R exactly - se = np.sqrt((n_successful - 1) / n_successful) * np.std( - placebo_estimates, ddof=1 - ) + se = np.sqrt((n_successful - 1) / n_successful) * np.std(placebo_estimates, ddof=1) return se, placebo_estimates @@ -835,8 +945,7 @@ def set_params(self, **params) -> "SyntheticDiD": for key, value in params.items(): if key in _deprecated: warnings.warn( - f"{key} is deprecated and ignored. Use zeta_omega/zeta_lambda " - f"instead.", + f"{key} is deprecated and ignored. Use zeta_omega/zeta_lambda " f"instead.", DeprecationWarning, stacklevel=2, ) diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 6949a155..b12b1212 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -143,9 +143,7 @@ def __init__( # 'local'/'global' are preferred; 'twostep'/'joint' are deprecated aliases valid_methods = ("local", "twostep", "joint", "global") if method not in valid_methods: - raise ValueError( - f"method must be one of {valid_methods}, got '{method}'" - ) + raise ValueError(f"method must be one of {valid_methods}, got '{method}'") if method == "twostep": warnings.warn( "method='twostep' is deprecated and will be removed in v3.0. " @@ -263,9 +261,9 @@ def _univariate_loocv_search( for value in grid: params = {**fixed_params, param_name: value} - lambda_time = params.get('lambda_time', 0.0) - lambda_unit = params.get('lambda_unit', 0.0) - lambda_nn = params.get('lambda_nn', 0.0) + lambda_time = params.get("lambda_time", 0.0) + lambda_unit = params.get("lambda_unit", 0.0) + lambda_nn = params.get("lambda_nn", 0.0) # Convert λ_nn=∞ → large finite value (factor model disabled, L≈0) # λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed) @@ -274,9 +272,15 @@ def _univariate_loocv_search( try: score = self._loocv_score_obs_specific( - Y, D, control_mask, control_unit_idx, - lambda_time, lambda_unit, lambda_nn, - n_units, n_periods + Y, + D, + control_mask, + control_unit_idx, + lambda_time, + lambda_unit, + lambda_nn, + n_units, + n_periods, ) if score < best_score: best_score = score @@ -333,30 +337,47 @@ def _cycling_parameter_search( for cycle in range(max_cycles): # Optimize λ_unit (fix λ_time, λ_nn) lambda_unit, _ = self._univariate_loocv_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - 'lambda_unit', self.lambda_unit_grid, - {'lambda_time': lambda_time, 'lambda_nn': lambda_nn} + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + "lambda_unit", + self.lambda_unit_grid, + {"lambda_time": lambda_time, "lambda_nn": lambda_nn}, ) # Optimize λ_time (fix λ_unit, λ_nn) lambda_time, _ = self._univariate_loocv_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - 'lambda_time', self.lambda_time_grid, - {'lambda_unit': lambda_unit, 'lambda_nn': lambda_nn} + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + "lambda_time", + self.lambda_time_grid, + {"lambda_unit": lambda_unit, "lambda_nn": lambda_nn}, ) # Optimize λ_nn (fix λ_unit, λ_time) lambda_nn, score = self._univariate_loocv_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - 'lambda_nn', self.lambda_nn_grid, - {'lambda_unit': lambda_unit, 'lambda_time': lambda_time} + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + "lambda_nn", + self.lambda_nn_grid, + {"lambda_unit": lambda_unit, "lambda_time": lambda_time}, ) # Check convergence if abs(score - prev_score) < 1e-6: logger.debug( - "Cycling search converged after %d cycles with score %.6f", - cycle + 1, score + "Cycling search converged after %d cycles with score %.6f", cycle + 1, score ) break prev_score = score @@ -374,6 +395,7 @@ def fit( treatment: str, unit: str, time: str, + survey_design=None, ) -> TROPResults: """ Fit the TROP model. @@ -403,6 +425,10 @@ def fit( Name of the unit identifier column. time : str Name of the time period column. + survey_design : SurveyDesign, optional + Survey design specification. Only pweight designs are supported + (strata/PSU/FPC raise NotImplementedError). Survey weights enter + ATT aggregation only. Returns ------- @@ -412,6 +438,13 @@ def fit( attributes show the selected grid values. For lambda_time and lambda_unit, 0.0 means uniform weights; inf is not accepted. For lambda_nn, inf is converted to 1e10 (factor model disabled). + + Raises + ------ + ValueError + If required columns are missing or non-pweight survey design. + NotImplementedError + If survey_design includes strata, PSU, or FPC. """ # Validate inputs required_cols = [outcome, treatment, unit, time] @@ -419,13 +452,43 @@ def fit( if missing: raise ValueError(f"Missing columns: {missing}") + # Resolve survey design + from diff_diff.survey import ( + _extract_unit_survey_weights, + _resolve_pweight_only, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, _survey_weights, _survey_wt, survey_metadata = _resolve_survey_for_fit( + survey_design, data, "analytical" + ) + _resolve_pweight_only(resolved_survey, "TROP") + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + # Dispatch based on estimation method if self.method == "global": - return self._fit_global(data, outcome, treatment, unit, time) + return self._fit_global( + data, + outcome, + treatment, + unit, + time, + resolved_survey=resolved_survey, + survey_metadata=survey_metadata, + survey_design=survey_design, + ) # Below is the local method (default) # Get unique units and periods all_units = sorted(data[unit].unique()) + + # Extract unit-level survey weights + if resolved_survey is not None: + unit_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units) + else: + unit_weight_arr = None all_periods = sorted(data[time].unique()) n_units = len(all_units) @@ -447,9 +510,8 @@ def fit( # For D matrix, track missing values BEFORE fillna to support unbalanced panels # Issue 3 fix: Missing observations should not trigger spurious violations - D_raw = ( - data.pivot(index=time, columns=unit, values=treatment) - .reindex(index=all_periods, columns=all_units) + D_raw = data.pivot(index=time, columns=unit, values=treatment).reindex( + index=all_periods, columns=all_units ) missing_mask = pd.isna(D_raw).values # True where originally missing D = D_raw.fillna(0).astype(int).values @@ -519,9 +581,7 @@ def fit( control_mask = D == 0 # Pre-compute structures that are reused across LOOCV iterations - self._precomputed = self._precompute_structures( - Y, D, control_unit_idx, n_units, n_periods - ) + self._precomputed = self._precompute_structures(Y, D, control_unit_idx, n_units, n_periods) # Use Rust backend for parallel LOOCV grid search (10-50x speedup) if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None: @@ -535,13 +595,20 @@ def fit( lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) result = _rust_loocv_grid_search( - Y, D.astype(np.float64), control_mask_u8, + Y, + D.astype(np.float64), + control_mask_u8, time_dist_matrix, - lambda_time_arr, lambda_unit_arr, lambda_nn_arr, - self.max_iter, self.tol, + lambda_time_arr, + lambda_unit_arr, + lambda_nn_arr, + self.max_iter, + self.tol, ) # Unpack result - 7 values including optional first_failed_obs - best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result + best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = ( + result + ) # Only accept finite scores - infinite means all fits failed if np.isfinite(best_score): best_lambda = (best_lt, best_lu, best_ln) @@ -557,7 +624,7 @@ def fit( f"LOOCV: All {n_attempted} fits failed for " f"\u03bb=({best_lt}, {best_lu}, {best_ln}). " f"Returning infinite score.{obs_info}", - UserWarning + UserWarning, ) elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted: n_failed = n_attempted - n_valid @@ -570,13 +637,11 @@ def fit( f"LOOCV: {n_failed}/{n_attempted} fits failed for " f"\u03bb=({best_lt}, {best_lu}, {best_ln}). " f"This may indicate numerical instability.{obs_info}", - UserWarning + UserWarning, ) except Exception as e: # Fall back to Python implementation on error - logger.debug( - "Rust LOOCV grid search failed, falling back to Python: %s", e - ) + logger.debug("Rust LOOCV grid search failed, falling back to Python: %s", e) best_lambda = None best_score = np.inf @@ -590,37 +655,66 @@ def fit( # λ_time search: fix λ_unit=0, λ_nn=∞ (disabled - no factor adjustment) lambda_time_init, _ = self._univariate_loocv_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - 'lambda_time', self.lambda_time_grid, - {'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF} + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + "lambda_time", + self.lambda_time_grid, + {"lambda_unit": 0.0, "lambda_nn": _LAMBDA_INF}, ) # λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0 lambda_nn_init, _ = self._univariate_loocv_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - 'lambda_nn', self.lambda_nn_grid, - {'lambda_time': 0.0, 'lambda_unit': 0.0} + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + "lambda_nn", + self.lambda_nn_grid, + {"lambda_time": 0.0, "lambda_unit": 0.0}, ) # λ_unit search: fix λ_nn=∞, λ_time=0 lambda_unit_init, _ = self._univariate_loocv_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - 'lambda_unit', self.lambda_unit_grid, - {'lambda_nn': _LAMBDA_INF, 'lambda_time': 0.0} + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + "lambda_unit", + self.lambda_unit_grid, + {"lambda_nn": _LAMBDA_INF, "lambda_time": 0.0}, ) # Stage 2: Cycling refinement (coordinate descent) lambda_time, lambda_unit, lambda_nn = self._cycling_parameter_search( - Y, D, control_mask, control_unit_idx, n_units, n_periods, - (lambda_time_init, lambda_unit_init, lambda_nn_init) + Y, + D, + control_mask, + control_unit_idx, + n_units, + n_periods, + (lambda_time_init, lambda_unit_init, lambda_nn_init), ) # Compute final score for the optimized parameters try: best_score = self._loocv_score_obs_specific( - Y, D, control_mask, control_unit_idx, - lambda_time, lambda_unit, lambda_nn, - n_units, n_periods + Y, + D, + control_mask, + control_unit_idx, + lambda_time, + lambda_unit, + lambda_nn, + n_units, + n_periods, ) # Only accept finite scores - infinite means all fits failed if np.isfinite(best_score): @@ -631,10 +725,7 @@ def fit( pass if best_lambda is None: - warnings.warn( - "All tuning parameter combinations failed. Using defaults.", - UserWarning - ) + warnings.warn("All tuning parameter combinations failed. Using defaults.", UserWarning) best_lambda = (1.0, 1.0, 0.1) best_score = np.nan @@ -657,6 +748,7 @@ def fit( # For each treated (i,t): compute observation-specific weights, fit model, compute tau_{it} treatment_effects = {} tau_values = [] + tau_weights = [] # parallel to tau_values for survey-weighted ATT alpha_estimates = [] beta_estimates = [] L_estimates = [] @@ -676,14 +768,12 @@ def fit( # Compute observation-specific weights for this (i, t) weight_matrix = self._compute_observation_weights( - Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, - n_units, n_periods + Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods ) # Fit model with these weights alpha_hat, beta_hat, L_hat = self._estimate_model( - Y, control_mask, weight_matrix, lambda_nn, - n_units, n_periods + Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods ) # Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it} @@ -691,6 +781,8 @@ def fit( treatment_effects[(unit_id, time_id)] = tau_it tau_values.append(tau_it) + if unit_weight_arr is not None: + tau_weights.append(unit_weight_arr[i]) # Store for averaging alpha_estimates.append(alpha_hat) @@ -711,8 +803,11 @@ def fit( UserWarning, ) - # Average ATT - att = np.mean(tau_values) if tau_values else np.nan + # Average ATT (survey-weighted when applicable) + if unit_weight_arr is not None and tau_values: + att = float(np.average(tau_values, weights=tau_weights)) + else: + att = np.mean(tau_values) if tau_values else np.nan # Average parameter estimates for output (representative) alpha_hat = np.mean(alpha_estimates, axis=0) if alpha_estimates else np.zeros(n_units) @@ -730,8 +825,17 @@ def fit( # Use effective_lambda (converted values) to ensure SE is computed with same # parameters as point estimation. This fixes the variance inconsistency issue. se, bootstrap_dist = self._bootstrap_variance( - data, outcome, treatment, unit, time, - effective_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx + data, + outcome, + treatment, + unit, + time, + effective_lambda, + Y=Y, + D=D, + control_unit_idx=control_unit_idx, + survey_design=survey_design, + unit_weight_arr=unit_weight_arr, ) # Compute test statistics @@ -767,6 +871,7 @@ def fit( n_post_periods=n_post_periods, n_bootstrap=self.n_bootstrap, bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -822,6 +927,7 @@ def trop( treatment: str, unit: str, time: str, + survey_design=None, **kwargs, ) -> TROPResults: """ @@ -844,6 +950,8 @@ def trop( Unit identifier column name. time : str Time period column name. + survey_design : SurveyDesign, optional + Survey design specification. Only pweight designs are supported. **kwargs Additional arguments passed to TROP constructor. @@ -859,4 +967,4 @@ def trop( >>> print(f"ATT: {results.att:.3f}") """ estimator = TROP(**kwargs) - return estimator.fit(data, outcome, treatment, unit, time) + return estimator.fit(data, outcome, treatment, unit, time, survey_design=survey_design) diff --git a/diff_diff/trop_global.py b/diff_diff/trop_global.py index 1fdde78c..c8a9097b 100644 --- a/diff_diff/trop_global.py +++ b/diff_diff/trop_global.py @@ -118,14 +118,12 @@ def _compute_global_weights( # dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre) # Use NaN-safe operations: treat NaN differences as 0 (excluded) diff = average_treated[:, np.newaxis] - Y - diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis] + diff_sq = np.where(np.isfinite(diff), diff**2, 0.0) * pre_mask[:, np.newaxis] # Count valid observations per unit in pre-period # Must check diff is finite (both Y and average_treated finite) # to match the periods contributing to diff_sq - valid_count = np.sum( - np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0 - ) + valid_count = np.sum(np.isfinite(diff) * pre_mask[:, np.newaxis], axis=0) sum_sq = np.sum(diff_sq, axis=0) n_pre = np.sum(pre_mask) @@ -184,6 +182,7 @@ def _extract_posthoc_tau( L: np.ndarray, idx_to_unit: Optional[Dict] = None, idx_to_period: Optional[Dict] = None, + unit_weights: Optional[np.ndarray] = None, ) -> Tuple[float, Dict, List[float]]: """ Extract post-hoc treatment effects: tau_it = Y - mu - alpha - beta - L. @@ -199,7 +198,11 @@ def _extract_posthoc_tau( valid_treated = treated_mask & finite_mask tau_values = tau_matrix[valid_treated].tolist() - att = float(np.mean(tau_values)) if tau_values else np.nan + if unit_weights is not None and tau_values: + obs_weights = unit_weights[np.where(valid_treated)[1]] + att = float(np.average(tau_values, weights=obs_weights)) + else: + att = float(np.mean(tau_values)) if tau_values else np.nan # Build treatment effects dict treatment_effects: Dict = {} @@ -282,7 +285,7 @@ def _loocv_score_global( # Pseudo treatment effect: tau = Y - mu - alpha - beta - L if np.isfinite(Y[t_ex, i_ex]): tau_loocv = Y[t_ex, i_ex] - mu - alpha[i_ex] - beta[t_ex] - L[t_ex, i_ex] - tau_sq_sum += tau_loocv ** 2 + tau_sq_sum += tau_loocv**2 n_valid += 1 except (np.linalg.LinAlgError, ValueError): @@ -374,7 +377,7 @@ def _solve_global_no_lowrank( alpha = np.zeros(n_units) alpha[1:] = coeffs[1:n_units] beta = np.zeros(n_periods) - beta[1:] = coeffs[n_units:(n_units + n_periods - 1)] + beta[1:] = coeffs[n_units : (n_units + n_periods - 1)] return float(mu), alpha, beta @@ -490,6 +493,9 @@ def _fit_global( treatment: str, unit: str, time: str, + resolved_survey=None, + survey_metadata=None, + survey_design=None, ) -> TROPResults: """ Fit TROP using global weighted least squares method. @@ -529,6 +535,14 @@ def _fit_global( all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) + # Extract per-unit survey weights for weighted ATT aggregation + if resolved_survey is not None: + from diff_diff.survey import _extract_unit_survey_weights + + unit_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units) + else: + unit_weight_arr = None + n_units = len(all_units) n_periods = len(all_periods) @@ -542,9 +556,8 @@ def _fit_global( .values ) - D_raw = ( - data.pivot(index=time, columns=unit, values=treatment) - .reindex(index=all_periods, columns=all_units) + D_raw = data.pivot(index=time, columns=unit, values=treatment).reindex( + index=all_periods, columns=all_units ) missing_mask = pd.isna(D_raw).values D = D_raw.fillna(0).astype(int).values @@ -634,12 +647,19 @@ def _fit_global( lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) result = _rust_loocv_grid_search_global( - Y, D.astype(np.float64), control_mask_u8, - lambda_time_arr, lambda_unit_arr, lambda_nn_arr, - self.max_iter, self.tol, + Y, + D.astype(np.float64), + control_mask_u8, + lambda_time_arr, + lambda_unit_arr, + lambda_nn_arr, + self.max_iter, + self.tol, ) # Unpack result - 7 values including optional first_failed_obs - best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result + best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = ( + result + ) # Only accept finite scores - infinite means all fits failed if np.isfinite(best_score): best_lambda = (best_lt, best_lu, best_ln) @@ -653,7 +673,7 @@ def _fit_global( f"LOOCV: All {n_attempted} fits failed for " f"\u03bb=({best_lt}, {best_lu}, {best_ln}). " f"Returning infinite score.{obs_info}", - UserWarning + UserWarning, ) elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted: n_failed = n_attempted - n_valid @@ -665,7 +685,7 @@ def _fit_global( f"LOOCV: {n_failed}/{n_attempted} fits failed for " f"\u03bb=({best_lt}, {best_lu}, {best_ln}). " f"This may indicate numerical instability.{obs_info}", - UserWarning + UserWarning, ) except Exception as e: # Fall back to Python implementation on error @@ -679,7 +699,9 @@ def _fit_global( if best_lambda is None: # Get control observations for LOOCV control_obs = [ - (t, i) for t in range(n_periods) for i in range(n_units) + (t, i) + for t in range(n_periods) + for i in range(n_units) if control_mask[t, i] and not np.isnan(Y[t, i]) ] @@ -694,8 +716,7 @@ def _fit_global( try: score = self._loocv_score_global( - Y, D, control_obs, lt, lu, ln, - treated_periods, n_units, n_periods + Y, D, control_obs, lt, lu, ln, treated_periods, n_units, n_periods ) if score < best_score: @@ -706,10 +727,7 @@ def _fit_global( continue if best_lambda is None: - warnings.warn( - "All tuning parameter combinations failed. Using defaults.", - UserWarning - ) + warnings.warn("All tuning parameter combinations failed. Using defaults.", UserWarning) best_lambda = (1.0, 1.0, 0.1) best_score = np.nan @@ -731,7 +749,15 @@ def _fit_global( # Post-hoc tau extraction (per paper Eq. 2) att, treatment_effects, tau_values = self._extract_posthoc_tau( - Y, D, mu, alpha, beta, L, idx_to_unit, idx_to_period + Y, + D, + mu, + alpha, + beta, + L, + idx_to_unit, + idx_to_period, + unit_weights=unit_weight_arr, ) # Use count of valid (finite) treated outcomes for df and metadata @@ -759,8 +785,15 @@ def _fit_global( effective_lambda = (lambda_time, lambda_unit, lambda_nn) se, bootstrap_dist = self._bootstrap_variance_global( - data, outcome, treatment, unit, time, - effective_lambda, treated_periods + data, + outcome, + treatment, + unit, + time, + effective_lambda, + treated_periods, + survey_design=survey_design, + unit_weight_arr=unit_weight_arr, ) # Compute test statistics @@ -795,6 +828,7 @@ def _fit_global( n_post_periods=n_post_periods, n_bootstrap=self.n_bootstrap, bootstrap_distribution=bootstrap_dist if len(bootstrap_dist) > 0 else None, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -809,6 +843,8 @@ def _bootstrap_variance_global( time: str, optimal_lambda: Tuple[float, float, float], treated_periods: int, + survey_design=None, + unit_weight_arr: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """ Compute bootstrap standard error for global method. @@ -860,16 +896,22 @@ def _bootstrap_variance_global( ) bootstrap_estimates, se = _rust_bootstrap_trop_variance_global( - Y, D, - lambda_time, lambda_unit, lambda_nn, - self.n_bootstrap, self.max_iter, self.tol, - self.seed if self.seed is not None else 0 + Y, + D, + lambda_time, + lambda_unit, + lambda_nn, + self.n_bootstrap, + self.max_iter, + self.tol, + self.seed if self.seed is not None else 0, + unit_weight_arr, ) if len(bootstrap_estimates) < 10: warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", - UserWarning + UserWarning, ) if len(bootstrap_estimates) == 0: return np.nan, np.array([]) @@ -877,9 +919,7 @@ def _bootstrap_variance_global( return float(se), np.array(bootstrap_estimates) except Exception as e: - logger.debug( - "Rust bootstrap (global) failed, falling back to Python: %s", e - ) + logger.debug("Rust bootstrap (global) failed, falling back to Python: %s", e) # Python fallback implementation rng = np.random.default_rng(self.seed) @@ -897,31 +937,36 @@ def _bootstrap_variance_global( for _ in range(self.n_bootstrap): # Stratified sampling if n_control_units > 0: - sampled_control = rng.choice( - control_units, size=n_control_units, replace=True - ) + sampled_control = rng.choice(control_units, size=n_control_units, replace=True) else: sampled_control = np.array([], dtype=object) if n_treated_units > 0: - sampled_treated = rng.choice( - treated_units, size=n_treated_units, replace=True - ) + sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True) else: sampled_treated = np.array([], dtype=object) sampled_units = np.concatenate([sampled_control, sampled_treated]) # Create bootstrap sample - boot_data = pd.concat([ - data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) - for idx, u in enumerate(sampled_units) - ], ignore_index=True) + boot_data = pd.concat( + [ + data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) + for idx, u in enumerate(sampled_units) + ], + ignore_index=True, + ) try: tau = self._fit_global_with_fixed_lambda( - boot_data, outcome, treatment, unit, time, - optimal_lambda, treated_periods + boot_data, + outcome, + treatment, + unit, + time, + optimal_lambda, + treated_periods, + survey_design=survey_design, ) if np.isfinite(tau): bootstrap_estimates_list.append(tau) @@ -932,8 +977,7 @@ def _bootstrap_variance_global( if len(bootstrap_estimates) < 10: warnings.warn( - f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", - UserWarning + f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning ) if len(bootstrap_estimates) == 0: return np.nan, np.array([]) @@ -950,6 +994,7 @@ def _fit_global_with_fixed_lambda( time: str, fixed_lambda: Tuple[float, float, float], treated_periods: int, + survey_design=None, ) -> float: """ Fit global model with fixed tuning parameters. @@ -961,6 +1006,14 @@ def _fit_global_with_fixed_lambda( all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) + # Extract per-unit survey weights for weighted ATT in bootstrap + if survey_design is not None and survey_design.weights is not None: + from diff_diff.survey import _extract_unit_survey_weights + + local_weight_arr = _extract_unit_survey_weights(data, unit, survey_design, all_units) + else: + local_weight_arr = None + n_units = len(all_units) n_periods = len(all_periods) @@ -984,5 +1037,7 @@ def _fit_global_with_fixed_lambda( # Fit model on control data and extract post-hoc tau mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) - att, _, _ = self._extract_posthoc_tau(Y, D, mu, alpha, beta, L) + att, _, _ = self._extract_posthoc_tau( + Y, D, mu, alpha, beta, L, unit_weights=local_weight_arr + ) return att diff --git a/diff_diff/trop_local.py b/diff_diff/trop_local.py index b7361fbe..dd1398ca 100644 --- a/diff_diff/trop_local.py +++ b/diff_diff/trop_local.py @@ -87,7 +87,7 @@ def _soft_threshold_svd( # Compute result, suppressing expected numerical warnings from # ill-conditioned matrices during alternating minimization - with np.errstate(divide='ignore', over='ignore', invalid='ignore'): + with np.errstate(divide="ignore", over="ignore", invalid="ignore"): result = (U_trunc * s_trunc) @ Vt_trunc # Replace any NaN/Inf in result with zeros @@ -178,8 +178,12 @@ def _precompute_structures( treated_observations = list(zip(*np.where(treated_mask))) # Control observations for LOOCV - control_obs = [(t, i) for t in range(n_periods) for i in range(n_units) - if control_mask[t, i] and not np.isnan(Y[t, i])] + control_obs = [ + (t, i) + for t in range(n_periods) + for i in range(n_units) + if control_mask[t, i] and not np.isnan(Y[t, i]) + ] return { "unit_dist_matrix": unit_dist_matrix, @@ -245,7 +249,7 @@ def _compute_all_unit_distances( # Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods) # diff has shape (n_units, n_units, n_periods) diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :] - sq_diff = diff ** 2 + sq_diff = diff**2 # Count valid (non-NaN) observations per pair # A difference is valid only if both units have valid observations @@ -257,7 +261,7 @@ def _compute_all_unit_distances( # Compute RMSE distance: sqrt(sum / n_valid) # Avoid division by zero - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): dist_matrix = np.sqrt(sq_diff_sum / n_valid) # Set pairs with no valid observations to inf @@ -733,8 +737,12 @@ def _loocv_score_obs_specific( control_obs = self._precomputed["control_obs"] else: # Get all control observations - control_obs = [(t, i) for t in range(n_periods) for i in range(n_units) - if control_mask[t, i] and not np.isnan(Y[t, i])] + control_obs = [ + (t, i) + for t in range(n_periods) + for i in range(n_units) + if control_mask[t, i] and not np.isnan(Y[t, i]) + ] # Empty control set check: if no control observations, return infinity # A score of 0.0 would incorrectly "win" over legitimate parameters @@ -743,7 +751,7 @@ def _loocv_score_obs_specific( f"LOOCV: No valid control observations for " f"\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}). " "Returning infinite score.", - UserWarning + UserWarning, ) return np.inf @@ -755,19 +763,23 @@ def _loocv_score_obs_specific( # Compute observation-specific weights for pseudo-treated (i, t) # Uses pre-computed distance matrices when available weight_matrix = self._compute_observation_weights( - Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, - n_units, n_periods + Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods ) # Estimate model excluding observation (t, i) alpha, beta, L = self._estimate_model( - Y, control_mask, weight_matrix, lambda_nn, - n_units, n_periods, exclude_obs=(t, i) + Y, + control_mask, + weight_matrix, + lambda_nn, + n_units, + n_periods, + exclude_obs=(t, i), ) # Pseudo treatment effect tau_ti = Y[t, i] - alpha[i] - beta[t] - L[t, i] - tau_squared_sum += tau_ti ** 2 + tau_squared_sum += tau_ti**2 n_valid += 1 except (np.linalg.LinAlgError, ValueError): @@ -777,7 +789,7 @@ def _loocv_score_obs_specific( f"LOOCV: Fit failed for observation ({t}, {i}) with " f"\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}). " "Returning infinite score per Equation 5.", - UserWarning + UserWarning, ) return np.inf @@ -796,6 +808,8 @@ def _bootstrap_variance( Y: Optional[np.ndarray] = None, D: Optional[np.ndarray] = None, control_unit_idx: Optional[np.ndarray] = None, + survey_design=None, + unit_weight_arr: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """ Compute bootstrap standard error using unit-level block bootstrap. @@ -848,20 +862,30 @@ def _bootstrap_variance( lambda_time, lambda_unit, lambda_nn = optimal_lambda # Try Rust backend for parallel bootstrap (5-15x speedup) - if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None - and self._precomputed is not None and Y is not None - and D is not None): + if ( + HAS_RUST_BACKEND + and _rust_bootstrap_trop_variance is not None + and self._precomputed is not None + and Y is not None + and D is not None + ): try: control_mask = self._precomputed["control_mask"] time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64) bootstrap_estimates, se = _rust_bootstrap_trop_variance( - Y, D.astype(np.float64), + Y, + D.astype(np.float64), control_mask.astype(np.uint8), time_dist_matrix, - lambda_time, lambda_unit, lambda_nn, - self.n_bootstrap, self.max_iter, self.tol, - self.seed if self.seed is not None else 0 + lambda_time, + lambda_unit, + lambda_nn, + self.n_bootstrap, + self.max_iter, + self.tol, + self.seed if self.seed is not None else 0, + unit_weight_arr, ) if len(bootstrap_estimates) >= 10: @@ -869,12 +893,10 @@ def _bootstrap_variance( # Fall through to Python if too few bootstrap samples logger.debug( "Rust bootstrap returned only %d samples, falling back to Python", - len(bootstrap_estimates) + len(bootstrap_estimates), ) except Exception as e: - logger.debug( - "Rust bootstrap variance failed, falling back to Python: %s", e - ) + logger.debug("Rust bootstrap variance failed, falling back to Python: %s", e) # Python implementation (fallback) rng = np.random.default_rng(self.seed) @@ -895,16 +917,12 @@ def _bootstrap_variance( # Stratified sampling: sample control and treated units separately # This preserves the treatment ratio in each bootstrap sample if n_control_units > 0: - sampled_control = rng.choice( - control_units, size=n_control_units, replace=True - ) + sampled_control = rng.choice(control_units, size=n_control_units, replace=True) else: sampled_control = np.array([], dtype=control_units.dtype) if n_treated_units > 0: - sampled_treated = rng.choice( - treated_units, size=n_treated_units, replace=True - ) + sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True) else: sampled_treated = np.array([], dtype=treated_units.dtype) @@ -912,18 +930,27 @@ def _bootstrap_variance( sampled_units = np.concatenate([sampled_control, sampled_treated]) # Create bootstrap sample with unique unit IDs - boot_data = pd.concat([ - data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) - for idx, u in enumerate(sampled_units) - ], ignore_index=True) + boot_data = pd.concat( + [ + data[data[unit] == u].assign(**{unit: f"{u}_{idx}"}) + for idx, u in enumerate(sampled_units) + ], + ignore_index=True, + ) try: # Fit with fixed lambda (skip LOOCV for speed) att = self._fit_with_fixed_lambda( - boot_data, outcome, treatment, unit, time, - optimal_lambda + boot_data, + outcome, + treatment, + unit, + time, + optimal_lambda, + survey_design=survey_design, ) - bootstrap_estimates_list.append(att) + if np.isfinite(att): + bootstrap_estimates_list.append(att) except (ValueError, np.linalg.LinAlgError, KeyError): continue @@ -933,7 +960,7 @@ def _bootstrap_variance( warnings.warn( f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. " "Standard errors may be unreliable.", - UserWarning + UserWarning, ) if len(bootstrap_estimates) == 0: return np.nan, np.array([]) @@ -949,6 +976,7 @@ def _fit_with_fixed_lambda( unit: str, time: str, fixed_lambda: Tuple[float, float, float], + survey_design=None, ) -> float: """ Fit model with fixed tuning parameters (for bootstrap). @@ -958,6 +986,17 @@ def _fit_with_fixed_lambda( """ lambda_time, lambda_unit, lambda_nn = fixed_lambda + # Extract survey weights from bootstrap data (units are renamed) + if survey_design is not None and survey_design.weights is not None: + from diff_diff.survey import _extract_unit_survey_weights + + local_all_units = sorted(data[unit].unique()) + local_weight_arr = _extract_unit_survey_weights( + data, unit, survey_design, local_all_units + ) + else: + local_weight_arr = None + # Setup matrices all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) @@ -986,29 +1025,39 @@ def _fit_with_fixed_lambda( control_unit_idx = np.where(~unit_ever_treated)[0] # Get list of treated observations - treated_observations = [(t, i) for t in range(n_periods) for i in range(n_units) - if D[t, i] == 1] + treated_observations = [ + (t, i) for t in range(n_periods) for i in range(n_units) if D[t, i] == 1 + ] if not treated_observations: raise ValueError("No treated observations") # Compute ATT using observation-specific weights (Algorithm 2) tau_values = [] + tau_weights = [] for t, i in treated_observations: + # Skip non-finite outcomes (match main fit NaN contract) + if not np.isfinite(Y[t, i]): + continue + # Compute observation-specific weights for this (i, t) weight_matrix = self._compute_observation_weights( - Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, - n_units, n_periods + Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods ) # Fit model with these weights alpha, beta, L = self._estimate_model( - Y, control_mask, weight_matrix, lambda_nn, - n_units, n_periods + Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods ) # Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it} tau = Y[t, i] - alpha[i] - beta[t] - L[t, i] tau_values.append(tau) - - return np.mean(tau_values) + if local_weight_arr is not None: + tau_weights.append(local_weight_arr[i]) + + if not tau_values: + return float("nan") + if local_weight_arr is not None: + return float(np.average(tau_values, weights=tau_weights)) + return float(np.mean(tau_values)) diff --git a/diff_diff/trop_results.py b/diff_diff/trop_results.py index a2189e81..94c38657 100644 --- a/diff_diff/trop_results.py +++ b/diff_diff/trop_results.py @@ -29,7 +29,7 @@ # Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0). # For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3: # exp(-0 × dist) = 1 for all distances. -_LAMBDA_INF: float = float('inf') +_LAMBDA_INF: float = float("inf") class _PrecomputedStructures(TypedDict): @@ -147,6 +147,8 @@ class TROPResults: n_post_periods: int = 0 n_bootstrap: Optional[int] = field(default=None) bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None) def __repr__(self) -> str: """Concise string representation.""" @@ -203,25 +205,51 @@ def summary(self, alpha: Optional[float] = None) -> str: if self.n_bootstrap is not None: lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") - lines.extend([ - "", - "-" * 75, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'':>5}", - "-" * 75, - f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} " - f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", - "-" * 75, - "", - f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", - ]) + # Add survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "", + "-" * 75, + "Survey Design".center(75), + "-" * 75, + f"{'Weight type:':<25} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<25} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<25} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<25} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<25} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<25} {sm.df_survey:>10}") + lines.append("-" * 75) + + lines.extend( + [ + "", + "-" * 75, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'':>5}", + "-" * 75, + f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} " + f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", + "-" * 75, + "", + f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", + ] + ) # Add significance codes - lines.extend([ - "", - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 75, - ]) + lines.extend( + [ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 75, + ] + ) return "\n".join(lines) @@ -238,7 +266,7 @@ def to_dict(self) -> Dict[str, Any]: Dict[str, Any] Dictionary containing all estimation results. """ - return { + result = { "att": self.att, "se": self.se, "t_stat": self.t_stat, @@ -257,6 +285,16 @@ def to_dict(self) -> Dict[str, Any]: "effective_rank": self.effective_rank, "loocv_score": self.loocv_score, } + if self.survey_metadata is not None: + sm = self.survey_metadata + result["weight_type"] = sm.weight_type + result["effective_n"] = sm.effective_n + result["design_effect"] = sm.design_effect + result["sum_weights"] = sm.sum_weights + result["n_strata"] = sm.n_strata + result["n_psu"] = sm.n_psu + result["df_survey"] = sm.df_survey + return result def to_dataframe(self) -> pd.DataFrame: """ @@ -278,10 +316,12 @@ def get_treatment_effects_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with unit, time, and treatment effect columns. """ - return pd.DataFrame([ - {"unit": unit, "time": time, "effect": effect} - for (unit, time), effect in self.treatment_effects.items() - ]) + return pd.DataFrame( + [ + {"unit": unit, "time": time, "effect": effect} + for (unit, time), effect in self.treatment_effects.items() + ] + ) def get_unit_effects_df(self) -> pd.DataFrame: """ @@ -292,10 +332,9 @@ def get_unit_effects_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with unit and effect columns. """ - return pd.DataFrame([ - {"unit": unit, "effect": effect} - for unit, effect in self.unit_effects.items() - ]) + return pd.DataFrame( + [{"unit": unit, "effect": effect} for unit, effect in self.unit_effects.items()] + ) def get_time_effects_df(self) -> pd.DataFrame: """ @@ -306,10 +345,9 @@ def get_time_effects_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with time and effect columns. """ - return pd.DataFrame([ - {"time": time, "effect": effect} - for time, effect in self.time_effects.items() - ]) + return pd.DataFrame( + [{"time": time, "effect": effect} for time, effect in self.time_effects.items()] + ) @property def is_significant(self) -> bool: diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 60eb521d..24b17f75 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1143,6 +1143,7 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi - **Varying treatment within unit**: Raises `ValueError`. SDID requires block treatment (constant within each unit). Suggests CallawaySantAnna or ImputationDiD for staggered adoption. - **Unbalanced panel**: Raises `ValueError`. SDID requires all units observed in all periods. Suggests `balance_panel()`. - **Poor pre-treatment fit**: Warns (`UserWarning`) when `pre_fit_rmse > std(treated_pre_outcomes, ddof=1)`. Diagnostic only; estimation proceeds. +- **Note:** Survey support: pweight only (strata/PSU/FPC raise NotImplementedError). Both sides weighted per WLS regression interpretation: treated-side means are survey-weighted (Frank-Wolfe target and ATT formula); control-side synthetic weights are composed with survey weights post-optimization (ω_eff = ω * w_co, renormalized). Frank-Wolfe optimization itself is unweighted — survey importance enters after trajectory-matching. Covariate residualization uses WLS with survey weights. Placebo and bootstrap SE preserve survey weights on both sides. **Reference implementation(s):** - R: `synthdid::synthdid_estimate()` (Arkhangelsky et al.'s official package) @@ -1396,6 +1397,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - [x] No post_periods parameter (D matrix determines treatment timing) - [x] D matrix semantics documented (absorbing state, not event indicator) - [x] Unbalanced panels supported (missing observations don't trigger false violations) +- **Note:** Survey support: pweight only (strata/PSU/FPC raise NotImplementedError). Survey weights enter ATT aggregation only — population-weighted average of per-observation treatment effects. Model fitting (kernel weights, LOOCV, nuclear norm regularization) stays unchanged. Rust and Python bootstrap paths both support survey-weighted ATT in each iteration. ### TROP Global Estimation Method diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index ba97f088..289120e6 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -63,18 +63,21 @@ TripleDifference IPW/DR from Phase 3 deferred work. | CallawaySantAnna | Covariates + IPW/DR + survey | Phase 5: DRDID panel nuisance IF corrections | | CallawaySantAnna | Efficient DRDID nuisance IF for reg+covariates | Phase 5: replace conservative plug-in IF with semiparametrically efficient IF | -### Remaining for Phase 5 +## Implemented (Phase 5): SyntheticDiD + TROP Survey Support -| Estimator | File | Complexity | Notes | -|-----------|------|------------|-------| -| SyntheticDiD | `synthetic_did.py` | Medium | Survey-weighted treated mean in optimization, weighted placebo variance (bootstrap-based SE only) | -| TROP | `trop.py` | Medium | Survey weights in ATT aggregation and LOOCV (bootstrap-based SE only) | +| Estimator | File | Survey Support | Notes | +|-----------|------|----------------|-------| +| SyntheticDiD | `synthetic_did.py` | pweight | Both sides weighted (WLS interpretation): treated means survey-weighted, ω composed with control survey weights post-optimization. Placebo and bootstrap SE preserve survey weights. Covariates residualized with WLS. | +| TROP | `trop.py` | pweight | ATT aggregation only: population-weighted average of per-observation treatment effects. Model fitting (kernel weights, LOOCV, nuclear norm) unchanged. Rust and Python bootstrap paths both support weighted ATT. | + +### Phase 5 Deferred Work -## Phase 5: Advanced Features + Remaining Estimators +| Estimator | Deferred Capability | Blocker | +|-----------|-------------------|---------| +| SyntheticDiD | strata/PSU/FPC + survey-aware bootstrap | Bootstrap + Survey Interaction | +| TROP | strata/PSU/FPC + survey-aware bootstrap | Bootstrap + Survey Interaction | -### SyntheticDiD and TROP Survey Support -Both estimators use bootstrap/placebo for SE with no analytical variance path. -Phase 5 provides survey-weighted point estimates and survey-aware bootstrap SE. +## Phase 6: Advanced Features ### Bootstrap + Survey Interaction Unblock bootstrap + survey for all estimators that currently defer it diff --git a/rust/src/trop.rs b/rust/src/trop.rs index e26528e9..605303a7 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -914,11 +914,15 @@ fn max_abs_diff_2d(a: &Array2, b: &Array2) -> f64 { /// * `max_iter` - Maximum iterations for model estimation /// * `tol` - Convergence tolerance /// * `seed` - Random seed +/// * `survey_weights` - Optional unit-level survey weights (length n_units). +/// When provided, ATT is computed as a weighted mean of per-observation +/// treatment effects using unit weights. Model fitting, LOOCV, and distance +/// computation are unchanged. /// /// # Returns /// (bootstrap_estimates, standard_error) #[pyfunction] -#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))] +#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed, survey_weights=None))] #[allow(clippy::too_many_arguments)] pub fn bootstrap_trop_variance<'py>( py: Python<'py>, @@ -933,11 +937,13 @@ pub fn bootstrap_trop_variance<'py>( max_iter: usize, tol: f64, seed: u64, + survey_weights: Option>, ) -> PyResult<(Bound<'py, PyArray1>, f64)> { let y_arr = y.as_array().to_owned(); let d_arr = d.as_array().to_owned(); let control_mask_arr = control_mask.as_array().to_owned(); let time_dist_arr = time_dist_matrix.as_array().to_owned(); + let sw_arr: Option> = survey_weights.map(|sw| sw.as_array().to_owned()); let n_units = y_arr.ncols(); let n_periods = y_arr.nrows(); @@ -1022,9 +1028,18 @@ pub fn bootstrap_trop_variance<'py>( } // Compute ATT for bootstrap sample - let mut tau_values = Vec::with_capacity(boot_treated.len()); + // When survey weights are provided, ATT is a weighted mean of + // per-observation treatment effects using unit-level weights. + let mut tau_sum = 0.0; + let mut weight_sum = 0.0; + let mut tau_count = 0usize; for (t, i) in boot_treated { + // Skip non-finite outcomes (match main fit NaN contract) + if !y_boot[[t, i]].is_finite() { + continue; + } + let weight_matrix = compute_weight_matrix( &y_boot.view(), &d_boot.view(), @@ -1049,14 +1064,20 @@ pub fn bootstrap_trop_variance<'py>( None, ) { let tau = y_boot[[t, i]] - alpha[i] - beta[t] - l[[t, i]]; - tau_values.push(tau); + let w = match &sw_arr { + Some(sw) => sw[sampled_units[i]], + None => 1.0, + }; + tau_sum += w * tau; + weight_sum += w; + tau_count += 1; } } - if tau_values.is_empty() { + if tau_count == 0 { None } else { - Some(tau_values.iter().sum::() / tau_values.len() as f64) + Some(tau_sum / weight_sum) } }) .collect(); @@ -1649,11 +1670,15 @@ pub fn loocv_grid_search_global<'py>( /// * `max_iter` - Maximum iterations for model estimation /// * `tol` - Convergence tolerance /// * `seed` - Random seed +/// * `survey_weights` - Optional unit-level survey weights (length n_units). +/// When provided, ATT is computed as a weighted mean of per-observation +/// treatment effects using unit weights. Model fitting, LOOCV, and distance +/// computation are unchanged. /// /// # Returns /// (bootstrap_estimates, standard_error) #[pyfunction] -#[pyo3(signature = (y, d, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))] +#[pyo3(signature = (y, d, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed, survey_weights=None))] #[allow(clippy::too_many_arguments)] pub fn bootstrap_trop_variance_global<'py>( py: Python<'py>, @@ -1666,9 +1691,11 @@ pub fn bootstrap_trop_variance_global<'py>( max_iter: usize, tol: f64, seed: u64, + survey_weights: Option>, ) -> PyResult<(Bound<'py, PyArray1>, f64)> { let y_arr = y.as_array().to_owned(); let d_arr = d.as_array().to_owned(); + let sw_arr: Option> = survey_weights.map(|sw| sw.as_array().to_owned()); let n_units = y_arr.ncols(); let n_periods = y_arr.nrows(); @@ -1767,19 +1794,27 @@ pub fn bootstrap_trop_variance_global<'py>( }; // Post-hoc tau extraction: ATT = mean(Y - mu - alpha - beta - L) over treated + // When survey weights are provided, ATT is a weighted mean using unit-level weights. result.and_then(|(mu, alpha, beta, l)| { let mut tau_sum = 0.0; + let mut weight_sum = 0.0; let mut tau_count = 0; for t in 0..n_periods { for i in 0..n_units { if d_boot[[t, i]] == 1.0 && y_boot[[t, i]].is_finite() { - tau_sum += y_boot[[t, i]] - mu - alpha[i] - beta[t] - l[[t, i]]; + let tau = y_boot[[t, i]] - mu - alpha[i] - beta[t] - l[[t, i]]; + let w = match &sw_arr { + Some(sw) => sw[sampled_units[i]], + None => 1.0, + }; + tau_sum += w * tau; + weight_sum += w; tau_count += 1; } } } if tau_count > 0 { - Some(tau_sum / tau_count as f64) + Some(tau_sum / weight_sum) } else { None } diff --git a/tests/test_survey_phase5.py b/tests/test_survey_phase5.py new file mode 100644 index 00000000..79bd00e4 --- /dev/null +++ b/tests/test_survey_phase5.py @@ -0,0 +1,789 @@ +"""Tests for Phase 5 survey support: SyntheticDiD and TROP. + +Covers: pweight-only survey integration for both estimators, including +point estimate weighting, bootstrap/placebo SE threading, survey_metadata +in results, error guards for unsupported designs, and scale invariance. +""" + +import numpy as np +import pandas as pd +import pytest + +from diff_diff import SurveyDesign, SyntheticDiD +from diff_diff.trop import TROP, trop + +# ============================================================================= +# Shared Fixtures +# ============================================================================= + + +@pytest.fixture +def sdid_survey_data(): + """Balanced panel for SDID with survey design columns. + + 20 units (5 treated, 15 control), 10 periods, block treatment at period 6. + Unit-constant weight column that varies across units. + """ + np.random.seed(42) + n_units = 20 + n_periods = 10 + n_treated = 5 + + units = list(range(n_units)) + periods = list(range(n_periods)) + + rows = [] + for u in units: + is_treated = 1 if u < n_treated else 0 + base = np.random.randn() * 2 + for t in periods: + y = base + 0.5 * t + np.random.randn() * 0.5 + if is_treated and t >= 6: + y += 2.0 # treatment effect + rows.append({"unit": u, "time": t, "outcome": y, "treated": is_treated}) + + data = pd.DataFrame(rows) + + # Unit-constant survey columns + unit_weight = 1.0 + np.arange(n_units) * 0.1 # [1.0, 1.1, ..., 2.9] + unit_stratum = np.arange(n_units) // 10 + unit_psu = np.arange(n_units) // 5 + unit_map = {u: i for i, u in enumerate(units)} + idx = data["unit"].map(unit_map).values + + data["weight"] = unit_weight[idx] + data["stratum"] = unit_stratum[idx] + data["psu"] = unit_psu[idx] + + return data + + +@pytest.fixture +def trop_survey_data(): + """Panel data for TROP with absorbing-state D and survey columns. + + 20 units (5 treated starting at period 5), 10 periods. + """ + np.random.seed(123) + n_units = 20 + n_periods = 10 + n_treated = 5 + + units = list(range(n_units)) + periods = list(range(n_periods)) + + rows = [] + for u in units: + is_treated_unit = u < n_treated + base = np.random.randn() * 2 + for t in periods: + y = base + 0.3 * t + np.random.randn() * 0.5 + # Absorbing state: D=1 for t >= 5 if treated unit + d = 1 if (is_treated_unit and t >= 5) else 0 + if d == 1: + y += 1.5 # treatment effect + rows.append({"unit": u, "time": t, "outcome": y, "D": d}) + + data = pd.DataFrame(rows) + + # Unit-constant survey columns + unit_weight = 1.0 + np.arange(n_units) * 0.15 + unit_stratum = np.arange(n_units) // 10 + unit_psu = np.arange(n_units) // 5 + unit_map = {u: i for i, u in enumerate(units)} + idx = data["unit"].map(unit_map).values + + data["weight"] = unit_weight[idx] + data["stratum"] = unit_stratum[idx] + data["psu"] = unit_psu[idx] + + return data + + +@pytest.fixture +def survey_design_weights(): + return SurveyDesign(weights="weight") + + +@pytest.fixture +def survey_design_full(): + return SurveyDesign(weights="weight", strata="stratum", psu="psu") + + +# ============================================================================= +# SyntheticDiD Survey Tests +# ============================================================================= + + +class TestSyntheticDiDSurvey: + """Survey support tests for SyntheticDiD.""" + + def test_smoke_weights_only(self, sdid_survey_data, survey_design_weights): + """Fit completes and survey_metadata is populated.""" + est = SyntheticDiD(n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + assert result.survey_metadata is not None + assert np.isfinite(result.att) + assert np.isfinite(result.se) + + def test_uniform_weights_match_unweighted(self, sdid_survey_data): + """Uniform weights (all 1.0) produce same ATT as unweighted.""" + sdid_survey_data = sdid_survey_data.copy() + sdid_survey_data["uniform_w"] = 1.0 + + est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42) + result_no_survey = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + ) + result_uniform = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=SurveyDesign(weights="uniform_w"), + ) + assert result_uniform.att == pytest.approx(result_no_survey.att, abs=1e-10) + + def test_survey_metadata_fields(self, sdid_survey_data, survey_design_weights): + """Metadata has correct weight_type, effective_n, design_effect.""" + est = SyntheticDiD(n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + sm = result.survey_metadata + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + + def test_strata_psu_fpc_raises(self, sdid_survey_data, survey_design_full): + """Full design raises NotImplementedError.""" + est = SyntheticDiD(n_bootstrap=50, seed=42) + with pytest.raises(NotImplementedError, match="strata/PSU/FPC"): + est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_full, + ) + + def test_fweight_aweight_raises(self, sdid_survey_data): + """Non-pweight raises ValueError.""" + est = SyntheticDiD(n_bootstrap=50, seed=42) + sd = SurveyDesign(weights="weight", weight_type="fweight") + with pytest.raises(ValueError, match="pweight"): + est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=sd, + ) + + def test_weighted_att_differs(self, sdid_survey_data, survey_design_weights): + """Non-uniform weights produce different ATT than unweighted.""" + est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42) + result_no_survey = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + ) + result_survey = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + # ATTs should differ since weights are non-uniform + assert result_survey.att != pytest.approx(result_no_survey.att, abs=1e-6) + + def test_summary_includes_survey(self, sdid_survey_data, survey_design_weights): + """summary() output contains Survey Design section.""" + est = SyntheticDiD(n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + def test_bootstrap_with_survey(self, sdid_survey_data, survey_design_weights): + """variance_method='bootstrap' completes with survey weights.""" + est = SyntheticDiD(variance_method="bootstrap", n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + assert np.isfinite(result.se) + assert result.se > 0 + + def test_placebo_with_survey(self, sdid_survey_data, survey_design_weights): + """variance_method='placebo' completes with survey weights.""" + est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + assert np.isfinite(result.se) + assert result.se > 0 + + def test_weight_scale_invariance(self, sdid_survey_data, survey_design_weights): + """Multiplying all weights by constant produces same ATT.""" + sdid_survey_data = sdid_survey_data.copy() + sdid_survey_data["weight_2x"] = sdid_survey_data["weight"] * 2.0 + + est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42) + result_1x = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + result_2x = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=SurveyDesign(weights="weight_2x"), + ) + assert result_2x.att == pytest.approx(result_1x.att, rel=1e-6) + + def test_unit_varying_survey_raises(self, sdid_survey_data): + """Time-varying weight column raises ValueError.""" + sdid_survey_data = sdid_survey_data.copy() + sdid_survey_data["bad_weight"] = sdid_survey_data["weight"] + sdid_survey_data["time"] * 0.1 + est = SyntheticDiD(n_bootstrap=50, seed=42) + with pytest.raises(ValueError): + est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=SurveyDesign(weights="bad_weight"), + ) + + def test_to_dict_includes_survey(self, sdid_survey_data, survey_design_weights): + """to_dict() output includes survey metadata fields.""" + est = SyntheticDiD(n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + d = result.to_dict() + assert "weight_type" in d + assert d["weight_type"] == "pweight" + + def test_covariates_with_survey(self, sdid_survey_data, survey_design_weights): + """Covariates + survey_design smoke test (WLS residualization).""" + sdid_survey_data = sdid_survey_data.copy() + sdid_survey_data["x1"] = np.random.randn(len(sdid_survey_data)) + est = SyntheticDiD(n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + covariates=["x1"], + survey_design=survey_design_weights, + ) + assert np.isfinite(result.att) + assert result.survey_metadata is not None + + def test_effective_weights_returned(self, sdid_survey_data, survey_design_weights): + """unit_weights returns composed ω_eff (not raw ω) under survey weighting.""" + est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42) + result = est.fit( + sdid_survey_data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=survey_design_weights, + ) + weights = result.unit_weights + # Effective weights should sum to 1 (renormalized) + assert sum(weights.values()) == pytest.approx(1.0, abs=1e-10) + # With non-uniform survey weights, effective weights should differ + # from what uniform survey weights would produce + sdid_survey_data_u = sdid_survey_data.copy() + sdid_survey_data_u["uniform_w"] = 1.0 + result_u = est.fit( + sdid_survey_data_u, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[6, 7, 8, 9], + survey_design=SurveyDesign(weights="uniform_w"), + ) + # Non-uniform weights should change the returned weight distribution + eff_vals = sorted(weights.values(), reverse=True) + uni_vals = sorted(result_u.unit_weights.values(), reverse=True) + assert eff_vals != pytest.approx(uni_vals, abs=1e-6) + + +# ============================================================================= +# TROP Survey Tests +# ============================================================================= + + +class TestTROPSurvey: + """Survey support tests for TROP (local and global methods).""" + + def test_smoke_local_weights_only(self, trop_survey_data, survey_design_weights): + """Local method completes with survey weights.""" + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + assert result.survey_metadata is not None + assert np.isfinite(result.att) + + def test_smoke_global_weights_only(self, trop_survey_data, survey_design_weights): + """Global method completes with survey weights.""" + est = TROP(method="global", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + assert result.survey_metadata is not None + assert np.isfinite(result.att) + + def test_uniform_weights_match_local(self, trop_survey_data): + """Uniform weights produce same ATT as unweighted (local).""" + trop_survey_data = trop_survey_data.copy() + trop_survey_data["uniform_w"] = 1.0 + + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result_no_survey = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + ) + result_uniform = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=SurveyDesign(weights="uniform_w"), + ) + assert result_uniform.att == pytest.approx(result_no_survey.att, abs=1e-10) + + def test_uniform_weights_match_global(self, trop_survey_data): + """Uniform weights produce same ATT as unweighted (global).""" + trop_survey_data = trop_survey_data.copy() + trop_survey_data["uniform_w"] = 1.0 + + est = TROP(method="global", n_bootstrap=10, seed=42, max_iter=5) + result_no_survey = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + ) + result_uniform = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=SurveyDesign(weights="uniform_w"), + ) + assert result_uniform.att == pytest.approx(result_no_survey.att, abs=1e-10) + + def test_survey_metadata_fields(self, trop_survey_data, survey_design_weights): + """Metadata has correct fields.""" + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + sm = result.survey_metadata + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + + def test_strata_psu_fpc_raises(self, trop_survey_data, survey_design_full): + """Full design raises NotImplementedError.""" + est = TROP(method="local", n_bootstrap=10, seed=42) + with pytest.raises(NotImplementedError, match="strata/PSU/FPC"): + est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_full, + ) + + def test_fweight_aweight_raises(self, trop_survey_data): + """Non-pweight raises ValueError.""" + est = TROP(method="local", n_bootstrap=10, seed=42) + sd = SurveyDesign(weights="weight", weight_type="aweight") + with pytest.raises(ValueError, match="pweight"): + est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=sd, + ) + + def test_weighted_att_differs(self, trop_survey_data, survey_design_weights): + """Non-uniform weights change ATT.""" + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result_no = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + ) + result_survey = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + assert result_survey.att != pytest.approx(result_no.att, abs=1e-6) + + def test_weighted_att_differs_global(self, trop_survey_data, survey_design_weights): + """Non-uniform weights change ATT for method='global'.""" + est = TROP(method="global", n_bootstrap=10, seed=42, max_iter=5) + result_no = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + ) + result_survey = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + assert result_survey.att != pytest.approx(result_no.att, abs=1e-6) + + def test_summary_includes_survey(self, trop_survey_data, survey_design_weights): + """summary() contains survey section.""" + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + summary = result.summary() + assert "Survey Design" in summary + + def test_weight_scale_invariance(self, trop_survey_data, survey_design_weights): + """Scale invariance: 2x weights produce same ATT.""" + trop_survey_data = trop_survey_data.copy() + trop_survey_data["weight_3x"] = trop_survey_data["weight"] * 3.0 + + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result_1x = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + result_3x = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=SurveyDesign(weights="weight_3x"), + ) + assert result_3x.att == pytest.approx(result_1x.att, rel=1e-6) + + def test_unit_varying_survey_raises(self, trop_survey_data): + """Validation catches time-varying weights.""" + trop_survey_data = trop_survey_data.copy() + trop_survey_data["bad_weight"] = trop_survey_data["weight"] + trop_survey_data["time"] * 0.1 + est = TROP(method="local", n_bootstrap=10, seed=42) + with pytest.raises(ValueError): + est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=SurveyDesign(weights="bad_weight"), + ) + + def test_convenience_function_with_survey(self, trop_survey_data, survey_design_weights): + """trop() convenience function accepts survey_design.""" + result = trop( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + n_bootstrap=10, + seed=42, + max_iter=5, + ) + assert result.survey_metadata is not None + + def test_to_dict_includes_survey(self, trop_survey_data, survey_design_weights): + """to_dict() includes survey metadata fields.""" + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + d = result.to_dict() + assert "weight_type" in d + assert d["weight_type"] == "pweight" + + def test_local_bootstrap_nan_treated_outcomes(self, trop_survey_data): + """Bootstrap handles NaN treated outcomes without poisoning SE.""" + trop_survey_data = trop_survey_data.copy() + # Set some treated post-treatment outcomes to NaN + mask = (trop_survey_data["D"] == 1) & (trop_survey_data["time"] == 7) + trop_survey_data.loc[mask, "outcome"] = np.nan + + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + ) + # Point estimate should use finite cells only + assert np.isfinite(result.att) + # SE should remain finite (not poisoned by NaN) + assert np.isfinite(result.se) + + def test_local_bootstrap_nan_with_survey(self, trop_survey_data, survey_design_weights): + """Bootstrap + survey handles NaN treated outcomes correctly.""" + trop_survey_data = trop_survey_data.copy() + mask = (trop_survey_data["D"] == 1) & (trop_survey_data["time"] == 8) + trop_survey_data.loc[mask, "outcome"] = np.nan + + est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5) + result = est.fit( + trop_survey_data, + outcome="outcome", + treatment="D", + unit="unit", + time="time", + survey_design=survey_design_weights, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + + +# ============================================================================= +# Pinned Numerical Tests +# ============================================================================= + + +class TestPinnedNumerical: + """Deterministic numerical tests for exact weighted formulas.""" + + def test_sdid_weighted_att_manual(self): + """Manual ATT check: survey-weighted treated means + ω∘w_co composition.""" + # Tiny 2x2 balanced panel: 2 control, 1 treated, 2 pre + 1 post + np.random.seed(99) + data = pd.DataFrame( + { + "unit": [0, 0, 0, 1, 1, 1, 2, 2, 2], + "time": [0, 1, 2, 0, 1, 2, 0, 1, 2], + "outcome": [1.0, 2.0, 3.0, 2.0, 3.0, 4.5, 5.0, 6.0, 10.0], + "treated": [0, 0, 0, 0, 0, 0, 1, 1, 1], + "weight": [1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0], + } + ) + # Single treated unit → treated means are trivially that unit's outcomes + # (survey weight doesn't change a single-unit mean) + est = SyntheticDiD(variance_method="placebo", n_bootstrap=20, seed=42) + result = est.fit( + data, + outcome="outcome", + treatment="treated", + unit="unit", + time="time", + post_periods=[2], + survey_design=SurveyDesign(weights="weight"), + ) + # Verify unit_weights sum to 1 (composed with survey) + assert sum(result.unit_weights.values()) == pytest.approx(1.0, abs=1e-10) + assert np.isfinite(result.att) + + def test_trop_weighted_att_aggregation(self): + """Verify TROP ATT = weighted mean of tau values.""" + # Create data where we can predict directional effect of weighting + np.random.seed(77) + n_units = 15 + n_periods = 6 + n_treated = 3 + + units = list(range(n_units)) + periods = list(range(n_periods)) + + rows = [] + for u in units: + is_treated = u < n_treated + base = u * 0.5 + for t in periods: + y = base + 0.2 * t + np.random.randn() * 0.3 + d = 1 if (is_treated and t >= 3) else 0 + if d == 1: + # Different effect per unit: unit 0 gets +1, unit 1 gets +3, unit 2 gets +5 + y += 1.0 + 2.0 * u + rows.append({"unit": u, "time": t, "outcome": y, "D": d}) + + data = pd.DataFrame(rows) + # Weight unit 2 (biggest effect) heavily + weights = np.ones(n_units) + weights[2] = 10.0 # unit 2 has effect ~5, heavily weighted + unit_map = {u: i for i, u in enumerate(units)} + data["weight"] = weights[data["unit"].map(unit_map).values] + + est_no = TROP(method="local", n_bootstrap=5, seed=42, max_iter=3) + result_no = est_no.fit(data, "outcome", "D", "unit", "time") + + est_w = TROP(method="local", n_bootstrap=5, seed=42, max_iter=3) + result_w = est_w.fit( + data, + "outcome", + "D", + "unit", + "time", + survey_design=SurveyDesign(weights="weight"), + ) + + # Weighted ATT should be pulled toward unit 2's larger effect + assert result_w.att > result_no.att + + def test_sdid_to_dict_schema_matches_did(self): + """SyntheticDiDResults.to_dict() survey fields match DiDResults schema.""" + np.random.seed(42) + data = pd.DataFrame( + { + "unit": [0, 0, 1, 1, 2, 2], + "time": [0, 1, 0, 1, 0, 1], + "outcome": [1.0, 2.0, 2.0, 3.0, 5.0, 8.0], + "treated": [0, 0, 0, 0, 1, 1], + "weight": [1.0, 1.0, 2.0, 2.0, 1.5, 1.5], + } + ) + est = SyntheticDiD(n_bootstrap=10, seed=42) + result = est.fit( + data, + "outcome", + "treated", + "unit", + "time", + post_periods=[1], + survey_design=SurveyDesign(weights="weight"), + ) + d = result.to_dict() + # Schema alignment: all these fields should be present + for key in [ + "weight_type", + "effective_n", + "design_effect", + "sum_weights", + "n_strata", + "n_psu", + "df_survey", + ]: + assert key in d, f"Missing key: {key}" diff --git a/tests/test_trop.py b/tests/test_trop.py index 0372f155..fd3762e1 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2137,9 +2137,9 @@ def test_variance_estimation_uses_converted_params(self, simple_panel_data): original_fit_with_fixed = TROP._fit_with_fixed_lambda captured_lambda = [] - def tracking_fit(self, data, outcome, treatment, unit, time, fixed_lambda): + def tracking_fit(self, data, outcome, treatment, unit, time, fixed_lambda, **kwargs): captured_lambda.append(fixed_lambda) - return original_fit_with_fixed(self, data, outcome, treatment, unit, time, fixed_lambda) + return original_fit_with_fixed(self, data, outcome, treatment, unit, time, fixed_lambda, **kwargs) with patch.object(TROP, '_fit_with_fixed_lambda', tracking_fit): results = trop_est.fit(