Yeefei commited on
Commit
dfad5be
·
verified ·
1 Parent(s): ded9f97

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -30,9 +30,9 @@ MRISEQ_CAT = ["T1", "T2-FLAIR"] # 0,1
30
  SEX_CAT = ["female", "male"] # 0,1
31
  HEIGHT, WIDTH = 270, 270
32
  # chest
33
- SEX_CAT_CHEST = ["male", "female"] # 0,1
34
- RACE_CAT = ["white", "asian", "black"] # 0,1,2
35
- FIND_CAT = ["no disease", "pleural effusion"]
36
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
 
@@ -42,6 +42,24 @@ class Hparams:
42
  setattr(self, k, v)
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def get_paths(dataset_id):
46
  if "MNIST" in dataset_id:
47
  data_path = "./data/morphomnist"
@@ -53,14 +71,15 @@ def get_paths(dataset_id):
53
  vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
54
  elif "Chest" in dataset_id:
55
  data_path = "./data/mimic_subset"
56
- pgm_path = "./checkpoints/a_r_s_f/pgm/checkpoint.pt"
57
  vae_path = [
58
- "./checkpoints/a_r_s_f/base_vae/checkpoint.pt", # base vae
59
- "./checkpoints/a_r_s_f/cf_vae/6500_checkpoint.pt", # cf trained DSCM
60
  ]
61
  return data_path, vae_path, pgm_path
62
 
63
 
 
64
  def load_pgm(dataset_id, pgm_path):
65
  checkpoint = torch.load(pgm_path, map_location=DEVICE)
66
  args = Hparams()
@@ -202,7 +221,7 @@ def get_chest_obs(idx=None):
202
  idx, obs = get_obs_item(dataset_id, idx)
203
  x = get_fig_arr(postprocess(obs["x"].clone()))
204
  s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
205
- f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())]
206
  r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
207
  a = (obs["age"].clone().squeeze().numpy() + 1) * 50
208
  return (idx, x, r, s, f, float(np.round(a, 1)))
@@ -360,7 +379,10 @@ def infer_chest_cf(*args):
360
  if do_s:
361
  do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
362
  if do_f:
363
- do_pa["finding"] = torch.tensor(FIND_CAT.index(f)).view(1, 1)
 
 
 
364
  if do_r:
365
  do_pa["race"] = F.one_hot(
366
  torch.tensor(RACE_CAT.index(r)), num_classes=3
@@ -387,7 +409,7 @@ def infer_chest_cf(*args):
387
  rec_x = postprocess(rec_x)
388
  cf_r = RACE_CAT[cf_r.argmax(-1)]
389
  cf_s = SEX_CAT_CHEST[int(cf_s.item())]
390
- cf_f = FIND_CAT[cf_f.argmax(-1)]
391
  cf_a = (cf_a.item() + 1) * 50
392
  # plots
393
  # plt.close('all')
@@ -426,7 +448,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
426
  label="Direct Causal Effect", interactive=False
427
  , height=HEIGHT) #).style(height=HEIGHT)
428
  with gr.Row(): #.style(equal_height=True):
429
- with gr.Column(scale=1.75):
430
  gr.Markdown(
431
  "**Intervention**"
432
  + 20 * " "
@@ -487,7 +509,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
487
  label="Direct Causal Effect", interactive=False
488
  , height=HEIGHT) #).style(height=HEIGHT)
489
  with gr.Row():
490
- with gr.Column(scale=2.55):
491
  gr.Markdown(
492
  "**Intervention**"
493
  + 20 * " "
@@ -571,7 +593,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
571
  , height=HEIGHT) #).style(height=HEIGHT)
572
 
573
  with gr.Row():
574
- with gr.Column(scale=2.55):
575
  gr.Markdown(
576
  "**Intervention**"
577
  + 20 * " "
 
30
  SEX_CAT = ["female", "male"] # 0,1
31
  HEIGHT, WIDTH = 270, 270
32
  # chest
33
+ SEX_CAT_CHEST = ["female", "male"] # 0,1
34
+ RACE_CAT = ["white", "black", "asian"] # 0,1,2
35
+ FIND_CAT = ["no disease", "consolidation", "opacity"] # 0,1,2
36
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
 
 
42
  setattr(self, k, v)
43
 
44
 
45
+ # def get_paths(dataset_id):
46
+ # if "MNIST" in dataset_id:
47
+ # data_path = "./data/morphomnist"
48
+ # pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt"
49
+ # vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt"
50
+ # elif "Brain" in dataset_id:
51
+ # data_path = "./data/ukbb_subset"
52
+ # pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt"
53
+ # vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
54
+ # elif "Chest" in dataset_id:
55
+ # data_path = "./data/mimic_subset"
56
+ # pgm_path = "./checkpoints/a_r_s_f/pgm/checkpoint.pt"
57
+ # vae_path = [
58
+ # "./checkpoints/a_r_s_f/base_vae/checkpoint.pt", # base vae
59
+ # "./checkpoints/a_r_s_f/cf_vae/6500_checkpoint.pt", # cf trained DSCM
60
+ # ]
61
+ # return data_path, vae_path, pgm_path
62
+
63
  def get_paths(dataset_id):
64
  if "MNIST" in dataset_id:
65
  data_path = "./data/morphomnist"
 
71
  vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt"
72
  elif "Chest" in dataset_id:
73
  data_path = "./data/mimic_subset"
74
+ pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/checkpoint.pt"
75
  vae_path = [
76
+ "./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/checkpoint.pt", # base vae
77
+ "./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/6500_checkpoint.pt", # cf trained DSCM
78
  ]
79
  return data_path, vae_path, pgm_path
80
 
81
 
82
+
83
  def load_pgm(dataset_id, pgm_path):
84
  checkpoint = torch.load(pgm_path, map_location=DEVICE)
85
  args = Hparams()
 
221
  idx, obs = get_obs_item(dataset_id, idx)
222
  x = get_fig_arr(postprocess(obs["x"].clone()))
223
  s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
224
+ f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)]
225
  r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
226
  a = (obs["age"].clone().squeeze().numpy() + 1) * 50
227
  return (idx, x, r, s, f, float(np.round(a, 1)))
 
379
  if do_s:
380
  do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1)
381
  if do_f:
382
+ do_pa["finding"] = F.one_hot(
383
+ torch.tensor(FIND_CAT.index(f)), num_classes=3
384
+ ).view(1, 3)
385
+ # torch.tensor(FIND_CAT.index(f)).view(1, 1)
386
  if do_r:
387
  do_pa["race"] = F.one_hot(
388
  torch.tensor(RACE_CAT.index(r)), num_classes=3
 
409
  rec_x = postprocess(rec_x)
410
  cf_r = RACE_CAT[cf_r.argmax(-1)]
411
  cf_s = SEX_CAT_CHEST[int(cf_s.item())]
412
+ cf_f = FIND_CAT[cf_f.argmax(-1)] #FIND_CAT[int(cf_f.item())]
413
  cf_a = (cf_a.item() + 1) * 50
414
  # plots
415
  # plt.close('all')
 
448
  label="Direct Causal Effect", interactive=False
449
  , height=HEIGHT) #).style(height=HEIGHT)
450
  with gr.Row(): #.style(equal_height=True):
451
+ with gr.Column(scale=1):#.75):
452
  gr.Markdown(
453
  "**Intervention**"
454
  + 20 * " "
 
509
  label="Direct Causal Effect", interactive=False
510
  , height=HEIGHT) #).style(height=HEIGHT)
511
  with gr.Row():
512
+ with gr.Column(scale=2):#.55):
513
  gr.Markdown(
514
  "**Intervention**"
515
  + 20 * " "
 
593
  , height=HEIGHT) #).style(height=HEIGHT)
594
 
595
  with gr.Row():
596
+ with gr.Column(scale=2):#.55):
597
  gr.Markdown(
598
  "**Intervention**"
599
  + 20 * " "