File size: 16,557 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
decision tree component for selecting causal inference methods

this module implements the decision tree logic to select the most appropriate
causal inference method based on dataset characteristics and available variables
"""

import logging
from typing import Dict, List, Any, Optional
import pandas as pd

# define method names
BACKDOOR_ADJUSTMENT = "backdoor_adjustment"
LINEAR_REGRESSION = "linear_regression"
DIFF_IN_MEANS = "diff_in_means"
DIFF_IN_DIFF = "difference_in_differences"
REGRESSION_DISCONTINUITY = "regression_discontinuity_design"
PROPENSITY_SCORE_MATCHING = "propensity_score_matching"
INSTRUMENTAL_VARIABLE = "instrumental_variable"
CORRELATION_ANALYSIS = "correlation_analysis"
PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting"
GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score"
FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment"


logger = logging.getLogger(__name__)

# method assumptions mapping
METHOD_ASSUMPTIONS = {
    BACKDOOR_ADJUSTMENT: [
        "no unmeasured confounders (conditional ignorability given covariates)",
        "correct model specification for outcome conditional on treatment and covariates",
        "positivity/overlap (for all covariate values, units could potentially receive either treatment level)"
    ],
    LINEAR_REGRESSION: [
        "linear relationship between treatment, covariates, and outcome",
        "no unmeasured confounders (if observational)",
        "correct model specification",
        "homoscedasticity of errors",
        "normally distributed errors (for inference)"
    ],
    DIFF_IN_MEANS: [
        "treatment is randomly assigned (or as-if random)",
        "no spillover effects",
        "stable unit treatment value assumption (SUTVA)"
    ],
    DIFF_IN_DIFF: [
        "parallel trends between treatment and control groups before treatment",
        "no spillover effects between groups",
        "no anticipation effects before treatment",
        "stable composition of treatment and control groups",
        "treatment timing is exogenous"
    ],
    REGRESSION_DISCONTINUITY: [
        "units cannot precisely manipulate the running variable around the cutoff",
        "continuity of conditional expectation functions of potential outcomes at the cutoff",
        "no other changes occurring precisely at the cutoff"
    ],
    PROPENSITY_SCORE_MATCHING: [
        "no unmeasured confounders (conditional ignorability)",
        "sufficient overlap (common support) between treatment and control groups",
        "correct propensity score model specification"
    ],
    INSTRUMENTAL_VARIABLE: [
        "instrument is correlated with treatment (relevance)",
        "instrument affects outcome only through treatment (exclusion restriction)",
        "instrument is independent of unmeasured confounders (exogeneity/independence)"
    ],
    CORRELATION_ANALYSIS: [
        "data represents a sample from the population of interest",
        "variables are measured appropriately"
    ],
    PROPENSITY_SCORE_WEIGHTING: [
        "no unmeasured confounders (conditional ignorability)",
        "sufficient overlap (common support) between treatment and control groups",
        "correct propensity score model specification",
        "weights correctly specified (e.g., ATE, ATT)"
    ],
    GENERALIZED_PROPENSITY_SCORE: [
        "conditional mean independence",
        "positivity/common support for GPS",
        "correct specification of the GPS model",
        "correct specification of the outcome model",
        "no unmeasured confounders affecting both treatment and outcome, given X",
        "treatment variable is continuous"
    ],
    FRONTDOOR_ADJUSTMENT: [
        "mediator is affected by treatment and affects outcome",
        "mediator is not affected by any confounders of the treatment-outcome relationship"
    ]
}


def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
    excluded_methods = set(excluded_methods or [])
    logger.info(f"Excluded methods: {sorted(excluded_methods)}")

    treatment = dataset_properties.get("treatment_variable")
    outcome = dataset_properties.get("outcome_variable")
    if not treatment or not outcome:
        raise ValueError("Both treatment and outcome variables must be specified")

    instrument_var = dataset_properties.get("instrument_variable")
    running_var = dataset_properties.get("running_variable")
    cutoff_val = dataset_properties.get("cutoff_value")
    time_var = dataset_properties.get("time_variable")
    is_rct = dataset_properties.get("is_rct", False)
    has_temporal = dataset_properties.get("has_temporal_structure", False)
    frontdoor = dataset_properties.get("frontdoor_criterion", False)
    covariate_overlap_result = dataset_properties.get("covariate_overlap_score")
    covariates = dataset_properties.get("covariates", [])
    treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary")

    # Helpers to collect candidates
    candidates = []  # list of (method, priority_index)
    justifications: Dict[str, str] = {}
    assumptions: Dict[str, List[str]] = {}

    def add(method: str, justification: str, prio_order: List[str]):
        if method in justifications:  # already added
            return
        justifications[method] = justification
        assumptions[method] = METHOD_ASSUMPTIONS[method]
        # priority index from provided order (fallback large if not present)
        try:
            idx = prio_order.index(method)
        except ValueError:
            idx = 10**6
        candidates.append((method, idx))

    # ----- Build candidate set (no returns here) -----

    # RCT branch
    if is_rct:
        logger.info("Dataset is from a randomized controlled trial (RCT)")
        rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS]

        if instrument_var and instrument_var != treatment:
            add(INSTRUMENTAL_VARIABLE,
                f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.",
                rct_priority)

        if covariates:
            add(LINEAR_REGRESSION,
                "RCT with covariates—use OLS for precision.",
                rct_priority)
        else:
            add(DIFF_IN_MEANS,
                "Pure RCT without covariates—difference-in-means.",
                rct_priority)

    # Observational branch
    obs_priority_binary = [
        INSTRUMENTAL_VARIABLE,
        PROPENSITY_SCORE_MATCHING,
        PROPENSITY_SCORE_WEIGHTING,
        FRONTDOOR_ADJUSTMENT,
        LINEAR_REGRESSION,
    ]
    obs_priority_nonbinary = [
        INSTRUMENTAL_VARIABLE,
        FRONTDOOR_ADJUSTMENT,
        LINEAR_REGRESSION,
    ]

    # Common early structural signals first (still only add as candidates)
    if has_temporal and time_var:
        add(DIFF_IN_DIFF,
            f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).",
            [DIFF_IN_DIFF])  # highest among itself

    if running_var and cutoff_val is not None:
        add(REGRESSION_DISCONTINUITY,
            f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.",
            [REGRESSION_DISCONTINUITY])

    # Binary vs non-binary pathways
    if treatment_variable_type == "binary":
        if instrument_var:
            add(INSTRUMENTAL_VARIABLE,
                f"Instrumental variable '{instrument_var}' available.",
                obs_priority_binary)

        # Propensity score methods only if covariates exist
        if covariates:
            if covariate_overlap_result is not None:
                ps_method = (PROPENSITY_SCORE_WEIGHTING
                             if covariate_overlap_result < 0.1
                             else PROPENSITY_SCORE_MATCHING)
            else:
                ps_method = PROPENSITY_SCORE_MATCHING
            add(ps_method,
                "Covariates observed; PS method chosen based on overlap.",
                obs_priority_binary)

        if frontdoor:
            add(FRONTDOOR_ADJUSTMENT,
                "Front-door criterion satisfied.",
                obs_priority_binary)

        add(LINEAR_REGRESSION,
            "OLS as a fallback specification.",
            obs_priority_binary)

    else:
        logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}")
        if instrument_var:
            add(INSTRUMENTAL_VARIABLE,
                f"Instrument '{instrument_var}' candidate for non-binary treatment.",
                obs_priority_nonbinary)
        if frontdoor:
            add(FRONTDOOR_ADJUSTMENT,
                "Front-door criterion satisfied.",
                obs_priority_nonbinary)
        add(LINEAR_REGRESSION,
            "Fallback for non-binary treatment without stronger identification.",
            obs_priority_nonbinary)

    # ----- Centralized exclusion handling -----
    # Remove excluded
    filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods]

    # If nothing survives, attempt a safe fallback not excluded
    if not filtered:
        logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}")
        fallback_order = [
            LINEAR_REGRESSION,
            DIFF_IN_MEANS,
            PROPENSITY_SCORE_MATCHING,
            PROPENSITY_SCORE_WEIGHTING,
            DIFF_IN_DIFF,
            REGRESSION_DISCONTINUITY,
            INSTRUMENTAL_VARIABLE,
            FRONTDOOR_ADJUSTMENT,
        ]
        fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None)
        if not fallback:
            # truly nothing left; raise with context
            raise RuntimeError("No viable method remains after exclusions.")
        selected_method = fallback
        alternatives = []
        justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.")
    else:
        # Pick by smallest priority index, then stable by insertion
        filtered.sort(key=lambda x: x[1])
        selected_method = filtered[0][0]
        alternatives = [m for (m, _) in filtered[1:] if m != selected_method]

    logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}")

    return {
        "selected_method": selected_method,
        "method_justification": justifications[selected_method],
        "method_assumptions": assumptions[selected_method],
        "alternatives": alternatives,
        "excluded_methods": sorted(excluded_methods),
    }



def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None):
    """
    Wrapped function to select causal method based on dataset properties and query 

    Args:
      dataset_analysis (Dict): results of dataset analysis
      variables (Dict): dictionary of variable names and types
      is_rct (bool): whether the dataset is from a randomized controlled trial
      llm (BaseChatModel): language model instance for generating prompts
      dataset_description (str): description of the dataset
      original_query (str): the original user query
      excluded_methods (List[str], optional): list of methods to exclude from selection
    """

    logger.info("Running rule-based method selection")


    properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"),
                  "covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"),
                  "time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"),
                  "treatment_variable_type": variables.get("treatment_variable_type", "binary"),
                  "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
                  "frontdoor_criterion": variables.get("frontdoor_criterion", False),
                  "cutoff_value": variables.get("cutoff_value"),
                  "covariate_overlap_score": variables.get("covariate_overlap_result", 0)}
    
    properties["is_rct"] = is_rct
    logger.info(f"Dataset properties for method selection: {properties}")

    return select_method(properties, excluded_methods)



class DecisionTreeEngine:
    """
    Engine for applying decision trees to select appropriate causal methods.
    
    This class wraps the functional decision tree implementation to provide
    an object-oriented interface for method selection.
    """
    
    def __init__(self, verbose=False):
        self.verbose = verbose
    
    def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str],
                      dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply decision tree to select appropriate causal method.
        """

        if self.verbose:
            print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}")
            print(f"Available covariates: {covariates}")

        treatment_variable_type = query_details.get("treatment_variable_type")
        covariate_overlap_result = query_details.get("covariate_overlap_result")
        info = {"treatment_variable": treatment, "outcome_variable": outcome,
                     "covariates": covariates, "time_variable": query_details.get("time_variable"),
                     "group_variable": query_details.get("group_variable"),
                     "instrument_variable": query_details.get("instrument_variable"),
                     "running_variable": query_details.get("running_variable"),
                     "cutoff_value": query_details.get("cutoff_value"),
                     "is_rct": query_details.get("is_rct", False),
                     "has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
                     "frontdoor_criterion": query_details.get("frontdoor_criterion", False),
                     "covariate_overlap_score": covariate_overlap_result,
                     "treatment_variable_type": treatment_variable_type}
        
        result = select_method(info)

        if self.verbose:
            print(f"Selected method: {result['selected_method']}")
            print(f"Justification: {result['method_justification']}")

        result["decision_path"] = self._get_decision_path(result["selected_method"])
        return result
    
    
    def _get_decision_path(self, method):
        if method == "linear_regression":
            return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"]
        elif method == "propensity_score_matching":
            return ["Check if randomized experiment", "Data is observational", 
                    "Check for sufficient covariate overlap", "Sufficient overlap exists"]
        elif method == "propensity_score_weighting":
            return ["Check if randomized experiment", "Data is observational", 
                "Check for sufficient covariate overlap", "Low overlap—weighting preferred"]
        elif method == "backdoor_adjustment":
            return ["Check if randomized experiment", "Data is observational", 
                "Check for sufficient covariate overlap", "Adjusting for covariates"]
        elif method == "instrumental_variable":
            return ["Check if randomized experiment", "Data is observational", 
                "Check for instrumental variables", "Instrument is available"]
        elif method == "regression_discontinuity_design":
            return ["Check if randomized experiment", "Data is observational", 
                "Check for discontinuity", "Discontinuity exists"]
        elif method == "difference_in_differences":
            return ["Check if randomized experiment", "Data is observational", 
                "Check for temporal structure", "Panel data structure exists"]
        elif method == "frontdoor_adjustment":
            return ["Check if randomized experiment", "Data is observational",
                "Check front-door criterion", "Front-door path identified"]
        elif method == "diff_in_means":
            return ["Check if randomized experiment", "Pure RCT without covariates"]
        else:
            return ["Default method selection"]