Yeefei commited on
Commit
5ebd65a
·
verified ·
1 Parent(s): 9aa68f2

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. vae.py +196 -155
app.py CHANGED
@@ -727,4 +727,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
727
 
728
  if __name__ == "__main__":
729
  demo.queue()
730
- demo.launch(share=True)
 
727
 
728
  if __name__ == "__main__":
729
  demo.queue()
730
+ demo.launch()
vae.py CHANGED
@@ -1,14 +1,20 @@
 
 
1
  import numpy as np
2
  import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
  import torch.distributions as dist
 
 
 
 
6
 
7
  EPS = -9 # minimum logscale
8
 
9
 
10
  @torch.jit.script
11
- def gaussian_kl(q_loc, q_logscale, p_loc, p_logscale):
 
 
12
  return (
13
  -0.5
14
  + p_logscale
@@ -20,27 +26,27 @@ def gaussian_kl(q_loc, q_logscale, p_loc, p_logscale):
20
 
21
 
22
  @torch.jit.script
23
- def sample_gaussian(loc, logscale):
24
  return loc + logscale.exp() * torch.randn_like(loc)
25
 
26
 
27
  class Block(nn.Module):
28
  def __init__(
29
  self,
30
- in_width,
31
- bottleneck,
32
- out_width,
33
- kernel_size=3,
34
- residual=True,
35
- down_rate=None,
36
- version=None,
37
  ):
38
  super().__init__()
39
  self.d = down_rate
40
  self.residual = residual
41
  padding = 0 if kernel_size == 1 else 1
42
 
43
- if version == "light": # for ukbb
44
  activation = nn.ReLU()
45
  self.conv = nn.Sequential(
46
  activation,
@@ -64,7 +70,7 @@ class Block(nn.Module):
64
  if self.residual and (self.d or in_width > out_width):
65
  self.width_proj = nn.Conv2d(in_width, out_width, 1, 1)
66
 
67
- def forward(self, x):
68
  out = self.conv(x)
69
  if self.residual:
70
  if x.shape[1] != out.shape[1]:
@@ -79,7 +85,7 @@ class Block(nn.Module):
79
 
80
 
81
  class Encoder(nn.Module):
82
- def __init__(self, args):
83
  super().__init__()
84
  # parse architecture
85
  stages = []
@@ -91,23 +97,17 @@ class Encoder(nn.Module):
91
  if i == 0: # define network stem
92
  if n_blocks == 0 and "d" not in stage:
93
  print("Using stride=2 conv encoder stem.")
94
- self.stem = nn.Conv2d(
95
- args.input_channels,
96
- args.widths[1],
97
- kernel_size=7,
98
- stride=2,
99
- padding=3,
100
- )
101
  continue
102
  else:
103
- self.stem = nn.Conv2d(
104
- args.input_channels,
105
- args.widths[0],
106
- kernel_size=7,
107
- stride=1,
108
- padding=3,
109
- )
110
-
111
  stages += [(args.widths[i], None) for _ in range(n_blocks)]
112
  if "d" in stage: # downsampling block
113
  stages += [(args.widths[i + 1], int(stage[stage.index("d") + 1]))]
@@ -118,12 +118,11 @@ class Encoder(nn.Module):
118
  blocks.append(
119
  Block(prev_width, bottleneck, width, down_rate=d, version=args.vr)
120
  )
121
- # scale weights of last conv layer in each block
122
  for b in blocks:
123
  b.conv[-1].weight.data *= np.sqrt(1 / len(blocks))
124
  self.blocks = nn.ModuleList(blocks)
125
 
126
- def forward(self, x):
127
  x = self.stem(x)
128
  acts = {}
129
  for block in self.blocks:
@@ -136,24 +135,18 @@ class Encoder(nn.Module):
136
 
137
 
138
  class DecoderBlock(nn.Module):
139
- def __init__(self, args, in_width, out_width, resolution):
140
  super().__init__()
141
  bottleneck = int(in_width / args.bottleneck)
142
  self.res = resolution
143
  self.stochastic = self.res <= args.z_max_res
144
  self.z_dim = args.z_dim
145
  self.cond_prior = args.cond_prior
 
146
  k = 3 if self.res > 2 else 1
147
 
148
- if self.cond_prior: # conditional prior
149
- p_in_width = in_width + args.context_dim
150
- else: # exogenous prior
151
- p_in_width = in_width
152
- # self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
153
- self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
154
-
155
  self.prior = Block(
156
- p_in_width,
157
  bottleneck,
158
  2 * self.z_dim + in_width,
159
  kernel_size=k,
@@ -170,11 +163,21 @@ class DecoderBlock(nn.Module):
170
  version=args.vr,
171
  )
172
  self.z_proj = nn.Conv2d(self.z_dim + args.context_dim, in_width, 1)
 
 
173
  self.conv = Block(
174
  in_width, bottleneck, out_width, kernel_size=k, version=args.vr
175
  )
176
 
177
- def forward_prior(self, z, pa=None, t=None):
 
 
 
 
 
 
 
 
178
  if self.cond_prior:
179
  z = torch.cat([z, pa], dim=1)
180
  z = self.prior(z)
@@ -185,8 +188,18 @@ class DecoderBlock(nn.Module):
185
  p_logscale = p_logscale + torch.tensor(t).to(z.device).log()
186
  return p_loc, p_logscale, p_features
187
 
188
- def forward_posterior(self, z, pa, x, t=None):
 
 
 
 
 
 
 
 
 
189
  h = torch.cat([z, pa, x], dim=1)
 
190
  q_loc, q_logscale = self.posterior(h).chunk(2, dim=1)
191
  if t is not None:
192
  q_logscale = q_logscale + torch.tensor(t).to(z.device).log()
@@ -194,7 +207,7 @@ class DecoderBlock(nn.Module):
194
 
195
 
196
  class Decoder(nn.Module):
197
- def __init__(self, args):
198
  super().__init__()
199
  # parse architecture
200
  stages = []
@@ -218,73 +231,58 @@ class Decoder(nn.Module):
218
  )
219
  self.bias = nn.ParameterList(bias)
220
  self.cond_prior = args.cond_prior
221
- self.is_drop_cond = True if "mnist" in args.hps else False # hacky
222
 
223
- def _scale_weights(self):
224
- scale = np.sqrt(1 / len(self.blocks))
225
- for b in self.blocks:
226
- b.z_proj.weight.data *= scale
227
- b.conv.conv[-1].weight.data *= scale
228
- b.prior.conv[-1].weight.data *= 0.0
229
-
230
- def forward(self, parents, x=None, t=None, abduct=False, latents=[]):
231
  # learnt params for each resolution r
232
  bias = {r.shape[2]: r for r in self.bias}
233
- h = bias[1].repeat(parents.shape[0], 1, 1, 1) # h_init
234
- z = h # for exogenous prior
235
- # for conditioning dropout, stochastic path (p1), deterministic path (p2)
236
- p1, p2 = self.drop_cond() if (self.training and self.cond_prior) else (1, 1)
 
237
 
238
  stats = []
239
  for i, block in enumerate(self.blocks):
240
  res = block.res # current block resolution, e.g. 64x64
241
  pa = parents[..., :res, :res].clone() # select parents @ res
242
 
243
- if (
244
- self.is_drop_cond
245
- ): # for morphomnist w/ conditioning dropout. Hacky, clean up later
246
- pa_drop1 = pa.clone()
247
- pa_drop1[:, 2:, ...] = pa_drop1[:, 2:, ...] * p1
248
- pa_drop2 = pa.clone()
249
- pa_drop2[:, 2:, ...] = pa_drop2[:, 2:, ...] * p2
250
- else: # for ukbb
251
- pa_drop1 = pa_drop2 = pa
252
 
253
  if h.size(-1) < res: # upsample previous layer output
254
  b = bias[res] if res in bias.keys() else 0 # broadcasting
255
  h = b + F.interpolate(h, scale_factor=res / h.shape[-1])
256
 
257
- if block.cond_prior: # conditional prior: p(z_i | z_<i, pa_x)
258
- # w/ posterior correction
259
- # p_loc, p_logscale, p_feat = block.forward_prior(h, pa_drop1, t=t)
260
- if z.size(-1) < res: # w/o posterior correction
261
- z = b + F.interpolate(z, scale_factor=res / z.shape[-1])
262
- p_loc, p_logscale, p_feat = block.forward_prior(z, pa_drop1, t=t)
263
- else: # exogenous prior: p(z_i | z_<i)
264
- if z.size(-1) < res:
265
- z = b + F.interpolate(z, scale_factor=res / z.shape[-1])
266
- p_loc, p_logscale, p_feat = block.forward_prior(z, t=t)
267
-
268
- # computation tree:
269
- # decoder block
270
- # / \
271
- # deterministic stochastic
272
- # | / \
273
- # forward z = p_loc given x not given x
274
- # / / \
275
- # abduct forward z or z* z ~ prior
276
- # / \ |
277
- # (prior: conditional exogenous) get p(z|pa*) if abduct
278
- # get z* get z
279
- #
280
 
281
  if block.stochastic:
282
- if x is not None: # z_i ~ q(z_i | z_<i, pa_x, x)
283
- q_loc, q_logscale = block.forward_posterior(h, pa, x[res], t=t)
 
284
  z = sample_gaussian(q_loc, q_logscale)
285
  stat = dict(kl=gaussian_kl(q_loc, q_logscale, p_loc, p_logscale))
286
- # abduct exogenous noise
287
- if abduct:
288
  if block.cond_prior: # z* if conditional prior
289
  stat.update(
290
  dict(
@@ -292,57 +290,52 @@ class Decoder(nn.Module):
292
  )
293
  )
294
  else: # z if exogenous prior
295
- # stat.update(dict(z=z.detach()))
296
- stat.update(dict(z=z)) # if cf training
297
  stats.append(stat)
298
  else:
299
- if latents[i] is None:
 
 
 
300
  z = sample_gaussian(p_loc, p_logscale)
301
-
302
  if abduct and block.cond_prior: # for abducting z*
303
  stats.append(
304
  dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
305
  )
306
- else:
307
- try: # forward fixed latents z or z*
308
- z = latents[i]
309
- except: # sample prior
310
- z = sample_gaussian(p_loc, p_logscale)
311
-
312
- if abduct and block.cond_prior: # for abducting z*
313
- stats.append(
314
- dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
315
- )
316
- else:
317
- z = p_loc # deterministic path
318
-
319
  h = h + p_feat # merge prior features
320
- h = self.forward_merge(block, h, z, pa_drop2)
321
-
322
- # if not block.cond_prior:
323
- if (i + 1) < len(self.blocks):
324
- # z independent of pa_x for next layer prior
325
- z = block.z_feat_proj(torch.cat([z, p_feat], dim=1))
 
 
326
  return h, stats
327
 
328
- def forward_merge(self, block, h, z, pa):
329
- # h_i = h_<i + f(z_i, pa_x)
330
- h = h + block.z_proj(torch.cat([z, pa], dim=1))
331
- return block.conv(h)
 
 
332
 
333
- def drop_cond(self):
 
334
  opt = dist.Categorical(1 / 3 * torch.ones(3)).sample()
335
  if opt == 0: # drop stochastic path
336
- p1, p2 = 0, 1
337
  elif opt == 1: # drop deterministic path
338
- p1, p2 = 1, 0
339
  elif opt == 2: # keep both
340
- p1, p2 = 1, 1
341
- return p1, p2
342
 
343
 
344
  class DGaussNet(nn.Module):
345
- def __init__(self, args):
346
  super(DGaussNet, self).__init__()
347
  self.x_loc = nn.Conv2d(
348
  args.widths[0], args.input_channels, kernel_size=1, stride=1
@@ -371,36 +364,48 @@ class DGaussNet(nn.Module):
371
  else:
372
  NotImplementedError(f"{args.x_like} not implemented.")
373
 
374
- def forward(self, h, x=None, t=None):
 
 
375
  loc, logscale = self.x_loc(h), self.x_logscale(h).clamp(min=EPS)
376
 
377
  # for RGB inputs
378
- # if hasattr(self, 'channel_coeffs'):
379
- # coeff = torch.tanh(self.channel_coeffs(h))
380
- # if x is None: # inference
381
- # # loc = loc + logscale.exp() * torch.randn_like(loc) # random sampling
382
- # f = lambda x: torch.clamp(x, min=-1, max=1)
383
- # loc_red = f(loc[:,0,...])
384
- # loc_green = f(loc[:,1,...] + coeff[:,0,...] * loc_red)
385
- # loc_blue = f(loc[:,2,...] + coeff[:,1,...] * loc_red + coeff[:,2,...] * loc_green)
386
- # else: # training
387
- # loc_red = loc[:,0,...]
388
- # loc_green = loc[:,1,...] + coeff[:,0,...] * x[:,0,...]
389
- # loc_blue = loc[:,2,...] + coeff[:,1,...] * x[:,0,...] + coeff[:,2,...] * x[:,1,...]
390
-
391
- # loc = torch.cat([loc_red.unsqueeze(1),
392
- # loc_green.unsqueeze(1), loc_blue.unsqueeze(1)], dim=1)
 
 
 
 
 
 
 
 
 
 
393
 
394
  if t is not None:
395
  logscale = logscale + torch.tensor(t).to(h.device).log()
396
  return loc, logscale
397
 
398
- def approx_cdf(self, x):
399
  return 0.5 * (
400
  1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))
401
  )
402
 
403
- def nll(self, h, x):
404
  loc, logscale = self.forward(h, x)
405
  centered_x = x - loc
406
  inv_stdv = torch.exp(-logscale)
@@ -420,7 +425,9 @@ class DGaussNet(nn.Module):
420
  )
421
  return -1.0 * log_probs.mean(dim=(1, 2, 3))
422
 
423
- def sample(self, h, return_loc=True, t=None):
 
 
424
  if return_loc:
425
  x, logscale = self.forward(h)
426
  else:
@@ -431,7 +438,7 @@ class DGaussNet(nn.Module):
431
 
432
 
433
  class HVAE(nn.Module):
434
- def __init__(self, args):
435
  super().__init__()
436
  args.vr = "light" if "ukbb" in args.hps else None # hacky
437
  self.encoder = Encoder(args)
@@ -442,10 +449,30 @@ class HVAE(nn.Module):
442
  NotImplementedError(f"{args.x_like} not implemented.")
443
  self.cond_prior = args.cond_prior
444
  self.free_bits = args.kl_free_bits
 
445
 
446
- def forward(self, x, parents, beta=1):
 
 
 
447
  acts = self.encoder(x)
 
 
 
 
 
 
 
 
448
  h, stats = self.decoder(parents=parents, x=acts)
 
 
 
 
 
 
 
 
449
  nll_pp = self.likelihood.nll(h, x)
450
  if self.free_bits > 0:
451
  free_bits = torch.tensor(self.free_bits).type_as(nll_pp)
@@ -456,17 +483,28 @@ class HVAE(nn.Module):
456
  ).sum()
457
  else:
458
  kl_pp = torch.zeros_like(nll_pp)
459
- for i, stat in enumerate(stats):
460
  kl_pp += stat["kl"].sum(dim=(1, 2, 3))
461
  kl_pp = kl_pp / np.prod(x.shape[1:]) # per pixel
462
- elbo = nll_pp.mean() + beta * kl_pp.mean() # negative elbo (free energy)
463
- return dict(elbo=elbo, nll=nll_pp.mean(), kl=kl_pp.mean())
464
-
465
- def sample(self, parents, return_loc=True, t=None):
 
 
 
 
466
  h, _ = self.decoder(parents=parents, t=t)
467
  return self.likelihood.sample(h, return_loc, t=t)
468
 
469
- def abduct(self, x, parents, cf_parents=None, alpha=0.5, t=None):
 
 
 
 
 
 
 
470
  acts = self.encoder(x)
471
  _, q_stats = self.decoder(
472
  x=acts, parents=parents, abduct=True, t=t
@@ -493,8 +531,9 @@ class HVAE(nn.Module):
493
  # Option1: mixture distribution: r(z_i | z_{<i}, x, pa, pa*)
494
  # = a*q(z_i | z_{<i}, x, pa) + (1-a)*p(z_i | z_{<i}, pa*)
495
  r_loc = alpha * q_loc + (1 - alpha) * p_loc
496
- # assumes independence
497
- r_var = alpha * q_scale.pow(2) + (1 - alpha) * p_var
 
498
  # r_var = a*(q_loc.pow(2) + q_var) + (1-a)*(p_loc.pow(2) + p_var) - r_loc.pow(2)
499
 
500
  # # Option 2: precision weighted distribution
@@ -512,6 +551,8 @@ class HVAE(nn.Module):
512
  else:
513
  return q_stats # zs
514
 
515
- def forward_latents(self, latents, parents, t=None):
 
 
516
  h, _ = self.decoder(latents=latents, parents=parents, t=t)
517
  return self.likelihood.sample(h, t=t)
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
  import numpy as np
4
  import torch
 
 
5
  import torch.distributions as dist
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from hps import Hparams
10
 
11
  EPS = -9 # minimum logscale
12
 
13
 
14
  @torch.jit.script
15
+ def gaussian_kl(
16
+ q_loc: Tensor, q_logscale: Tensor, p_loc: Tensor, p_logscale: Tensor
17
+ ) -> Tensor:
18
  return (
19
  -0.5
20
  + p_logscale
 
26
 
27
 
28
  @torch.jit.script
29
+ def sample_gaussian(loc: Tensor, logscale: Tensor) -> Tensor:
30
  return loc + logscale.exp() * torch.randn_like(loc)
31
 
32
 
33
  class Block(nn.Module):
34
  def __init__(
35
  self,
36
+ in_width: int,
37
+ bottleneck: int,
38
+ out_width: int,
39
+ kernel_size: int = 3,
40
+ residual: bool = True,
41
+ down_rate: Optional[int] = None,
42
+ version: Optional[str] = None,
43
  ):
44
  super().__init__()
45
  self.d = down_rate
46
  self.residual = residual
47
  padding = 0 if kernel_size == 1 else 1
48
 
49
+ if version == "light": # uses less VRAM
50
  activation = nn.ReLU()
51
  self.conv = nn.Sequential(
52
  activation,
 
70
  if self.residual and (self.d or in_width > out_width):
71
  self.width_proj = nn.Conv2d(in_width, out_width, 1, 1)
72
 
73
+ def forward(self, x: Tensor) -> Tensor:
74
  out = self.conv(x)
75
  if self.residual:
76
  if x.shape[1] != out.shape[1]:
 
85
 
86
 
87
  class Encoder(nn.Module):
88
+ def __init__(self, args: Hparams):
89
  super().__init__()
90
  # parse architecture
91
  stages = []
 
97
  if i == 0: # define network stem
98
  if n_blocks == 0 and "d" not in stage:
99
  print("Using stride=2 conv encoder stem.")
100
+ stem_width, stem_stride = args.widths[1], 2
 
 
 
 
 
 
101
  continue
102
  else:
103
+ stem_width, stem_stride = args.widths[0], 1
104
+ self.stem = nn.Conv2d(
105
+ args.input_channels,
106
+ stem_width,
107
+ kernel_size=7,
108
+ stride=stem_stride,
109
+ padding=3,
110
+ )
111
  stages += [(args.widths[i], None) for _ in range(n_blocks)]
112
  if "d" in stage: # downsampling block
113
  stages += [(args.widths[i + 1], int(stage[stage.index("d") + 1]))]
 
118
  blocks.append(
119
  Block(prev_width, bottleneck, width, down_rate=d, version=args.vr)
120
  )
 
121
  for b in blocks:
122
  b.conv[-1].weight.data *= np.sqrt(1 / len(blocks))
123
  self.blocks = nn.ModuleList(blocks)
124
 
125
+ def forward(self, x: Tensor) -> Dict[int, Tensor]:
126
  x = self.stem(x)
127
  acts = {}
128
  for block in self.blocks:
 
135
 
136
 
137
  class DecoderBlock(nn.Module):
138
+ def __init__(self, args: Hparams, in_width: int, out_width: int, resolution: int):
139
  super().__init__()
140
  bottleneck = int(in_width / args.bottleneck)
141
  self.res = resolution
142
  self.stochastic = self.res <= args.z_max_res
143
  self.z_dim = args.z_dim
144
  self.cond_prior = args.cond_prior
145
+ self.q_correction = args.q_correction
146
  k = 3 if self.res > 2 else 1
147
 
 
 
 
 
 
 
 
148
  self.prior = Block(
149
+ (in_width + args.context_dim if self.cond_prior else in_width),
150
  bottleneck,
151
  2 * self.z_dim + in_width,
152
  kernel_size=k,
 
163
  version=args.vr,
164
  )
165
  self.z_proj = nn.Conv2d(self.z_dim + args.context_dim, in_width, 1)
166
+ if not self.q_correction: # for no posterior correction
167
+ self.z_feat_proj = nn.Conv2d(self.z_dim + in_width, out_width, 1)
168
  self.conv = Block(
169
  in_width, bottleneck, out_width, kernel_size=k, version=args.vr
170
  )
171
 
172
+ def forward_prior(
173
+ self, z: Tensor, pa: Optional[Tensor] = None, t: Optional[float] = None
174
+ ) -> Tuple[Tensor, Tensor, Tensor]:
175
+ #print('Prior')
176
+ #print('z')
177
+ #print(z.shape)
178
+ #print('pa')
179
+ #print(pa.shape)
180
+
181
  if self.cond_prior:
182
  z = torch.cat([z, pa], dim=1)
183
  z = self.prior(z)
 
188
  p_logscale = p_logscale + torch.tensor(t).to(z.device).log()
189
  return p_loc, p_logscale, p_features
190
 
191
+ def forward_posterior(
192
+ self, z: Tensor, x: Tensor, pa: Tensor, t: Optional[float] = None
193
+ ) -> Tuple[Tensor, Tensor]:
194
+ #print('Posterior')
195
+ #print('z')
196
+ #print(z.shape)
197
+ #print('x')
198
+ #print(x.shape)
199
+ #print('pa')
200
+ #print(pa.shape)
201
  h = torch.cat([z, pa, x], dim=1)
202
+ #print('h shape: ', h.shape)
203
  q_loc, q_logscale = self.posterior(h).chunk(2, dim=1)
204
  if t is not None:
205
  q_logscale = q_logscale + torch.tensor(t).to(z.device).log()
 
207
 
208
 
209
  class Decoder(nn.Module):
210
+ def __init__(self, args: Hparams):
211
  super().__init__()
212
  # parse architecture
213
  stages = []
 
231
  )
232
  self.bias = nn.ParameterList(bias)
233
  self.cond_prior = args.cond_prior
234
+ self.is_drop_cond = True if "morphomnist" in args.hps else False # hacky
235
 
236
+ def forward(
237
+ self,
238
+ parents: Tensor,
239
+ x: Optional[Dict[int, Tensor]] = None,
240
+ t: Optional[float] = None,
241
+ abduct: bool = False,
242
+ latents: List[Tensor] = [],
243
+ ) -> Tuple[Tensor, List[Dict[str, Tensor]]]:
244
  # learnt params for each resolution r
245
  bias = {r.shape[2]: r for r in self.bias}
246
+ h = z = bias[1].repeat(parents.shape[0], 1, 1, 1) # initial state
247
+ # conditioning dropout: stochastic path (p_sto), deterministic path (p_det)
248
+ p_sto, p_det = (
249
+ self.drop_cond() if (self.training and self.cond_prior) else (1, 1)
250
+ )
251
 
252
  stats = []
253
  for i, block in enumerate(self.blocks):
254
  res = block.res # current block resolution, e.g. 64x64
255
  pa = parents[..., :res, :res].clone() # select parents @ res
256
 
257
+ # for morphomnist w/ conditioning dropout of y only, clean up later
258
+ if self.is_drop_cond:
259
+ pa_sto, pa_det = pa.clone(), pa.clone()
260
+ pa_sto[:, 2:, ...] = pa_sto[:, 2:, ...] * p_sto
261
+ pa_det[:, 2:, ...] = pa_det[:, 2:, ...] * p_det
262
+ else: # disabled otherwise
263
+ pa_sto = pa_det = pa
 
 
264
 
265
  if h.size(-1) < res: # upsample previous layer output
266
  b = bias[res] if res in bias.keys() else 0 # broadcasting
267
  h = b + F.interpolate(h, scale_factor=res / h.shape[-1])
268
 
269
+ if block.q_correction:
270
+ p_input = h # current prior depends on previous posterior
271
+ else: # current prior depends on previous prior only, upsample previous prior latent z
272
+ p_input = (
273
+ b + F.interpolate(z, scale_factor=res / z.shape[-1])
274
+ if z.size(-1) < res
275
+ else z
276
+ )
277
+ p_loc, p_logscale, p_feat = block.forward_prior(p_input, pa_sto, t=t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  if block.stochastic:
280
+ if x is not None: # z_i ~ q(z_i | z_<i, x, pa_x)
281
+ # print(res)
282
+ q_loc, q_logscale = block.forward_posterior(h, x[res], pa, t=t)
283
  z = sample_gaussian(q_loc, q_logscale)
284
  stat = dict(kl=gaussian_kl(q_loc, q_logscale, p_loc, p_logscale))
285
+ if abduct: # abduct exogenous noise
 
286
  if block.cond_prior: # z* if conditional prior
287
  stat.update(
288
  dict(
 
290
  )
291
  )
292
  else: # z if exogenous prior
293
+ stat.update(dict(z=z)) # .detach() z if not cf training
 
294
  stats.append(stat)
295
  else:
296
+ try: # forward abducted latents
297
+ z = latents[i]
298
+ z = sample_gaussian(p_loc, p_logscale) if z is None else z
299
+ except: # sample prior
300
  z = sample_gaussian(p_loc, p_logscale)
 
301
  if abduct and block.cond_prior: # for abducting z*
302
  stats.append(
303
  dict(z={"p_loc": p_loc, "p_logscale": p_logscale})
304
  )
305
+ else: # deterministic block
306
+ z = p_loc
 
 
 
 
 
 
 
 
 
 
 
307
  h = h + p_feat # merge prior features
308
+ # h_i = h_<i + f(z_i, pa_x)
309
+ h = h + block.z_proj(torch.cat([z, pa], dim=1))
310
+ h = block.conv(h)
311
+
312
+ if not block.q_correction:
313
+ if (i + 1) < len(self.blocks):
314
+ # z independent of pa_x for next layer prior
315
+ z = block.z_feat_proj(torch.cat([z, p_feat], dim=1))
316
  return h, stats
317
 
318
+ def _scale_weights(self):
319
+ scale = np.sqrt(1 / len(self.blocks))
320
+ for b in self.blocks:
321
+ b.z_proj.weight.data *= scale
322
+ b.conv.conv[-1].weight.data *= scale
323
+ b.prior.conv[-1].weight.data *= 0.0
324
 
325
+ @torch.no_grad()
326
+ def drop_cond(self) -> Tuple[int, int]:
327
  opt = dist.Categorical(1 / 3 * torch.ones(3)).sample()
328
  if opt == 0: # drop stochastic path
329
+ p_sto, p_det = 0, 1
330
  elif opt == 1: # drop deterministic path
331
+ p_sto, p_det = 1, 0
332
  elif opt == 2: # keep both
333
+ p_sto, p_det = 1, 1
334
+ return p_sto, p_det
335
 
336
 
337
  class DGaussNet(nn.Module):
338
+ def __init__(self, args: Hparams):
339
  super(DGaussNet, self).__init__()
340
  self.x_loc = nn.Conv2d(
341
  args.widths[0], args.input_channels, kernel_size=1, stride=1
 
364
  else:
365
  NotImplementedError(f"{args.x_like} not implemented.")
366
 
367
+ def forward(
368
+ self, h: Tensor, x: Optional[Tensor] = None, t: Optional[float] = None
369
+ ) -> Tuple[Tensor, Tensor]:
370
  loc, logscale = self.x_loc(h), self.x_logscale(h).clamp(min=EPS)
371
 
372
  # for RGB inputs
373
+ if hasattr(self, "channel_coeffs"):
374
+ coeff = torch.tanh(self.channel_coeffs(h))
375
+ if x is None: # inference
376
+ # loc = loc + logscale.exp() * torch.randn_like(loc) # random sampling
377
+ f = lambda x: torch.clamp(x, min=-1, max=1)
378
+ loc_red = f(loc[:, 0, ...])
379
+ loc_green = f(loc[:, 1, ...] + coeff[:, 0, ...] * loc_red)
380
+ loc_blue = f(
381
+ loc[:, 2, ...]
382
+ + coeff[:, 1, ...] * loc_red
383
+ + coeff[:, 2, ...] * loc_green
384
+ )
385
+ else: # training
386
+ loc_red = loc[:, 0, ...]
387
+ loc_green = loc[:, 1, ...] + coeff[:, 0, ...] * x[:, 0, ...]
388
+ loc_blue = (
389
+ loc[:, 2, ...]
390
+ + coeff[:, 1, ...] * x[:, 0, ...]
391
+ + coeff[:, 2, ...] * x[:, 1, ...]
392
+ )
393
+
394
+ loc = torch.cat(
395
+ [loc_red.unsqueeze(1), loc_green.unsqueeze(1), loc_blue.unsqueeze(1)],
396
+ dim=1,
397
+ )
398
 
399
  if t is not None:
400
  logscale = logscale + torch.tensor(t).to(h.device).log()
401
  return loc, logscale
402
 
403
+ def approx_cdf(self, x: Tensor) -> Tensor:
404
  return 0.5 * (
405
  1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))
406
  )
407
 
408
+ def nll(self, h: Tensor, x: Tensor) -> Tensor:
409
  loc, logscale = self.forward(h, x)
410
  centered_x = x - loc
411
  inv_stdv = torch.exp(-logscale)
 
425
  )
426
  return -1.0 * log_probs.mean(dim=(1, 2, 3))
427
 
428
+ def sample(
429
+ self, h: Tensor, return_loc: bool = True, t: Optional[float] = None
430
+ ) -> Tuple[Tensor, Tensor]:
431
  if return_loc:
432
  x, logscale = self.forward(h)
433
  else:
 
438
 
439
 
440
  class HVAE(nn.Module):
441
+ def __init__(self, args: Hparams):
442
  super().__init__()
443
  args.vr = "light" if "ukbb" in args.hps else None # hacky
444
  self.encoder = Encoder(args)
 
449
  NotImplementedError(f"{args.x_like} not implemented.")
450
  self.cond_prior = args.cond_prior
451
  self.free_bits = args.kl_free_bits
452
+ self.register_buffer("log2", torch.tensor(2.0).log())
453
 
454
+ def forward(self, x: Tensor, parents: Tensor, beta: int = 1) -> Dict[str, Tensor]:
455
+ #print(f'Encoder Input:')
456
+ #print(type(x))
457
+ #print(x.shape)
458
  acts = self.encoder(x)
459
+ #print(type(acts))
460
+ #for key, i in acts.items():
461
+ #print(f'Encoder output key: {key}')
462
+ #print(type(i))
463
+ #print(i.shape)
464
+
465
+ #print('Parents')
466
+ #print(parents.shape)
467
  h, stats = self.decoder(parents=parents, x=acts)
468
+ #print('Decoder output shape: ', h.shape)
469
+ #print('Stats: ')
470
+ #for stat in stats:
471
+ #for key, i in stat.items():
472
+ #print(f'Key: {key}')
473
+ #print(type(i))
474
+ #print(i.shape)
475
+
476
  nll_pp = self.likelihood.nll(h, x)
477
  if self.free_bits > 0:
478
  free_bits = torch.tensor(self.free_bits).type_as(nll_pp)
 
483
  ).sum()
484
  else:
485
  kl_pp = torch.zeros_like(nll_pp)
486
+ for _, stat in enumerate(stats):
487
  kl_pp += stat["kl"].sum(dim=(1, 2, 3))
488
  kl_pp = kl_pp / np.prod(x.shape[1:]) # per pixel
489
+ kl_pp = kl_pp.mean() # / self.log2
490
+ nll_pp = nll_pp.mean() # / self.log2
491
+ nelbo = nll_pp + beta * kl_pp # negative elbo (free energy)
492
+ return dict(elbo=nelbo, nll=nll_pp, kl=kl_pp)
493
+
494
+ def sample(
495
+ self, parents: Tensor, return_loc: bool = True, t: Optional[float] = None
496
+ ) -> Tuple[Tensor, Tensor]:
497
  h, _ = self.decoder(parents=parents, t=t)
498
  return self.likelihood.sample(h, return_loc, t=t)
499
 
500
+ def abduct(
501
+ self,
502
+ x: Tensor,
503
+ parents: Tensor,
504
+ cf_parents: Optional[Tensor] = None,
505
+ alpha: float = 0.5,
506
+ t: Optional[float] = None,
507
+ ) -> List[Tensor]:
508
  acts = self.encoder(x)
509
  _, q_stats = self.decoder(
510
  x=acts, parents=parents, abduct=True, t=t
 
531
  # Option1: mixture distribution: r(z_i | z_{<i}, x, pa, pa*)
532
  # = a*q(z_i | z_{<i}, x, pa) + (1-a)*p(z_i | z_{<i}, pa*)
533
  r_loc = alpha * q_loc + (1 - alpha) * p_loc
534
+ r_var = (
535
+ alpha**2 * q_scale.pow(2) + (1 - alpha)**2 * p_var
536
+ ) # assumes independence
537
  # r_var = a*(q_loc.pow(2) + q_var) + (1-a)*(p_loc.pow(2) + p_var) - r_loc.pow(2)
538
 
539
  # # Option 2: precision weighted distribution
 
551
  else:
552
  return q_stats # zs
553
 
554
+ def forward_latents(
555
+ self, latents: List[Tensor], parents: Tensor, t: Optional[float] = None
556
+ ) -> Tuple[Tensor, Tensor]:
557
  h, _ = self.decoder(latents=latents, parents=parents, t=t)
558
  return self.likelihood.sample(h, t=t)