|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pylab as plt |
|
import torch.nn.functional as F |
|
|
|
from vae import HVAE |
|
from datasets import morphomnist, ukbb, mimic, get_attr_max_min |
|
from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM |
|
from app_utils import ( |
|
mnist_graph, |
|
brain_graph, |
|
chest_graph, |
|
vae_preprocess, |
|
normalize, |
|
preprocess_brain, |
|
get_fig_arr, |
|
postprocess, |
|
MidpointNormalize, |
|
) |
|
|
|
DATA, MODELS = {}, {} |
|
for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]: |
|
DATA[k], MODELS[k] = {}, {} |
|
|
|
|
|
DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] |
|
|
|
MRISEQ_CAT = ["T1", "T2-FLAIR"] |
|
SEX_CAT = ["female", "male"] |
|
|
|
HEIGHT, WIDTH = 500, 500 |
|
|
|
SEX_CAT_CHEST = ["female", "male"] |
|
RACE_CAT = ["white", "black", "asian"] |
|
|
|
FIND_CAT = ["no disease", "effusion", "pneumonia"] |
|
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class Hparams: |
|
def update(self, dict): |
|
for k, v in dict.items(): |
|
setattr(self, k, v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_paths(dataset_id): |
|
if "MNIST" in dataset_id: |
|
data_path = "./data/morphomnist" |
|
pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt" |
|
vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt" |
|
elif "Brain" in dataset_id: |
|
data_path = "./data/ukbb_subset" |
|
pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt" |
|
vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt" |
|
elif "Chest" in dataset_id: |
|
data_path = "./data/mimic_subset" |
|
pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/60k_checkpoint.pt" |
|
vae_path = [ |
|
"./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/60k_checkpoint.pt", |
|
"./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/60k_checkpoint.pt", |
|
] |
|
return data_path, vae_path, pgm_path |
|
|
|
|
|
|
|
def load_pgm(dataset_id, pgm_path): |
|
checkpoint = torch.load(pgm_path, map_location=DEVICE) |
|
args = Hparams() |
|
args.update(checkpoint["hparams"]) |
|
args.device = DEVICE |
|
if "MNIST" in dataset_id: |
|
pgm = MorphoMNISTPGM(args).to(args.device) |
|
elif "Brain" in dataset_id: |
|
pgm = FlowPGM(args).to(args.device) |
|
elif "Chest" in dataset_id: |
|
pgm = ChestPGM(args).to(args.device) |
|
pgm.load_state_dict(checkpoint["ema_model_state_dict"]) |
|
MODELS[dataset_id]["pgm"] = pgm |
|
MODELS[dataset_id]["pgm_args"] = args |
|
|
|
|
|
def load_vae(dataset_id, vae_path): |
|
if "Chest" in dataset_id: |
|
vae_path, dscm_path = vae_path[0], vae_path[1] |
|
checkpoint = torch.load(vae_path, map_location=DEVICE) |
|
args = Hparams() |
|
args.update(checkpoint["hparams"]) |
|
|
|
if not hasattr(args, "vae"): |
|
args.vae = "hierarchical" |
|
if not hasattr(args, "cond_prior"): |
|
args.cond_prior = False |
|
if hasattr(args, "free_bits"): |
|
args.kl_free_bits = args.free_bits |
|
args.device = DEVICE |
|
vae = HVAE(args).to(args.device) |
|
|
|
if "Chest" in dataset_id: |
|
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE) |
|
vae.load_state_dict( |
|
{ |
|
k[4:]: v |
|
for k, v in dscm_ckpt["ema_model_state_dict"].items() |
|
if "vae." in k |
|
} |
|
) |
|
else: |
|
vae.load_state_dict(checkpoint["ema_model_state_dict"]) |
|
MODELS[dataset_id]["vae"] = vae |
|
MODELS[dataset_id]["vae_args"] = args |
|
print(MODELS[dataset_id]["vae_args"]) |
|
|
|
|
|
def get_dataloader(dataset_id, data_path): |
|
MODELS[dataset_id]["pgm_args"].data_dir = data_path |
|
args = MODELS[dataset_id]["pgm_args"] |
|
if "MNIST" in dataset_id: |
|
datasets = morphomnist(args) |
|
elif "Brain" in dataset_id: |
|
datasets = ukbb(args) |
|
elif "Chest" in dataset_id: |
|
datasets = mimic(args) |
|
DATA[dataset_id]["test"] = torch.utils.data.DataLoader( |
|
datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4 |
|
) |
|
|
|
|
|
def load_dataset(dataset_id): |
|
data_path, _, pgm_path = get_paths(dataset_id) |
|
checkpoint = torch.load(pgm_path, map_location=DEVICE) |
|
args = Hparams() |
|
args.update(checkpoint["hparams"]) |
|
args.device = DEVICE |
|
MODELS[dataset_id]["pgm_args"] = args |
|
get_dataloader(dataset_id, data_path) |
|
|
|
|
|
def load_model(dataset_id): |
|
_, vae_path, pgm_path = get_paths(dataset_id) |
|
load_pgm(dataset_id, pgm_path) |
|
load_vae(dataset_id, vae_path) |
|
|
|
|
|
@torch.no_grad() |
|
def counterfactual_inference(dataset_id, obs, do_pa): |
|
pa = {k: v.clone() for k, v in obs.items() if k != "x"} |
|
cf_pa = MODELS[dataset_id]["pgm"].counterfactual( |
|
obs=pa, intervention=do_pa, num_particles=1 |
|
) |
|
args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"] |
|
_pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()}) |
|
_cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()}) |
|
z_t = 0.1 if "mnist" in args.hps else 1.0 |
|
z = vae.abduct(x=obs["x"], parents=_pa, t=z_t) |
|
if vae.cond_prior: |
|
z = [z[j]["z"] for j in range(len(z))] |
|
px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa) |
|
cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa) |
|
u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12) |
|
u_t = 0.1 if "mnist" in args.hps else 1.0 |
|
cf_scale = cf_scale * u_t |
|
cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1) |
|
return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa} |
|
|
|
|
|
def get_obs_item(dataset_id, idx=None): |
|
if idx is None: |
|
n_test = len(DATA[dataset_id]["test"].dataset) |
|
idx = torch.randperm(n_test)[0] |
|
idx = int(idx) |
|
return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx) |
|
|
|
|
|
def get_mnist_obs(idx=None): |
|
dataset_id = "Morpho-MNIST" |
|
if not DATA[dataset_id]: |
|
load_dataset(dataset_id) |
|
idx, obs = get_obs_item(dataset_id, idx) |
|
x = get_fig_arr(obs["x"].clone().squeeze().numpy()) |
|
t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526 |
|
i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204 |
|
y = DIGITS[obs["digit"].clone().argmax(-1)] |
|
return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y) |
|
|
|
|
|
def get_brain_obs(idx=None): |
|
dataset_id = "Brain MRI" |
|
if not DATA[dataset_id]: |
|
load_dataset(dataset_id) |
|
idx, obs = get_obs_item(dataset_id, idx) |
|
x = get_fig_arr(obs["x"].clone().squeeze().numpy()) |
|
m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())] |
|
s = SEX_CAT[int(obs["sex"].clone().item())] |
|
a = obs["age"].clone().item() |
|
b = obs["brain_volume"].clone().item() / 1000 |
|
v = obs["ventricle_volume"].clone().item() / 1000 |
|
return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2))) |
|
|
|
|
|
def get_chest_obs(idx=None): |
|
dataset_id = "Chest X-ray" |
|
if not DATA[dataset_id]: |
|
load_dataset(dataset_id) |
|
idx, obs = get_obs_item(dataset_id, idx) |
|
x = get_fig_arr(postprocess(obs["x"].clone())) |
|
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())] |
|
f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)] |
|
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)] |
|
a = (obs["age"].clone().squeeze().numpy() + 1) * 50 |
|
return (idx, x, r, s, f, float(np.round(a, 1))) |
|
|
|
|
|
def infer_mnist_cf(*args): |
|
dataset_id = "Morpho-MNIST" |
|
idx, _, t, i, y, do_t, do_i, do_y = args |
|
n_particles = 32 |
|
|
|
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) |
|
obs["x"] = (obs["x"] - 127.5) / 127.5 |
|
for k, v in obs.items(): |
|
obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0) |
|
obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float() |
|
if n_particles > 1: |
|
ndims = (1,) * 3 if k == "x" else (1,) |
|
obs[k] = obs[k].repeat(n_particles, *ndims) |
|
|
|
do_pa = {} |
|
if do_t: |
|
do_pa["thickness"] = torch.tensor( |
|
normalize(t, x_max=6.255515, x_min=0.87598526) |
|
).view(1, 1) |
|
if do_i: |
|
do_pa["intensity"] = torch.tensor( |
|
normalize(i, x_max=254.90317, x_min=66.601204) |
|
).view(1, 1) |
|
if do_y: |
|
do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view( |
|
1, 10 |
|
) |
|
|
|
for k, v in do_pa.items(): |
|
do_pa[k] = ( |
|
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) |
|
) |
|
|
|
out = counterfactual_inference(dataset_id, obs, do_pa) |
|
|
|
cf_x = out["cf_x"].mean(0) |
|
cf_x_std = out["cf_x"].std(0) |
|
rec_x = out["rec_x"].mean(0) |
|
cf_t = out["cf_pa"]["thickness"].mean(0) |
|
cf_i = out["cf_pa"]["intensity"].mean(0) |
|
cf_y = out["cf_pa"]["digit"].mean(0) |
|
|
|
cf_x = postprocess(cf_x) |
|
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() |
|
rec_x = postprocess(rec_x) |
|
cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2) |
|
cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2) |
|
cf_y = DIGITS[cf_y.argmax(-1)] |
|
|
|
|
|
effect = cf_x - rec_x |
|
effect = get_fig_arr( |
|
effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255) |
|
) |
|
cf_x = get_fig_arr(cf_x) |
|
cf_x_std = get_fig_arr(cf_x_std, cmap="jet") |
|
return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y) |
|
|
|
|
|
def infer_brain_cf(*args): |
|
dataset_id = "Brain MRI" |
|
idx, _, m, s, a, b, v = args[:7] |
|
do_m, do_s, do_a, do_b, do_v = args[7:] |
|
n_particles = 16 |
|
|
|
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) |
|
obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs) |
|
for k, _v in obs.items(): |
|
if n_particles > 1: |
|
ndims = (1,) * 3 if k == "x" else (1,) |
|
obs[k] = _v.repeat(n_particles, *ndims) |
|
|
|
do_pa = {} |
|
if do_m: |
|
do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1) |
|
if do_s: |
|
do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1) |
|
if do_a: |
|
do_pa["age"] = torch.tensor(a).view(1, 1) |
|
if do_b: |
|
do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1) |
|
if do_v: |
|
do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1) |
|
|
|
for k in ["age", "brain_volume", "ventricle_volume"]: |
|
if k in do_pa.keys(): |
|
k_max, k_min = get_attr_max_min(k) |
|
do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) |
|
do_pa[k] = 2 * do_pa[k] - 1 |
|
|
|
for k, _v in do_pa.items(): |
|
do_pa[k] = ( |
|
_v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) |
|
) |
|
|
|
out = counterfactual_inference(dataset_id, obs, do_pa) |
|
|
|
cf_x = out["cf_x"].mean(0) |
|
cf_x_std = out["cf_x"].std(0) |
|
rec_x = out["rec_x"].mean(0) |
|
cf_m = out["cf_pa"]["mri_seq"].mean(0) |
|
cf_s = out["cf_pa"]["sex"].mean(0) |
|
|
|
cf_x = postprocess(cf_x) |
|
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() |
|
rec_x = postprocess(rec_x) |
|
cf_m = MRISEQ_CAT[int(cf_m.item())] |
|
cf_s = SEX_CAT[int(cf_s.item())] |
|
cf_ = {} |
|
for k in ["age", "brain_volume", "ventricle_volume"]: |
|
k_max, k_min = get_attr_max_min(k) |
|
cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min |
|
|
|
|
|
effect = cf_x - rec_x |
|
effect = get_fig_arr( |
|
effect, |
|
cmap="RdBu_r", |
|
norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), |
|
) |
|
cf_x = get_fig_arr(cf_x) |
|
cf_x_std = get_fig_arr(cf_x_std, cmap="jet") |
|
return ( |
|
cf_x, |
|
cf_x_std, |
|
effect, |
|
cf_m, |
|
cf_s, |
|
np.round(cf_["age"], 1), |
|
np.round(cf_["brain_volume"] / 1000, 2), |
|
np.round(cf_["ventricle_volume"] / 1000, 2), |
|
) |
|
|
|
|
|
def infer_chest_cf(*args): |
|
dataset_id = "Chest X-ray" |
|
idx, _, r, s, f, a = args[:6] |
|
do_r, do_s, do_f, do_a = args[6:] |
|
n_particles = 16 |
|
|
|
obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) |
|
observation = obs['x'] |
|
for k, v in obs.items(): |
|
obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float() |
|
if n_particles > 1: |
|
ndims = (1,) * 3 if k == "x" else (1,) |
|
obs[k] = obs[k].repeat(n_particles, *ndims) |
|
|
|
do_pa = {} |
|
with torch.no_grad(): |
|
if do_s: |
|
do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1) |
|
if do_f: |
|
do_pa["finding"] = F.one_hot( |
|
torch.tensor(FIND_CAT.index(f)), num_classes=3 |
|
).view(1, 3) |
|
|
|
if do_r: |
|
do_pa["race"] = F.one_hot( |
|
torch.tensor(RACE_CAT.index(r)), num_classes=3 |
|
).view(1, 3) |
|
if do_a: |
|
do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1) |
|
for k, v in do_pa.items(): |
|
do_pa[k] = ( |
|
v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) |
|
) |
|
|
|
out = counterfactual_inference(dataset_id, obs, do_pa) |
|
|
|
cf_x = out["cf_x"].mean(0) |
|
cf_x_std = out["cf_x"].std(0) |
|
rec_x = out["rec_x"].mean(0) |
|
cf_r = out["cf_pa"]["race"].mean(0) |
|
cf_s = out["cf_pa"]["sex"].mean(0) |
|
cf_f = out["cf_pa"]["finding"].mean(0) |
|
cf_a = out["cf_pa"]["age"].mean(0) |
|
|
|
cf_x = postprocess(cf_x) |
|
cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() |
|
rec_x = postprocess(rec_x) |
|
cf_r = RACE_CAT[cf_r.argmax(-1)] |
|
cf_s = SEX_CAT_CHEST[int(cf_s.item())] |
|
cf_f = FIND_CAT[cf_f.argmax(-1)] |
|
cf_a = (cf_a.item() + 1) * 50 |
|
|
|
|
|
|
|
effect = cf_x - postprocess(observation) |
|
effect = get_fig_arr( |
|
effect, |
|
cmap="RdBu_r", |
|
norm=MidpointNormalize(midpoint=0), |
|
|
|
) |
|
cf_x = get_fig_arr(cf_x) |
|
cf_x_std = get_fig_arr(cf_x_std, cmap="jet") |
|
return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1)) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
with gr.Tabs(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.TabItem("Chest X-ray") as chest_tab: |
|
chest_id = gr.Textbox(value=chest_tab.label, visible=False) |
|
|
|
with gr.Row(): |
|
idx_chest = gr.Number(value=0, visible=False) |
|
with gr.Column(scale=1, min_width=200): |
|
x_chest = gr.Image(label="Observation", interactive=False, height=HEIGHT) |
|
with gr.Column(scale=1, min_width=200): |
|
cf_x_chest = gr.Image( |
|
label="Counterfactual", interactive=False, height=HEIGHT) |
|
with gr.Column(scale=1, min_width=200): |
|
cf_x_std_chest = gr.Image( |
|
label="Counterfactual Uncertainty", interactive=False, height=HEIGHT) |
|
with gr.Column(scale=1, min_width=200): |
|
effect_chest = gr.Image( |
|
label="Direct Causal Effect", interactive=False, height=HEIGHT) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2.55): |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(min_width=200): |
|
do_f_chest = gr.Checkbox(label="do(disease)", value=False) |
|
f_chest = gr.Radio(FIND_CAT, label="", interactive=False) |
|
with gr.Column(min_width=200): |
|
do_s_chest = gr.Checkbox(label="do(sex)", value=False) |
|
s_chest = gr.Radio( |
|
SEX_CAT_CHEST, label="", interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(min_width=200): |
|
do_r_chest = gr.Checkbox(label="do(race)", value=False) |
|
r_chest = gr.Radio(RACE_CAT, label="", interactive=False) |
|
with gr.Column(min_width=200): |
|
do_a_chest = gr.Checkbox(label="do(age)", value=False) |
|
a_chest = gr.Slider( |
|
label="\u00A0", minimum=18, maximum=98, step=1 |
|
) |
|
|
|
with gr.Row(): |
|
new_chest = gr.Button("New Observation") |
|
reset_chest = gr.Button("Reset", variant="stop") |
|
submit_chest = gr.Button("Submit", variant="primary") |
|
with gr.Column(scale=1): |
|
|
|
causal_graph_chest = gr.Image( |
|
label="Causal Graph", interactive=False,height=345) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
do_chest = [do_r_chest, do_s_chest, do_f_chest, do_a_chest] |
|
obs_chest = [idx_chest, x_chest, r_chest, s_chest, f_chest, a_chest] |
|
cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest) |
|
demo.load(fn=load_model, inputs=chest_id) |
|
|
|
|
|
|
|
demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest) |
|
|
|
|
|
|
|
|
|
|
|
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) |
|
|
|
|
|
|
|
|
|
|
|
|
|
new_chest.click(fn=lambda:(gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest) |
|
|
|
|
|
|
|
|
|
|
|
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reset_chest.click( |
|
fn=lambda: (gr.update(value=False),) * len(do_chest), |
|
inputs=None, |
|
outputs=do_chest, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reset_chest.click(fn=lambda: plt.close("all"), inputs=None, outputs=None) |
|
reset_chest.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]): |
|
_k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) |
|
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_chest.click( |
|
fn=infer_chest_cf, |
|
inputs=obs_chest + do_chest, |
|
outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(share = False) |