File size: 5,598 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pytest

# Import the function to test and constants
from auto_causal.components.decision_tree import (
    select_method,
    METHOD_ASSUMPTIONS, # Import assumptions map
    REGRESSION_ADJUSTMENT, LINEAR_REGRESSION, LINEAR_REGRESSION_COV,
    DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, PROPENSITY_SCORE_MATCHING,
    INSTRUMENTAL_VARIABLE
)

# --- Test Data Fixtures (Optional, but good practice) ---
# Using simple dicts for now

@pytest.fixture
def base_variables():
    return {
        "treatment_variable": "T",
        "outcome_variable": "Y",
        "covariates": ["X1", "X2"],
        "time_variable": None,
        "group_variable": None,
        "instrument_variable": None,
        "running_variable": None,
        "cutoff_value": None
    }

@pytest.fixture
def base_dataset_analysis():
    return {
        "temporal_structure": False
        # Add other keys as needed by specific tests, e.g., potential_instruments
    }

# --- Test Cases ---

def test_no_covariates(base_dataset_analysis, base_variables):
    """Test: No covariates provided -> Regression Adjustment"""
    variables = base_variables.copy()
    variables["covariates"] = []
    result = select_method(base_dataset_analysis, variables, is_rct=False)
    assert result["selected_method"] == REGRESSION_ADJUSTMENT
    assert "no covariates" in result["method_justification"].lower()
    assert result["method_assumptions"] == METHOD_ASSUMPTIONS[REGRESSION_ADJUSTMENT]

def test_rct_no_covariates(base_dataset_analysis, base_variables):
    """Test: RCT, no covariates -> Linear Regression"""
    variables = base_variables.copy()
    variables["covariates"] = [] # Explicitly empty
    # Even though the first check catches empty covariates, test RCT path specifically
    result = select_method(base_dataset_analysis, variables, is_rct=True)
    # The initial check for no covariates takes precedence
    assert result["selected_method"] == REGRESSION_ADJUSTMENT
    # assert result["selected_method"] == LINEAR_REGRESSION # This won't be reached

def test_rct_with_covariates(base_dataset_analysis, base_variables):
    """Test: RCT with covariates -> Linear Regression with Covariates"""
    variables = base_variables.copy()
    result = select_method(base_dataset_analysis, variables, is_rct=True)
    assert result["selected_method"] == LINEAR_REGRESSION_COV
    assert "rct" in result["method_justification"].lower()
    assert "covariates are provided" in result["method_justification"].lower()
    assert result["method_assumptions"] == METHOD_ASSUMPTIONS[LINEAR_REGRESSION_COV]

def test_observational_temporal(base_dataset_analysis, base_variables):
    """Test: Observational, temporal structure -> DiD"""
    variables = base_variables.copy()
    variables["time_variable"] = "time"
    variables["group_variable"] = "unit" # Often needed for DiD context
    dataset_analysis = base_dataset_analysis.copy()
    dataset_analysis["temporal_structure"] = True
    result = select_method(dataset_analysis, variables, is_rct=False)
    assert result["selected_method"] == DIFF_IN_DIFF
    assert "temporal structure" in result["method_justification"].lower()
    assert result["method_assumptions"] == METHOD_ASSUMPTIONS[DIFF_IN_DIFF]

def test_observational_rdd(base_dataset_analysis, base_variables):
    """Test: Observational, RDD vars present -> RDD"""
    variables = base_variables.copy()
    variables["running_variable"] = "score"
    variables["cutoff_value"] = 50
    result = select_method(base_dataset_analysis, variables, is_rct=False)
    assert result["selected_method"] == REGRESSION_DISCONTINUITY
    assert "running variable" in result["method_justification"].lower()
    assert "cutoff" in result["method_justification"].lower()
    assert result["method_assumptions"] == METHOD_ASSUMPTIONS[REGRESSION_DISCONTINUITY]

def test_observational_iv(base_dataset_analysis, base_variables):
    """Test: Observational, IV present -> IV"""
    variables = base_variables.copy()
    variables["instrument_variable"] = "Z"
    result = select_method(base_dataset_analysis, variables, is_rct=False)
    assert result["selected_method"] == INSTRUMENTAL_VARIABLE
    assert "instrumental variable" in result["method_justification"].lower()
    assert result["method_assumptions"] == METHOD_ASSUMPTIONS[INSTRUMENTAL_VARIABLE]

def test_observational_confounders_default_psm(base_dataset_analysis, base_variables):
    """Test: Observational, confounders, no other design -> PSM (default)"""
    variables = base_variables.copy() # Has covariates by default
    # Ensure no other conditions are met
    dataset_analysis = base_dataset_analysis.copy()
    dataset_analysis["temporal_structure"] = False
    variables["time_variable"] = None
    variables["running_variable"] = None
    variables["instrument_variable"] = None
    
    result = select_method(dataset_analysis, variables, is_rct=False)
    assert result["selected_method"] == PROPENSITY_SCORE_MATCHING
    assert "observed confounders" in result["method_justification"].lower()
    assert "selected as the default method" in result["method_justification"].lower()
    assert result["method_assumptions"] == METHOD_ASSUMPTIONS[PROPENSITY_SCORE_MATCHING]

# Note: A specific test for the final fallback (Reg Adjustment for observational 
# with covariates but somehow no other method fits) might be hard to trigger 
# given the current logic defaults to PSM if covariates exist and IV/RDD/DiD don't apply.
# The initial 'no covariates' test effectively covers the main Reg Adjustment path.

if __name__ == '__main__':
    pytest.main()