Yeefei commited on
Commit
9631b4b
·
verified ·
1 Parent(s): eaf4f4f

Upload flow_pgm.py

Browse files
Files changed (1) hide show
  1. pgm/flow_pgm.py +67 -21
pgm/flow_pgm.py CHANGED
@@ -28,6 +28,23 @@ class Hparams:
28
  setattr(self, k, v)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class BasePGM(nn.Module):
32
  def __init__(self):
33
  super().__init__()
@@ -461,13 +478,13 @@ class ChestPGM(BasePGM):
461
  self.variables = {
462
  "race": "categorical",
463
  "sex": "binary",
464
- "finding": "binary",
465
  "age": "continuous",
466
  }
467
  # Discrete variables that are not root nodes
468
- self.discrete_variables = {"finding": "binary"}
469
  # define base distributions
470
- for k in ["a"]: # , "f"]:
471
  self.register_buffer(f"{k}_base_loc", torch.zeros(1))
472
  self.register_buffer(f"{k}_base_scale", torch.ones(1))
473
  # age spline flow
@@ -481,26 +498,51 @@ class ChestPGM(BasePGM):
481
  # self.age_constraints,
482
  ]
483
  )
484
- # Finding (conditional) via MLP, a -> f
485
- finding_net = DenseNN(1, [8, 16], param_dims=[2], nonlinearity=nn.Sigmoid())
486
  self.finding_transform_GumbelMax = ConditionalGumbelMax(
487
  context_nn=finding_net, event_dim=0
488
  )
489
  # log space for sex and race
490
  self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
491
  self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
 
492
 
493
- input_shape = (args.input_channels, args.input_res, args.input_res)
 
494
 
495
- if args.enc_net == "cnn":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  # q(s | x) ~ Bernoulli(f(x))
497
- self.encoder_s = CNN(input_shape, num_outputs=1)
498
- # q(r | x) ~ OneHotCategorical(logits=f(x))
499
- self.encoder_r = CNN(input_shape, num_outputs=3)
500
  # q(f | x) ~ Bernoulli(f(x))
501
- self.encoder_f = CNN(input_shape, num_outputs=1)
502
  # q(a | x, f) ~ Normal(mu(x), sigma(x))
503
- self.encoder_a = CNN(input_shape, num_outputs=1, context_dim=1)
 
 
 
 
 
504
 
505
  def model(self) -> Dict[str, Tensor]:
506
  pyro.module("ChestPGM", self)
@@ -521,11 +563,14 @@ class ChestPGM(BasePGM):
521
 
522
  # p(f | a), finding as OneHotCategorical conditioned on age
523
  # finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
524
- finding_dist_base = dist.Gumbel(torch.zeros(1), torch.ones(1)).to_event(1)
525
-
 
 
 
526
  finding_dist = ConditionalTransformedDistributionGumbelMax(
527
- finding_dist_base, [self.finding_transform_GumbelMax]
528
- ).condition(age)
529
  finding = pyro.sample("finding", finding_dist)
530
 
531
  return {
@@ -548,8 +593,8 @@ class ChestPGM(BasePGM):
548
  pyro.sample("race", qr_x)
549
  # q(f | x)
550
  if obs["finding"] is None:
551
- f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
552
- qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
553
  obs["finding"] = pyro.sample("finding", qf_x)
554
  # q(a | x, f)
555
  if obs["age"] is None:
@@ -576,8 +621,9 @@ class ChestPGM(BasePGM):
576
  pyro.sample("race_aux", qr_x, obs=obs["race"])
577
 
578
  # q(f | x)
579
- f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
580
- qf_x = dist.Bernoulli(probs=f_prob).to_event(1)
 
581
  pyro.sample("finding_aux", qf_x, obs=obs["finding"])
582
 
583
  # q(a | x, f)
@@ -594,7 +640,7 @@ class ChestPGM(BasePGM):
594
  # q(r | x)
595
  r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
596
  # q(f | x)
597
- f_prob = torch.sigmoid(self.encoder_f(obs["x"]))
598
  # q(a | x, f)
599
  a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
600
 
 
28
  setattr(self, k, v)
29
 
30
 
31
+ def is_one_hot(tensor):
32
+ """
33
+ Check if the given tensor is a valid one-hot tensor.
34
+
35
+ Args:
36
+ tensor (torch.Tensor): A tensor to check.
37
+
38
+ Returns:
39
+ bool: True if tensor is one-hot, False otherwise.
40
+ """
41
+ if tensor.ndim != 2:
42
+ return False
43
+
44
+ # Check if there is exactly one '1' in each row and all other elements are '0'
45
+ return torch.all((tensor.sum(dim=1) == 1) & (tensor.max(dim=1).values == 1) & (tensor.min(dim=1).values == 0))
46
+
47
+
48
  class BasePGM(nn.Module):
49
  def __init__(self):
50
  super().__init__()
 
478
  self.variables = {
479
  "race": "categorical",
480
  "sex": "binary",
481
+ "finding": "categorical",
482
  "age": "continuous",
483
  }
484
  # Discrete variables that are not root nodes
485
+ self.discrete_variables = {"finding": "categorical"}
486
  # define base distributions
487
+ for k in ["a", "f"]:
488
  self.register_buffer(f"{k}_base_loc", torch.zeros(1))
489
  self.register_buffer(f"{k}_base_scale", torch.ones(1))
490
  # age spline flow
 
498
  # self.age_constraints,
499
  ]
500
  )
501
+ # Finding (conditional) via MLP, a -> fss
502
+ finding_net = DenseNN(1, [8, 16], param_dims=[3], nonlinearity=nn.Softmax())
503
  self.finding_transform_GumbelMax = ConditionalGumbelMax(
504
  context_nn=finding_net, event_dim=0
505
  )
506
  # log space for sex and race
507
  self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
508
  self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
509
+ self.finding_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
510
 
511
+ if args.setup != "sup_pgm":
512
+ from resnet import CustomBlock, ResNet, ResNet18
513
 
514
+ shared_model = ResNet(
515
+ CustomBlock,
516
+ layers=[2, 2, 2, 2],
517
+ widths=[64, 128, 256, 512],
518
+ norm_layer=lambda c: nn.GroupNorm(min(32, c // 4), c),
519
+ )
520
+ # shared_model = torchvision.models.resnet18(weights=None)
521
+ shared_model.conv1 = nn.Conv2d(
522
+ args.input_channels,
523
+ 64,
524
+ kernel_size=7,
525
+ stride=2,
526
+ padding=3,
527
+ bias=False,
528
+ )
529
+ kwargs = {
530
+ "in_shape": (args.input_channels, *(args.input_res,) * 2),
531
+ "base_model": shared_model,
532
+ }
533
  # q(s | x) ~ Bernoulli(f(x))
534
+ self.encoder_s = ResNet18(num_outputs=1, **kwargs)
535
+ # q(r | x) ~ OneHotCategorical(f(x))
536
+ self.encoder_r = ResNet18(num_outputs=3, **kwargs)
537
  # q(f | x) ~ Bernoulli(f(x))
538
+ self.encoder_f = ResNet18(num_outputs=3, **kwargs)
539
  # q(a | x, f) ~ Normal(mu(x), sigma(x))
540
+ self.encoder_a = ResNet18(num_outputs=2, context_dim=3, **kwargs)
541
+ self.f = (
542
+ lambda x: args.std_fixed * torch.ones_like(x)
543
+ if args.std_fixed > 0
544
+ else F.softplus(x)
545
+ )
546
 
547
  def model(self) -> Dict[str, Tensor]:
548
  pyro.module("ChestPGM", self)
 
563
 
564
  # p(f | a), finding as OneHotCategorical conditioned on age
565
  # finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
566
+ # finding_dist = ConditionalTransformedDistributionGumbelMax(
567
+ # finding_dist_base, [self.finding_transform_GumbelMax]
568
+ # ).condition(age)
569
+ finding_dist_base = dist.OneHotCategorical(logits=self.finding_logits) #.to_event(1)
570
+ # finding_dist_base = dist.RelaxedOneHotCategoricalStraightThrough(temperature =1e-5,logits=self.finding_logits) #.to_event(1)
571
  finding_dist = ConditionalTransformedDistributionGumbelMax(
572
+ finding_dist_base, [self.finding_transform_GumbelMax]).condition(age)
573
+
574
  finding = pyro.sample("finding", finding_dist)
575
 
576
  return {
 
593
  pyro.sample("race", qr_x)
594
  # q(f | x)
595
  if obs["finding"] is None:
596
+ f_prob = F.softmax(self.encoder_f(obs["x"]),dim=-1)
597
+ qf_x = dist.OneHotCategorical(probs=f_prob).to_event(1)
598
  obs["finding"] = pyro.sample("finding", qf_x)
599
  # q(a | x, f)
600
  if obs["age"] is None:
 
621
  pyro.sample("race_aux", qr_x, obs=obs["race"])
622
 
623
  # q(f | x)
624
+ f_probs = F.softmax(self.encoder_f(obs["x"]), dim=-1)
625
+ qf_x= dist.OneHotCategorical(probs=f_probs) # .to_event(1)
626
+ # with pyro.poutine.scale(scale=0.5):
627
  pyro.sample("finding_aux", qf_x, obs=obs["finding"])
628
 
629
  # q(a | x, f)
 
640
  # q(r | x)
641
  r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
642
  # q(f | x)
643
+ f_prob = F.softmax(self.encoder_f(obs["x"]),dim=-1)
644
  # q(a | x, f)
645
  a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
646