# -*- coding: utf-8 -*- """ Render a JSON-aware visualization of CAIS's rule-based method selector. - Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green). - The actually selected method receives a thicker border. - The traversed decision path edges are colored. Usage: render_from_json(payload_dict, out_stem="artifacts/decision_tree") (Optional) CLI: python decision_tree.py payload.json """ from graphviz import Digraph import json, sys from typing import Dict, Any, List, Set, Tuple, Optional from auto_causal.components.decision_tree import ( DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING, GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT ) LABEL = { DIFF_IN_MEANS: "Diff-in-Means (RCT)", LINEAR_REGRESSION: "Linear Regression", DIFF_IN_DIFF: "Difference-in-Differences", REGRESSION_DISCONTINUITY: "Regression Discontinuity", INSTRUMENTAL_VARIABLE: "Instrumental Variables", PROPENSITY_SCORE_MATCHING: "PS Matching", PROPENSITY_SCORE_WEIGHTING: "PS Weighting", GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)", BACKDOOR_ADJUSTMENT: "Backdoor Adjustment", FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment", } # -------- Heuristic extractors from payload -------- # def _get(d: Dict, path: List[str], default=None): cur = d for k in path: if not isinstance(cur, dict) or k not in cur: return default cur = cur[k] return cur def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]: vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {} da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {} treatment = vars_.get("treatment_variable") t_type = vars_.get("treatment_variable_type") # "binary"/"continuous" is_rct = bool(vars_.get("is_rct", False)) # Temporal / panel temporal_detected = bool(da.get("temporal_structure_detected", False)) time_var = vars_.get("time_variable") group_var = vars_.get("group_variable") has_temporal = temporal_detected or bool(time_var) or bool(group_var) # RDD running_variable = vars_.get("running_variable") cutoff_value = vars_.get("cutoff_value") rdd_ready = running_variable is not None and cutoff_value is not None # (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.) # If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected"). # Instruments instrument = vars_.get("instrument_variable") pot_instr = da.get("potential_instruments") or [] # Consider an instrument valid only if it exists and is NOT the treatment itself has_valid_instrument = ( instrument is not None and instrument != treatment ) or any(pi and pi != treatment for pi in pot_instr) covariates = vars_.get("covariates") or [] has_covariates = len(covariates) > 0 # Frontdoor: only mark if explicitly provided (else too speculative) frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False)) # Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible. overlap_assessment = da.get("overlap_assessment") strong_overlap = None if isinstance(overlap_assessment, dict): # accept typical keys like {"strong_overlap": true} strong_overlap = overlap_assessment.get("strong_overlap") return dict( treatment=treatment, t_type=t_type, is_rct=is_rct, has_temporal=has_temporal, rdd_ready=rdd_ready, has_valid_instrument=has_valid_instrument, has_covariates=has_covariates, frontdoor_ok=frontdoor_ok, strong_overlap=strong_overlap, ) # -------- Candidate inference (green leaves) -------- # def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]: cands: Set[str] = set() is_rct = signals["is_rct"] # RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement) if is_rct: cands.add(DIFF_IN_MEANS) if signals["has_covariates"]: cands.add(LINEAR_REGRESSION) if signals["has_valid_instrument"]: cands.add(INSTRUMENTAL_VARIABLE) return cands # stop here; the observational tree is not needed # Observational branch if signals["has_temporal"]: cands.add(DIFF_IN_DIFF) if signals["rdd_ready"]: cands.add(REGRESSION_DISCONTINUITY) if signals["has_valid_instrument"]: cands.add(INSTRUMENTAL_VARIABLE) if signals["frontdoor_ok"]: cands.add(FRONTDOOR_ADJUSTMENT) # Treatment type if str(signals["t_type"]).lower() == "continuous": cands.add(GENERALIZED_PROPENSITY_SCORE) # Backdoor / PS (need covariates) if signals["has_covariates"]: # If overlap is known, choose one; if unknown, mark both as plausible. if signals["strong_overlap"] is True: cands.add(PROPENSITY_SCORE_MATCHING) elif signals["strong_overlap"] is False: cands.add(PROPENSITY_SCORE_WEIGHTING) else: cands.add(PROPENSITY_SCORE_MATCHING) cands.add(PROPENSITY_SCORE_WEIGHTING) cands.add(BACKDOOR_ADJUSTMENT) return cands # -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- # def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]: path: List[Tuple[str, str]] = [] # Start → is_rct path.append(("start", "is_rct")) if signals["is_rct"]: path.append(("is_rct", "has_instr_rct")) if signals["has_valid_instrument"]: path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE)) else: path.append(("has_instr_rct", "has_cov_rct")) if signals["has_covariates"]: path.append(("has_cov_rct", LINEAR_REGRESSION)) else: path.append(("has_cov_rct", DIFF_IN_MEANS)) return path # Observational path.append(("is_rct", "has_temporal")) if signals["has_temporal"]: path.append(("has_temporal", DIFF_IN_DIFF)) return path else: path.append(("has_temporal", "has_rv")) if signals["rdd_ready"]: path.append(("has_rv", REGRESSION_DISCONTINUITY)) return path else: path.append(("has_rv", "has_instr")) if signals["has_valid_instrument"]: path.append(("has_instr", INSTRUMENTAL_VARIABLE)) return path else: path.append(("has_instr", "frontdoor")) if signals["frontdoor_ok"]: path.append(("frontdoor", FRONTDOOR_ADJUSTMENT)) return path else: path.append(("frontdoor", "t_cont")) if str(signals["t_type"]).lower() == "continuous": path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE)) return path else: path.append(("t_cont", "has_cov")) if signals["has_covariates"]: path.append(("has_cov", "overlap")) # If overlap known, pick the branch; else default to weighting. if signals["strong_overlap"] is True: path.append(("overlap", PROPENSITY_SCORE_MATCHING)) else: path.append(("overlap", PROPENSITY_SCORE_WEIGHTING)) else: path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message return path # -------- Graph building -------- # def build_graph(payload: Dict[str, Any]) -> Digraph: g = Digraph("CAISDecisionTree", format="svg") g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11") # Decisions g.node("start", "Start", shape="circle") g.node("is_rct", "Is RCT?", shape="diamond") g.node("has_instr_rct", "Instrument available?", shape="diamond") g.node("has_cov_rct", "Covariates observed?", shape="diamond") g.node("has_temporal", "Temporal structure?", shape="diamond") g.node("has_rv", "Running var & cutoff?", shape="diamond") g.node("has_instr", "Instrument available?", shape="diamond") g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond") g.node("has_cov", "Covariates observed?", shape="diamond") g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond") g.node("t_cont", "Treatment continuous?", shape="diamond") # Leaves def leaf(name_const, fill=None, bold=False): attrs = {"shape": "box", "style": "rounded"} if fill: attrs.update(style="rounded,filled", fillcolor=fill) if bold: attrs.update(penwidth="2") g.node(name_const, LABEL[name_const], **attrs) # Compute signals, candidates, path signals = extract_signals(payload) candidates = infer_candidate_methods(signals) selected_method_str = _get(payload, ["results", "results", "method_used"]) \ or _get(payload, ["results", "method_used"]) \ or _get(payload, ["method"]) selected_method = { "linear_regression": LINEAR_REGRESSION, "diff_in_means": DIFF_IN_MEANS, "difference_in_differences": DIFF_IN_DIFF, "regression_discontinuity": REGRESSION_DISCONTINUITY, "instrumental_variable": INSTRUMENTAL_VARIABLE, "propensity_score_matching": PROPENSITY_SCORE_MATCHING, "propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING, "generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE, "backdoor_adjustment": BACKDOOR_ADJUSTMENT, "frontdoor_adjustment": FRONTDOOR_ADJUSTMENT, }.get(str(selected_method_str or "").lower()) # Add leaves with coloring for m in [ DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING, GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT ]: leaf(m, fill=("palegreen" if m in candidates else None), bold=(m == selected_method)) # Edges with optional path highlighting path_edges = set(infer_decision_path(signals, selected_method)) def e(u, v, label=None): attrs = {} if (u, v) in path_edges: attrs.update(color="forestgreen", penwidth="2") g.edge(u, v, **({} if label is None else {"label": label}) | attrs) # Topology (unchanged) e("start", "is_rct") # RCT branch e("is_rct", "has_instr_rct", label="Yes") e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes") e("has_instr_rct", "has_cov_rct", label="No") e("has_cov_rct", LINEAR_REGRESSION, label="Yes") e("has_cov_rct", DIFF_IN_MEANS, label="No") # Observational branch e("is_rct", "has_temporal", label="No") e("has_temporal", DIFF_IN_DIFF, label="Yes") e("has_temporal", "has_rv", label="No") e("has_rv", REGRESSION_DISCONTINUITY, label="Yes") e("has_rv", "has_instr", label="No") e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes") e("has_instr", "frontdoor", label="No") e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes") e("frontdoor", "t_cont", label="No") e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes") e("t_cont", "has_cov", label="No") e("has_cov", "overlap", label="Yes") e("has_cov", BACKDOOR_ADJUSTMENT, label="No") e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes") e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No") # Optional legend g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note") g.edge("legend", "start", style="dashed", arrowhead="none") return g def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"): g = build_graph(payload) g.save(filename=f"{out_stem}.dot") g.render(filename=out_stem, cleanup=True) # SVG g.format = "png" g.render(filename=out_stem, cleanup=True) # PNG def main(): # if len(sys.argv) >= 2: with open('sample_output.json', "r") as f: payload = json.load(f) # else: # payload = json.load() render_from_json(payload) if __name__ == "__main__": main()