import pandas as pd import statsmodels.api as sm from statsmodels.sandbox.regression.gmm import IV2SLS from dowhy import CausalModel # Primary path from typing import Dict, Any, List, Union, Optional import logging from langchain.chat_models.base import BaseChatModel from .diagnostics import run_iv_diagnostics from .llm_assist import identify_instrument_variable, validate_instrument_assumptions_qualitative, interpret_iv_results logger = logging.getLogger(__name__) def build_iv_graph_gml(treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> str: """ Constructs a GML string representing the causal graph for IV. Assumptions: - Instruments cause Treatment - Covariates cause Treatment and Outcome - Treatment causes Outcome - Instruments do NOT directly cause Outcome (Exclusion) - Instruments are NOT caused by Covariates (can be relaxed if needed) - Unobserved Confounder (U) affects Treatment and Outcome Args: treatment: Name of the treatment variable. outcome: Name of the outcome variable. instruments: List of instrument variable names. covariates: List of covariate names. Returns: A GML graph string. """ nodes = [] edges = [] # Define nodes - ensure no duplicates if a variable is both instrument and covariate (SHOULD NOT HAPPEN) # Use a set to ensure unique variable names all_vars_set = set([treatment, outcome] + instruments + covariates + ['U']) all_vars = list(all_vars_set) for var in all_vars: nodes.append(f'node [ id "{var}" label "{var}" ]') # Define edges # Instruments -> Treatment for inst in instruments: edges.append(f'edge [ source "{inst}" target "{treatment}" ]') # Covariates -> Treatment for cov in covariates: # Ensure we don't add self-loops or duplicate edges if cov == treatment (shouldn't happen) if cov != treatment: edges.append(f'edge [ source "{cov}" target "{treatment}" ]') # Covariates -> Outcome for cov in covariates: if cov != outcome: edges.append(f'edge [ source "{cov}" target "{outcome}" ]') # Treatment -> Outcome edges.append(f'edge [ source "{treatment}" target "{outcome}" ]') # Unobserved Confounder -> Treatment and Outcome edges.append(f'edge [ source "U" target "{treatment}" ]') edges.append(f'edge [ source "U" target "{outcome}" ]') # Core IV Assumption: Instruments are NOT caused by U (implicitly handled by not adding edge) # Core IV Assumption: Instruments do NOT directly cause Outcome (handled by not adding edge) # Format nodes and edges with indentation before inserting into f-string formatted_nodes = '\n '.join(nodes) formatted_edges = '\n '.join(edges) gml_string = f""" graph [ directed 1 {formatted_nodes} {formatted_edges} ] """ # Convert print to logger logger.debug("\n--- Generated GML Graph ---") logger.debug(gml_string) logger.debug("-------------------------\n") return gml_string def format_iv_results(estimate: Optional[float], raw_results: Dict, diagnostics: Dict, treatment: str, outcome: str, instrument: List[str], method_used: str, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]: """ Formats the results from IV estimation into a standardized dictionary. Args: estimate: The point estimate of the causal effect. raw_results: Dictionary containing raw outputs from DoWhy/statsmodels. diagnostics: Dictionary containing diagnostic results. treatment: Name of the treatment variable. outcome: Name of the outcome variable. instrument: List of instrument variable names. method_used: 'dowhy' or 'statsmodels'. llm: Optional LLM instance for interpretation. Returns: Standardized results dictionary. """ formatted = { "effect_estimate": estimate, "treatment_variable": treatment, "outcome_variable": outcome, "instrument_variables": instrument, "method_used": method_used, "diagnostics": diagnostics, "raw_results": {k: str(v) for k, v in raw_results.items() if "object" not in k}, # Avoid serializing large objects "confidence_interval": None, "standard_error": None, "p_value": None, "interpretation": "Placeholder" } # Extract details from statsmodels results if available sm_results = raw_results.get('statsmodels_results_object') if method_used == 'statsmodels' and sm_results: try: # Use .bse for standard error in statsmodels results formatted["standard_error"] = float(sm_results.bse[treatment]) formatted["p_value"] = float(sm_results.pvalues[treatment]) conf_int = sm_results.conf_int().loc[treatment].tolist() formatted["confidence_interval"] = [float(ci) for ci in conf_int] except AttributeError as e: logger.warning(f"Could not extract all details from statsmodels results object (likely missing attribute): {e}") except Exception as e: logger.warning(f"Error extracting details from statsmodels results: {e}") # Extract details from DoWhy results if available # Note: DoWhy's CausalEstimate object structure needs inspection dw_results = raw_results.get('dowhy_results_object') if method_used == 'dowhy' and dw_results: try: # Attempt common attributes, may need adjustment based on DoWhy version/output if hasattr(dw_results, 'stderr'): formatted["standard_error"] = float(dw_results.stderr) if hasattr(dw_results, 'p_value'): formatted["p_value"] = float(dw_results.p_value) if hasattr(dw_results, 'conf_intervals'): # Assuming it's stored similarly to statsmodels, might need adjustment ci = dw_results.conf_intervals().loc[treatment].tolist() # Fictional attribute/method - check DoWhy docs! formatted["confidence_interval"] = [float(c) for c in ci] elif hasattr(dw_results, 'get_confidence_intervals'): ci = dw_results.get_confidence_intervals() # Check DoWhy docs for format # Check format of ci before converting if isinstance(ci, (list, tuple)) and len(ci) == 2: formatted["confidence_interval"] = [float(c) for c in ci] # Adapt parsing else: logger.warning(f"Could not parse confidence intervals from DoWhy object: {ci}") except Exception as e: logger.warning(f"Could not extract all details from DoWhy results: {e}. Structure might be different.", exc_info=True) # Avoid printing dir in production code, use logger.debug if needed for dev # logger.debug(f"DoWhy result object dir(): {dir(dw_results)}") # Generate LLM interpretation - pass llm object if estimate is not None: formatted["interpretation"] = interpret_iv_results(formatted, diagnostics, llm=llm) else: formatted["interpretation"] = "Estimation failed, cannot interpret results." return formatted def estimate_effect( df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str], query: Optional[str] = None, dataset_description: Optional[str] = None, llm: Optional[BaseChatModel] = None, **kwargs ) -> Dict[str, Any]: instrument = kwargs.get('instrument_variable') if not instrument: return {"error": "Instrument variable ('instrument_variable') not found in kwargs.", "method_used": "none", "diagnostics": {}} instrument_list = [instrument] if isinstance(instrument, str) else instrument valid_instruments = [inst for inst in instrument_list if isinstance(inst, str)] clean_covariates = [cov for cov in covariates if cov not in valid_instruments] logger.info(f"\n--- Starting Instrumental Variable Estimation ---") logger.info(f"Treatment: {treatment}, Outcome: {outcome}, Instrument(s): {valid_instruments}, Original Covariates: {covariates}, Cleaned Covariates: {clean_covariates}") results = {} method_used = "none" sm_results_obj = None dw_results_obj = None identified_estimand = None # Initialize model = None # Initialize refutation_results = {} # Initialize # --- Input Validation --- required_cols = [treatment, outcome] + valid_instruments + clean_covariates missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: return {"error": f"Missing required columns in DataFrame: {missing_cols}", "method_used": method_used, "diagnostics": {}} if not valid_instruments: return {"error": "Instrument variable(s) must be provided and valid.", "method_used": method_used, "diagnostics": {}} # --- LLM Pre-Checks --- if query and llm: qualitative_check = validate_instrument_assumptions_qualitative(treatment, outcome, valid_instruments, clean_covariates, query, llm=llm) results['llm_assumption_check'] = qualitative_check logger.info(f"LLM Qualitative Assumption Check: {qualitative_check}") # --- Build Graph and Instantiate CausalModel (Do this before estimation attempts) --- # This allows using identify_effect and refute_estimate even if DoWhy estimation fails try: graph = build_iv_graph_gml(treatment, outcome, valid_instruments, clean_covariates) if not graph: raise ValueError("Failed to build GML graph for DoWhy.") model = CausalModel(data=df, treatment=treatment, outcome=outcome, graph=graph) # Identify Effect (essential for refutation later) identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) logger.debug("\nDoWhy Identified Estimand:") logger.debug(identified_estimand) if not identified_estimand: raise ValueError("DoWhy could not identify a valid estimand.") except Exception as model_init_e: logger.error(f"Failed to initialize CausalModel or identify effect: {model_init_e}", exc_info=True) # Cannot proceed without model/estimand for DoWhy or refutation results['error'] = f"Failed to initialize CausalModel: {model_init_e}" # Attempt statsmodels anyway? Or return error? Let's try statsmodels. pass # Allow falling through to statsmodels if desired # --- Primary Path: DoWhy Estimation --- if model and identified_estimand and not kwargs.get('force_statsmodels', False): logger.info("\nAttempting estimation with DoWhy...") try: dw_results_obj = model.estimate_effect( identified_estimand, method_name="iv.instrumental_variable", method_params={'iv_instrument_name': valid_instruments} ) logger.debug("\nDoWhy Estimation Result:") logger.debug(dw_results_obj) results['dowhy_estimate'] = dw_results_obj.value results['dowhy_results_object'] = dw_results_obj method_used = 'dowhy' logger.info("DoWhy estimation successful.") except Exception as e: logger.error(f"DoWhy IV estimation failed: {e}", exc_info=True) results['dowhy_error'] = str(e) if not kwargs.get('allow_fallback', True): logger.warning("Fallback to statsmodels disabled. Estimation failed.") method_used = "dowhy_failed" # Still run diagnostics and format output else: logger.info("Proceeding to statsmodels fallback.") elif not model or not identified_estimand: logger.warning("Skipping DoWhy estimation due to CausalModel initialization/identification failure.") # Ensure we proceed to statsmodels if fallback is allowed if not kwargs.get('allow_fallback', True): logger.error("Cannot estimate effect: CausalModel failed and fallback disabled.") method_used = "dowhy_failed" else: logger.info("Proceeding to statsmodels fallback.") # --- Fallback Path: statsmodels IV2SLS --- if method_used not in ['dowhy', 'dowhy_failed']: logger.info("\nAttempting estimation with statsmodels IV2SLS...") try: df_copy = df.copy().dropna(subset=required_cols) if df_copy.empty: raise ValueError("DataFrame becomes empty after dropping NAs in required columns.") df_copy['intercept'] = 1 exog_regressors = ['intercept'] + clean_covariates endog_var = treatment all_instruments_for_sm = list(dict.fromkeys(exog_regressors + valid_instruments)) endog_data = df_copy[outcome] exog_data_sm_cols = list(dict.fromkeys(exog_regressors + [endog_var])) exog_data_sm = df_copy[exog_data_sm_cols] instrument_data_sm = df_copy[all_instruments_for_sm] num_endog = 1 num_external_iv = len(valid_instruments) if num_endog > num_external_iv: raise ValueError(f"Model underidentified: More endogenous regressors ({num_endog}) than unique external instruments ({num_external_iv}).") iv_model = IV2SLS(endog=endog_data, exog=exog_data_sm, instrument=instrument_data_sm) sm_results_obj = iv_model.fit() logger.info("\nStatsmodels Estimation Summary:") logger.info(f" Estimate for {treatment}: {sm_results_obj.params[treatment]}") logger.info(f" Std Error: {sm_results_obj.bse[treatment]}") logger.info(f" P-value: {sm_results_obj.pvalues[treatment]}") results['statsmodels_estimate'] = sm_results_obj.params[treatment] results['statsmodels_results_object'] = sm_results_obj method_used = 'statsmodels' logger.info("Statsmodels estimation successful.") except Exception as sm_e: logger.error(f"Statsmodels IV estimation also failed: {sm_e}", exc_info=True) results['statsmodels_error'] = str(sm_e) method_used = 'statsmodels_failed' if method_used == "none" else "dowhy_failed_sm_failed" # --- Diagnostics --- logger.info("\nRunning diagnostics...") diagnostics = run_iv_diagnostics(df, treatment, outcome, valid_instruments, clean_covariates, sm_results_obj, dw_results_obj) results['diagnostics'] = diagnostics # --- Refutation Step --- final_estimate_value = results.get('dowhy_estimate') if method_used == 'dowhy' else results.get('statsmodels_estimate') # Only run permute refuter if estimate is valid AND came from DoWhy if method_used == 'dowhy' and dw_results_obj and final_estimate_value is not None: logger.info("\nRunning refutation test (Placebo Treatment - Permute - requires DoWhy estimate object)...") try: # Pass the actual DoWhy estimate object refuter_result = model.refute_estimate( identified_estimand, dw_results_obj, # Pass the original DoWhy result object method_name="placebo_treatment_refuter", placebo_type="permute" # Necessary for IV according to docs/examples ) logger.info("Refutation test completed.") logger.debug(f"Refuter Result:\n{refuter_result}") # Store relevant info from refuter_result (check its structure) refutation_results = { "refuter": "placebo_treatment_refuter", "new_effect": getattr(refuter_result, 'new_effect', 'N/A'), "p_value": getattr(refuter_result, 'refutation_result', {}).get('p_value', 'N/A') if hasattr(refuter_result, 'refutation_result') else 'N/A', # Passed if p-value > 0.05 (or not statistically significant) "passed": getattr(refuter_result, 'refutation_result', {}).get('is_statistically_significant', None) == False if hasattr(refuter_result, 'refutation_result') else None } except Exception as refute_e: logger.error(f"Refutation test failed: {refute_e}", exc_info=True) refutation_results = {"error": f"Refutation failed: {refute_e}"} elif final_estimate_value is not None and method_used == 'statsmodels': logger.warning("Skipping placebo permutation refuter: Estimate was generated by statsmodels, not DoWhy's IV estimator.") refutation_results = {"status": "skipped_wrong_estimator_for_permute"} elif final_estimate_value is None: logger.warning("Skipping refutation test because estimation failed.") refutation_results = {"status": "skipped_due_to_failed_estimation"} else: # Model or estimand failed earlier, or unknown method_used logger.warning(f"Skipping refutation test due to earlier failure (method_used: {method_used}).") refutation_results = {"status": "skipped_due_to_model_failure_or_unknown"} results['refutation_results'] = refutation_results # Add to main results # --- Formatting Results --- if final_estimate_value is None and method_used not in ['dowhy', 'statsmodels']: logger.error("ERROR: Both estimation methods failed.") # Ensure error key exists if not set earlier if 'error' not in results: results['error'] = "Both DoWhy and statsmodels IV estimation failed." logger.info("\n--- Formatting Final Results ---") formatted_results = format_iv_results( final_estimate_value, # Pass the numeric value results, # Pass the dict containing estimate objects and refutation results diagnostics, treatment, outcome, valid_instruments, method_used, llm=llm ) logger.info("--- Instrumental Variable Estimation Complete ---\n") return formatted_results