Upload 2 files
Browse files
app.py
CHANGED
@@ -221,7 +221,7 @@ def get_chest_obs(idx=None):
|
|
221 |
idx, obs = get_obs_item(dataset_id, idx)
|
222 |
x = get_fig_arr(postprocess(obs["x"].clone()))
|
223 |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
224 |
-
f = FIND_CAT[
|
225 |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
226 |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
227 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
|
|
221 |
idx, obs = get_obs_item(dataset_id, idx)
|
222 |
x = get_fig_arr(postprocess(obs["x"].clone()))
|
223 |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
224 |
+
f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)]
|
225 |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
226 |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
227 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
vae.py
CHANGED
@@ -6,7 +6,7 @@ import torch.distributions as dist
|
|
6 |
import torch.nn.functional as F
|
7 |
from torch import Tensor, nn
|
8 |
|
9 |
-
from hps import Hparams
|
10 |
|
11 |
EPS = -9 # minimum logscale
|
12 |
|
@@ -85,7 +85,7 @@ class Block(nn.Module):
|
|
85 |
|
86 |
|
87 |
class Encoder(nn.Module):
|
88 |
-
def __init__(self, args
|
89 |
super().__init__()
|
90 |
# parse architecture
|
91 |
stages = []
|
@@ -135,7 +135,7 @@ class Encoder(nn.Module):
|
|
135 |
|
136 |
|
137 |
class DecoderBlock(nn.Module):
|
138 |
-
def __init__(self, args
|
139 |
super().__init__()
|
140 |
bottleneck = int(in_width / args.bottleneck)
|
141 |
self.res = resolution
|
@@ -207,7 +207,7 @@ class DecoderBlock(nn.Module):
|
|
207 |
|
208 |
|
209 |
class Decoder(nn.Module):
|
210 |
-
def __init__(self, args
|
211 |
super().__init__()
|
212 |
# parse architecture
|
213 |
stages = []
|
@@ -335,7 +335,7 @@ class Decoder(nn.Module):
|
|
335 |
|
336 |
|
337 |
class DGaussNet(nn.Module):
|
338 |
-
def __init__(self, args
|
339 |
super(DGaussNet, self).__init__()
|
340 |
self.x_loc = nn.Conv2d(
|
341 |
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
@@ -438,7 +438,7 @@ class DGaussNet(nn.Module):
|
|
438 |
|
439 |
|
440 |
class HVAE(nn.Module):
|
441 |
-
def __init__(self, args
|
442 |
super().__init__()
|
443 |
args.vr = "light" if "ukbb" in args.hps else None # hacky
|
444 |
self.encoder = Encoder(args)
|
|
|
6 |
import torch.nn.functional as F
|
7 |
from torch import Tensor, nn
|
8 |
|
9 |
+
# from hps import Hparams
|
10 |
|
11 |
EPS = -9 # minimum logscale
|
12 |
|
|
|
85 |
|
86 |
|
87 |
class Encoder(nn.Module):
|
88 |
+
def __init__(self, args):
|
89 |
super().__init__()
|
90 |
# parse architecture
|
91 |
stages = []
|
|
|
135 |
|
136 |
|
137 |
class DecoderBlock(nn.Module):
|
138 |
+
def __init__(self, args, in_width, out_width, resolution):
|
139 |
super().__init__()
|
140 |
bottleneck = int(in_width / args.bottleneck)
|
141 |
self.res = resolution
|
|
|
207 |
|
208 |
|
209 |
class Decoder(nn.Module):
|
210 |
+
def __init__(self, args):
|
211 |
super().__init__()
|
212 |
# parse architecture
|
213 |
stages = []
|
|
|
335 |
|
336 |
|
337 |
class DGaussNet(nn.Module):
|
338 |
+
def __init__(self, args):
|
339 |
super(DGaussNet, self).__init__()
|
340 |
self.x_loc = nn.Conv2d(
|
341 |
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
|
|
438 |
|
439 |
|
440 |
class HVAE(nn.Module):
|
441 |
+
def __init__(self, args):
|
442 |
super().__init__()
|
443 |
args.vr = "light" if "ukbb" in args.hps else None # hacky
|
444 |
self.encoder = Encoder(args)
|