Upload flow_pgm.py
Browse files- pgm/flow_pgm.py +67 -21
pgm/flow_pgm.py
CHANGED
@@ -28,6 +28,23 @@ class Hparams:
|
|
28 |
setattr(self, k, v)
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
class BasePGM(nn.Module):
|
32 |
def __init__(self):
|
33 |
super().__init__()
|
@@ -461,13 +478,13 @@ class ChestPGM(BasePGM):
|
|
461 |
self.variables = {
|
462 |
"race": "categorical",
|
463 |
"sex": "binary",
|
464 |
-
"finding": "
|
465 |
"age": "continuous",
|
466 |
}
|
467 |
# Discrete variables that are not root nodes
|
468 |
-
self.discrete_variables = {"finding": "
|
469 |
# define base distributions
|
470 |
-
for k in ["a"
|
471 |
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
|
472 |
self.register_buffer(f"{k}_base_scale", torch.ones(1))
|
473 |
# age spline flow
|
@@ -481,26 +498,51 @@ class ChestPGM(BasePGM):
|
|
481 |
# self.age_constraints,
|
482 |
]
|
483 |
)
|
484 |
-
# Finding (conditional) via MLP, a ->
|
485 |
-
finding_net = DenseNN(1, [8, 16], param_dims=[
|
486 |
self.finding_transform_GumbelMax = ConditionalGumbelMax(
|
487 |
context_nn=finding_net, event_dim=0
|
488 |
)
|
489 |
# log space for sex and race
|
490 |
self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
|
491 |
self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
|
|
|
492 |
|
493 |
-
|
|
|
494 |
|
495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
# q(s | x) ~ Bernoulli(f(x))
|
497 |
-
self.encoder_s =
|
498 |
-
# q(r | x) ~ OneHotCategorical(
|
499 |
-
self.encoder_r =
|
500 |
# q(f | x) ~ Bernoulli(f(x))
|
501 |
-
self.encoder_f =
|
502 |
# q(a | x, f) ~ Normal(mu(x), sigma(x))
|
503 |
-
self.encoder_a =
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
def model(self) -> Dict[str, Tensor]:
|
506 |
pyro.module("ChestPGM", self)
|
@@ -521,11 +563,14 @@ class ChestPGM(BasePGM):
|
|
521 |
|
522 |
# p(f | a), finding as OneHotCategorical conditioned on age
|
523 |
# finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
526 |
finding_dist = ConditionalTransformedDistributionGumbelMax(
|
527 |
-
finding_dist_base, [self.finding_transform_GumbelMax]
|
528 |
-
|
529 |
finding = pyro.sample("finding", finding_dist)
|
530 |
|
531 |
return {
|
@@ -548,8 +593,8 @@ class ChestPGM(BasePGM):
|
|
548 |
pyro.sample("race", qr_x)
|
549 |
# q(f | x)
|
550 |
if obs["finding"] is None:
|
551 |
-
f_prob =
|
552 |
-
qf_x = dist.
|
553 |
obs["finding"] = pyro.sample("finding", qf_x)
|
554 |
# q(a | x, f)
|
555 |
if obs["age"] is None:
|
@@ -576,8 +621,9 @@ class ChestPGM(BasePGM):
|
|
576 |
pyro.sample("race_aux", qr_x, obs=obs["race"])
|
577 |
|
578 |
# q(f | x)
|
579 |
-
|
580 |
-
qf_x
|
|
|
581 |
pyro.sample("finding_aux", qf_x, obs=obs["finding"])
|
582 |
|
583 |
# q(a | x, f)
|
@@ -594,7 +640,7 @@ class ChestPGM(BasePGM):
|
|
594 |
# q(r | x)
|
595 |
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
|
596 |
# q(f | x)
|
597 |
-
f_prob =
|
598 |
# q(a | x, f)
|
599 |
a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
|
600 |
|
|
|
28 |
setattr(self, k, v)
|
29 |
|
30 |
|
31 |
+
def is_one_hot(tensor):
|
32 |
+
"""
|
33 |
+
Check if the given tensor is a valid one-hot tensor.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
tensor (torch.Tensor): A tensor to check.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
bool: True if tensor is one-hot, False otherwise.
|
40 |
+
"""
|
41 |
+
if tensor.ndim != 2:
|
42 |
+
return False
|
43 |
+
|
44 |
+
# Check if there is exactly one '1' in each row and all other elements are '0'
|
45 |
+
return torch.all((tensor.sum(dim=1) == 1) & (tensor.max(dim=1).values == 1) & (tensor.min(dim=1).values == 0))
|
46 |
+
|
47 |
+
|
48 |
class BasePGM(nn.Module):
|
49 |
def __init__(self):
|
50 |
super().__init__()
|
|
|
478 |
self.variables = {
|
479 |
"race": "categorical",
|
480 |
"sex": "binary",
|
481 |
+
"finding": "categorical",
|
482 |
"age": "continuous",
|
483 |
}
|
484 |
# Discrete variables that are not root nodes
|
485 |
+
self.discrete_variables = {"finding": "categorical"}
|
486 |
# define base distributions
|
487 |
+
for k in ["a", "f"]:
|
488 |
self.register_buffer(f"{k}_base_loc", torch.zeros(1))
|
489 |
self.register_buffer(f"{k}_base_scale", torch.ones(1))
|
490 |
# age spline flow
|
|
|
498 |
# self.age_constraints,
|
499 |
]
|
500 |
)
|
501 |
+
# Finding (conditional) via MLP, a -> fss
|
502 |
+
finding_net = DenseNN(1, [8, 16], param_dims=[3], nonlinearity=nn.Softmax())
|
503 |
self.finding_transform_GumbelMax = ConditionalGumbelMax(
|
504 |
context_nn=finding_net, event_dim=0
|
505 |
)
|
506 |
# log space for sex and race
|
507 |
self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1))
|
508 |
self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
|
509 |
+
self.finding_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3))
|
510 |
|
511 |
+
if args.setup != "sup_pgm":
|
512 |
+
from resnet import CustomBlock, ResNet, ResNet18
|
513 |
|
514 |
+
shared_model = ResNet(
|
515 |
+
CustomBlock,
|
516 |
+
layers=[2, 2, 2, 2],
|
517 |
+
widths=[64, 128, 256, 512],
|
518 |
+
norm_layer=lambda c: nn.GroupNorm(min(32, c // 4), c),
|
519 |
+
)
|
520 |
+
# shared_model = torchvision.models.resnet18(weights=None)
|
521 |
+
shared_model.conv1 = nn.Conv2d(
|
522 |
+
args.input_channels,
|
523 |
+
64,
|
524 |
+
kernel_size=7,
|
525 |
+
stride=2,
|
526 |
+
padding=3,
|
527 |
+
bias=False,
|
528 |
+
)
|
529 |
+
kwargs = {
|
530 |
+
"in_shape": (args.input_channels, *(args.input_res,) * 2),
|
531 |
+
"base_model": shared_model,
|
532 |
+
}
|
533 |
# q(s | x) ~ Bernoulli(f(x))
|
534 |
+
self.encoder_s = ResNet18(num_outputs=1, **kwargs)
|
535 |
+
# q(r | x) ~ OneHotCategorical(f(x))
|
536 |
+
self.encoder_r = ResNet18(num_outputs=3, **kwargs)
|
537 |
# q(f | x) ~ Bernoulli(f(x))
|
538 |
+
self.encoder_f = ResNet18(num_outputs=3, **kwargs)
|
539 |
# q(a | x, f) ~ Normal(mu(x), sigma(x))
|
540 |
+
self.encoder_a = ResNet18(num_outputs=2, context_dim=3, **kwargs)
|
541 |
+
self.f = (
|
542 |
+
lambda x: args.std_fixed * torch.ones_like(x)
|
543 |
+
if args.std_fixed > 0
|
544 |
+
else F.softplus(x)
|
545 |
+
)
|
546 |
|
547 |
def model(self) -> Dict[str, Tensor]:
|
548 |
pyro.module("ChestPGM", self)
|
|
|
563 |
|
564 |
# p(f | a), finding as OneHotCategorical conditioned on age
|
565 |
# finding_dist_base = dist.Gumbel(self.f_base_loc, self.f_base_scale).to_event(1)
|
566 |
+
# finding_dist = ConditionalTransformedDistributionGumbelMax(
|
567 |
+
# finding_dist_base, [self.finding_transform_GumbelMax]
|
568 |
+
# ).condition(age)
|
569 |
+
finding_dist_base = dist.OneHotCategorical(logits=self.finding_logits) #.to_event(1)
|
570 |
+
# finding_dist_base = dist.RelaxedOneHotCategoricalStraightThrough(temperature =1e-5,logits=self.finding_logits) #.to_event(1)
|
571 |
finding_dist = ConditionalTransformedDistributionGumbelMax(
|
572 |
+
finding_dist_base, [self.finding_transform_GumbelMax]).condition(age)
|
573 |
+
|
574 |
finding = pyro.sample("finding", finding_dist)
|
575 |
|
576 |
return {
|
|
|
593 |
pyro.sample("race", qr_x)
|
594 |
# q(f | x)
|
595 |
if obs["finding"] is None:
|
596 |
+
f_prob = F.softmax(self.encoder_f(obs["x"]),dim=-1)
|
597 |
+
qf_x = dist.OneHotCategorical(probs=f_prob).to_event(1)
|
598 |
obs["finding"] = pyro.sample("finding", qf_x)
|
599 |
# q(a | x, f)
|
600 |
if obs["age"] is None:
|
|
|
621 |
pyro.sample("race_aux", qr_x, obs=obs["race"])
|
622 |
|
623 |
# q(f | x)
|
624 |
+
f_probs = F.softmax(self.encoder_f(obs["x"]), dim=-1)
|
625 |
+
qf_x= dist.OneHotCategorical(probs=f_probs) # .to_event(1)
|
626 |
+
# with pyro.poutine.scale(scale=0.5):
|
627 |
pyro.sample("finding_aux", qf_x, obs=obs["finding"])
|
628 |
|
629 |
# q(a | x, f)
|
|
|
640 |
# q(r | x)
|
641 |
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1)
|
642 |
# q(f | x)
|
643 |
+
f_prob = F.softmax(self.encoder_f(obs["x"]),dim=-1)
|
644 |
# q(a | x, f)
|
645 |
a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1)
|
646 |
|