Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
""" | |
Render a JSON-aware visualization of CAIS's rule-based method selector. | |
- Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green). | |
- The actually selected method receives a thicker border. | |
- The traversed decision path edges are colored. | |
Usage: | |
render_from_json(payload_dict, out_stem="artifacts/decision_tree") | |
(Optional) CLI: | |
python decision_tree.py payload.json | |
""" | |
from graphviz import Digraph | |
import json, sys | |
from typing import Dict, Any, List, Set, Tuple, Optional | |
from auto_causal.components.decision_tree import ( | |
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, | |
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING, | |
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT | |
) | |
LABEL = { | |
DIFF_IN_MEANS: "Diff-in-Means (RCT)", | |
LINEAR_REGRESSION: "Linear Regression", | |
DIFF_IN_DIFF: "Difference-in-Differences", | |
REGRESSION_DISCONTINUITY: "Regression Discontinuity", | |
INSTRUMENTAL_VARIABLE: "Instrumental Variables", | |
PROPENSITY_SCORE_MATCHING: "PS Matching", | |
PROPENSITY_SCORE_WEIGHTING: "PS Weighting", | |
GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)", | |
BACKDOOR_ADJUSTMENT: "Backdoor Adjustment", | |
FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment", | |
} | |
# -------- Heuristic extractors from payload -------- # | |
def _get(d: Dict, path: List[str], default=None): | |
cur = d | |
for k in path: | |
if not isinstance(cur, dict) or k not in cur: | |
return default | |
cur = cur[k] | |
return cur | |
def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]: | |
vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {} | |
da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {} | |
treatment = vars_.get("treatment_variable") | |
t_type = vars_.get("treatment_variable_type") # "binary"/"continuous" | |
is_rct = bool(vars_.get("is_rct", False)) | |
# Temporal / panel | |
temporal_detected = bool(da.get("temporal_structure_detected", False)) | |
time_var = vars_.get("time_variable") | |
group_var = vars_.get("group_variable") | |
has_temporal = temporal_detected or bool(time_var) or bool(group_var) | |
# RDD | |
running_variable = vars_.get("running_variable") | |
cutoff_value = vars_.get("cutoff_value") | |
rdd_ready = running_variable is not None and cutoff_value is not None | |
# (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.) | |
# If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected"). | |
# Instruments | |
instrument = vars_.get("instrument_variable") | |
pot_instr = da.get("potential_instruments") or [] | |
# Consider an instrument valid only if it exists and is NOT the treatment itself | |
has_valid_instrument = ( | |
instrument is not None and instrument != treatment | |
) or any(pi and pi != treatment for pi in pot_instr) | |
covariates = vars_.get("covariates") or [] | |
has_covariates = len(covariates) > 0 | |
# Frontdoor: only mark if explicitly provided (else too speculative) | |
frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False)) | |
# Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible. | |
overlap_assessment = da.get("overlap_assessment") | |
strong_overlap = None | |
if isinstance(overlap_assessment, dict): | |
# accept typical keys like {"strong_overlap": true} | |
strong_overlap = overlap_assessment.get("strong_overlap") | |
return dict( | |
treatment=treatment, | |
t_type=t_type, | |
is_rct=is_rct, | |
has_temporal=has_temporal, | |
rdd_ready=rdd_ready, | |
has_valid_instrument=has_valid_instrument, | |
has_covariates=has_covariates, | |
frontdoor_ok=frontdoor_ok, | |
strong_overlap=strong_overlap, | |
) | |
# -------- Candidate inference (green leaves) -------- # | |
def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]: | |
cands: Set[str] = set() | |
is_rct = signals["is_rct"] | |
# RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement) | |
if is_rct: | |
cands.add(DIFF_IN_MEANS) | |
if signals["has_covariates"]: | |
cands.add(LINEAR_REGRESSION) | |
if signals["has_valid_instrument"]: | |
cands.add(INSTRUMENTAL_VARIABLE) | |
return cands # stop here; the observational tree is not needed | |
# Observational branch | |
if signals["has_temporal"]: | |
cands.add(DIFF_IN_DIFF) | |
if signals["rdd_ready"]: | |
cands.add(REGRESSION_DISCONTINUITY) | |
if signals["has_valid_instrument"]: | |
cands.add(INSTRUMENTAL_VARIABLE) | |
if signals["frontdoor_ok"]: | |
cands.add(FRONTDOOR_ADJUSTMENT) | |
# Treatment type | |
if str(signals["t_type"]).lower() == "continuous": | |
cands.add(GENERALIZED_PROPENSITY_SCORE) | |
# Backdoor / PS (need covariates) | |
if signals["has_covariates"]: | |
# If overlap is known, choose one; if unknown, mark both as plausible. | |
if signals["strong_overlap"] is True: | |
cands.add(PROPENSITY_SCORE_MATCHING) | |
elif signals["strong_overlap"] is False: | |
cands.add(PROPENSITY_SCORE_WEIGHTING) | |
else: | |
cands.add(PROPENSITY_SCORE_MATCHING) | |
cands.add(PROPENSITY_SCORE_WEIGHTING) | |
cands.add(BACKDOOR_ADJUSTMENT) | |
return cands | |
# -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- # | |
def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]: | |
path: List[Tuple[str, str]] = [] | |
# Start → is_rct | |
path.append(("start", "is_rct")) | |
if signals["is_rct"]: | |
path.append(("is_rct", "has_instr_rct")) | |
if signals["has_valid_instrument"]: | |
path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE)) | |
else: | |
path.append(("has_instr_rct", "has_cov_rct")) | |
if signals["has_covariates"]: | |
path.append(("has_cov_rct", LINEAR_REGRESSION)) | |
else: | |
path.append(("has_cov_rct", DIFF_IN_MEANS)) | |
return path | |
# Observational | |
path.append(("is_rct", "has_temporal")) | |
if signals["has_temporal"]: | |
path.append(("has_temporal", DIFF_IN_DIFF)) | |
return path | |
else: | |
path.append(("has_temporal", "has_rv")) | |
if signals["rdd_ready"]: | |
path.append(("has_rv", REGRESSION_DISCONTINUITY)) | |
return path | |
else: | |
path.append(("has_rv", "has_instr")) | |
if signals["has_valid_instrument"]: | |
path.append(("has_instr", INSTRUMENTAL_VARIABLE)) | |
return path | |
else: | |
path.append(("has_instr", "frontdoor")) | |
if signals["frontdoor_ok"]: | |
path.append(("frontdoor", FRONTDOOR_ADJUSTMENT)) | |
return path | |
else: | |
path.append(("frontdoor", "t_cont")) | |
if str(signals["t_type"]).lower() == "continuous": | |
path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE)) | |
return path | |
else: | |
path.append(("t_cont", "has_cov")) | |
if signals["has_covariates"]: | |
path.append(("has_cov", "overlap")) | |
# If overlap known, pick the branch; else default to weighting. | |
if signals["strong_overlap"] is True: | |
path.append(("overlap", PROPENSITY_SCORE_MATCHING)) | |
else: | |
path.append(("overlap", PROPENSITY_SCORE_WEIGHTING)) | |
else: | |
path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message | |
return path | |
# -------- Graph building -------- # | |
def build_graph(payload: Dict[str, Any]) -> Digraph: | |
g = Digraph("CAISDecisionTree", format="svg") | |
g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11") | |
# Decisions | |
g.node("start", "Start", shape="circle") | |
g.node("is_rct", "Is RCT?", shape="diamond") | |
g.node("has_instr_rct", "Instrument available?", shape="diamond") | |
g.node("has_cov_rct", "Covariates observed?", shape="diamond") | |
g.node("has_temporal", "Temporal structure?", shape="diamond") | |
g.node("has_rv", "Running var & cutoff?", shape="diamond") | |
g.node("has_instr", "Instrument available?", shape="diamond") | |
g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond") | |
g.node("has_cov", "Covariates observed?", shape="diamond") | |
g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond") | |
g.node("t_cont", "Treatment continuous?", shape="diamond") | |
# Leaves | |
def leaf(name_const, fill=None, bold=False): | |
attrs = {"shape": "box", "style": "rounded"} | |
if fill: | |
attrs.update(style="rounded,filled", fillcolor=fill) | |
if bold: | |
attrs.update(penwidth="2") | |
g.node(name_const, LABEL[name_const], **attrs) | |
# Compute signals, candidates, path | |
signals = extract_signals(payload) | |
candidates = infer_candidate_methods(signals) | |
selected_method_str = _get(payload, ["results", "results", "method_used"]) \ | |
or _get(payload, ["results", "method_used"]) \ | |
or _get(payload, ["method"]) | |
selected_method = { | |
"linear_regression": LINEAR_REGRESSION, | |
"diff_in_means": DIFF_IN_MEANS, | |
"difference_in_differences": DIFF_IN_DIFF, | |
"regression_discontinuity": REGRESSION_DISCONTINUITY, | |
"instrumental_variable": INSTRUMENTAL_VARIABLE, | |
"propensity_score_matching": PROPENSITY_SCORE_MATCHING, | |
"propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING, | |
"generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE, | |
"backdoor_adjustment": BACKDOOR_ADJUSTMENT, | |
"frontdoor_adjustment": FRONTDOOR_ADJUSTMENT, | |
}.get(str(selected_method_str or "").lower()) | |
# Add leaves with coloring | |
for m in [ | |
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, | |
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING, | |
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT | |
]: | |
leaf(m, | |
fill=("palegreen" if m in candidates else None), | |
bold=(m == selected_method)) | |
# Edges with optional path highlighting | |
path_edges = set(infer_decision_path(signals, selected_method)) | |
def e(u, v, label=None): | |
attrs = {} | |
if (u, v) in path_edges: | |
attrs.update(color="forestgreen", penwidth="2") | |
g.edge(u, v, **({} if label is None else {"label": label}) | attrs) | |
# Topology (unchanged) | |
e("start", "is_rct") | |
# RCT branch | |
e("is_rct", "has_instr_rct", label="Yes") | |
e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes") | |
e("has_instr_rct", "has_cov_rct", label="No") | |
e("has_cov_rct", LINEAR_REGRESSION, label="Yes") | |
e("has_cov_rct", DIFF_IN_MEANS, label="No") | |
# Observational branch | |
e("is_rct", "has_temporal", label="No") | |
e("has_temporal", DIFF_IN_DIFF, label="Yes") | |
e("has_temporal", "has_rv", label="No") | |
e("has_rv", REGRESSION_DISCONTINUITY, label="Yes") | |
e("has_rv", "has_instr", label="No") | |
e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes") | |
e("has_instr", "frontdoor", label="No") | |
e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes") | |
e("frontdoor", "t_cont", label="No") | |
e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes") | |
e("t_cont", "has_cov", label="No") | |
e("has_cov", "overlap", label="Yes") | |
e("has_cov", BACKDOOR_ADJUSTMENT, label="No") | |
e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes") | |
e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No") | |
# Optional legend | |
g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note") | |
g.edge("legend", "start", style="dashed", arrowhead="none") | |
return g | |
def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"): | |
g = build_graph(payload) | |
g.save(filename=f"{out_stem}.dot") | |
g.render(filename=out_stem, cleanup=True) # SVG | |
g.format = "png" | |
g.render(filename=out_stem, cleanup=True) # PNG | |
def main(): | |
# if len(sys.argv) >= 2: | |
with open('sample_output.json', "r") as f: | |
payload = json.load(f) | |
# else: | |
# payload = json.load() | |
render_from_json(payload) | |
if __name__ == "__main__": | |
main() | |