Yeefei commited on
Commit
aec687f
·
verified ·
1 Parent(s): 0073b09

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -342,8 +342,7 @@ def infer_brain_cf(*args):
342
  cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min
343
  # plots
344
  # plt.close('all')
345
- # effect = cf_x - rec_x
346
- effect = cf_x - obs["x"]
347
  effect = get_fig_arr(
348
  effect,
349
  cmap="RdBu_r",
@@ -370,6 +369,7 @@ def infer_chest_cf(*args):
370
  n_particles = 16
371
  # preprocessing
372
  obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
 
373
  for k, v in obs.items():
374
  obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float()
375
  if n_particles > 1:
@@ -415,10 +415,8 @@ def infer_chest_cf(*args):
415
  cf_a = (cf_a.item() + 1) * 50
416
  # plots
417
  # plt.close('all')
418
- effect = cf_x - rec_x
419
- # effect = cf_x - obs["x"]
420
- print(effect.max())
421
- print(effect.min())
422
  effect = get_fig_arr(
423
  effect,
424
  cmap="RdBu_r",
 
342
  cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min
343
  # plots
344
  # plt.close('all')
345
+ effect = cf_x - rec_x
 
346
  effect = get_fig_arr(
347
  effect,
348
  cmap="RdBu_r",
 
369
  n_particles = 16
370
  # preprocessing
371
  obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx))
372
+ observation = obs['x']
373
  for k, v in obs.items():
374
  obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float()
375
  if n_particles > 1:
 
415
  cf_a = (cf_a.item() + 1) * 50
416
  # plots
417
  # plt.close('all')
418
+ # effect = cf_x - rec_x
419
+ effect = cf_x - postprocess(observation)
 
 
420
  effect = get_fig_arr(
421
  effect,
422
  cmap="RdBu_r",