Spaces:
Running
Running
import pytest | |
import pandas as pd | |
import numpy as np | |
from auto_causal.methods.regression_discontinuity.diagnostics import run_rdd_diagnostics | |
# --- Fixture for RDD data --- | |
def sample_rdd_data(): | |
"""Generates synthetic data suitable for RDD testing.""" | |
np.random.seed(123) | |
n_samples = 200 | |
cutoff = 50.0 | |
treatment_effect = 10.0 | |
running_var = np.random.uniform(cutoff - 20, cutoff + 20, n_samples) | |
treatment = (running_var >= cutoff).astype(int) | |
# Covariate correlated with running variable (potential imbalance) | |
covariate1 = 0.5 * running_var + np.random.normal(0, 5, n_samples) | |
# Covariate uncorrelated (should be balanced) | |
covariate2 = np.random.normal(10, 2, n_samples) | |
error = np.random.normal(0, 5, n_samples) | |
outcome = (10 + 0.8 * running_var + | |
treatment_effect * treatment + | |
1.2 * treatment * (running_var - cutoff) + | |
2.0 * covariate1 + 1.0 * covariate2 + error) | |
df = pd.DataFrame({ | |
'outcome': outcome, | |
'treatment_indicator': treatment, | |
'running_var': running_var, | |
'covariate1': covariate1, | |
'covariate2': covariate2 | |
}) | |
return df | |
# --- Test Cases --- | |
def test_run_rdd_diagnostics_success(sample_rdd_data): | |
"""Test the diagnostics function with covariates.""" | |
covariates = ['covariate1', 'covariate2'] | |
results = run_rdd_diagnostics( | |
sample_rdd_data, | |
'outcome', | |
'running_var', | |
cutoff=50.0, | |
covariates=covariates, | |
bandwidth=10.0 # Use a reasonable bandwidth | |
) | |
assert results["status"] == "Success (Partial Implementation)" | |
assert "details" in results | |
details = results["details"] | |
assert "covariate_balance" in details | |
balance = details['covariate_balance'] | |
assert isinstance(balance, dict) | |
assert 'covariate1' in balance | |
assert 'covariate2' in balance | |
# Check structure of balance results | |
assert 't_statistic' in balance['covariate1'] | |
assert 'p_value' in balance['covariate1'] | |
assert 'balanced' in balance['covariate1'] | |
assert 't_statistic' in balance['covariate2'] | |
assert 'p_value' in balance['covariate2'] | |
assert 'balanced' in balance['covariate2'] | |
# Check expected balance (covariate1 likely unbalanced, covariate2 likely balanced) | |
# Due to random noise, these might occasionally fail, but should usually hold | |
assert balance['covariate1']['balanced'].startswith("No") | |
assert balance['covariate2']['balanced'] == "Yes" | |
# Check placeholders | |
assert details['continuity_density_test'] == "Not Implemented (Requires specialized libraries like rdd)" | |
assert details['visual_inspection'] == "Recommended (Plot outcome vs running variable with fits)" | |
def test_run_rdd_diagnostics_no_covariates(sample_rdd_data): | |
"""Test diagnostics when no covariates are provided.""" | |
results = run_rdd_diagnostics( | |
sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=None, bandwidth=10.0 | |
) | |
assert results["status"] == "Success (Partial Implementation)" | |
assert results["details"]['covariate_balance'] == "No covariates provided to check." | |
def test_run_rdd_diagnostics_small_bandwidth(sample_rdd_data): | |
"""Test diagnostics handles cases with insufficient data in bandwidth.""" | |
# Bandwidth so small it likely excludes one side | |
results = run_rdd_diagnostics( | |
sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=['covariate1'], bandwidth=0.1 | |
) | |
assert results["status"] == "Skipped" | |
assert "Insufficient data near cutoff" in results["reason"] | |
def test_run_rdd_diagnostics_missing_covariate(sample_rdd_data): | |
"""Test diagnostics handles missing covariate columns gracefully.""" | |
results = run_rdd_diagnostics( | |
sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=['covariate1', 'missing_cov'], bandwidth=10.0 | |
) | |
assert results["status"] == "Success (Partial Implementation)" | |
balance = results["details"]['covariate_balance'] | |
assert balance['missing_cov']['status'] == "Column Not Found" | |
assert 't_statistic' in balance['covariate1'] # Check other covariate was still processed | |