""" Difference-in-Differences Estimator using DoWhy with Statsmodels fallback. """ import logging import pandas as pd import numpy as np from typing import Dict, List, Optional, Any, Tuple from auto_causal.config import get_llm_client # IMPORT LLM Client Factory # DoWhy imports (Commented out for simplification) # from dowhy import CausalModel # from dowhy.causal_estimators import CausalEstimator # from dowhy.causal_estimator import CausalEstimate # Statsmodels import for estimation import statsmodels.formula.api as smf # Local imports from .llm_assist import ( identify_time_variable, determine_treatment_period, identify_treatment_group, interpret_did_results ) from .diagnostics import validate_parallel_trends # Import diagnostics # Import from the new utils module from .utils import create_post_indicator logger = logging.getLogger(__name__) # --- Helper functions moved from old file --- def format_did_results(statsmodels_results: Any, interaction_term_key: str, validation_results: Dict[str, Any], method_details: str, parameters: Dict[str, Any]) -> Dict[str, Any]: '''Formats the DiD results from statsmodels results into a standard dictionary.''' try: # Use the interaction_term_key passed directly effect = float(statsmodels_results.params[interaction_term_key]) stderr = float(statsmodels_results.bse[interaction_term_key]) pval = float(statsmodels_results.pvalues[interaction_term_key]) ci = statsmodels_results.conf_int().loc[interaction_term_key].values.tolist() ci_lower, ci_upper = float(ci[0]), float(ci[1]) logger.info(f"Extracted effect for '{interaction_term_key}'") except KeyError: logger.error(f"Interaction term '{interaction_term_key}' not found in statsmodels results. Available params: {statsmodels_results.params.index.tolist()}") # Fallback to NaN if term not found effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan except Exception as e: logger.error(f"Error extracting results from statsmodels object: {e}") effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan # Create a standardized results dictionary results = { "effect_estimate": effect, "standard_error": stderr, "p_value": pval, "confidence_interval": [ci_lower, ci_upper], "diagnostics": validation_results, "parameters": parameters, "details": str(statsmodels_results.summary()) } return results # Comment out unused DoWhy result formatter # def format_dowhy_results(estimate: CausalEstimate, # validation_results: Dict[str, Any], # parameters: Dict[str, Any]) -> Dict[str, Any]: # '''Formats the DiD results from DoWhy causal estimate into a standard dictionary.''' # try: # # Extract values from DoWhy estimate # effect = float(estimate.value) # stderr = float(estimate.get_standard_error()) if hasattr(estimate, 'get_standard_error') else np.nan # ci_lower, ci_upper = estimate.get_confidence_intervals() if hasattr(estimate, 'get_confidence_intervals') else (np.nan, np.nan) # # Extract p-value if available, otherwise use NaN # pval = estimate.get_significance_test_results().get('p_value', np.nan) if hasattr(estimate, 'get_significance_test_results') else np.nan # # Get available details from estimate # details = str(estimate) # if hasattr(estimate, 'summary'): # details = str(estimate.summary()) # logger.info(f"Extracted effect from DoWhy estimate: {effect}") # except Exception as e: # logger.error(f"Error extracting results from DoWhy estimate: {e}") # effect, stderr, pval, ci_lower, ci_upper = np.nan, np.nan, np.nan, np.nan, np.nan # details = f"Error extracting DoWhy results: {e}" # # Create a standardized results dictionary # results = { # "effect_estimate": effect, # "effect_se": stderr, # "p_value": pval, # "confidence_interval": [ci_lower, ci_upper], # "diagnostics": validation_results, # "parameters": parameters, # "details": details, # "estimator": "dowhy" # } # return results # --- Main `estimate_effect` function --- def estimate_effect(df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str], dataset_description: Optional[str] = None, query: Optional[str] = None, **kwargs) -> Dict[str, Any]: """Difference-in-Differences estimation using DoWhy with Statsmodels fallback. Args: df: Dataset containing causal variables treatment: Name of treatment variable (or variable indicating treated group) outcome: Name of outcome variable covariates: List of covariate names dataset_description: Optional dictionary describing the dataset **kwargs: Method-specific parameters (e.g., time_var, group_var, query, llm instance if needed) Returns: Dictionary with effect estimate and diagnostics """ query = kwargs.get('query_str') # llm_instance = kwargs.get('llm') # Pass llm if helpers need it df_processed = df.copy() # Work on a copy logger.info("Starting DiD estimation using DoWhy with Statsmodels fallback...") # --- Step 1: Identify Key Variables (using LLM Assist placeholders) --- # Pass llm_instance to helpers if they are implemented to use it llm_instance = get_llm_client() # Get llm instance if passed time_var = kwargs.get('time_variable', identify_time_variable(df_processed, query, dataset_description, llm=llm_instance)) if time_var is None: raise ValueError("Time variable could not be identified for DiD.") if time_var not in df_processed.columns: raise ValueError(f"Identified time variable '{time_var}' not found in DataFrame.") # Determine the variable that identifies the panel unit (for grouping/FE) group_var = kwargs.get('group_variable', identify_treatment_group(df_processed, treatment, query, dataset_description, llm=llm_instance)) if group_var is None: raise ValueError("Group/Unit variable could not be identified for DiD.") if group_var not in df_processed.columns: raise ValueError(f"Identified group/unit variable '{group_var}' not found in DataFrame.") # Check outcome exists before proceeding further if outcome not in df_processed.columns: raise ValueError(f"Outcome variable '{outcome}' not found in DataFrame.") # Determine treatment period start treatment_period = kwargs.get('treatment_period_start', kwargs.get('treatment_period', determine_treatment_period(df_processed, time_var, treatment, query, dataset_description, llm=llm_instance))) # --- Identify the TRUE binary treatment group indicator column --- treated_group_col_for_formula = None # Priority 1: Check if the 'treatment' argument itself is a valid binary indicator if treatment in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed[treatment]): unique_treat_vals = set(df_processed[treatment].dropna().unique()) if unique_treat_vals.issubset({0, 1}): treated_group_col_for_formula = treatment logger.info(f"Using the provided 'treatment' argument '{treatment}' as binary group indicator.") # Priority 2: Check if a column explicitly named 'group' exists and is binary if treated_group_col_for_formula is None and 'group' in df_processed.columns and pd.api.types.is_numeric_dtype(df_processed['group']): unique_group_vals = set(df_processed['group'].dropna().unique()) if unique_group_vals.issubset({0, 1}): treated_group_col_for_formula = 'group' logger.info(f"Using column 'group' as binary group indicator.") # Priority 3: Fallback - Search other columns (excluding known roles and time-related ones) if treated_group_col_for_formula is None: logger.warning(f"Provided 'treatment' arg '{treatment}' is not binary 0/1 and no 'group' column found. Searching other columns...") potential_group_cols = [] # Exclude outcome, time var, unit ID var, and common time indicators like 'post' excluded_cols = [outcome, time_var, group_var, 'post', 'is_post_treatment', 'did_interaction'] for col_name in df_processed.columns: if col_name in excluded_cols: continue try: col_data = df_processed[col_name] # Ensure we are working with a Series if isinstance(col_data, pd.DataFrame): if col_data.shape[1] == 1: col_data = col_data.iloc[:, 0] # Extract the Series else: logger.warning(f"Skipping multi-column DataFrame slice for '{col_name}'.") continue # Check if the Series can be interpreted as binary 0/1 if not pd.api.types.is_numeric_dtype(col_data) and not pd.api.types.is_bool_dtype(col_data): continue # Skip non-numeric/non-boolean columns unique_vals = set(col_data.dropna().unique()) # Simplified check: directly test if unique values are a subset of {0, 1} if unique_vals.issubset({0, 1}): logger.info(f" Found potential binary indicator: {col_name}") potential_group_cols.append(col_name) except AttributeError as ae: # Catch attribute errors likely due to unexpected types logger.warning(f"Attribute error checking column '{col_name}': {ae}. Skipping.") except Exception as e: logger.warning(f"Unexpected error checking column '{col_name}' during group ID search: {e}") if potential_group_cols: treated_group_col_for_formula = potential_group_cols[0] # Take the first suitable one found logger.info(f"Using column '{treated_group_col_for_formula}' found during search as binary group indicator.") else: # Final fallback: Use the originally identified group_var, but warn heavily treated_group_col_for_formula = group_var logger.error(f"CRITICAL WARNING: Could not find suitable binary treatment group indicator. Using '{group_var}', but this is likely incorrect and will produce invalid DiD estimates.") # --- Final Check --- if treated_group_col_for_formula not in df_processed.columns: # This case should ideally not happen with the logic above but added defensively raise ValueError(f"Determined treatment group column '{treated_group_col_for_formula}' not found in DataFrame.") if df_processed[treated_group_col_for_formula].nunique(dropna=True) > 2: logger.warning(f"Selected treatment group column '{treated_group_col_for_formula}' is not binary (has {df_processed[treated_group_col_for_formula].nunique()} unique values). DiD requires binary treatment group.") # --- Step 2: Create Indicator Variables --- post_indicator_col = 'post' if post_indicator_col not in df_processed.columns: # Create the post indicator if it doesn't exist df_processed[post_indicator_col] = create_post_indicator(df_processed, time_var, treatment_period) # Interaction term is treatment group * post interaction_term_col = 'did_interaction' # Keep explicit interaction term df_processed[interaction_term_col] = df_processed[treated_group_col_for_formula] * df_processed[post_indicator_col] # --- Step 3: Validate Parallel Trends (using the group column) --- parallel_trends_validation = validate_parallel_trends(df_processed, time_var, outcome, treated_group_col_for_formula, treatment_period, dataset_description) # Note: The validation result is currently just a placeholder if not parallel_trends_validation.get('valid', False): logger.warning("Parallel trends assumption potentially violated (based on placeholder check). Proceeding with estimation, but results may be biased.") # Add this info to the final results diagnostics # --- Step 4: Prepare for Statsmodels Estimation --- # (DoWhy section commented out for simplicity) # all_common_causes = covariates + [time_var, group_var] # group_var is unit ID # use_dowhy_estimate = False # dowhy_estimate = None # try: # # Create DoWhy CausalModel # model = CausalModel( # data=df_processed, # treatment=treated_group_col_for_formula, # Use group indicator here # outcome=outcome, # common_causes=all_common_causes, # ) # logger.info("DoWhy CausalModel created for DiD estimation.") # # Identify estimand # identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) # logger.info(f"DoWhy identified estimand: {identified_estimand.estimand_type}") # # Try to estimate using DiD estimator if available in DoWhy # try: # logger.info("Attempting to use DoWhy's DiD estimator...") # # Debug info - print DataFrame info to help diagnose possible issues # logger.debug(f"DataFrame shape before DoWhy DiD: {df_processed.shape}") # # ... (rest of DoWhy debug logs commented out) ... # # Create params dictionary for DoWhy DiD estimator # did_params = { # 'time_var': time_var, # 'treatment_period': treatment_period, # 'unit_var': group_var # } # # Add control variables if available # if covariates: # did_params['control_vars'] = covariates # logger.debug(f"DoWhy DiD params: {did_params}") # # Try to use DiD estimator from DoWhy (requires recent version of DoWhy) # if hasattr(model, 'estimate_effect'): # try: # # First check if difference_in_differences method is available # available_methods = model.get_available_effect_estimators() if hasattr(model, 'get_available_effect_estimators') else [] # logger.debug(f"Available DoWhy estimators: {available_methods}") # if "difference_in_differences" not in str(available_methods): # logger.warning("'difference_in_differences' estimator not found in available DoWhy estimators. Falling back to statsmodels.") # else: # # Try the estimation with more error handling # logger.info("Calling DoWhy DiD estimator...") # estimate = model.estimate_effect( # identified_estimand, # method_name="difference_in_differences", # method_params=did_params # ) # if estimate: # # Extra check to verify estimate has expected attributes # if hasattr(estimate, 'value') and not pd.isna(estimate.value): # dowhy_estimate = estimate # use_dowhy_estimate = True # logger.info(f"Successfully used DoWhy's DiD estimator. Effect estimate: {estimate.value}") # else: # logger.warning(f"DoWhy's DiD estimator returned invalid estimate: {estimate}. Falling back to statsmodels.") # else: # logger.warning("DoWhy's DiD estimator returned None. Falling back to statsmodels.") # except IndexError as idx_err: # # Handle specific IndexError that's occurring # logger.error(f"IndexError in DoWhy DiD estimator: {idx_err}. Check input data structure.") # # Trace more details about the error # import traceback # logger.error(f"Error traceback: {traceback.format_exc()}") # logger.warning("Falling back to statsmodels due to IndexError in DoWhy.") # else: # logger.warning("DoWhy model does not have estimate_effect method. Falling back to statsmodels.") # except (ImportError, AttributeError) as e: # logger.warning(f"DoWhy DiD estimator not available or not implemented: {e}. Falling back to statsmodels.") # except ValueError as ve: # logger.error(f"ValueError in DoWhy DiD estimator: {ve}. Likely issue with data formatting. Falling back to statsmodels.") # except Exception as e: # logger.error(f"Error using DoWhy's DiD estimator: {e}. Falling back to statsmodels.") # # Add traceback for better debugging # import traceback # logger.error(f"Full error traceback: {traceback.format_exc()}") # except Exception as e: # logger.error(f"Failed to create DoWhy CausalModel: {e}", exc_info=True) # # model = None # Set model to None if creation fails # Create parameters dictionary for formatting results parameters = { "time_var": time_var, "group_var": group_var, # Unit ID "treatment_indicator": treated_group_col_for_formula, # Group indicator used in formula basis "post_indicator": post_indicator_col, "treatment_period_start": treatment_period, "covariates": covariates, } # Group diagnostics for formatting did_diagnostics = { "parallel_trends": parallel_trends_validation, # "placebo_test": run_placebo_test(...) } # If DoWhy estimation was successful, use those results (Section Commented Out) # if use_dowhy_estimate and dowhy_estimate: # logger.info("Using DoWhy DiD estimation results.") # parameters["estimation_method"] = "DoWhy Difference-in-Differences" # # Format the results # formatted_results = format_dowhy_results(dowhy_estimate, did_diagnostics, parameters) # else: # --- Step 5: Use Statsmodels OLS --- logger.info("Determining Statsmodels OLS formula based on number of time periods...") num_time_periods = df_processed[time_var].nunique() interaction_term_key_for_results: str method_details_str: str formula: str if num_time_periods == 2: logger.info( f"Number of unique time periods is 2. Using 2x2 DiD formula: " f"{outcome} ~ {treated_group_col_for_formula} * {post_indicator_col}" ) # For 2x2 DiD: outcome ~ group * post_indicator # The interaction term A:B in statsmodels gives the DiD estimate. formula_core = f"{treated_group_col_for_formula} * {post_indicator_col}" interaction_term_key_for_results = f"{treated_group_col_for_formula}:{post_indicator_col}" formula_parts = [formula_core] main_model_terms = {outcome, treated_group_col_for_formula, post_indicator_col} if covariates: filtered_covs = [ c for c in covariates if c not in main_model_terms ] if filtered_covs: formula_parts.extend(filtered_covs) formula = f"{outcome} ~ {' + '.join(formula_parts)}" parameters["estimation_method"] = "Statsmodels OLS for 2x2 DiD (Group * Post interaction)" method_details_str = "DiD via Statsmodels 2x2 (Group * Post interaction)" else: # num_time_periods > 2 logger.info( f"Number of unique time periods is {num_time_periods} (>2). " f"Using TWFE DiD formula: {outcome} ~ {interaction_term_col} + C({group_var}) + C({time_var})" ) # For TWFE: outcome ~ actual_treatment_variable + UnitFE + TimeFE # actual_treatment_variable is interaction_term_col (e.g., treated_group * post_indicator) # UnitFE is C(group_var), TimeFE is C(time_var) formula_parts = [ interaction_term_col, f"C({group_var})", f"C({time_var})" ] interaction_term_key_for_results = interaction_term_col main_model_terms = {outcome, interaction_term_col, group_var, time_var} if covariates: filtered_covs = [ c for c in covariates if c not in main_model_terms ] if filtered_covs: formula_parts.extend(filtered_covs) formula = f"{outcome} ~ {' + '.join(formula_parts)}" parameters["estimation_method"] = "Statsmodels OLS with TWFE (C() Notation)" method_details_str = "DiD via Statsmodels TWFE (C() Notation)" try: logger.info(f"Using formula: {formula}") logger.debug(f"Data head for statsmodels:\n{df_processed.head().to_string()}") logger.debug(f"Regression DataFrame shape: {df_processed.shape}, Columns: {df_processed.columns.tolist()}") ols_model = smf.ols(formula=formula, data=df_processed) if group_var not in df_processed.columns: # This check is mainly for clustering but good to ensure group_var exists. # For 2x2, group_var (unit ID) might not be in formula but needed for clustering. raise ValueError(f"Clustering variable '{group_var}' (panel unit ID) not found in regression data.") logger.debug(f"Clustering standard errors by: {group_var}") results = ols_model.fit(cov_type='cluster', cov_kwds={'groups': df_processed[group_var]}) logger.info("Statsmodels estimation complete.") logger.info(f"Statsmodels Results Summary:\n{results.summary()}") logger.debug(f"Extracting results using interaction term key: {interaction_term_key_for_results}") parameters["final_formula"] = formula parameters["interaction_term_coefficient_name"] = interaction_term_key_for_results formatted_results = format_did_results(results, interaction_term_key_for_results, did_diagnostics, method_details=method_details_str, parameters=parameters) formatted_results["estimator"] = "statsmodels" except Exception as e: logger.error(f"Statsmodels OLS estimation failed: {e}", exc_info=True) raise ValueError(f"DiD estimation failed (both DoWhy and Statsmodels): {e}") # --- Add Interpretation --- (Now add interpretation to the formatted results) try: # Use the llm_instance fetched earlier interpretation = interpret_did_results(formatted_results, did_diagnostics, dataset_description, llm=llm_instance) formatted_results['interpretation'] = interpretation except Exception as interp_e: logger.error(f"DiD Interpretation failed: {interp_e}") formatted_results['interpretation'] = "Interpretation failed." return formatted_results