Spaces:
Running
Running
""" | |
Difference-in-Differences Estimator using DoWhy with Statsmodels fallback. | |
""" | |
import logging | |
import pandas as pd | |
import numpy as np | |
from typing import Dict, List, Optional, Any, Tuple | |
from auto_causal.config import get_llm_client # IMPORT LLM Client Factory | |
# DoWhy imports (Commented out for simplification) | |
# from dowhy import CausalModel | |
# from dowhy.causal_estimators import CausalEstimator | |
# from dowhy.causal_estimator import CausalEstimate | |
# Statsmodels import for estimation | |
import statsmodels.formula.api as smf | |
# Local imports | |
from .llm_assist import ( | |
identify_time_variable, | |
determine_treatment_period, | |
identify_treatment_group, | |
interpret_did_results | |
) | |
from .diagnostics import validate_parallel_trends # Import diagnostics | |
# Import from the new utils module | |
from .utils import create_post_indicator | |
logger = logging.getLogger(__name__) | |
# --- Helper functions moved from old file --- | |
def format_did_results(statsmodels_results: Any, interaction_term_key: str, | |
validation_results: Dict[str, Any], | |
method_details: str, parameters: Dict[str, Any]) -> Dict[str, Any]: | |
'''Formats the DiD results from statsmodels results into a standard dictionary.''' | |
try: | |
# Use the interaction_term_key passed directly | |
effect = float(statsmodels_results.params[interaction_term_key]) | |
stderr = float(statsmodels_results.bse[interaction_term_key]) | |
pval = float(statsmodels_results.pvalues[interaction_term_key]) | |
ci = statsmodels_results.conf_int().loc[interaction_term_key].values.tolist() | |
ci_lower, ci_upper = float(ci[0]), float(ci[1]) | |
logger.info(f"Extracted effect for '{interaction_term_key}'") | |
except KeyError: | |
logger.error(f"Interaction term '{interaction_term_key}' not found in statsmodels results. Available params: {statsmodels_results.params.index.tolist()}") | |
# Fallback to NaN if term not found | |
effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan | |
except Exception as e: | |
logger.error(f"Error extracting results from statsmodels object: {e}") | |
effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan | |
# Create a standardized results dictionary | |
results = { | |
"effect_estimate": effect, | |
"standard_error": stderr, | |
"p_value": pval, | |
"confidence_interval": [ci_lower, ci_upper], | |
"diagnostics": validation_results, | |
"parameters": parameters, | |
"details": str(statsmodels_results.summary()) | |
} | |
return results | |
# Comment out unused DoWhy result formatter | |
# def format_dowhy_results(estimate: CausalEstimate, | |
# validation_results: Dict[str, Any], | |
# parameters: Dict[str, Any]) -> Dict[str, Any]: | |
# '''Formats the DiD results from DoWhy causal estimate into a standard dictionary.''' | |
# try: | |
# # Extract values from DoWhy estimate | |
# effect = float(estimate.value) | |
# stderr = float(estimate.get_standard_error()) if hasattr(estimate, 'get_standard_error') else np.nan | |
# ci_lower, ci_upper = estimate.get_confidence_intervals() if hasattr(estimate, 'get_confidence_intervals') else (np.nan, np.nan) | |
# # Extract p-value if available, otherwise use NaN | |
# pval = estimate.get_significance_test_results().get('p_value', np.nan) if hasattr(estimate, 'get_significance_test_results') else np.nan | |
# # Get available details from estimate | |
# details = str(estimate) | |
# if hasattr(estimate, 'summary'): | |
# details = str(estimate.summary()) | |
# logger.info(f"Extracted effect from DoWhy estimate: {effect}") | |
# except Exception as e: | |
# logger.error(f"Error extracting results from DoWhy estimate: {e}") | |
# effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan | |
# details = f"Error extracting DoWhy results: {e}" | |
# # Create a standardized results dictionary | |
# results = { | |
# "effect_estimate": effect, | |
# "effect_se": stderr, | |
# "p_value": pval, | |
# "confidence_interval": [ci_lower, ci_upper], | |
# "diagnostics": validation_results, | |
# "parameters": parameters, | |
# "details": details, | |
# "estimator": "dowhy" | |
# } | |
# return results | |
# --- Main `estimate_effect` function --- | |
def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, | |
covariates: List[str], | |
dataset_description: Optional[str] = None, | |
query: Optional[str] = None, | |
**kwargs) -> Dict[str, Any]: | |
"""Difference-in-Differences estimation using DoWhy with Statsmodels fallback. | |
Args: | |
df: Dataset containing causal variables | |
treatment: Name of treatment variable (or variable indicating treated group) | |
outcome: Name of outcome variable | |
covariates: List of covariate names | |
dataset_description: Optional dictionary describing the dataset | |
**kwargs: Method-specific parameters (e.g., time_var, group_var, query, llm instance if needed) | |
Returns: | |
Dictionary with effect estimate and diagnostics | |
""" | |
query = kwargs.get('query_str') | |
# llm_instance = kwargs.get('llm') # Pass llm if helpers need it | |
df_processed = df.copy() # Work on a copy | |
logger.info("Starting DiD estimation using DoWhy with Statsmodels fallback...") | |
# --- Step 1: Identify Key Variables (using LLM Assist placeholders) --- | |
# Pass llm_instance to helpers if they are implemented to use it | |
llm_instance = get_llm_client() # Get llm instance if passed | |
time_var = kwargs.get('time_variable', identify_time_variable(df_processed, query, dataset_description, llm=llm_instance)) | |
if time_var is None: | |
raise ValueError("Time variable could not be identified for DiD.") | |
if time_var not in df_processed.columns: | |
raise ValueError(f"Identified time variable '{time_var}' not found in DataFrame.") | |
# Determine the variable that identifies the panel unit (for grouping/FE) | |
group_var = kwargs.get('group_variable', identify_treatment_group(df_processed, treatment, query, dataset_description, llm=llm_instance)) | |
if group_var is None: | |
raise ValueError("Group/Unit variable could not be identified for DiD.") | |
if group_var not in df_processed.columns: | |
raise ValueError(f"Identified group/unit variable '{group_var}' not found in DataFrame.") | |
# Check outcome exists before proceeding further | |
if outcome not in df_processed.columns: | |
raise ValueError(f"Outcome variable '{outcome}' not found in DataFrame.") | |
# Determine treatment period start | |
treatment_period = kwargs.get('treatment_period_start', kwargs.get('treatment_period', | |
determine_treatment_period(df_processed, time_var, treatment, query, dataset_description, llm=llm_instance))) | |
# --- Identify the TRUE binary treatment group indicator column --- | |
treated_group_col_for_formula = None | |
# Priority 1: Check if the 'treatment' argument itself is a valid binary indicator | |
if treatment in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed[treatment]): | |
unique_treat_vals = set(df_processed[treatment].dropna().unique()) | |
if unique_treat_vals.issubset({0, 1}): | |
treated_group_col_for_formula = treatment | |
logger.info(f"Using the provided 'treatment' argument '{treatment}' as binary group indicator.") | |
# Priority 2: Check if a column explicitly named 'group' exists and is binary | |
if treated_group_col_for_formula is None and 'group' in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed['group']): | |
unique_group_vals = set(df_processed['group'].dropna().unique()) | |
if unique_group_vals.issubset({0, 1}): | |
treated_group_col_for_formula = 'group' | |
logger.info(f"Using column 'group' as binary group indicator.") | |
# Priority 3: Fallback - Search other columns (excluding known roles and time-related ones) | |
if treated_group_col_for_formula is None: | |
logger.warning(f"Provided 'treatment' arg '{treatment}' is not binary 0/1 and no 'group' column found. Searching other columns...") | |
potential_group_cols = [] | |
# Exclude outcome, time var, unit ID var, and common time indicators like 'post' | |
excluded_cols = [outcome, time_var, group_var, 'post', 'is_post_treatment', 'did_interaction'] | |
for col_name in df_processed.columns: | |
if col_name in excluded_cols: | |
continue | |
try: | |
col_data = df_processed[col_name] | |
# Ensure we are working with a Series | |
if isinstance(col_data, pd.DataFrame): | |
if col_data.shape[1] == 1: | |
col_data = col_data.iloc[:, 0] # Extract the Series | |
else: | |
logger.warning(f"Skipping multi-column DataFrame slice for '{col_name}'.") | |
continue | |
# Check if the Series can be interpreted as binary 0/1 | |
if not pd.api.types.is_numeric_dtype(col_data) and not pd.api.types.is_bool_dtype(col_data): | |
continue # Skip non-numeric/non-boolean columns | |
unique_vals = set(col_data.dropna().unique()) | |
# Simplified check: directly test if unique values are a subset of {0, 1} | |
if unique_vals.issubset({0, 1}): | |
logger.info(f" Found potential binary indicator: {col_name}") | |
potential_group_cols.append(col_name) | |
except AttributeError as ae: | |
# Catch attribute errors likely due to unexpected types | |
logger.warning(f"Attribute error checking column '{col_name}': {ae}. Skipping.") | |
except Exception as e: | |
logger.warning(f"Unexpected error checking column '{col_name}' during group ID search: {e}") | |
if potential_group_cols: | |
treated_group_col_for_formula = potential_group_cols[0] # Take the first suitable one found | |
logger.info(f"Using column '{treated_group_col_for_formula}' found during search as binary group indicator.") | |
else: | |
# Final fallback: Use the originally identified group_var, but warn heavily | |
treated_group_col_for_formula = group_var | |
logger.error(f"CRITICAL WARNING: Could not find suitable binary treatment group indicator. Using '{group_var}', but this is likely incorrect and will produce invalid DiD estimates.") | |
# --- Final Check --- | |
if treated_group_col_for_formula not in df_processed.columns: | |
# This case should ideally not happen with the logic above but added defensively | |
raise ValueError(f"Determined treatment group column '{treated_group_col_for_formula}' not found in DataFrame.") | |
if df_processed[treated_group_col_for_formula].nunique(dropna=True) > 2: | |
logger.warning(f"Selected treatment group column '{treated_group_col_for_formula}' is not binary (has {df_processed[treated_group_col_for_formula].nunique()} unique values). DiD requires binary treatment group.") | |
# --- Step 2: Create Indicator Variables --- | |
post_indicator_col = 'post' | |
if post_indicator_col not in df_processed.columns: | |
# Create the post indicator if it doesn't exist | |
df_processed[post_indicator_col] = create_post_indicator(df_processed, time_var, treatment_period) | |
# Interaction term is treatment group * post | |
interaction_term_col = 'did_interaction' # Keep explicit interaction term | |
df_processed[interaction_term_col] = df_processed[treated_group_col_for_formula] * df_processed[post_indicator_col] | |
# --- Step 3: Validate Parallel Trends (using the group column) --- | |
parallel_trends_validation = validate_parallel_trends(df_processed, time_var, outcome, | |
treated_group_col_for_formula, treatment_period, dataset_description) | |
# Note: The validation result is currently just a placeholder | |
if not parallel_trends_validation.get('valid', False): | |
logger.warning("Parallel trends assumption potentially violated (based on placeholder check). Proceeding with estimation, but results may be biased.") | |
# Add this info to the final results diagnostics | |
# --- Step 4: Prepare for Statsmodels Estimation --- | |
# (DoWhy section commented out for simplicity) | |
# all_common_causes = covariates + [time_var, group_var] # group_var is unit ID | |
# use_dowhy_estimate = False | |
# dowhy_estimate = None | |
# try: | |
# # Create DoWhy CausalModel | |
# model = CausalModel( | |
# data=df_processed, | |
# treatment=treated_group_col_for_formula, # Use group indicator here | |
# outcome=outcome, | |
# common_causes=all_common_causes, | |
# ) | |
# logger.info("DoWhy CausalModel created for DiD estimation.") | |
# # Identify estimand | |
# identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) | |
# logger.info(f"DoWhy identified estimand: {identified_estimand.estimand_type}") | |
# # Try to estimate using DiD estimator if available in DoWhy | |
# try: | |
# logger.info("Attempting to use DoWhy's DiD estimator...") | |
# # Debug info - print DataFrame info to help diagnose possible issues | |
# logger.debug(f"DataFrame shape before DoWhy DiD: {df_processed.shape}") | |
# # ... (rest of DoWhy debug logs commented out) ... | |
# # Create params dictionary for DoWhy DiD estimator | |
# did_params = { | |
# 'time_var': time_var, | |
# 'treatment_period': treatment_period, | |
# 'unit_var': group_var | |
# } | |
# # Add control variables if available | |
# if covariates: | |
# did_params['control_vars'] = covariates | |
# logger.debug(f"DoWhy DiD params: {did_params}") | |
# # Try to use DiD estimator from DoWhy (requires recent version of DoWhy) | |
# if hasattr(model, 'estimate_effect'): | |
# try: | |
# # First check if difference_in_differences method is available | |
# available_methods = model.get_available_effect_estimators() if hasattr(model, 'get_available_effect_estimators') else [] | |
# logger.debug(f"Available DoWhy estimators: {available_methods}") | |
# if "difference_in_differences" not in str(available_methods): | |
# logger.warning("'difference_in_differences' estimator not found in available DoWhy estimators. Falling back to statsmodels.") | |
# else: | |
# # Try the estimation with more error handling | |
# logger.info("Calling DoWhy DiD estimator...") | |
# estimate = model.estimate_effect( | |
# identified_estimand, | |
# method_name="difference_in_differences", | |
# method_params=did_params | |
# ) | |
# if estimate: | |
# # Extra check to verify estimate has expected attributes | |
# if hasattr(estimate, 'value') and not pd.isna(estimate.value): | |
# dowhy_estimate = estimate | |
# use_dowhy_estimate = True | |
# logger.info(f"Successfully used DoWhy's DiD estimator. Effect estimate: {estimate.value}") | |
# else: | |
# logger.warning(f"DoWhy's DiD estimator returned invalid estimate: {estimate}. Falling back to statsmodels.") | |
# else: | |
# logger.warning("DoWhy's DiD estimator returned None. Falling back to statsmodels.") | |
# except IndexError as idx_err: | |
# # Handle specific IndexError that's occurring | |
# logger.error(f"IndexError in DoWhy DiD estimator: {idx_err}. Check input data structure.") | |
# # Trace more details about the error | |
# import traceback | |
# logger.error(f"Error traceback: {traceback.format_exc()}") | |
# logger.warning("Falling back to statsmodels due to IndexError in DoWhy.") | |
# else: | |
# logger.warning("DoWhy model does not have estimate_effect method. Falling back to statsmodels.") | |
# except (ImportError, AttributeError) as e: | |
# logger.warning(f"DoWhy DiD estimator not available or not implemented: {e}. Falling back to statsmodels.") | |
# except ValueError as ve: | |
# logger.error(f"ValueError in DoWhy DiD estimator: {ve}. Likely issue with data formatting. Falling back to statsmodels.") | |
# except Exception as e: | |
# logger.error(f"Error using DoWhy's DiD estimator: {e}. Falling back to statsmodels.") | |
# # Add traceback for better debugging | |
# import traceback | |
# logger.error(f"Full error traceback: {traceback.format_exc()}") | |
# except Exception as e: | |
# logger.error(f"Failed to create DoWhy CausalModel: {e}", exc_info=True) | |
# # model = None # Set model to None if creation fails | |
# Create parameters dictionary for formatting results | |
parameters = { | |
"time_var": time_var, | |
"group_var": group_var, # Unit ID | |
"treatment_indicator": treated_group_col_for_formula, # Group indicator used in formula basis | |
"post_indicator": post_indicator_col, | |
"treatment_period_start": treatment_period, | |
"covariates": covariates, | |
} | |
# Group diagnostics for formatting | |
did_diagnostics = { | |
"parallel_trends": parallel_trends_validation, | |
# "placebo_test": run_placebo_test(...) | |
} | |
# If DoWhy estimation was successful, use those results (Section Commented Out) | |
# if use_dowhy_estimate and dowhy_estimate: | |
# logger.info("Using DoWhy DiD estimation results.") | |
# parameters["estimation_method"] = "DoWhy Difference-in-Differences" | |
# # Format the results | |
# formatted_results = format_dowhy_results(dowhy_estimate, did_diagnostics, parameters) | |
# else: | |
# --- Step 5: Use Statsmodels OLS --- | |
logger.info("Determining Statsmodels OLS formula based on number of time periods...") | |
num_time_periods = df_processed[time_var].nunique() | |
interaction_term_key_for_results: str | |
method_details_str: str | |
formula: str | |
if num_time_periods == 2: | |
logger.info( | |
f"Number of unique time periods is 2. Using 2x2 DiD formula: " | |
f"{outcome} ~ {treated_group_col_for_formula} * {post_indicator_col}" | |
) | |
# For 2x2 DiD: outcome ~ group * post_indicator | |
# The interaction term A:B in statsmodels gives the DiD estimate. | |
formula_core = f"{treated_group_col_for_formula} * {post_indicator_col}" | |
interaction_term_key_for_results = f"{treated_group_col_for_formula}:{post_indicator_col}" | |
formula_parts = [formula_core] | |
main_model_terms = {outcome, treated_group_col_for_formula, post_indicator_col} | |
if covariates: | |
filtered_covs = [ | |
c for c in covariates if c not in main_model_terms | |
] | |
if filtered_covs: | |
formula_parts.extend(filtered_covs) | |
formula = f"{outcome} ~ {' + '.join(formula_parts)}" | |
parameters["estimation_method"] = "Statsmodels OLS for 2x2 DiD (Group * Post interaction)" | |
method_details_str = "DiD via Statsmodels 2x2 (Group * Post interaction)" | |
else: # num_time_periods > 2 | |
logger.info( | |
f"Number of unique time periods is {num_time_periods} (>2). " | |
f"Using TWFE DiD formula: {outcome} ~ {interaction_term_col} + C({group_var}) + C({time_var})" | |
) | |
# For TWFE: outcome ~ actual_treatment_variable + UnitFE + TimeFE | |
# actual_treatment_variable is interaction_term_col (e.g., treated_group * post_indicator) | |
# UnitFE is C(group_var), TimeFE is C(time_var) | |
formula_parts = [ | |
interaction_term_col, | |
f"C({group_var})", | |
f"C({time_var})" | |
] | |
interaction_term_key_for_results = interaction_term_col | |
main_model_terms = {outcome, interaction_term_col, group_var, time_var} | |
if covariates: | |
filtered_covs = [ | |
c for c in covariates if c not in main_model_terms | |
] | |
if filtered_covs: | |
formula_parts.extend(filtered_covs) | |
formula = f"{outcome} ~ {' + '.join(formula_parts)}" | |
parameters["estimation_method"] = "Statsmodels OLS with TWFE (C() Notation)" | |
method_details_str = "DiD via Statsmodels TWFE (C() Notation)" | |
try: | |
logger.info(f"Using formula: {formula}") | |
logger.debug(f"Data head for statsmodels:\n{df_processed.head().to_string()}") | |
logger.debug(f"Regression DataFrame shape: {df_processed.shape}, Columns: {df_processed.columns.tolist()}") | |
ols_model = smf.ols(formula=formula, data=df_processed) | |
if group_var not in df_processed.columns: | |
# This check is mainly for clustering but good to ensure group_var exists. | |
# For 2x2, group_var (unit ID) might not be in formula but needed for clustering. | |
raise ValueError(f"Clustering variable '{group_var}' (panel unit ID) not found in regression data.") | |
logger.debug(f"Clustering standard errors by: {group_var}") | |
results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_processed[group_var]}) | |
logger.info("Statsmodels estimation complete.") | |
logger.info(f"Statsmodels Results Summary:\n{results.summary()}") | |
logger.debug(f"Extracting results using interaction term key: {interaction_term_key_for_results}") | |
parameters["final_formula"] = formula | |
parameters["interaction_term_coefficient_name"] = interaction_term_key_for_results | |
formatted_results = format_did_results(results, interaction_term_key_for_results, | |
did_diagnostics, | |
method_details=method_details_str, | |
parameters=parameters) | |
formatted_results["estimator"] = "statsmodels" | |
except Exception as e: | |
logger.error(f"Statsmodels OLS estimation failed: {e}", exc_info=True) | |
raise ValueError(f"DiD estimation failed (both DoWhy and Statsmodels): {e}") | |
# --- Add Interpretation --- (Now add interpretation to the formatted results) | |
try: | |
# Use the llm_instance fetched earlier | |
interpretation = interpret_did_results(formatted_results, did_diagnostics, dataset_description, llm=llm_instance) | |
formatted_results['interpretation'] = interpretation | |
except Exception as interp_e: | |
logger.error(f"DiD Interpretation failed: {interp_e}") | |
formatted_results['interpretation'] = "Interpretation failed." | |
return formatted_results |