import pyro import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict from pyro.distributions.conditional import ( ConditionalTransformModule, ConditionalTransformedDistribution, TransformedDistribution, ) from pyro.distributions.torch_distribution import TorchDistributionMixin from torch.distributions import constraints from torch.distributions.utils import _sum_rightmost from torch.distributions.transforms import Transform class ConditionalAffineTransform(ConditionalTransformModule): def __init__(self, context_nn, event_dim=0, **kwargs): super().__init__(**kwargs) self.event_dim = event_dim self.context_nn = context_nn def condition(self, context): loc, log_scale = self.context_nn(context) return torch.distributions.transforms.AffineTransform( loc, log_scale.exp(), event_dim=self.event_dim ) class MLP(nn.Module): def __init__(self, num_inputs=1, width=32, num_outputs=1): super().__init__() activation = nn.LeakyReLU() self.mlp = nn.Sequential( nn.Linear(num_inputs, width, bias=False), nn.BatchNorm1d(width), activation, nn.Linear(width, width, bias=False), nn.BatchNorm1d(width), activation, nn.Linear(width, num_outputs), ) def forward(self, x): return self.mlp(x) class CNN(nn.Module): def __init__(self, in_shape=(1, 192, 192), width=16, num_outputs=1, context_dim=0): super().__init__() in_channels = in_shape[0] res = in_shape[1] s = 2 if res > 64 else 1 activation = nn.LeakyReLU() self.cnn = nn.Sequential( nn.Conv2d(in_channels, width, 7, s, 3, bias=False), nn.BatchNorm2d(width), activation, (nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()), nn.Conv2d(width, 2 * width, 3, 2, 1, bias=False), nn.BatchNorm2d(2 * width), activation, nn.Conv2d(2 * width, 2 * width, 3, 1, 1, bias=False), nn.BatchNorm2d(2 * width), activation, nn.Conv2d(2 * width, 4 * width, 3, 2, 1, bias=False), nn.BatchNorm2d(4 * width), activation, nn.Conv2d(4 * width, 4 * width, 3, 1, 1, bias=False), nn.BatchNorm2d(4 * width), activation, nn.Conv2d(4 * width, 8 * width, 3, 2, 1, bias=False), nn.BatchNorm2d(8 * width), activation, ) self.fc = nn.Sequential( nn.Linear(8 * width + context_dim, 8 * width, bias=False), nn.BatchNorm1d(8 * width), activation, nn.Linear(8 * width, num_outputs), ) def forward(self, x, y=None): x = self.cnn(x) x = x.mean(dim=(-2, -1)) # avg pool if y is not None: x = torch.cat([x, y], dim=-1) return self.fc(x) class ArgMaxGumbelMax(Transform): r"""ArgMax as Transform, but inv conditioned on logits""" def __init__(self, logits, event_dim=0, cache_size=0): super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size) self.logits = logits self._event_dim = event_dim self._categorical = pyro.distributions.torch.Categorical( logits=self.logits ).to_event(0) @property def event_dim(self): return self._event_dim def __call__(self, gumbels): """ Computes the forward transform """ assert self.logits != None, "Logits not defined." if self._cache_size == 0: return self._call(gumbels) y = self._call(gumbels) return y def _call(self, gumbels): """ Abstract method to compute forward transformation. """ assert self.logits != None, "Logits not defined." y = gumbels + self.logits # print(f'y: {y}') # print(f'logits: {self.logits}') return y.argmax(-1, keepdim=True) @property def domain(self): """ " Domain of input(gumbel variables), Real """ if self.event_dim == 0: return constraints.real return constraints.independent(constraints.real, self.event_dim) @property def codomain(self): """ " Domain of output(categorical variables), should be natural numbers, but set to Real for now """ if self.event_dim == 0: return constraints.real return constraints.independent(constraints.real, self.event_dim) def inv(self, k): """Infer the gumbels noises given k and logits.""" assert self.logits != None, "Logits not defined." # uniforms = torch.rand( # self.logits.shape, dtype=self.logits.dtype, device=self.logits.device # ) # gumbels = -((-(uniforms.log())).log()) # # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}') # # (batch_size, num_classes) mask to select kth class # # print(f'k : {k.size()}') # mask = F.one_hot( # k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1] # ) # # print(f'mask: {mask.size()}, {mask.dtype}') # # (batch_size, 1) select topgumbel for truncation of other classes # topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - ( # mask * self.logits # ).sum(dim=-1, keepdim=True) # mask = 1 - mask # invert mask to select other != k classes # g = gumbels + self.logits # # (batch_size, num_classes) # epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - ( # mask * self.logits # ) def sample_gumbel(shape, eps=1e-20): U = torch.rand(shape) U = U.to(torch.device('cpu')) # U = U.to(torch.device('cuda:1')) return -torch.log(-torch.log(U + eps) + eps) def gumbel_softmax_sample(logits, temperature): y = logits + sample_gumbel(logits.shape) return F.softmax(y / temperature, dim=-1) def gumbel_softmax(logits, temperature,k, hard=False): """ ST-gumple-softmax input: [*, n_class] return: flatten --> [*, n_class] an one-hot vector """ y = gumbel_softmax_sample(logits, temperature) if not hard: return y.view(-1, logits.shape[-1]) shape = y.size() _, ind = k.max(dim=-1) y_hard = torch.zeros_like(y).view(-1, shape[-1]) y_hard.scatter_(1, ind.view(-1, 1), 1) y_hard = y_hard.view(*shape) # Set gradients w.r.t. y_hard gradients w.r.t. y y_hard = (y_hard - y).detach() + y return y_hard.view(-1, logits.shape[-1]) epsilons = gumbel_softmax(self.logits,1e-3,k) return epsilons def log_abs_det_jacobian(self, x, y): """We use the log_abs_det_jacobian to account for the categorical prob x: Gumbels; y: argmax(x+logits) P(y) = softmax """ # print(f"logits: {torch.log(F.softmax(self.logits, dim=-1)).size()}") # print(f'y: {y.size()} ') # print(f"log_abs_det_jacobian: {self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1).size()}") return -self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1) class ConditionalGumbelMax(ConditionalTransformModule): r"""Given gumbels+logits, output the OneHot Categorical""" def __init__(self, context_nn, event_dim=0, **kwargs): # The logits_nn which predict the logits given ages: super().__init__(**kwargs) self.context_nn = context_nn self.event_dim = event_dim def condition(self, context): """Given context (age), output the Categorical results""" logits = self.context_nn( context ) # The logits for calculating argmax(Gumbel + logits) return ArgMaxGumbelMax(logits) def _logits(self, context): """Return logits given context""" return self.context_nn(context) @property def domain(self): """ " Domain of input(gumbel variables), Real """ if self.event_dim == 0: return constraints.real return constraints.independent(constraints.real, self.event_dim) @property def codomain(self): """ " Domain of output(categorical variables), should be natural numbers, but set to Real for now """ if self.event_dim == 0: return constraints.real return constraints.independent(constraints.real, self.event_dim) class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin): r"""Define a TransformedDistribution class for Gumbel max""" arg_constraints: Dict[str, constraints.Constraint] = {} def log_prob(self, value): """ We do not use the log_prob() of the base Gumbel distribution, because the likelihood for each class for Gumbel Max sampling is determined by the logits. """ # print("This happens") if self._validate_args: self._validate_sample(value) event_dim = len(self.event_shape) log_prob = 0.0 y = value for transform in reversed(self.transforms): x = transform.inv(y) event_dim += transform.domain.event_dim - transform.codomain.event_dim log_prob = log_prob - _sum_rightmost( transform.log_abs_det_jacobian(x, y), event_dim - transform.domain.event_dim, ) y = x # print(f"log_prob: {log_prob.size()}") return log_prob class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution): def condition(self, context): base_dist = self.base_dist.condition(context) transforms = [t.condition(context) for t in self.transforms] # return TransformedDistribution(base_dist, transforms) return TransformedDistributionGumbelMax(base_dist, transforms) def clear_cache(self): pass