Spaces:
Running
Running
File size: 4,267 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import pytest
import pandas as pd
import numpy as np
from auto_causal.methods.regression_discontinuity.diagnostics import run_rdd_diagnostics
# --- Fixture for RDD data ---
@pytest.fixture
def sample_rdd_data():
"""Generates synthetic data suitable for RDD testing."""
np.random.seed(123)
n_samples = 200
cutoff = 50.0
treatment_effect = 10.0
running_var = np.random.uniform(cutoff - 20, cutoff + 20, n_samples)
treatment = (running_var >= cutoff).astype(int)
# Covariate correlated with running variable (potential imbalance)
covariate1 = 0.5 * running_var + np.random.normal(0, 5, n_samples)
# Covariate uncorrelated (should be balanced)
covariate2 = np.random.normal(10, 2, n_samples)
error = np.random.normal(0, 5, n_samples)
outcome = (10 + 0.8 * running_var +
treatment_effect * treatment +
1.2 * treatment * (running_var - cutoff) +
2.0 * covariate1 + 1.0 * covariate2 + error)
df = pd.DataFrame({
'outcome': outcome,
'treatment_indicator': treatment,
'running_var': running_var,
'covariate1': covariate1,
'covariate2': covariate2
})
return df
# --- Test Cases ---
def test_run_rdd_diagnostics_success(sample_rdd_data):
"""Test the diagnostics function with covariates."""
covariates = ['covariate1', 'covariate2']
results = run_rdd_diagnostics(
sample_rdd_data,
'outcome',
'running_var',
cutoff=50.0,
covariates=covariates,
bandwidth=10.0 # Use a reasonable bandwidth
)
assert results["status"] == "Success (Partial Implementation)"
assert "details" in results
details = results["details"]
assert "covariate_balance" in details
balance = details['covariate_balance']
assert isinstance(balance, dict)
assert 'covariate1' in balance
assert 'covariate2' in balance
# Check structure of balance results
assert 't_statistic' in balance['covariate1']
assert 'p_value' in balance['covariate1']
assert 'balanced' in balance['covariate1']
assert 't_statistic' in balance['covariate2']
assert 'p_value' in balance['covariate2']
assert 'balanced' in balance['covariate2']
# Check expected balance (covariate1 likely unbalanced, covariate2 likely balanced)
# Due to random noise, these might occasionally fail, but should usually hold
assert balance['covariate1']['balanced'].startswith("No")
assert balance['covariate2']['balanced'] == "Yes"
# Check placeholders
assert details['continuity_density_test'] == "Not Implemented (Requires specialized libraries like rdd)"
assert details['visual_inspection'] == "Recommended (Plot outcome vs running variable with fits)"
def test_run_rdd_diagnostics_no_covariates(sample_rdd_data):
"""Test diagnostics when no covariates are provided."""
results = run_rdd_diagnostics(
sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=None, bandwidth=10.0
)
assert results["status"] == "Success (Partial Implementation)"
assert results["details"]['covariate_balance'] == "No covariates provided to check."
def test_run_rdd_diagnostics_small_bandwidth(sample_rdd_data):
"""Test diagnostics handles cases with insufficient data in bandwidth."""
# Bandwidth so small it likely excludes one side
results = run_rdd_diagnostics(
sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=['covariate1'], bandwidth=0.1
)
assert results["status"] == "Skipped"
assert "Insufficient data near cutoff" in results["reason"]
def test_run_rdd_diagnostics_missing_covariate(sample_rdd_data):
"""Test diagnostics handles missing covariate columns gracefully."""
results = run_rdd_diagnostics(
sample_rdd_data, 'outcome', 'running_var', cutoff=50.0, covariates=['covariate1', 'missing_cov'], bandwidth=10.0
)
assert results["status"] == "Success (Partial Implementation)"
balance = results["details"]['covariate_balance']
assert balance['missing_cov']['status'] == "Column Not Found"
assert 't_statistic' in balance['covariate1'] # Check other covariate was still processed
|