Yeefei commited on
Commit
f8e446a
·
verified ·
1 Parent(s): 9ecb0d4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. vae.py +6 -6
app.py CHANGED
@@ -221,7 +221,7 @@ def get_chest_obs(idx=None):
221
  idx, obs = get_obs_item(dataset_id, idx)
222
  x = get_fig_arr(postprocess(obs["x"].clone()))
223
  s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
224
- f = FIND_CAT[int(obs["finding"].clone().squeeze().numpy())]
225
  r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
226
  a = (obs["age"].clone().squeeze().numpy() + 1) * 50
227
  return (idx, x, r, s, f, float(np.round(a, 1)))
 
221
  idx, obs = get_obs_item(dataset_id, idx)
222
  x = get_fig_arr(postprocess(obs["x"].clone()))
223
  s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
224
+ f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)]
225
  r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
226
  a = (obs["age"].clone().squeeze().numpy() + 1) * 50
227
  return (idx, x, r, s, f, float(np.round(a, 1)))
vae.py CHANGED
@@ -6,7 +6,7 @@ import torch.distributions as dist
6
  import torch.nn.functional as F
7
  from torch import Tensor, nn
8
 
9
- from hps import Hparams
10
 
11
  EPS = -9 # minimum logscale
12
 
@@ -85,7 +85,7 @@ class Block(nn.Module):
85
 
86
 
87
  class Encoder(nn.Module):
88
- def __init__(self, args: Hparams):
89
  super().__init__()
90
  # parse architecture
91
  stages = []
@@ -135,7 +135,7 @@ class Encoder(nn.Module):
135
 
136
 
137
  class DecoderBlock(nn.Module):
138
- def __init__(self, args: Hparams, in_width: int, out_width: int, resolution: int):
139
  super().__init__()
140
  bottleneck = int(in_width / args.bottleneck)
141
  self.res = resolution
@@ -207,7 +207,7 @@ class DecoderBlock(nn.Module):
207
 
208
 
209
  class Decoder(nn.Module):
210
- def __init__(self, args: Hparams):
211
  super().__init__()
212
  # parse architecture
213
  stages = []
@@ -335,7 +335,7 @@ class Decoder(nn.Module):
335
 
336
 
337
  class DGaussNet(nn.Module):
338
- def __init__(self, args: Hparams):
339
  super(DGaussNet, self).__init__()
340
  self.x_loc = nn.Conv2d(
341
  args.widths[0], args.input_channels, kernel_size=1, stride=1
@@ -438,7 +438,7 @@ class DGaussNet(nn.Module):
438
 
439
 
440
  class HVAE(nn.Module):
441
- def __init__(self, args: Hparams):
442
  super().__init__()
443
  args.vr = "light" if "ukbb" in args.hps else None # hacky
444
  self.encoder = Encoder(args)
 
6
  import torch.nn.functional as F
7
  from torch import Tensor, nn
8
 
9
+ # from hps import Hparams
10
 
11
  EPS = -9 # minimum logscale
12
 
 
85
 
86
 
87
  class Encoder(nn.Module):
88
+ def __init__(self, args):
89
  super().__init__()
90
  # parse architecture
91
  stages = []
 
135
 
136
 
137
  class DecoderBlock(nn.Module):
138
+ def __init__(self, args, in_width, out_width, resolution):
139
  super().__init__()
140
  bottleneck = int(in_width / args.bottleneck)
141
  self.res = resolution
 
207
 
208
 
209
  class Decoder(nn.Module):
210
+ def __init__(self, args):
211
  super().__init__()
212
  # parse architecture
213
  stages = []
 
335
 
336
 
337
  class DGaussNet(nn.Module):
338
+ def __init__(self, args):
339
  super(DGaussNet, self).__init__()
340
  self.x_loc = nn.Conv2d(
341
  args.widths[0], args.input_channels, kernel_size=1, stride=1
 
438
 
439
 
440
  class HVAE(nn.Module):
441
+ def __init__(self, args):
442
  super().__init__()
443
  args.vr = "light" if "ukbb" in args.hps else None # hacky
444
  self.encoder = Encoder(args)