Spaces:
Running
Running
File size: 16,557 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
"""
decision tree component for selecting causal inference methods
this module implements the decision tree logic to select the most appropriate
causal inference method based on dataset characteristics and available variables
"""
import logging
from typing import Dict, List, Any, Optional
import pandas as pd
# define method names
BACKDOOR_ADJUSTMENT = "backdoor_adjustment"
LINEAR_REGRESSION = "linear_regression"
DIFF_IN_MEANS = "diff_in_means"
DIFF_IN_DIFF = "difference_in_differences"
REGRESSION_DISCONTINUITY = "regression_discontinuity_design"
PROPENSITY_SCORE_MATCHING = "propensity_score_matching"
INSTRUMENTAL_VARIABLE = "instrumental_variable"
CORRELATION_ANALYSIS = "correlation_analysis"
PROPENSITY_SCORE_WEIGHTING = "propensity_score_weighting"
GENERALIZED_PROPENSITY_SCORE = "generalized_propensity_score"
FRONTDOOR_ADJUSTMENT = "frontdoor_adjustment"
logger = logging.getLogger(__name__)
# method assumptions mapping
METHOD_ASSUMPTIONS = {
BACKDOOR_ADJUSTMENT: [
"no unmeasured confounders (conditional ignorability given covariates)",
"correct model specification for outcome conditional on treatment and covariates",
"positivity/overlap (for all covariate values, units could potentially receive either treatment level)"
],
LINEAR_REGRESSION: [
"linear relationship between treatment, covariates, and outcome",
"no unmeasured confounders (if observational)",
"correct model specification",
"homoscedasticity of errors",
"normally distributed errors (for inference)"
],
DIFF_IN_MEANS: [
"treatment is randomly assigned (or as-if random)",
"no spillover effects",
"stable unit treatment value assumption (SUTVA)"
],
DIFF_IN_DIFF: [
"parallel trends between treatment and control groups before treatment",
"no spillover effects between groups",
"no anticipation effects before treatment",
"stable composition of treatment and control groups",
"treatment timing is exogenous"
],
REGRESSION_DISCONTINUITY: [
"units cannot precisely manipulate the running variable around the cutoff",
"continuity of conditional expectation functions of potential outcomes at the cutoff",
"no other changes occurring precisely at the cutoff"
],
PROPENSITY_SCORE_MATCHING: [
"no unmeasured confounders (conditional ignorability)",
"sufficient overlap (common support) between treatment and control groups",
"correct propensity score model specification"
],
INSTRUMENTAL_VARIABLE: [
"instrument is correlated with treatment (relevance)",
"instrument affects outcome only through treatment (exclusion restriction)",
"instrument is independent of unmeasured confounders (exogeneity/independence)"
],
CORRELATION_ANALYSIS: [
"data represents a sample from the population of interest",
"variables are measured appropriately"
],
PROPENSITY_SCORE_WEIGHTING: [
"no unmeasured confounders (conditional ignorability)",
"sufficient overlap (common support) between treatment and control groups",
"correct propensity score model specification",
"weights correctly specified (e.g., ATE, ATT)"
],
GENERALIZED_PROPENSITY_SCORE: [
"conditional mean independence",
"positivity/common support for GPS",
"correct specification of the GPS model",
"correct specification of the outcome model",
"no unmeasured confounders affecting both treatment and outcome, given X",
"treatment variable is continuous"
],
FRONTDOOR_ADJUSTMENT: [
"mediator is affected by treatment and affects outcome",
"mediator is not affected by any confounders of the treatment-outcome relationship"
]
}
def select_method(dataset_properties: Dict[str, Any], excluded_methods: Optional[List[str]] = None) -> Dict[str, Any]:
excluded_methods = set(excluded_methods or [])
logger.info(f"Excluded methods: {sorted(excluded_methods)}")
treatment = dataset_properties.get("treatment_variable")
outcome = dataset_properties.get("outcome_variable")
if not treatment or not outcome:
raise ValueError("Both treatment and outcome variables must be specified")
instrument_var = dataset_properties.get("instrument_variable")
running_var = dataset_properties.get("running_variable")
cutoff_val = dataset_properties.get("cutoff_value")
time_var = dataset_properties.get("time_variable")
is_rct = dataset_properties.get("is_rct", False)
has_temporal = dataset_properties.get("has_temporal_structure", False)
frontdoor = dataset_properties.get("frontdoor_criterion", False)
covariate_overlap_result = dataset_properties.get("covariate_overlap_score")
covariates = dataset_properties.get("covariates", [])
treatment_variable_type = dataset_properties.get("treatment_variable_type", "binary")
# Helpers to collect candidates
candidates = [] # list of (method, priority_index)
justifications: Dict[str, str] = {}
assumptions: Dict[str, List[str]] = {}
def add(method: str, justification: str, prio_order: List[str]):
if method in justifications: # already added
return
justifications[method] = justification
assumptions[method] = METHOD_ASSUMPTIONS[method]
# priority index from provided order (fallback large if not present)
try:
idx = prio_order.index(method)
except ValueError:
idx = 10**6
candidates.append((method, idx))
# ----- Build candidate set (no returns here) -----
# RCT branch
if is_rct:
logger.info("Dataset is from a randomized controlled trial (RCT)")
rct_priority = [INSTRUMENTAL_VARIABLE, LINEAR_REGRESSION, DIFF_IN_MEANS]
if instrument_var and instrument_var != treatment:
add(INSTRUMENTAL_VARIABLE,
f"RCT encouragement: instrument '{instrument_var}' differs from treatment '{treatment}'.",
rct_priority)
if covariates:
add(LINEAR_REGRESSION,
"RCT with covariates—use OLS for precision.",
rct_priority)
else:
add(DIFF_IN_MEANS,
"Pure RCT without covariates—difference-in-means.",
rct_priority)
# Observational branch
obs_priority_binary = [
INSTRUMENTAL_VARIABLE,
PROPENSITY_SCORE_MATCHING,
PROPENSITY_SCORE_WEIGHTING,
FRONTDOOR_ADJUSTMENT,
LINEAR_REGRESSION,
]
obs_priority_nonbinary = [
INSTRUMENTAL_VARIABLE,
FRONTDOOR_ADJUSTMENT,
LINEAR_REGRESSION,
]
# Common early structural signals first (still only add as candidates)
if has_temporal and time_var:
add(DIFF_IN_DIFF,
f"Temporal structure via '{time_var}'—consider Difference-in-Differences (assumes parallel trends).",
[DIFF_IN_DIFF]) # highest among itself
if running_var and cutoff_val is not None:
add(REGRESSION_DISCONTINUITY,
f"Running variable '{running_var}' with cutoff {cutoff_val}—consider RDD.",
[REGRESSION_DISCONTINUITY])
# Binary vs non-binary pathways
if treatment_variable_type == "binary":
if instrument_var:
add(INSTRUMENTAL_VARIABLE,
f"Instrumental variable '{instrument_var}' available.",
obs_priority_binary)
# Propensity score methods only if covariates exist
if covariates:
if covariate_overlap_result is not None:
ps_method = (PROPENSITY_SCORE_WEIGHTING
if covariate_overlap_result < 0.1
else PROPENSITY_SCORE_MATCHING)
else:
ps_method = PROPENSITY_SCORE_MATCHING
add(ps_method,
"Covariates observed; PS method chosen based on overlap.",
obs_priority_binary)
if frontdoor:
add(FRONTDOOR_ADJUSTMENT,
"Front-door criterion satisfied.",
obs_priority_binary)
add(LINEAR_REGRESSION,
"OLS as a fallback specification.",
obs_priority_binary)
else:
logger.info(f"Non-binary treatment variable detected: {treatment_variable_type}")
if instrument_var:
add(INSTRUMENTAL_VARIABLE,
f"Instrument '{instrument_var}' candidate for non-binary treatment.",
obs_priority_nonbinary)
if frontdoor:
add(FRONTDOOR_ADJUSTMENT,
"Front-door criterion satisfied.",
obs_priority_nonbinary)
add(LINEAR_REGRESSION,
"Fallback for non-binary treatment without stronger identification.",
obs_priority_nonbinary)
# ----- Centralized exclusion handling -----
# Remove excluded
filtered = [(m, p) for (m, p) in candidates if m not in excluded_methods]
# If nothing survives, attempt a safe fallback not excluded
if not filtered:
logger.warning(f"All candidates excluded. Candidates were: {[m for m,_ in candidates]}. Excluded: {sorted(excluded_methods)}")
fallback_order = [
LINEAR_REGRESSION,
DIFF_IN_MEANS,
PROPENSITY_SCORE_MATCHING,
PROPENSITY_SCORE_WEIGHTING,
DIFF_IN_DIFF,
REGRESSION_DISCONTINUITY,
INSTRUMENTAL_VARIABLE,
FRONTDOOR_ADJUSTMENT,
]
fallback = next((m for m in fallback_order if m in justifications and m not in excluded_methods), None)
if not fallback:
# truly nothing left; raise with context
raise RuntimeError("No viable method remains after exclusions.")
selected_method = fallback
alternatives = []
justifications[selected_method] = justifications.get(selected_method, "Fallback after exclusions.")
else:
# Pick by smallest priority index, then stable by insertion
filtered.sort(key=lambda x: x[1])
selected_method = filtered[0][0]
alternatives = [m for (m, _) in filtered[1:] if m != selected_method]
logger.info(f"Selected method: {selected_method}; alternatives: {alternatives}")
return {
"selected_method": selected_method,
"method_justification": justifications[selected_method],
"method_assumptions": assumptions[selected_method],
"alternatives": alternatives,
"excluded_methods": sorted(excluded_methods),
}
def rule_based_select_method(dataset_analysis, variables, is_rct, llm, dataset_description, original_query, excluded_methods=None):
"""
Wrapped function to select causal method based on dataset properties and query
Args:
dataset_analysis (Dict): results of dataset analysis
variables (Dict): dictionary of variable names and types
is_rct (bool): whether the dataset is from a randomized controlled trial
llm (BaseChatModel): language model instance for generating prompts
dataset_description (str): description of the dataset
original_query (str): the original user query
excluded_methods (List[str], optional): list of methods to exclude from selection
"""
logger.info("Running rule-based method selection")
properties = {"treatment_variable": variables.get("treatment_variable"), "instrument_variable":variables.get("instrument_variable"),
"covariates": variables.get("covariates", []), "outcome_variable": variables.get("outcome_variable"),
"time_variable": variables.get("time_variable"), "running_variable": variables.get("running_variable"),
"treatment_variable_type": variables.get("treatment_variable_type", "binary"),
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
"frontdoor_criterion": variables.get("frontdoor_criterion", False),
"cutoff_value": variables.get("cutoff_value"),
"covariate_overlap_score": variables.get("covariate_overlap_result", 0)}
properties["is_rct"] = is_rct
logger.info(f"Dataset properties for method selection: {properties}")
return select_method(properties, excluded_methods)
class DecisionTreeEngine:
"""
Engine for applying decision trees to select appropriate causal methods.
This class wraps the functional decision tree implementation to provide
an object-oriented interface for method selection.
"""
def __init__(self, verbose=False):
self.verbose = verbose
def select_method(self, df: pd.DataFrame, treatment: str, outcome: str, covariates: List[str],
dataset_analysis: Dict[str, Any], query_details: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply decision tree to select appropriate causal method.
"""
if self.verbose:
print(f"Applying decision tree for treatment: {treatment}, outcome: {outcome}")
print(f"Available covariates: {covariates}")
treatment_variable_type = query_details.get("treatment_variable_type")
covariate_overlap_result = query_details.get("covariate_overlap_result")
info = {"treatment_variable": treatment, "outcome_variable": outcome,
"covariates": covariates, "time_variable": query_details.get("time_variable"),
"group_variable": query_details.get("group_variable"),
"instrument_variable": query_details.get("instrument_variable"),
"running_variable": query_details.get("running_variable"),
"cutoff_value": query_details.get("cutoff_value"),
"is_rct": query_details.get("is_rct", False),
"has_temporal_structure": dataset_analysis.get("temporal_structure", False).get("has_temporal_structure", False),
"frontdoor_criterion": query_details.get("frontdoor_criterion", False),
"covariate_overlap_score": covariate_overlap_result,
"treatment_variable_type": treatment_variable_type}
result = select_method(info)
if self.verbose:
print(f"Selected method: {result['selected_method']}")
print(f"Justification: {result['method_justification']}")
result["decision_path"] = self._get_decision_path(result["selected_method"])
return result
def _get_decision_path(self, method):
if method == "linear_regression":
return ["Check if randomized experiment", "Data appears to be from a randomized experiment with covariates"]
elif method == "propensity_score_matching":
return ["Check if randomized experiment", "Data is observational",
"Check for sufficient covariate overlap", "Sufficient overlap exists"]
elif method == "propensity_score_weighting":
return ["Check if randomized experiment", "Data is observational",
"Check for sufficient covariate overlap", "Low overlap—weighting preferred"]
elif method == "backdoor_adjustment":
return ["Check if randomized experiment", "Data is observational",
"Check for sufficient covariate overlap", "Adjusting for covariates"]
elif method == "instrumental_variable":
return ["Check if randomized experiment", "Data is observational",
"Check for instrumental variables", "Instrument is available"]
elif method == "regression_discontinuity_design":
return ["Check if randomized experiment", "Data is observational",
"Check for discontinuity", "Discontinuity exists"]
elif method == "difference_in_differences":
return ["Check if randomized experiment", "Data is observational",
"Check for temporal structure", "Panel data structure exists"]
elif method == "frontdoor_adjustment":
return ["Check if randomized experiment", "Data is observational",
"Check front-door criterion", "Front-door path identified"]
elif method == "diff_in_means":
return ["Check if randomized experiment", "Pure RCT without covariates"]
else:
return ["Default method selection"] |