Spaces:
Running
Running
import pandas as pd | |
import statsmodels.api as sm | |
from statsmodels.sandbox.regression.gmm import IV2SLS | |
from dowhy import CausalModel # Primary path | |
from typing import Dict, Any, List, Union, Optional | |
import logging | |
from langchain.chat_models.base import BaseChatModel | |
from .diagnostics import run_iv_diagnostics | |
from .llm_assist import identify_instrument_variable, validate_instrument_assumptions_qualitative, interpret_iv_results | |
logger = logging.getLogger(__name__) | |
def build_iv_graph_gml(treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> str: | |
""" | |
Constructs a GML string representing the causal graph for IV. | |
Assumptions: | |
- Instruments cause Treatment | |
- Covariates cause Treatment and Outcome | |
- Treatment causes Outcome | |
- Instruments do NOT directly cause Outcome (Exclusion) | |
- Instruments are NOT caused by Covariates (can be relaxed if needed) | |
- Unobserved Confounder (U) affects Treatment and Outcome | |
Args: | |
treatment: Name of the treatment variable. | |
outcome: Name of the outcome variable. | |
instruments: List of instrument variable names. | |
covariates: List of covariate names. | |
Returns: | |
A GML graph string. | |
""" | |
nodes = [] | |
edges = [] | |
# Define nodes - ensure no duplicates if a variable is both instrument and covariate (SHOULD NOT HAPPEN) | |
# Use a set to ensure unique variable names | |
all_vars_set = set([treatment, outcome] + instruments + covariates + ['U']) | |
all_vars = list(all_vars_set) | |
for var in all_vars: | |
nodes.append(f'node [ id "{var}" label "{var}" ]') | |
# Define edges | |
# Instruments -> Treatment | |
for inst in instruments: | |
edges.append(f'edge [ source "{inst}" target "{treatment}" ]') | |
# Covariates -> Treatment | |
for cov in covariates: | |
# Ensure we don't add self-loops or duplicate edges if cov == treatment (shouldn't happen) | |
if cov != treatment: | |
edges.append(f'edge [ source "{cov}" target "{treatment}" ]') | |
# Covariates -> Outcome | |
for cov in covariates: | |
if cov != outcome: | |
edges.append(f'edge [ source "{cov}" target "{outcome}" ]') | |
# Treatment -> Outcome | |
edges.append(f'edge [ source "{treatment}" target "{outcome}" ]') | |
# Unobserved Confounder -> Treatment and Outcome | |
edges.append(f'edge [ source "U" target "{treatment}" ]') | |
edges.append(f'edge [ source "U" target "{outcome}" ]') | |
# Core IV Assumption: Instruments are NOT caused by U (implicitly handled by not adding edge) | |
# Core IV Assumption: Instruments do NOT directly cause Outcome (handled by not adding edge) | |
# Format nodes and edges with indentation before inserting into f-string | |
formatted_nodes = '\n '.join(nodes) | |
formatted_edges = '\n '.join(edges) | |
gml_string = f""" | |
graph [ | |
directed 1 | |
{formatted_nodes} | |
{formatted_edges} | |
] | |
""" | |
# Convert print to logger | |
logger.debug("\n--- Generated GML Graph ---") | |
logger.debug(gml_string) | |
logger.debug("-------------------------\n") | |
return gml_string | |
def format_iv_results(estimate: Optional[float], raw_results: Dict, diagnostics: Dict, treatment: str, outcome: str, instrument: List[str], method_used: str, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]: | |
""" | |
Formats the results from IV estimation into a standardized dictionary. | |
Args: | |
estimate: The point estimate of the causal effect. | |
raw_results: Dictionary containing raw outputs from DoWhy/statsmodels. | |
diagnostics: Dictionary containing diagnostic results. | |
treatment: Name of the treatment variable. | |
outcome: Name of the outcome variable. | |
instrument: List of instrument variable names. | |
method_used: 'dowhy' or 'statsmodels'. | |
llm: Optional LLM instance for interpretation. | |
Returns: | |
Standardized results dictionary. | |
""" | |
formatted = { | |
"effect_estimate": estimate, | |
"treatment_variable": treatment, | |
"outcome_variable": outcome, | |
"instrument_variables": instrument, | |
"method_used": method_used, | |
"diagnostics": diagnostics, | |
"raw_results": {k: str(v) for k, v in raw_results.items() if "object" not in k}, # Avoid serializing large objects | |
"confidence_interval": None, | |
"standard_error": None, | |
"p_value": None, | |
"interpretation": "Placeholder" | |
} | |
# Extract details from statsmodels results if available | |
sm_results = raw_results.get('statsmodels_results_object') | |
if method_used == 'statsmodels' and sm_results: | |
try: | |
# Use .bse for standard error in statsmodels results | |
formatted["standard_error"] = float(sm_results.bse[treatment]) | |
formatted["p_value"] = float(sm_results.pvalues[treatment]) | |
conf_int = sm_results.conf_int().loc[treatment].tolist() | |
formatted["confidence_interval"] = [float(ci) for ci in conf_int] | |
except AttributeError as e: | |
logger.warning(f"Could not extract all details from statsmodels results object (likely missing attribute): {e}") | |
except Exception as e: | |
logger.warning(f"Error extracting details from statsmodels results: {e}") | |
# Extract details from DoWhy results if available | |
# Note: DoWhy's CausalEstimate object structure needs inspection | |
dw_results = raw_results.get('dowhy_results_object') | |
if method_used == 'dowhy' and dw_results: | |
try: | |
# Attempt common attributes, may need adjustment based on DoWhy version/output | |
if hasattr(dw_results, 'stderr'): | |
formatted["standard_error"] = float(dw_results.stderr) | |
if hasattr(dw_results, 'p_value'): | |
formatted["p_value"] = float(dw_results.p_value) | |
if hasattr(dw_results, 'conf_intervals'): | |
# Assuming it's stored similarly to statsmodels, might need adjustment | |
ci = dw_results.conf_intervals().loc[treatment].tolist() # Fictional attribute/method - check DoWhy docs! | |
formatted["confidence_interval"] = [float(c) for c in ci] | |
elif hasattr(dw_results, 'get_confidence_intervals'): | |
ci = dw_results.get_confidence_intervals() # Check DoWhy docs for format | |
# Check format of ci before converting | |
if isinstance(ci, (list, tuple)) and len(ci) == 2: | |
formatted["confidence_interval"] = [float(c) for c in ci] # Adapt parsing | |
else: | |
logger.warning(f"Could not parse confidence intervals from DoWhy object: {ci}") | |
except Exception as e: | |
logger.warning(f"Could not extract all details from DoWhy results: {e}. Structure might be different.", exc_info=True) | |
# Avoid printing dir in production code, use logger.debug if needed for dev | |
# logger.debug(f"DoWhy result object dir(): {dir(dw_results)}") | |
# Generate LLM interpretation - pass llm object | |
if estimate is not None: | |
formatted["interpretation"] = interpret_iv_results(formatted, diagnostics, llm=llm) | |
else: | |
formatted["interpretation"] = "Estimation failed, cannot interpret results." | |
return formatted | |
def estimate_effect( | |
df: pd.DataFrame, | |
treatment: str, | |
outcome: str, | |
covariates: List[str], | |
query: Optional[str] = None, | |
dataset_description: Optional[str] = None, | |
llm: Optional[BaseChatModel] = None, | |
**kwargs | |
) -> Dict[str, Any]: | |
instrument = kwargs.get('instrument_variable') | |
if not instrument: | |
return {"error": "Instrument variable ('instrument_variable') not found in kwargs.", "method_used": "none", "diagnostics": {}} | |
instrument_list = [instrument] if isinstance(instrument, str) else instrument | |
valid_instruments = [inst for inst in instrument_list if isinstance(inst, str)] | |
clean_covariates = [cov for cov in covariates if cov not in valid_instruments] | |
logger.info(f"\n--- Starting Instrumental Variable Estimation ---") | |
logger.info(f"Treatment: {treatment}, Outcome: {outcome}, Instrument(s): {valid_instruments}, Original Covariates: {covariates}, Cleaned Covariates: {clean_covariates}") | |
results = {} | |
method_used = "none" | |
sm_results_obj = None | |
dw_results_obj = None | |
identified_estimand = None # Initialize | |
model = None # Initialize | |
refutation_results = {} # Initialize | |
# --- Input Validation --- | |
required_cols = [treatment, outcome] + valid_instruments + clean_covariates | |
missing_cols = [col for col in required_cols if col not in df.columns] | |
if missing_cols: | |
return {"error": f"Missing required columns in DataFrame: {missing_cols}", "method_used": method_used, "diagnostics": {}} | |
if not valid_instruments: | |
return {"error": "Instrument variable(s) must be provided and valid.", "method_used": method_used, "diagnostics": {}} | |
# --- LLM Pre-Checks --- | |
if query and llm: | |
qualitative_check = validate_instrument_assumptions_qualitative(treatment, outcome, valid_instruments, clean_covariates, query, llm=llm) | |
results['llm_assumption_check'] = qualitative_check | |
logger.info(f"LLM Qualitative Assumption Check: {qualitative_check}") | |
# --- Build Graph and Instantiate CausalModel (Do this before estimation attempts) --- | |
# This allows using identify_effect and refute_estimate even if DoWhy estimation fails | |
try: | |
graph = build_iv_graph_gml(treatment, outcome, valid_instruments, clean_covariates) | |
if not graph: | |
raise ValueError("Failed to build GML graph for DoWhy.") | |
model = CausalModel(data=df, treatment=treatment, outcome=outcome, graph=graph) | |
# Identify Effect (essential for refutation later) | |
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) | |
logger.debug("\nDoWhy Identified Estimand:") | |
logger.debug(identified_estimand) | |
if not identified_estimand: | |
raise ValueError("DoWhy could not identify a valid estimand.") | |
except Exception as model_init_e: | |
logger.error(f"Failed to initialize CausalModel or identify effect: {model_init_e}", exc_info=True) | |
# Cannot proceed without model/estimand for DoWhy or refutation | |
results['error'] = f"Failed to initialize CausalModel: {model_init_e}" | |
# Attempt statsmodels anyway? Or return error? Let's try statsmodels. | |
pass # Allow falling through to statsmodels if desired | |
# --- Primary Path: DoWhy Estimation --- | |
if model and identified_estimand and not kwargs.get('force_statsmodels', False): | |
logger.info("\nAttempting estimation with DoWhy...") | |
try: | |
dw_results_obj = model.estimate_effect( | |
identified_estimand, | |
method_name="iv.instrumental_variable", | |
method_params={'iv_instrument_name': valid_instruments} | |
) | |
logger.debug("\nDoWhy Estimation Result:") | |
logger.debug(dw_results_obj) | |
results['dowhy_estimate'] = dw_results_obj.value | |
results['dowhy_results_object'] = dw_results_obj | |
method_used = 'dowhy' | |
logger.info("DoWhy estimation successful.") | |
except Exception as e: | |
logger.error(f"DoWhy IV estimation failed: {e}", exc_info=True) | |
results['dowhy_error'] = str(e) | |
if not kwargs.get('allow_fallback', True): | |
logger.warning("Fallback to statsmodels disabled. Estimation failed.") | |
method_used = "dowhy_failed" | |
# Still run diagnostics and format output | |
else: | |
logger.info("Proceeding to statsmodels fallback.") | |
elif not model or not identified_estimand: | |
logger.warning("Skipping DoWhy estimation due to CausalModel initialization/identification failure.") | |
# Ensure we proceed to statsmodels if fallback is allowed | |
if not kwargs.get('allow_fallback', True): | |
logger.error("Cannot estimate effect: CausalModel failed and fallback disabled.") | |
method_used = "dowhy_failed" | |
else: | |
logger.info("Proceeding to statsmodels fallback.") | |
# --- Fallback Path: statsmodels IV2SLS --- | |
if method_used not in ['dowhy', 'dowhy_failed']: | |
logger.info("\nAttempting estimation with statsmodels IV2SLS...") | |
try: | |
df_copy = df.copy().dropna(subset=required_cols) | |
if df_copy.empty: | |
raise ValueError("DataFrame becomes empty after dropping NAs in required columns.") | |
df_copy['intercept'] = 1 | |
exog_regressors = ['intercept'] + clean_covariates | |
endog_var = treatment | |
all_instruments_for_sm = list(dict.fromkeys(exog_regressors + valid_instruments)) | |
endog_data = df_copy[outcome] | |
exog_data_sm_cols = list(dict.fromkeys(exog_regressors + [endog_var])) | |
exog_data_sm = df_copy[exog_data_sm_cols] | |
instrument_data_sm = df_copy[all_instruments_for_sm] | |
num_endog = 1 | |
num_external_iv = len(valid_instruments) | |
if num_endog > num_external_iv: | |
raise ValueError(f"Model underidentified: More endogenous regressors ({num_endog}) than unique external instruments ({num_external_iv}).") | |
iv_model = IV2SLS(endog=endog_data, exog=exog_data_sm, instrument=instrument_data_sm) | |
sm_results_obj = iv_model.fit() | |
logger.info("\nStatsmodels Estimation Summary:") | |
logger.info(f" Estimate for {treatment}: {sm_results_obj.params[treatment]}") | |
logger.info(f" Std Error: {sm_results_obj.bse[treatment]}") | |
logger.info(f" P-value: {sm_results_obj.pvalues[treatment]}") | |
results['statsmodels_estimate'] = sm_results_obj.params[treatment] | |
results['statsmodels_results_object'] = sm_results_obj | |
method_used = 'statsmodels' | |
logger.info("Statsmodels estimation successful.") | |
except Exception as sm_e: | |
logger.error(f"Statsmodels IV estimation also failed: {sm_e}", exc_info=True) | |
results['statsmodels_error'] = str(sm_e) | |
method_used = 'statsmodels_failed' if method_used == "none" else "dowhy_failed_sm_failed" | |
# --- Diagnostics --- | |
logger.info("\nRunning diagnostics...") | |
diagnostics = run_iv_diagnostics(df, treatment, outcome, valid_instruments, clean_covariates, sm_results_obj, dw_results_obj) | |
results['diagnostics'] = diagnostics | |
# --- Refutation Step --- | |
final_estimate_value = results.get('dowhy_estimate') if method_used == 'dowhy' else results.get('statsmodels_estimate') | |
# Only run permute refuter if estimate is valid AND came from DoWhy | |
if method_used == 'dowhy' and dw_results_obj and final_estimate_value is not None: | |
logger.info("\nRunning refutation test (Placebo Treatment - Permute - requires DoWhy estimate object)...") | |
try: | |
# Pass the actual DoWhy estimate object | |
refuter_result = model.refute_estimate( | |
identified_estimand, | |
dw_results_obj, # Pass the original DoWhy result object | |
method_name="placebo_treatment_refuter", | |
placebo_type="permute" # Necessary for IV according to docs/examples | |
) | |
logger.info("Refutation test completed.") | |
logger.debug(f"Refuter Result:\n{refuter_result}") | |
# Store relevant info from refuter_result (check its structure) | |
refutation_results = { | |
"refuter": "placebo_treatment_refuter", | |
"new_effect": getattr(refuter_result, 'new_effect', 'N/A'), | |
"p_value": getattr(refuter_result, 'refutation_result', {}).get('p_value', 'N/A') if hasattr(refuter_result, 'refutation_result') else 'N/A', | |
# Passed if p-value > 0.05 (or not statistically significant) | |
"passed": getattr(refuter_result, 'refutation_result', {}).get('is_statistically_significant', None) == False if hasattr(refuter_result, 'refutation_result') else None | |
} | |
except Exception as refute_e: | |
logger.error(f"Refutation test failed: {refute_e}", exc_info=True) | |
refutation_results = {"error": f"Refutation failed: {refute_e}"} | |
elif final_estimate_value is not None and method_used == 'statsmodels': | |
logger.warning("Skipping placebo permutation refuter: Estimate was generated by statsmodels, not DoWhy's IV estimator.") | |
refutation_results = {"status": "skipped_wrong_estimator_for_permute"} | |
elif final_estimate_value is None: | |
logger.warning("Skipping refutation test because estimation failed.") | |
refutation_results = {"status": "skipped_due_to_failed_estimation"} | |
else: # Model or estimand failed earlier, or unknown method_used | |
logger.warning(f"Skipping refutation test due to earlier failure (method_used: {method_used}).") | |
refutation_results = {"status": "skipped_due_to_model_failure_or_unknown"} | |
results['refutation_results'] = refutation_results # Add to main results | |
# --- Formatting Results --- | |
if final_estimate_value is None and method_used not in ['dowhy', 'statsmodels']: | |
logger.error("ERROR: Both estimation methods failed.") | |
# Ensure error key exists if not set earlier | |
if 'error' not in results: | |
results['error'] = "Both DoWhy and statsmodels IV estimation failed." | |
logger.info("\n--- Formatting Final Results ---") | |
formatted_results = format_iv_results( | |
final_estimate_value, # Pass the numeric value | |
results, # Pass the dict containing estimate objects and refutation results | |
diagnostics, | |
treatment, | |
outcome, | |
valid_instruments, | |
method_used, | |
llm=llm | |
) | |
logger.info("--- Instrumental Variable Estimation Complete ---\n") | |
return formatted_results |