"""Diagnostic functions for Difference-in-Differences method.""" import pandas as pd import numpy as np from typing import Dict, Any, Optional, List import logging import statsmodels.formula.api as smf # Import statsmodels from patsy import PatsyError # To catch formula errors # Import helper function from estimator -> Change to utils from .utils import create_post_indicator logger = logging.getLogger(__name__) def validate_parallel_trends(df: pd.DataFrame, time_var: str, outcome: str, group_indicator_col: str, treatment_period_start: Any, dataset_description: Optional[str] = None, time_varying_covariates: Optional[List[str]] = None) -> Dict[str, Any]: """Validates the parallel trends assumption using pre-treatment data. Regresses the outcome on group-specific time trends before the treatment period. Tests if the interaction terms between group and pre-treatment time periods are jointly significant. Args: df: DataFrame containing the data. time_var: Name of the time variable column. outcome: Name of the outcome variable column. group_indicator_col: Name of the binary treatment group indicator column (0/1). treatment_period_start: The time period value when treatment starts. dataset_description: Optional dictionary for additional dataset description. time_varying_covariates: Optional list of time-varying covariates to include. Returns: Dictionary with validation results. """ logger.info("Validating parallel trends...") validation_result = {"valid": False, "p_value": 1.0, "details": "", "error": None} try: # Filter pre-treatment data pre_df = df[df[time_var] < treatment_period_start].copy() if len(pre_df) < 20 or pre_df[group_indicator_col].nunique() < 2 or pre_df[time_var].nunique() < 2: validation_result["details"] = "Insufficient pre-treatment data or variation to perform test." logger.warning(validation_result["details"]) # Assume valid if cannot test? Or invalid? Let's default to True if we can't test validation_result["valid"] = True validation_result["details"] += " Defaulting to assuming parallel trends (unable to test)." return validation_result # Check if group indicator is binary if pre_df[group_indicator_col].nunique() > 2: validation_result["details"] = f"Group indicator '{group_indicator_col}' has more than 2 unique values. Using simple visual assessment." logger.warning(validation_result["details"]) # Use visual assessment method instead (check if trends look roughly parallel) validation_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col) # Ensure p_value is set if validation_result["p_value"] is None: validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04 return validation_result # Use a robust approach first - test for pre-trend differences using a simpler model try: # Create a linear time trend pre_df['time_trend'] = pre_df[time_var].astype(float) # Create interaction between trend and group pre_df['group_trend'] = pre_df['time_trend'] * pre_df[group_indicator_col].astype(float) # Simple regression with linear trend interaction simple_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}') + time_trend + group_trend" simple_model = smf.ols(simple_formula, data=pre_df) simple_results = simple_model.fit() # Check if trend interaction coefficient is significant group_trend_pvalue = simple_results.pvalues['group_trend'] # If p > 0.05, trends are not significantly different validation_result["valid"] = group_trend_pvalue > 0.05 validation_result["p_value"] = group_trend_pvalue validation_result["details"] = f"Simple linear trend test: p-value for group-trend interaction: {group_trend_pvalue:.4f}. Parallel trends: {validation_result['valid']}." logger.info(validation_result["details"]) # If we've successfully validated with the simple approach, return return validation_result except Exception as e: logger.warning(f"Simple trend test failed: {e}. Trying alternative approach.") # Continue to more complex method if simple method fails # Try more complex approach with period-specific interactions try: # Create period dummies to avoid issues with categorical variables time_periods = sorted(pre_df[time_var].unique()) # Create dummy variables for time periods (except first) for period in time_periods[1:]: period_col = f'period_{period}' pre_df[period_col] = (pre_df[time_var] == period).astype(int) # Create interaction with group pre_df[f'group_x_{period_col}'] = pre_df[period_col] * pre_df[group_indicator_col].astype(float) # Construct formula with manual dummies interaction_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}')" # Add period dummies except first (reference) for period in time_periods[1:]: period_col = f'period_{period}' interaction_formula += f" + {period_col}" # Add interactions interaction_terms = [] for period in time_periods[1:]: interaction_col = f'group_x_period_{period}' interaction_formula += f" + {interaction_col}" interaction_terms.append(interaction_col) # Add covariates if provided if time_varying_covariates: for cov in time_varying_covariates: interaction_formula += f" + Q('{cov}')" # Fit model complex_model = smf.ols(interaction_formula, data=pre_df) complex_results = complex_model.fit() # Test joint significance of interaction terms if interaction_terms: from statsmodels.formula.api import ols from statsmodels.stats.anova import anova_lm # Create models with and without interactions formula_with = interaction_formula formula_without = interaction_formula for term in interaction_terms: formula_without = formula_without.replace(f" + {term}", "") model_with = smf.ols(formula_with, data=pre_df).fit() model_without = smf.ols(formula_without, data=pre_df).fit() # Compare models try: from scipy import stats df_model = len(interaction_terms) df_residual = model_with.df_resid f_value = ((model_without.ssr - model_with.ssr) / df_model) / (model_with.ssr / df_residual) p_value = 1 - stats.f.cdf(f_value, df_model, df_residual) validation_result["valid"] = p_value > 0.05 validation_result["p_value"] = p_value validation_result["details"] = f"Manual F-test for pre-treatment interactions: F({df_model}, {df_residual})={f_value:.4f}, p={p_value:.4f}. Parallel trends: {validation_result['valid']}." logger.info(validation_result["details"]) except Exception as e: logger.warning(f"Manual F-test failed: {e}. Using individual coefficient significance.") # If F-test fails, check individual coefficients significant_interactions = 0 for term in interaction_terms: if term in complex_results.pvalues and complex_results.pvalues[term] < 0.05: significant_interactions += 1 validation_result["valid"] = significant_interactions == 0 # Set a dummy p-value based on proportion of significant interactions if len(interaction_terms) > 0: validation_result["p_value"] = 1.0 - (significant_interactions / len(interaction_terms)) else: validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms validation_result["details"] = f"{significant_interactions} out of {len(interaction_terms)} pre-treatment interactions are significant at p<0.05. Parallel trends: {validation_result['valid']}." logger.info(validation_result["details"]) else: validation_result["valid"] = True validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms validation_result["details"] = "No pre-treatment interaction terms could be tested. Defaulting to assuming parallel trends." logger.warning(validation_result["details"]) except Exception as e: logger.warning(f"Complex trend test failed: {e}. Falling back to visual assessment.") tmp_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col) # Copy over values from visual assessment ensuring p_value is set validation_result.update(tmp_result) # Ensure p_value is set if validation_result["p_value"] is None: validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04 except Exception as e: error_msg = f"Error during parallel trends validation: {e}" logger.error(error_msg, exc_info=True) validation_result["details"] = error_msg validation_result["error"] = str(e) # Default to assuming valid if test fails completely validation_result["valid"] = True validation_result["p_value"] = 1.0 # Default to 1.0 if test fails validation_result["details"] += " Defaulting to assuming parallel trends (test failed)." return validation_result def assess_trends_visually(df: pd.DataFrame, time_var: str, outcome: str, group_indicator_col: str) -> Dict[str, Any]: """Simple visual assessment of parallel trends by comparing group means over time. This is a fallback method when statistical tests fail. """ result = {"valid": False, "p_value": 1.0, "details": "", "error": None} try: # Group by time and treatment group, calculate means grouped = df.groupby([time_var, group_indicator_col])[outcome].mean().reset_index() # Pivot to get time series for each group if df[group_indicator_col].nunique() <= 10: # Only if reasonable number of groups pivot = grouped.pivot(index=time_var, columns=group_indicator_col, values=outcome) # Calculate slopes between consecutive periods for each group slopes = {} time_values = sorted(df[time_var].unique()) if len(time_values) >= 3: # Need at least 3 periods to compare slopes for group in pivot.columns: group_slopes = [] for i in range(len(time_values) - 1): t1, t2 = time_values[i], time_values[i+1] if t1 in pivot.index and t2 in pivot.index: slope = (pivot.loc[t2, group] - pivot.loc[t1, group]) / (t2 - t1) group_slopes.append(slope) if group_slopes: slopes[group] = group_slopes # Compare slopes between groups if len(slopes) >= 2: slope_diffs = [] groups = list(slopes.keys()) for i in range(len(slopes[groups[0]])): if i < len(slopes[groups[1]]): slope_diffs.append(abs(slopes[groups[0]][i] - slopes[groups[1]][i])) # If average slope difference is small relative to outcome scale outcome_scale = df[outcome].std() avg_slope_diff = sum(slope_diffs) / len(slope_diffs) if slope_diffs else 0 relative_diff = avg_slope_diff / outcome_scale if outcome_scale > 0 else 0 result["valid"] = relative_diff < 0.2 # Threshold for "parallel enough" # Set p-value based on relative difference result["p_value"] = 1.0 - (relative_diff * 5) if relative_diff < 0.2 else 0.04 result["details"] = f"Visual assessment: relative slope difference = {relative_diff:.4f}. Parallel trends: {result['valid']}." else: result["valid"] = True result["p_value"] = 1.0 result["details"] = "Visual assessment: insufficient group data for comparison. Defaulting to assuming parallel trends." else: result["valid"] = True result["p_value"] = 1.0 result["details"] = "Visual assessment: insufficient time periods for comparison. Defaulting to assuming parallel trends." else: result["valid"] = True result["p_value"] = 1.0 result["details"] = f"Visual assessment: too many groups ({df[group_indicator_col].nunique()}) for visual comparison. Defaulting to assuming parallel trends." except Exception as e: result["error"] = str(e) result["valid"] = True result["p_value"] = 1.0 result["details"] = f"Visual assessment failed: {e}. Defaulting to assuming parallel trends." logger.info(result["details"]) return result def run_placebo_test(df: pd.DataFrame, time_var: str, group_var: str, outcome: str, treated_unit_indicator: str, covariates: List[str], treatment_period_start: Any, placebo_period_start: Any) -> Dict[str, Any]: """Runs a placebo test for DiD by assigning a fake earlier treatment period. Re-runs the DiD estimation using the placebo period and checks if the effect is non-significant. Args: df: Original DataFrame. time_var: Name of the time variable column. group_var: Name of the unit/group ID column (for clustering SE). outcome: Name of the outcome variable column. treated_unit_indicator: Name of the binary treatment group indicator column (0/1). covariates: List of covariate names. treatment_period_start: The actual treatment start period. placebo_period_start: The fake treatment start period (must be before actual start). Returns: Dictionary with placebo test results. """ logger.info(f"Running placebo test assigning treatment start at {placebo_period_start}...") placebo_result = {"passed": False, "effect_estimate": None, "p_value": None, "details": "", "error": None} if placebo_period_start >= treatment_period_start: error_msg = "Placebo period must be before the actual treatment period." logger.error(error_msg) placebo_result["error"] = error_msg placebo_result["details"] = error_msg return placebo_result try: df_placebo = df.copy() # Create placebo post and interaction terms post_placebo_col = 'post_placebo' interaction_placebo_col = 'did_interaction_placebo' df_placebo[post_placebo_col] = create_post_indicator(df_placebo, time_var, placebo_period_start) df_placebo[interaction_placebo_col] = df_placebo[treated_unit_indicator] * df_placebo[post_placebo_col] # Construct formula for placebo regression formula = f"`{outcome}` ~ `{treated_unit_indicator}` + `{post_placebo_col}` + `{interaction_placebo_col}`" if covariates: formula += f" + {' + '.join([f'`{c}`' for c in covariates])}" formula += f" + C(`{group_var}`) + C(`{time_var}`)" # Include FEs logger.debug(f"Placebo test formula: {formula}") # Fit the placebo model with clustered SE ols_model = smf.ols(formula=formula, data=df_placebo) results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_placebo[group_var]}) # Check the significance of the placebo interaction term placebo_effect = float(results.params[interaction_placebo_col]) placebo_p_value = float(results.pvalues[interaction_placebo_col]) # Test passes if the placebo effect is not statistically significant (e.g., p > 0.1) passed_test = placebo_p_value > 0.10 placebo_result["passed"] = passed_test placebo_result["effect_estimate"] = placebo_effect placebo_result["p_value"] = placebo_p_value placebo_result["details"] = f"Placebo treatment effect estimated at {placebo_effect:.4f} (p={placebo_p_value:.4f}). Test passed: {passed_test}." logger.info(placebo_result["details"]) except (KeyError, PatsyError, ValueError, Exception) as e: error_msg = f"Error during placebo test execution: {e}" logger.error(error_msg, exc_info=True) placebo_result["details"] = error_msg placebo_result["error"] = str(e) return placebo_result # TODO: Add function for Event Study plot (plot_event_study) # This would involve estimating effects for leads and lags around the treatment period. # Add other diagnostic functions as needed (e.g., plot_event_study)