Upload app.py
Browse files
app.py
CHANGED
@@ -83,7 +83,7 @@ def get_paths(dataset_id):
|
|
83 |
|
84 |
|
85 |
def load_pgm(dataset_id, pgm_path):
|
86 |
-
checkpoint = torch.load(pgm_path, map_location=DEVICE
|
87 |
args = Hparams()
|
88 |
args.update(checkpoint["hparams"])
|
89 |
args.device = DEVICE
|
@@ -101,7 +101,7 @@ def load_pgm(dataset_id, pgm_path):
|
|
101 |
def load_vae(dataset_id, vae_path):
|
102 |
if "Chest" in dataset_id:
|
103 |
vae_path, dscm_path = vae_path[0], vae_path[1]
|
104 |
-
checkpoint = torch.load(vae_path, map_location=DEVICE
|
105 |
args = Hparams()
|
106 |
args.update(checkpoint["hparams"])
|
107 |
# backwards compatibility hack
|
@@ -115,7 +115,7 @@ def load_vae(dataset_id, vae_path):
|
|
115 |
vae = HVAE(args).to(args.device)
|
116 |
|
117 |
if "Chest" in dataset_id:
|
118 |
-
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE
|
119 |
vae.load_state_dict(
|
120 |
{
|
121 |
k[4:]: v
|
|
|
83 |
|
84 |
|
85 |
def load_pgm(dataset_id, pgm_path):
|
86 |
+
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
87 |
args = Hparams()
|
88 |
args.update(checkpoint["hparams"])
|
89 |
args.device = DEVICE
|
|
|
101 |
def load_vae(dataset_id, vae_path):
|
102 |
if "Chest" in dataset_id:
|
103 |
vae_path, dscm_path = vae_path[0], vae_path[1]
|
104 |
+
checkpoint = torch.load(vae_path, map_location=DEVICE)
|
105 |
args = Hparams()
|
106 |
args.update(checkpoint["hparams"])
|
107 |
# backwards compatibility hack
|
|
|
115 |
vae = HVAE(args).to(args.device)
|
116 |
|
117 |
if "Chest" in dataset_id:
|
118 |
+
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
|
119 |
vae.load_state_dict(
|
120 |
{
|
121 |
k[4:]: v
|