Spaces:
Running
Running
""" | |
Regression Discontinuity Design (RDD) Estimator. | |
Tries to use DoWhy's RDD implementation first, falling back to a basic | |
comparison of linear fits around the cutoff if DoWhy fails. | |
""" | |
import pandas as pd | |
import statsmodels.api as sm | |
from dowhy import CausalModel | |
from typing import Dict, Any, List, Optional | |
import logging | |
from langchain.chat_models.base import BaseChatModel # For type hinting llm | |
from .diagnostics import run_rdd_diagnostics | |
from .llm_assist import interpret_rdd_results | |
logger = logging.getLogger(__name__) | |
# Attempt to import specific functions from the evan-magnusson/rdd package | |
_rdd_estimator_func_em = None | |
_rdd_optimal_bw_func_em = None | |
_rdd_em_import_error_message = "" | |
try: | |
import rdd | |
from rdd import rdd | |
logger.info("Successfully imported 'rdd' and 'optimal_bandwidth' from evan-magnusson/rdd package.") | |
except ImportError as e: | |
_rdd_em_import_error_message = f"ImportError for evan-magnusson/rdd: {e}. This package is needed for 'effect_estimate_rdd'." | |
logger.warning(_rdd_em_import_error_message) | |
except Exception as e: # Catch other potential errors during import | |
_rdd_em_import_error_message = f"An unexpected error occurred during import from evan-magnusson/rdd: {e}" | |
logger.warning(_rdd_em_import_error_message) | |
def estimate_effect_dowhy(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], **kwargs) -> Dict[str, Any]: | |
"""Estimate RDD effect using DoWhy.""" | |
logger.info("Attempting RDD estimation using DoWhy.") | |
if covariates: | |
logger.warning("Covariates provided but may not be used by the DoWhy RDD method_name='rdd'. Support varies.") | |
# For DoWhy RDD, we don't typically specify common causes in the model | |
# constructor in the same way as backdoor. The running variable is handled | |
# via method_params. Covariates might be used by specific underlying estimators | |
# if supported, but the basic RDD identification doesn't use them directly. | |
model = CausalModel( | |
data=df, | |
treatment=treatment, | |
outcome=outcome, | |
# No explicit graph needed for iv.regression_discontinuity method | |
) | |
# Identify the effect (DoWhy internally identifies RDD as IV) | |
# Although potentially redundant if method_name implies identification, | |
# the API requires identified_estimand as the first argument. | |
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) | |
# Estimate using RDD method | |
# Note: DoWhy's RDD often has limited direct support for covariates. | |
# Bandwidth selection is crucial and often done internally or specified. | |
bandwidth = kwargs.get('bandwidth') # Get user-specified bandwidth if provided | |
if bandwidth is None: | |
# Very basic default bandwidth if none provided - consider better methods | |
range_rv = df[running_variable].max() - df[running_variable].min() | |
bandwidth = 0.1 * range_rv | |
logger.warning(f"No bandwidth specified, using basic default: {bandwidth:.3f}") | |
estimate = model.estimate_effect( | |
identified_estimand, # ADD identified_estimand argument | |
method_name="iv.regression_discontinuity", | |
method_params={ | |
'rd_variable_name': running_variable, | |
'rd_threshold_value': cutoff_value, | |
'rd_bandwidth': bandwidth, | |
# 'covariates': covariates # Support depends on DoWhy version/estimator | |
}, | |
test_significance=True # Ask DoWhy to calculate p-values if possible | |
) | |
# Extract results - DoWhy's RDD estimate structure might vary | |
effect = estimate.value | |
# DoWhy's RDD significance testing might be limited/indirect | |
# Try to get p-value if estimate object supports it, else None | |
p_value = getattr(estimate, 'test_significance_pvalue', None) | |
if isinstance(p_value, (list, tuple)): | |
p_value = p_value[0] # Handle cases where it might be wrapped | |
# Confidence intervals might not be directly available from this method easily | |
conf_int = getattr(estimate, 'confidence_interval', None) | |
std_err = getattr(estimate, 'standard_error', None) | |
return { | |
'effect_estimate': effect, | |
'p_value': p_value, | |
'confidence_interval': conf_int, | |
'standard_error': std_err, | |
'method_details': f"DoWhy RDD (Bandwidth: {bandwidth:.3f})", | |
} | |
def estimate_effect_fallback(df: pd.DataFrame, treatment: str, outcome: str, running_variable: str, cutoff_value: float, covariates: Optional[List[str]], **kwargs) -> Dict[str, Any]: | |
"""Estimate RDD effect using simple linear regression comparison fallback.""" | |
logger.warning("DoWhy RDD failed or not used. Falling back to simple linear regression comparison.") | |
if covariates: | |
logger.warning("Covariates provided but are ignored in the fallback RDD linear regression estimation.") | |
bandwidth = kwargs.get('bandwidth') | |
if bandwidth is None: | |
range_rv = df[running_variable].max() - df[running_variable].min() | |
bandwidth = 0.1 * range_rv | |
logger.warning(f"No bandwidth specified for fallback, using basic default: {bandwidth:.3f}") | |
# Filter data within bandwidth | |
df_bw = df[(df[running_variable] >= cutoff_value - bandwidth) & (df[running_variable] <= cutoff_value + bandwidth)].copy() | |
if df_bw.empty: | |
raise ValueError("No data within the specified bandwidth.") | |
df_bw['above_cutoff'] = (df_bw[running_variable] >= cutoff_value).astype(int) | |
# Define predictors for the regression | |
# Interaction term allows different slopes above and below the cutoff | |
df_bw['running_centered'] = df_bw[running_variable] - cutoff_value | |
df_bw['running_x_above'] = df_bw['running_centered'] * df_bw['above_cutoff'] | |
predictors = ['above_cutoff', 'running_centered', 'running_x_above'] | |
# Covariates are NOT included in this basic RDD model | |
# if covariates: | |
# predictors.extend(covariates) # REMOVED as per user request | |
required_cols = [outcome] + predictors | |
missing_cols = [col for col in required_cols if col not in df_bw.columns] | |
if missing_cols: | |
raise ValueError(f"Fallback RDD missing columns: {missing_cols}") | |
df_analysis = df_bw[required_cols].dropna() | |
if df_analysis.empty: | |
raise ValueError("No data remaining after dropping NaNs for fallback RDD.") | |
X = df_analysis[predictors] | |
X = sm.add_constant(X) | |
y = df_analysis[outcome] | |
formula = f"{outcome} ~ {' + '.join(predictors)} + const" | |
logger.info(f"Running fallback RDD regression: {formula}") | |
model = sm.OLS(y, X) | |
# Use robust standard errors | |
results = model.fit(cov_type='HC1') | |
# The coefficient for 'above_cutoff' represents the jump at the cutoff | |
effect = results.params['above_cutoff'] | |
p_value = results.pvalues['above_cutoff'] | |
conf_int = results.conf_int().loc['above_cutoff'].tolist() | |
std_err = results.bse['above_cutoff'] | |
return { | |
'effect_estimate': effect, | |
'p_value': p_value, | |
'confidence_interval': conf_int, | |
'standard_error': std_err, | |
'method_details': f"Fallback Linear Interaction (Bandwidth: {bandwidth:.3f})", | |
'formula': formula, | |
'model_summary': results.summary() | |
} | |
def effect_estimate_rdd( | |
df: pd.DataFrame, | |
outcome: str, | |
running_variable: str, | |
cutoff_value: float, | |
treatment: Optional[str] = None, # Kept for API consistency, but unused by evan-magnusson/rdd | |
covariates: Optional[List[str]] = None, | |
bandwidth: Optional[float] = None, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
Estimates RDD effect using the 'evan-magnusson/rdd' package. | |
Uses IK optimal bandwidth selection from the same package by default. | |
""" | |
logger.info(f"Attempting RDD estimation using 'evan-magnusson/rdd' for outcome '{outcome}' and running variable '{running_variable}'.") | |
if treatment: | |
logger.info(f"Treatment variable '{treatment}' provided but is not explicitly used by the evan-magnusson/rdd estimation function.") | |
if covariates: | |
logger.warning("Covariates provided but are ignored by this 'evan-magnusson/rdd' implementation.") | |
# --- Bandwidth Selection --- | |
final_bandwidth = None | |
bandwidth_selection_method = "unknown" | |
if bandwidth is not None and bandwidth > 0: | |
logger.info(f"Using user-specified bandwidth: {bandwidth:.4f}") | |
final_bandwidth = bandwidth | |
bandwidth_selection_method = "user-specified" | |
else: | |
if bandwidth is not None and bandwidth <= 0: | |
logger.warning(f"User-specified bandwidth {bandwidth} is not positive. Attempting IK optimal bandwidth selection.") | |
try: | |
logger.info(f"Attempting IK optimal bandwidth selection using _rdd_optimal_bw_func_em for {outcome} ~ {running_variable} cut at {cutoff_value}.") | |
optimal_bw_val = rdd.optimal_bandwidth(df[outcome], df[running_variable], cut=cutoff_value) | |
if optimal_bw_val is not None and optimal_bw_val > 0: | |
final_bandwidth = optimal_bw_val | |
bandwidth_selection_method = "ik_optimal (evan-magnusson/rdd)" | |
logger.info(f"IK optimal bandwidth from evan-magnusson/rdd: {final_bandwidth:.4f}") | |
else: | |
logger.warning(f"IK optimal bandwidth from evan-magnusson/rdd was None or non-positive: {optimal_bw_val}. Falling back to default.") | |
except Exception as e: | |
logger.warning(f"IK optimal bandwidth selection from evan-magnusson/rdd failed: {e}. Falling back to default.") | |
if final_bandwidth is None: # Fallback if user did not specify and IK failed/invalid | |
logger.info("Falling back to default bandwidth (10% of running variable range).") | |
rv_min = df[running_variable].min() | |
rv_max = df[running_variable].max() | |
rv_range = rv_max - rv_min | |
if rv_range > 0: | |
final_bandwidth = 0.1 * rv_range | |
bandwidth_selection_method = "default_10_percent_range" | |
logger.info(f"Using default 10% range bandwidth: {final_bandwidth:.4f}") | |
else: | |
err_msg = "Running variable range is not positive. Cannot determine a default bandwidth for evan-magnusson/rdd." | |
logger.error(err_msg) | |
raise ValueError(err_msg) | |
if final_bandwidth is None or final_bandwidth <= 0: | |
raise ValueError(f"Could not determine a valid positive bandwidth for evan-magnusson/rdd. Last method: {bandwidth_selection_method}") | |
# --- RDD Estimation --- | |
try: | |
logger.info(f"Running RDD estimation with evan-magnusson/rdd: y='{outcome}', x='{running_variable}', cut={cutoff_value}, bw={final_bandwidth:.4f}") | |
# The evan-magnusson/rdd package's rdd function typically handles dataframes directly | |
# Ensure correct xname for truncated_data | |
data_rdd = rdd.truncated_data(df, running_variable,final_bandwidth, cut=cutoff_value) | |
model = rdd.rdd( | |
data_rdd, | |
xname=running_variable, # Correct: Name of the running variable column | |
yname=outcome, # Correct: Name of the outcome variable column | |
cut=cutoff_value | |
) | |
# Extract results - this package creates a treatment dummy 'TREATED' | |
# The 'model' object has a 'results' attribute which is a statsmodels result instance | |
sm_results = model.fit() | |
print(sm_results.summary()) | |
# Extract results - using 'TREATED' based on the provided summary output | |
effect = sm_results.params.get('TREATED') | |
std_err = sm_results.bse.get('TREATED') | |
p_value = sm_results.pvalues.get('TREATED') | |
conf_int_series = sm_results.conf_int() | |
conf_int = conf_int_series.loc['TREATED'].tolist() if 'TREATED' in conf_int_series.index else [None, None] | |
n_obs = model.nobs # or model.n_ if nobs is not available (check package details) | |
# The formula is implicit in the local linear regression performed by the package | |
# Update to reflect 'TREATED' as the dummy variable name if consistently used by the package | |
formula_desc = f"Local linear RDD: {outcome} ~ TREATED + {running_variable}_centered + TREATED*{running_variable}_centered (implicit, from evan-magnusson/rdd)" | |
return { | |
'effect_estimate': effect, | |
'standard_error': std_err, | |
'p_value': p_value, | |
'confidence_interval': conf_int, | |
'method_details': f"RDD (evan-magnusson/rdd package, Bandwidth: {final_bandwidth:.4f})", | |
'bandwidth_used': final_bandwidth, | |
'bandwidth_selection_method': bandwidth_selection_method, | |
'n_obs_in_bandwidth': int(n_obs) if n_obs is not None else None, | |
'formula': formula_desc, | |
'model_summary': sm_results.summary().as_text() if sm_results else "Summary not available." | |
} | |
except Exception as e: | |
logger.error(f"RDD estimation using 'evan-magnusson/rdd' failed: {e}", exc_info=True) | |
# Consider re-raising or returning a more structured error | |
raise e # Or return a dict like in the import failure case | |
def estimate_effect( | |
df: pd.DataFrame, | |
treatment: str, | |
outcome: str, | |
running_variable: str, | |
cutoff_value: float, | |
covariates: Optional[List[str]] = None, | |
bandwidth: Optional[float] = None, # Optional bandwidth param | |
query: Optional[str] = None, | |
llm: Optional[BaseChatModel] = None, | |
**kwargs # Capture other args like rd_estimator from DoWhy if needed | |
) -> Dict[str, Any]: | |
""" | |
Estimates the causal effect using Regression Discontinuity Design. | |
Tries DoWhy implementation first if use_dowhy=True, otherwise uses fallback. | |
Args: | |
df: Input DataFrame. | |
treatment: Name of the treatment variable (often implicitly defined by cutoff). | |
DoWhy might still need it, fallback doesn't use it directly. | |
outcome: Name of the outcome variable. | |
running_variable: Name of the variable determining treatment assignment. | |
cutoff: The threshold value for the running variable. | |
covariates: Optional list of covariate names (support varies). | |
bandwidth: Optional bandwidth around the cutoff. If None, a default is used. | |
use_dowhy: Whether to attempt using the DoWhy library first. | |
query: Optional user query for context. | |
llm: Optional Language Model instance. | |
**kwargs: Additional keyword arguments for underlying methods. | |
Returns: | |
Dictionary containing estimation results. | |
""" | |
required_args = { | |
"running_variable": running_variable, | |
"cutoff_value": cutoff_value | |
} | |
if any(val is None for val in required_args.values()): | |
raise ValueError(f"Missing required RDD arguments: running_variable and cutoff must be provided.") | |
results = {} | |
rdd_em_estimation_error = None # Error from effect_estimate_rdd (evan-magnusson) | |
fallback_estimation_error = None # Error from estimate_effect_fallback | |
# --- Try effect_estimate_rdd (evan-magnusson/rdd) First --- | |
try: | |
logger.info("Attempting RDD estimation using 'effect_estimate_rdd' (evan-magnusson/rdd package).") | |
# Note: treatment is passed but might be unused, covariates are also passed but typically ignored by this specific rdd package | |
results = effect_estimate_rdd( | |
df, | |
outcome, | |
running_variable, | |
cutoff_value, | |
treatment=treatment, # For API consistency, though evan-magnusson/rdd doesn't use it explicitly | |
covariates=covariates, | |
bandwidth=bandwidth, | |
**kwargs | |
) | |
results['method_used'] = 'evan-magnusson/rdd' # Ensure method_used is set | |
logger.info("Successfully estimated effect using 'effect_estimate_rdd'.") | |
except ImportError as ie: # Specifically catch import errors for the rdd package | |
logger.warning(f"'effect_estimate_rdd' could not run due to ImportError (likely evan-magnusson/rdd package not available/functional): {ie}") | |
rdd_em_estimation_error = ie | |
except Exception as e: | |
logger.warning(f"'effect_estimate_rdd' failed during execution: {e}") | |
rdd_em_estimation_error = e | |
# --- Fallback to estimate_effect_fallback if effect_estimate_rdd failed --- | |
if not results: # If effect_estimate_rdd wasn't used or failed | |
logger.info("'effect_estimate_rdd' did not produce results. Attempting fallback using 'estimate_effect_fallback'.") | |
try: | |
fallback_results = estimate_effect_fallback(df, treatment, outcome, running_variable, cutoff_value, covariates, bandwidth=bandwidth, **kwargs) | |
results.update(fallback_results) | |
results['method_used'] = 'Fallback RDD (Linear Interaction with Robust Errors)' | |
fallback_estimation_error = None # Clear fallback error if it succeeded | |
logger.info("Successfully estimated effect using 'estimate_effect_fallback'.") | |
except Exception as e: | |
logger.error(f"Fallback RDD estimation ('estimate_effect_fallback') also failed: {e}") | |
fallback_estimation_error = e | |
# Determine final error status | |
final_estimation_error = None | |
if not results: # If still no results, determine which error to report | |
if fallback_estimation_error: # Fallback was attempted and failed | |
final_estimation_error = fallback_estimation_error | |
logger.error(f"All RDD estimation attempts failed. Last error (from fallback): {final_estimation_error}") | |
elif rdd_em_estimation_error: # effect_estimate_rdd was attempted and failed, fallback was not (or also failed but error not captured) | |
final_estimation_error = rdd_em_estimation_error | |
logger.error(f"All RDD estimation attempts failed. Last error (from effect_estimate_rdd): {final_estimation_error}") | |
else: | |
logger.error("All RDD estimation attempts failed for an unknown reason.") | |
if final_estimation_error: | |
raise ValueError(f"RDD estimation failed. Last error: {final_estimation_error}") | |
else: | |
raise ValueError("RDD estimation failed using all available methods for an unknown reason.") | |
# --- Diagnostics --- | |
try: | |
diag_results = run_rdd_diagnostics(df, outcome, running_variable, cutoff_value, covariates, bandwidth) | |
results['diagnostics'] = diag_results | |
except Exception as diag_e: | |
logger.error(f"RDD Diagnostics failed: {diag_e}") | |
results['diagnostics'] = {"status": "Failed", "error": str(diag_e)} | |
# --- Interpretation --- | |
try: | |
interpretation = interpret_rdd_results(results, results.get('diagnostics'), llm=llm) | |
results['interpretation'] = interpretation | |
except Exception as interp_e: | |
logger.error(f"RDD Interpretation failed: {interp_e}") | |
results['interpretation'] = "Interpretation failed." | |
# Add info about primary attempt if fallback was used | |
if rdd_em_estimation_error and results.get('method_used', '').startswith('Fallback'): | |
results['primary_rdd_em_error_info'] = str(rdd_em_estimation_error) | |
return results | |