|
import pyro |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from typing import Dict |
|
from pyro.distributions.conditional import ( |
|
ConditionalTransformModule, |
|
ConditionalTransformedDistribution, |
|
TransformedDistribution, |
|
) |
|
from pyro.distributions.torch_distribution import TorchDistributionMixin |
|
|
|
from torch.distributions import constraints |
|
from torch.distributions.utils import _sum_rightmost |
|
from torch.distributions.transforms import Transform |
|
|
|
|
|
class ConditionalAffineTransform(ConditionalTransformModule): |
|
def __init__(self, context_nn, event_dim=0, **kwargs): |
|
super().__init__(**kwargs) |
|
self.event_dim = event_dim |
|
self.context_nn = context_nn |
|
|
|
def condition(self, context): |
|
loc, log_scale = self.context_nn(context) |
|
return torch.distributions.transforms.AffineTransform( |
|
loc, log_scale.exp(), event_dim=self.event_dim |
|
) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, num_inputs=1, width=32, num_outputs=1): |
|
super().__init__() |
|
activation = nn.LeakyReLU() |
|
self.mlp = nn.Sequential( |
|
nn.Linear(num_inputs, width, bias=False), |
|
nn.BatchNorm1d(width), |
|
activation, |
|
nn.Linear(width, width, bias=False), |
|
nn.BatchNorm1d(width), |
|
activation, |
|
nn.Linear(width, num_outputs), |
|
) |
|
|
|
def forward(self, x): |
|
return self.mlp(x) |
|
|
|
|
|
class CNN(nn.Module): |
|
def __init__(self, in_shape=(1, 192, 192), width=16, num_outputs=1, context_dim=0): |
|
super().__init__() |
|
in_channels = in_shape[0] |
|
res = in_shape[1] |
|
s = 2 if res > 64 else 1 |
|
activation = nn.LeakyReLU() |
|
self.cnn = nn.Sequential( |
|
nn.Conv2d(in_channels, width, 7, s, 3, bias=False), |
|
nn.BatchNorm2d(width), |
|
activation, |
|
(nn.MaxPool2d(2, 2) if res > 32 else nn.Identity()), |
|
nn.Conv2d(width, 2 * width, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(2 * width), |
|
activation, |
|
nn.Conv2d(2 * width, 2 * width, 3, 1, 1, bias=False), |
|
nn.BatchNorm2d(2 * width), |
|
activation, |
|
nn.Conv2d(2 * width, 4 * width, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(4 * width), |
|
activation, |
|
nn.Conv2d(4 * width, 4 * width, 3, 1, 1, bias=False), |
|
nn.BatchNorm2d(4 * width), |
|
activation, |
|
nn.Conv2d(4 * width, 8 * width, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(8 * width), |
|
activation, |
|
) |
|
self.fc = nn.Sequential( |
|
nn.Linear(8 * width + context_dim, 8 * width, bias=False), |
|
nn.BatchNorm1d(8 * width), |
|
activation, |
|
nn.Linear(8 * width, num_outputs), |
|
) |
|
|
|
def forward(self, x, y=None): |
|
x = self.cnn(x) |
|
x = x.mean(dim=(-2, -1)) |
|
if y is not None: |
|
x = torch.cat([x, y], dim=-1) |
|
return self.fc(x) |
|
|
|
|
|
class ArgMaxGumbelMax(Transform): |
|
r"""ArgMax as Transform, but inv conditioned on logits""" |
|
|
|
def __init__(self, logits, event_dim=0, cache_size=0): |
|
super(ArgMaxGumbelMax, self).__init__(cache_size=cache_size) |
|
self.logits = logits |
|
self._event_dim = event_dim |
|
self._categorical = pyro.distributions.torch.Categorical( |
|
logits=self.logits |
|
).to_event(0) |
|
|
|
@property |
|
def event_dim(self): |
|
return self._event_dim |
|
|
|
def __call__(self, gumbels): |
|
""" |
|
Computes the forward transform |
|
""" |
|
assert self.logits != None, "Logits not defined." |
|
|
|
if self._cache_size == 0: |
|
return self._call(gumbels) |
|
|
|
y = self._call(gumbels) |
|
return y |
|
|
|
def _call(self, gumbels): |
|
""" |
|
Abstract method to compute forward transformation. |
|
""" |
|
assert self.logits != None, "Logits not defined." |
|
y = gumbels + self.logits |
|
|
|
|
|
return y.argmax(-1, keepdim=True) |
|
|
|
@property |
|
def domain(self): |
|
""" " |
|
Domain of input(gumbel variables), Real |
|
""" |
|
if self.event_dim == 0: |
|
return constraints.real |
|
return constraints.independent(constraints.real, self.event_dim) |
|
|
|
@property |
|
def codomain(self): |
|
""" " |
|
Domain of output(categorical variables), should be natural numbers, but set to Real for now |
|
""" |
|
if self.event_dim == 0: |
|
return constraints.real |
|
return constraints.independent(constraints.real, self.event_dim) |
|
|
|
def inv(self, k): |
|
"""Infer the gumbels noises given k and logits.""" |
|
assert self.logits != None, "Logits not defined." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_gumbel(shape, eps=1e-20): |
|
U = torch.rand(shape) |
|
U = U.to(torch.device('cpu')) |
|
|
|
return -torch.log(-torch.log(U + eps) + eps) |
|
def gumbel_softmax_sample(logits, temperature): |
|
y = logits + sample_gumbel(logits.shape) |
|
return F.softmax(y / temperature, dim=-1) |
|
def gumbel_softmax(logits, temperature,k, hard=False): |
|
""" |
|
ST-gumple-softmax |
|
input: [*, n_class] |
|
return: flatten --> [*, n_class] an one-hot vector |
|
""" |
|
y = gumbel_softmax_sample(logits, temperature) |
|
|
|
if not hard: |
|
return y.view(-1, logits.shape[-1]) |
|
|
|
shape = y.size() |
|
_, ind = k.max(dim=-1) |
|
y_hard = torch.zeros_like(y).view(-1, shape[-1]) |
|
y_hard.scatter_(1, ind.view(-1, 1), 1) |
|
y_hard = y_hard.view(*shape) |
|
|
|
y_hard = (y_hard - y).detach() + y |
|
return y_hard.view(-1, logits.shape[-1]) |
|
epsilons = gumbel_softmax(self.logits,1e-3,k) |
|
|
|
return epsilons |
|
|
|
def log_abs_det_jacobian(self, x, y): |
|
"""We use the log_abs_det_jacobian to account for the categorical prob |
|
x: Gumbels; y: argmax(x+logits) |
|
P(y) = softmax |
|
""" |
|
|
|
|
|
|
|
return -self._categorical.log_prob(y.squeeze(-1)).unsqueeze(-1) |
|
|
|
|
|
class ConditionalGumbelMax(ConditionalTransformModule): |
|
r"""Given gumbels+logits, output the OneHot Categorical""" |
|
|
|
def __init__(self, context_nn, event_dim=0, **kwargs): |
|
|
|
super().__init__(**kwargs) |
|
self.context_nn = context_nn |
|
self.event_dim = event_dim |
|
|
|
def condition(self, context): |
|
"""Given context (age), output the Categorical results""" |
|
logits = self.context_nn( |
|
context |
|
) |
|
return ArgMaxGumbelMax(logits) |
|
|
|
def _logits(self, context): |
|
"""Return logits given context""" |
|
return self.context_nn(context) |
|
|
|
@property |
|
def domain(self): |
|
""" " |
|
Domain of input(gumbel variables), Real |
|
""" |
|
if self.event_dim == 0: |
|
return constraints.real |
|
return constraints.independent(constraints.real, self.event_dim) |
|
|
|
@property |
|
def codomain(self): |
|
""" " |
|
Domain of output(categorical variables), should be natural numbers, but set to Real for now |
|
""" |
|
if self.event_dim == 0: |
|
return constraints.real |
|
return constraints.independent(constraints.real, self.event_dim) |
|
|
|
|
|
class TransformedDistributionGumbelMax(TransformedDistribution, TorchDistributionMixin): |
|
r"""Define a TransformedDistribution class for Gumbel max""" |
|
arg_constraints: Dict[str, constraints.Constraint] = {} |
|
|
|
def log_prob(self, value): |
|
""" |
|
We do not use the log_prob() of the base Gumbel distribution, because the likelihood for |
|
each class for Gumbel Max sampling is determined by the logits. |
|
""" |
|
|
|
if self._validate_args: |
|
self._validate_sample(value) |
|
event_dim = len(self.event_shape) |
|
log_prob = 0.0 |
|
y = value |
|
for transform in reversed(self.transforms): |
|
x = transform.inv(y) |
|
event_dim += transform.domain.event_dim - transform.codomain.event_dim |
|
log_prob = log_prob - _sum_rightmost( |
|
transform.log_abs_det_jacobian(x, y), |
|
event_dim - transform.domain.event_dim, |
|
) |
|
y = x |
|
|
|
return log_prob |
|
|
|
|
|
class ConditionalTransformedDistributionGumbelMax(ConditionalTransformedDistribution): |
|
def condition(self, context): |
|
base_dist = self.base_dist.condition(context) |
|
transforms = [t.condition(context) for t in self.transforms] |
|
|
|
return TransformedDistributionGumbelMax(base_dist, transforms) |
|
|
|
def clear_cache(self): |
|
pass |
|
|