Upload app.py
Browse files
app.py
CHANGED
@@ -30,9 +30,9 @@ MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1
|
|
30 |
SEX_CAT = ["female", "male"] # 0,1
|
31 |
HEIGHT, WIDTH = 270, 270
|
32 |
# chest
|
33 |
-
SEX_CAT_CHEST = ["
|
34 |
-
RACE_CAT = ["white", "
|
35 |
-
FIND_CAT = ["no disease", "
|
36 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
|
38 |
|
@@ -42,6 +42,24 @@ class Hparams:
|
|
42 |
setattr(self, k, v)
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def get_paths(dataset_id):
|
46 |
if "MNIST" in dataset_id:
|
47 |
data_path = "./data/morphomnist"
|
@@ -53,14 +71,15 @@ def get_paths(dataset_id):
|
|
53 |
vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
|
54 |
elif "Chest" in dataset_id:
|
55 |
data_path = "./data/mimic_subset"
|
56 |
-
pgm_path = "./checkpoints/a_r_s_f/
|
57 |
vae_path = [
|
58 |
-
"./checkpoints/a_r_s_f/
|
59 |
-
"./checkpoints/a_r_s_f/
|
60 |
]
|
61 |
return data_path, vae_path, pgm_path
|
62 |
|
63 |
|
|
|
64 |
def load_pgm(dataset_id, pgm_path):
|
65 |
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
66 |
args = Hparams()
|
@@ -202,7 +221,7 @@ def get_chest_obs(idx=None):
|
|
202 |
idx, obs = get_obs_item(dataset_id, idx)
|
203 |
x = get_fig_arr(postprocess(obs["x"].clone()))
|
204 |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
205 |
-
f = FIND_CAT[
|
206 |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
207 |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
208 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
@@ -360,7 +379,10 @@ def infer_chest_cf(*args):
|
|
360 |
if do_s:
|
361 |
do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
|
362 |
if do_f:
|
363 |
-
do_pa["finding"] =
|
|
|
|
|
|
|
364 |
if do_r:
|
365 |
do_pa["race"] = F.one_hot(
|
366 |
torch.tensor(RACE_CAT.index(r)), num_classes=3
|
@@ -387,7 +409,7 @@ def infer_chest_cf(*args):
|
|
387 |
rec_x = postprocess(rec_x)
|
388 |
cf_r = RACE_CAT[cf_r.argmax(-1)]
|
389 |
cf_s = SEX_CAT_CHEST[int(cf_s.item())]
|
390 |
-
cf_f = FIND_CAT[cf_f.argmax(-1)]
|
391 |
cf_a = (cf_a.item() + 1) * 50
|
392 |
# plots
|
393 |
# plt.close('all')
|
@@ -426,7 +448,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
426 |
label="Direct Causal Effect", interactive=False
|
427 |
, height=HEIGHT) #).style(height=HEIGHT)
|
428 |
with gr.Row(): #.style(equal_height=True):
|
429 |
-
with gr.Column(scale=1
|
430 |
gr.Markdown(
|
431 |
"**Intervention**"
|
432 |
+ 20 * " "
|
@@ -487,7 +509,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
487 |
label="Direct Causal Effect", interactive=False
|
488 |
, height=HEIGHT) #).style(height=HEIGHT)
|
489 |
with gr.Row():
|
490 |
-
with gr.Column(scale=2
|
491 |
gr.Markdown(
|
492 |
"**Intervention**"
|
493 |
+ 20 * " "
|
@@ -571,7 +593,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
571 |
, height=HEIGHT) #).style(height=HEIGHT)
|
572 |
|
573 |
with gr.Row():
|
574 |
-
with gr.Column(scale=2
|
575 |
gr.Markdown(
|
576 |
"**Intervention**"
|
577 |
+ 20 * " "
|
|
|
30 |
SEX_CAT = ["female", "male"] # 0,1
|
31 |
HEIGHT, WIDTH = 270, 270
|
32 |
# chest
|
33 |
+
SEX_CAT_CHEST = ["female", "male"] # 0,1
|
34 |
+
RACE_CAT = ["white", "black", "asian"] # 0,1,2
|
35 |
+
FIND_CAT = ["no disease", "consolidation", "opacity"] # 0,1,2
|
36 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
|
38 |
|
|
|
42 |
setattr(self, k, v)
|
43 |
|
44 |
|
45 |
+
# def get_paths(dataset_id):
|
46 |
+
# if "MNIST" in dataset_id:
|
47 |
+
# data_path = "./data/morphomnist"
|
48 |
+
# pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt"
|
49 |
+
# vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt"
|
50 |
+
# elif "Brain" in dataset_id:
|
51 |
+
# data_path = "./data/ukbb_subset"
|
52 |
+
# pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt"
|
53 |
+
# vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
|
54 |
+
# elif "Chest" in dataset_id:
|
55 |
+
# data_path = "./data/mimic_subset"
|
56 |
+
# pgm_path = "./checkpoints/a_r_s_f/pgm/checkpoint.pt"
|
57 |
+
# vae_path = [
|
58 |
+
# "./checkpoints/a_r_s_f/base_vae/checkpoint.pt", # base vae
|
59 |
+
# "./checkpoints/a_r_s_f/cf_vae/6500_checkpoint.pt", # cf trained DSCM
|
60 |
+
# ]
|
61 |
+
# return data_path, vae_path, pgm_path
|
62 |
+
|
63 |
def get_paths(dataset_id):
|
64 |
if "MNIST" in dataset_id:
|
65 |
data_path = "./data/morphomnist"
|
|
|
71 |
vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
|
72 |
elif "Chest" in dataset_id:
|
73 |
data_path = "./data/mimic_subset"
|
74 |
+
pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt"
|
75 |
vae_path = [
|
76 |
+
"./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae
|
77 |
+
"./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM
|
78 |
]
|
79 |
return data_path, vae_path, pgm_path
|
80 |
|
81 |
|
82 |
+
|
83 |
def load_pgm(dataset_id, pgm_path):
|
84 |
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
85 |
args = Hparams()
|
|
|
221 |
idx, obs = get_obs_item(dataset_id, idx)
|
222 |
x = get_fig_arr(postprocess(obs["x"].clone()))
|
223 |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
224 |
+
f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)]
|
225 |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
226 |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
227 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
|
|
379 |
if do_s:
|
380 |
do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
|
381 |
if do_f:
|
382 |
+
do_pa["finding"] = F.one_hot(
|
383 |
+
torch.tensor(FIND_CAT.index(f)), num_classes=3
|
384 |
+
).view(1, 3)
|
385 |
+
# torch.tensor(FIND_CAT.index(f)).view(1, 1)
|
386 |
if do_r:
|
387 |
do_pa["race"] = F.one_hot(
|
388 |
torch.tensor(RACE_CAT.index(r)), num_classes=3
|
|
|
409 |
rec_x = postprocess(rec_x)
|
410 |
cf_r = RACE_CAT[cf_r.argmax(-1)]
|
411 |
cf_s = SEX_CAT_CHEST[int(cf_s.item())]
|
412 |
+
cf_f = FIND_CAT[cf_f.argmax(-1)] #FIND_CAT[int(cf_f.item())]
|
413 |
cf_a = (cf_a.item() + 1) * 50
|
414 |
# plots
|
415 |
# plt.close('all')
|
|
|
448 |
label="Direct Causal Effect", interactive=False
|
449 |
, height=HEIGHT) #).style(height=HEIGHT)
|
450 |
with gr.Row(): #.style(equal_height=True):
|
451 |
+
with gr.Column(scale=1):#.75):
|
452 |
gr.Markdown(
|
453 |
"**Intervention**"
|
454 |
+ 20 * " "
|
|
|
509 |
label="Direct Causal Effect", interactive=False
|
510 |
, height=HEIGHT) #).style(height=HEIGHT)
|
511 |
with gr.Row():
|
512 |
+
with gr.Column(scale=2):#.55):
|
513 |
gr.Markdown(
|
514 |
"**Intervention**"
|
515 |
+ 20 * " "
|
|
|
593 |
, height=HEIGHT) #).style(height=HEIGHT)
|
594 |
|
595 |
with gr.Row():
|
596 |
+
with gr.Column(scale=2):#.55):
|
597 |
gr.Markdown(
|
598 |
"**Intervention**"
|
599 |
+ 20 * " "
|