FireShadow's picture
Initial clean commit
1721aea
"""Diagnostic functions for Difference-in-Differences method."""
import pandas as pd
import numpy as np
from typing import Dict, Any, Optional, List
import logging
import statsmodels.formula.api as smf # Import statsmodels
from patsy import PatsyError # To catch formula errors
# Import helper function from estimator -> Change to utils
from .utils import create_post_indicator
logger = logging.getLogger(__name__)
def validate_parallel_trends(df: pd.DataFrame, time_var: str, outcome: str,
group_indicator_col: str, treatment_period_start: Any,
dataset_description: Optional[str] = None,
time_varying_covariates: Optional[List[str]] = None) -> Dict[str, Any]:
"""Validates the parallel trends assumption using pre-treatment data.
Regresses the outcome on group-specific time trends before the treatment period.
Tests if the interaction terms between group and pre-treatment time periods are jointly significant.
Args:
df: DataFrame containing the data.
time_var: Name of the time variable column.
outcome: Name of the outcome variable column.
group_indicator_col: Name of the binary treatment group indicator column (0/1).
treatment_period_start: The time period value when treatment starts.
dataset_description: Optional dictionary for additional dataset description.
time_varying_covariates: Optional list of time-varying covariates to include.
Returns:
Dictionary with validation results.
"""
logger.info("Validating parallel trends...")
validation_result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
try:
# Filter pre-treatment data
pre_df = df[df[time_var] < treatment_period_start].copy()
if len(pre_df) < 20 or pre_df[group_indicator_col].nunique() < 2 or pre_df[time_var].nunique() < 2:
validation_result["details"] = "Insufficient pre-treatment data or variation to perform test."
logger.warning(validation_result["details"])
# Assume valid if cannot test? Or invalid? Let's default to True if we can't test
validation_result["valid"] = True
validation_result["details"] += " Defaulting to assuming parallel trends (unable to test)."
return validation_result
# Check if group indicator is binary
if pre_df[group_indicator_col].nunique() > 2:
validation_result["details"] = f"Group indicator '{group_indicator_col}' has more than 2 unique values. Using simple visual assessment."
logger.warning(validation_result["details"])
# Use visual assessment method instead (check if trends look roughly parallel)
validation_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
# Ensure p_value is set
if validation_result["p_value"] is None:
validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
return validation_result
# Use a robust approach first - test for pre-trend differences using a simpler model
try:
# Create a linear time trend
pre_df['time_trend'] = pre_df[time_var].astype(float)
# Create interaction between trend and group
pre_df['group_trend'] = pre_df['time_trend'] * pre_df[group_indicator_col].astype(float)
# Simple regression with linear trend interaction
simple_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}') + time_trend + group_trend"
simple_model = smf.ols(simple_formula, data=pre_df)
simple_results = simple_model.fit()
# Check if trend interaction coefficient is significant
group_trend_pvalue = simple_results.pvalues['group_trend']
# If p > 0.05, trends are not significantly different
validation_result["valid"] = group_trend_pvalue > 0.05
validation_result["p_value"] = group_trend_pvalue
validation_result["details"] = f"Simple linear trend test: p-value for group-trend interaction: {group_trend_pvalue:.4f}. Parallel trends: {validation_result['valid']}."
logger.info(validation_result["details"])
# If we've successfully validated with the simple approach, return
return validation_result
except Exception as e:
logger.warning(f"Simple trend test failed: {e}. Trying alternative approach.")
# Continue to more complex method if simple method fails
# Try more complex approach with period-specific interactions
try:
# Create period dummies to avoid issues with categorical variables
time_periods = sorted(pre_df[time_var].unique())
# Create dummy variables for time periods (except first)
for period in time_periods[1:]:
period_col = f'period_{period}'
pre_df[period_col] = (pre_df[time_var] == period).astype(int)
# Create interaction with group
pre_df[f'group_x_{period_col}'] = pre_df[period_col] * pre_df[group_indicator_col].astype(float)
# Construct formula with manual dummies
interaction_formula = f"Q('{outcome}') ~ Q('{group_indicator_col}')"
# Add period dummies except first (reference)
for period in time_periods[1:]:
period_col = f'period_{period}'
interaction_formula += f" + {period_col}"
# Add interactions
interaction_terms = []
for period in time_periods[1:]:
interaction_col = f'group_x_period_{period}'
interaction_formula += f" + {interaction_col}"
interaction_terms.append(interaction_col)
# Add covariates if provided
if time_varying_covariates:
for cov in time_varying_covariates:
interaction_formula += f" + Q('{cov}')"
# Fit model
complex_model = smf.ols(interaction_formula, data=pre_df)
complex_results = complex_model.fit()
# Test joint significance of interaction terms
if interaction_terms:
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
# Create models with and without interactions
formula_with = interaction_formula
formula_without = interaction_formula
for term in interaction_terms:
formula_without = formula_without.replace(f" + {term}", "")
model_with = smf.ols(formula_with, data=pre_df).fit()
model_without = smf.ols(formula_without, data=pre_df).fit()
# Compare models
try:
from scipy import stats
df_model = len(interaction_terms)
df_residual = model_with.df_resid
f_value = ((model_without.ssr - model_with.ssr) / df_model) / (model_with.ssr / df_residual)
p_value = 1 - stats.f.cdf(f_value, df_model, df_residual)
validation_result["valid"] = p_value > 0.05
validation_result["p_value"] = p_value
validation_result["details"] = f"Manual F-test for pre-treatment interactions: F({df_model}, {df_residual})={f_value:.4f}, p={p_value:.4f}. Parallel trends: {validation_result['valid']}."
logger.info(validation_result["details"])
except Exception as e:
logger.warning(f"Manual F-test failed: {e}. Using individual coefficient significance.")
# If F-test fails, check individual coefficients
significant_interactions = 0
for term in interaction_terms:
if term in complex_results.pvalues and complex_results.pvalues[term] < 0.05:
significant_interactions += 1
validation_result["valid"] = significant_interactions == 0
# Set a dummy p-value based on proportion of significant interactions
if len(interaction_terms) > 0:
validation_result["p_value"] = 1.0 - (significant_interactions / len(interaction_terms))
else:
validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
validation_result["details"] = f"{significant_interactions} out of {len(interaction_terms)} pre-treatment interactions are significant at p<0.05. Parallel trends: {validation_result['valid']}."
logger.info(validation_result["details"])
else:
validation_result["valid"] = True
validation_result["p_value"] = 1.0 # Default to 1.0 if no interaction terms
validation_result["details"] = "No pre-treatment interaction terms could be tested. Defaulting to assuming parallel trends."
logger.warning(validation_result["details"])
except Exception as e:
logger.warning(f"Complex trend test failed: {e}. Falling back to visual assessment.")
tmp_result = assess_trends_visually(pre_df, time_var, outcome, group_indicator_col)
# Copy over values from visual assessment ensuring p_value is set
validation_result.update(tmp_result)
# Ensure p_value is set
if validation_result["p_value"] is None:
validation_result["p_value"] = 1.0 if validation_result["valid"] else 0.04
except Exception as e:
error_msg = f"Error during parallel trends validation: {e}"
logger.error(error_msg, exc_info=True)
validation_result["details"] = error_msg
validation_result["error"] = str(e)
# Default to assuming valid if test fails completely
validation_result["valid"] = True
validation_result["p_value"] = 1.0 # Default to 1.0 if test fails
validation_result["details"] += " Defaulting to assuming parallel trends (test failed)."
return validation_result
def assess_trends_visually(df: pd.DataFrame, time_var: str, outcome: str,
group_indicator_col: str) -> Dict[str, Any]:
"""Simple visual assessment of parallel trends by comparing group means over time.
This is a fallback method when statistical tests fail.
"""
result = {"valid": False, "p_value": 1.0, "details": "", "error": None}
try:
# Group by time and treatment group, calculate means
grouped = df.groupby([time_var, group_indicator_col])[outcome].mean().reset_index()
# Pivot to get time series for each group
if df[group_indicator_col].nunique() <= 10: # Only if reasonable number of groups
pivot = grouped.pivot(index=time_var, columns=group_indicator_col, values=outcome)
# Calculate slopes between consecutive periods for each group
slopes = {}
time_values = sorted(df[time_var].unique())
if len(time_values) >= 3: # Need at least 3 periods to compare slopes
for group in pivot.columns:
group_slopes = []
for i in range(len(time_values) - 1):
t1, t2 = time_values[i], time_values[i+1]
if t1 in pivot.index and t2 in pivot.index:
slope = (pivot.loc[t2, group] - pivot.loc[t1, group]) / (t2 - t1)
group_slopes.append(slope)
if group_slopes:
slopes[group] = group_slopes
# Compare slopes between groups
if len(slopes) >= 2:
slope_diffs = []
groups = list(slopes.keys())
for i in range(len(slopes[groups[0]])):
if i < len(slopes[groups[1]]):
slope_diffs.append(abs(slopes[groups[0]][i] - slopes[groups[1]][i]))
# If average slope difference is small relative to outcome scale
outcome_scale = df[outcome].std()
avg_slope_diff = sum(slope_diffs) / len(slope_diffs) if slope_diffs else 0
relative_diff = avg_slope_diff / outcome_scale if outcome_scale > 0 else 0
result["valid"] = relative_diff < 0.2 # Threshold for "parallel enough"
# Set p-value based on relative difference
result["p_value"] = 1.0 - (relative_diff * 5) if relative_diff < 0.2 else 0.04
result["details"] = f"Visual assessment: relative slope difference = {relative_diff:.4f}. Parallel trends: {result['valid']}."
else:
result["valid"] = True
result["p_value"] = 1.0
result["details"] = "Visual assessment: insufficient group data for comparison. Defaulting to assuming parallel trends."
else:
result["valid"] = True
result["p_value"] = 1.0
result["details"] = "Visual assessment: insufficient time periods for comparison. Defaulting to assuming parallel trends."
else:
result["valid"] = True
result["p_value"] = 1.0
result["details"] = f"Visual assessment: too many groups ({df[group_indicator_col].nunique()}) for visual comparison. Defaulting to assuming parallel trends."
except Exception as e:
result["error"] = str(e)
result["valid"] = True
result["p_value"] = 1.0
result["details"] = f"Visual assessment failed: {e}. Defaulting to assuming parallel trends."
logger.info(result["details"])
return result
def run_placebo_test(df: pd.DataFrame, time_var: str, group_var: str, outcome: str,
treated_unit_indicator: str, covariates: List[str],
treatment_period_start: Any,
placebo_period_start: Any) -> Dict[str, Any]:
"""Runs a placebo test for DiD by assigning a fake earlier treatment period.
Re-runs the DiD estimation using the placebo period and checks if the effect is non-significant.
Args:
df: Original DataFrame.
time_var: Name of the time variable column.
group_var: Name of the unit/group ID column (for clustering SE).
outcome: Name of the outcome variable column.
treated_unit_indicator: Name of the binary treatment group indicator column (0/1).
covariates: List of covariate names.
treatment_period_start: The actual treatment start period.
placebo_period_start: The fake treatment start period (must be before actual start).
Returns:
Dictionary with placebo test results.
"""
logger.info(f"Running placebo test assigning treatment start at {placebo_period_start}...")
placebo_result = {"passed": False, "effect_estimate": None, "p_value": None, "details": "", "error": None}
if placebo_period_start >= treatment_period_start:
error_msg = "Placebo period must be before the actual treatment period."
logger.error(error_msg)
placebo_result["error"] = error_msg
placebo_result["details"] = error_msg
return placebo_result
try:
df_placebo = df.copy()
# Create placebo post and interaction terms
post_placebo_col = 'post_placebo'
interaction_placebo_col = 'did_interaction_placebo'
df_placebo[post_placebo_col] = create_post_indicator(df_placebo, time_var, placebo_period_start)
df_placebo[interaction_placebo_col] = df_placebo[treated_unit_indicator] * df_placebo[post_placebo_col]
# Construct formula for placebo regression
formula = f"`{outcome}` ~ `{treated_unit_indicator}` + `{post_placebo_col}` + `{interaction_placebo_col}`"
if covariates:
formula += f" + {' + '.join([f'`{c}`' for c in covariates])}"
formula += f" + C(`{group_var}`) + C(`{time_var}`)" # Include FEs
logger.debug(f"Placebo test formula: {formula}")
# Fit the placebo model with clustered SE
ols_model = smf.ols(formula=formula, data=df_placebo)
results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_placebo[group_var]})
# Check the significance of the placebo interaction term
placebo_effect = float(results.params[interaction_placebo_col])
placebo_p_value = float(results.pvalues[interaction_placebo_col])
# Test passes if the placebo effect is not statistically significant (e.g., p > 0.1)
passed_test = placebo_p_value > 0.10
placebo_result["passed"] = passed_test
placebo_result["effect_estimate"] = placebo_effect
placebo_result["p_value"] = placebo_p_value
placebo_result["details"] = f"Placebo treatment effect estimated at {placebo_effect:.4f} (p={placebo_p_value:.4f}). Test passed: {passed_test}."
logger.info(placebo_result["details"])
except (KeyError, PatsyError, ValueError, Exception) as e:
error_msg = f"Error during placebo test execution: {e}"
logger.error(error_msg, exc_info=True)
placebo_result["details"] = error_msg
placebo_result["error"] = str(e)
return placebo_result
# TODO: Add function for Event Study plot (plot_event_study)
# This would involve estimating effects for leads and lags around the treatment period.
# Add other diagnostic functions as needed (e.g., plot_event_study)