causal-agent / visualise.py
FireShadow's picture
added decision tree visualization
efbcc96
raw
history blame
12.5 kB
# -*- coding: utf-8 -*-
"""
Render a JSON-aware visualization of CAIS's rule-based method selector.
- Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green).
- The actually selected method receives a thicker border.
- The traversed decision path edges are colored.
Usage:
render_from_json(payload_dict, out_stem="artifacts/decision_tree")
(Optional) CLI:
python decision_tree.py payload.json
"""
from graphviz import Digraph
import json, sys
from typing import Dict, Any, List, Set, Tuple, Optional
from auto_causal.components.decision_tree import (
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
)
LABEL = {
DIFF_IN_MEANS: "Diff-in-Means (RCT)",
LINEAR_REGRESSION: "Linear Regression",
DIFF_IN_DIFF: "Difference-in-Differences",
REGRESSION_DISCONTINUITY: "Regression Discontinuity",
INSTRUMENTAL_VARIABLE: "Instrumental Variables",
PROPENSITY_SCORE_MATCHING: "PS Matching",
PROPENSITY_SCORE_WEIGHTING: "PS Weighting",
GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)",
BACKDOOR_ADJUSTMENT: "Backdoor Adjustment",
FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment",
}
# -------- Heuristic extractors from payload -------- #
def _get(d: Dict, path: List[str], default=None):
cur = d
for k in path:
if not isinstance(cur, dict) or k not in cur:
return default
cur = cur[k]
return cur
def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]:
vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {}
da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {}
treatment = vars_.get("treatment_variable")
t_type = vars_.get("treatment_variable_type") # "binary"/"continuous"
is_rct = bool(vars_.get("is_rct", False))
# Temporal / panel
temporal_detected = bool(da.get("temporal_structure_detected", False))
time_var = vars_.get("time_variable")
group_var = vars_.get("group_variable")
has_temporal = temporal_detected or bool(time_var) or bool(group_var)
# RDD
running_variable = vars_.get("running_variable")
cutoff_value = vars_.get("cutoff_value")
rdd_ready = running_variable is not None and cutoff_value is not None
# (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.)
# If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected").
# Instruments
instrument = vars_.get("instrument_variable")
pot_instr = da.get("potential_instruments") or []
# Consider an instrument valid only if it exists and is NOT the treatment itself
has_valid_instrument = (
instrument is not None and instrument != treatment
) or any(pi and pi != treatment for pi in pot_instr)
covariates = vars_.get("covariates") or []
has_covariates = len(covariates) > 0
# Frontdoor: only mark if explicitly provided (else too speculative)
frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False))
# Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible.
overlap_assessment = da.get("overlap_assessment")
strong_overlap = None
if isinstance(overlap_assessment, dict):
# accept typical keys like {"strong_overlap": true}
strong_overlap = overlap_assessment.get("strong_overlap")
return dict(
treatment=treatment,
t_type=t_type,
is_rct=is_rct,
has_temporal=has_temporal,
rdd_ready=rdd_ready,
has_valid_instrument=has_valid_instrument,
has_covariates=has_covariates,
frontdoor_ok=frontdoor_ok,
strong_overlap=strong_overlap,
)
# -------- Candidate inference (green leaves) -------- #
def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]:
cands: Set[str] = set()
is_rct = signals["is_rct"]
# RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement)
if is_rct:
cands.add(DIFF_IN_MEANS)
if signals["has_covariates"]:
cands.add(LINEAR_REGRESSION)
if signals["has_valid_instrument"]:
cands.add(INSTRUMENTAL_VARIABLE)
return cands # stop here; the observational tree is not needed
# Observational branch
if signals["has_temporal"]:
cands.add(DIFF_IN_DIFF)
if signals["rdd_ready"]:
cands.add(REGRESSION_DISCONTINUITY)
if signals["has_valid_instrument"]:
cands.add(INSTRUMENTAL_VARIABLE)
if signals["frontdoor_ok"]:
cands.add(FRONTDOOR_ADJUSTMENT)
# Treatment type
if str(signals["t_type"]).lower() == "continuous":
cands.add(GENERALIZED_PROPENSITY_SCORE)
# Backdoor / PS (need covariates)
if signals["has_covariates"]:
# If overlap is known, choose one; if unknown, mark both as plausible.
if signals["strong_overlap"] is True:
cands.add(PROPENSITY_SCORE_MATCHING)
elif signals["strong_overlap"] is False:
cands.add(PROPENSITY_SCORE_WEIGHTING)
else:
cands.add(PROPENSITY_SCORE_MATCHING)
cands.add(PROPENSITY_SCORE_WEIGHTING)
cands.add(BACKDOOR_ADJUSTMENT)
return cands
# -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- #
def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]:
path: List[Tuple[str, str]] = []
# Start → is_rct
path.append(("start", "is_rct"))
if signals["is_rct"]:
path.append(("is_rct", "has_instr_rct"))
if signals["has_valid_instrument"]:
path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE))
else:
path.append(("has_instr_rct", "has_cov_rct"))
if signals["has_covariates"]:
path.append(("has_cov_rct", LINEAR_REGRESSION))
else:
path.append(("has_cov_rct", DIFF_IN_MEANS))
return path
# Observational
path.append(("is_rct", "has_temporal"))
if signals["has_temporal"]:
path.append(("has_temporal", DIFF_IN_DIFF))
return path
else:
path.append(("has_temporal", "has_rv"))
if signals["rdd_ready"]:
path.append(("has_rv", REGRESSION_DISCONTINUITY))
return path
else:
path.append(("has_rv", "has_instr"))
if signals["has_valid_instrument"]:
path.append(("has_instr", INSTRUMENTAL_VARIABLE))
return path
else:
path.append(("has_instr", "frontdoor"))
if signals["frontdoor_ok"]:
path.append(("frontdoor", FRONTDOOR_ADJUSTMENT))
return path
else:
path.append(("frontdoor", "t_cont"))
if str(signals["t_type"]).lower() == "continuous":
path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE))
return path
else:
path.append(("t_cont", "has_cov"))
if signals["has_covariates"]:
path.append(("has_cov", "overlap"))
# If overlap known, pick the branch; else default to weighting.
if signals["strong_overlap"] is True:
path.append(("overlap", PROPENSITY_SCORE_MATCHING))
else:
path.append(("overlap", PROPENSITY_SCORE_WEIGHTING))
else:
path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message
return path
# -------- Graph building -------- #
def build_graph(payload: Dict[str, Any]) -> Digraph:
g = Digraph("CAISDecisionTree", format="svg")
g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11")
# Decisions
g.node("start", "Start", shape="circle")
g.node("is_rct", "Is RCT?", shape="diamond")
g.node("has_instr_rct", "Instrument available?", shape="diamond")
g.node("has_cov_rct", "Covariates observed?", shape="diamond")
g.node("has_temporal", "Temporal structure?", shape="diamond")
g.node("has_rv", "Running var & cutoff?", shape="diamond")
g.node("has_instr", "Instrument available?", shape="diamond")
g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond")
g.node("has_cov", "Covariates observed?", shape="diamond")
g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond")
g.node("t_cont", "Treatment continuous?", shape="diamond")
# Leaves
def leaf(name_const, fill=None, bold=False):
attrs = {"shape": "box", "style": "rounded"}
if fill:
attrs.update(style="rounded,filled", fillcolor=fill)
if bold:
attrs.update(penwidth="2")
g.node(name_const, LABEL[name_const], **attrs)
# Compute signals, candidates, path
signals = extract_signals(payload)
candidates = infer_candidate_methods(signals)
selected_method_str = _get(payload, ["results", "results", "method_used"]) \
or _get(payload, ["results", "method_used"]) \
or _get(payload, ["method"])
selected_method = {
"linear_regression": LINEAR_REGRESSION,
"diff_in_means": DIFF_IN_MEANS,
"difference_in_differences": DIFF_IN_DIFF,
"regression_discontinuity": REGRESSION_DISCONTINUITY,
"instrumental_variable": INSTRUMENTAL_VARIABLE,
"propensity_score_matching": PROPENSITY_SCORE_MATCHING,
"propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING,
"generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE,
"backdoor_adjustment": BACKDOOR_ADJUSTMENT,
"frontdoor_adjustment": FRONTDOOR_ADJUSTMENT,
}.get(str(selected_method_str or "").lower())
# Add leaves with coloring
for m in [
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
]:
leaf(m,
fill=("palegreen" if m in candidates else None),
bold=(m == selected_method))
# Edges with optional path highlighting
path_edges = set(infer_decision_path(signals, selected_method))
def e(u, v, label=None):
attrs = {}
if (u, v) in path_edges:
attrs.update(color="forestgreen", penwidth="2")
g.edge(u, v, **({} if label is None else {"label": label}) | attrs)
# Topology (unchanged)
e("start", "is_rct")
# RCT branch
e("is_rct", "has_instr_rct", label="Yes")
e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes")
e("has_instr_rct", "has_cov_rct", label="No")
e("has_cov_rct", LINEAR_REGRESSION, label="Yes")
e("has_cov_rct", DIFF_IN_MEANS, label="No")
# Observational branch
e("is_rct", "has_temporal", label="No")
e("has_temporal", DIFF_IN_DIFF, label="Yes")
e("has_temporal", "has_rv", label="No")
e("has_rv", REGRESSION_DISCONTINUITY, label="Yes")
e("has_rv", "has_instr", label="No")
e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes")
e("has_instr", "frontdoor", label="No")
e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes")
e("frontdoor", "t_cont", label="No")
e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes")
e("t_cont", "has_cov", label="No")
e("has_cov", "overlap", label="Yes")
e("has_cov", BACKDOOR_ADJUSTMENT, label="No")
e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes")
e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No")
# Optional legend
g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note")
g.edge("legend", "start", style="dashed", arrowhead="none")
return g
def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"):
g = build_graph(payload)
g.save(filename=f"{out_stem}.dot")
g.render(filename=out_stem, cleanup=True) # SVG
g.format = "png"
g.render(filename=out_stem, cleanup=True) # PNG
def main():
# if len(sys.argv) >= 2:
with open('sample_output.json', "r") as f:
payload = json.load(f)
# else:
# payload = json.load()
render_from_json(payload)
if __name__ == "__main__":
main()