Upload layers.py
Browse files- pgm/layers.py +51 -21
pgm/layers.py
CHANGED
@@ -149,27 +149,57 @@ class ArgMaxGumbelMax(Transform):
|
|
149 |
"""Infer the gumbels noises given k and logits."""
|
150 |
assert self.logits != None, "Logits not defined."
|
151 |
|
152 |
-
uniforms = torch.rand(
|
153 |
-
|
154 |
-
)
|
155 |
-
gumbels = -((-(uniforms.log())).log())
|
156 |
-
# print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
|
157 |
-
# (batch_size, num_classes) mask to select kth class
|
158 |
-
# print(f'k : {k.size()}')
|
159 |
-
mask = F.one_hot(
|
160 |
-
|
161 |
-
)
|
162 |
-
# print(f'mask: {mask.size()}, {mask.dtype}')
|
163 |
-
# (batch_size, 1) select topgumbel for truncation of other classes
|
164 |
-
topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
|
165 |
-
|
166 |
-
).sum(dim=-1, keepdim=True)
|
167 |
-
mask = 1 - mask # invert mask to select other != k classes
|
168 |
-
g = gumbels + self.logits
|
169 |
-
# (batch_size, num_classes)
|
170 |
-
epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
|
171 |
-
|
172 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
return epsilons
|
174 |
|
175 |
def log_abs_det_jacobian(self, x, y):
|
|
|
149 |
"""Infer the gumbels noises given k and logits."""
|
150 |
assert self.logits != None, "Logits not defined."
|
151 |
|
152 |
+
# uniforms = torch.rand(
|
153 |
+
# self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
|
154 |
+
# )
|
155 |
+
# gumbels = -((-(uniforms.log())).log())
|
156 |
+
# # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
|
157 |
+
# # (batch_size, num_classes) mask to select kth class
|
158 |
+
# # print(f'k : {k.size()}')
|
159 |
+
# mask = F.one_hot(
|
160 |
+
# k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
|
161 |
+
# )
|
162 |
+
# # print(f'mask: {mask.size()}, {mask.dtype}')
|
163 |
+
# # (batch_size, 1) select topgumbel for truncation of other classes
|
164 |
+
# topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
|
165 |
+
# mask * self.logits
|
166 |
+
# ).sum(dim=-1, keepdim=True)
|
167 |
+
# mask = 1 - mask # invert mask to select other != k classes
|
168 |
+
# g = gumbels + self.logits
|
169 |
+
# # (batch_size, num_classes)
|
170 |
+
# epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
|
171 |
+
# mask * self.logits
|
172 |
+
# )
|
173 |
+
|
174 |
+
def sample_gumbel(shape, eps=1e-20):
|
175 |
+
U = torch.rand(shape)
|
176 |
+
|
177 |
+
U = U.cuda()
|
178 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
179 |
+
def gumbel_softmax_sample(logits, temperature):
|
180 |
+
y = logits + sample_gumbel(logits.shape)
|
181 |
+
return F.softmax(y / temperature, dim=-1)
|
182 |
+
def gumbel_softmax(logits, temperature,k, hard=False):
|
183 |
+
"""
|
184 |
+
ST-gumple-softmax
|
185 |
+
input: [*, n_class]
|
186 |
+
return: flatten --> [*, n_class] an one-hot vector
|
187 |
+
"""
|
188 |
+
y = gumbel_softmax_sample(logits, temperature)
|
189 |
+
|
190 |
+
if not hard:
|
191 |
+
return y.view(-1, logits.shape[-1])
|
192 |
+
|
193 |
+
shape = y.size()
|
194 |
+
_, ind = k.max(dim=-1)
|
195 |
+
y_hard = torch.zeros_like(y).view(-1, shape[-1])
|
196 |
+
y_hard.scatter_(1, ind.view(-1, 1), 1)
|
197 |
+
y_hard = y_hard.view(*shape)
|
198 |
+
# Set gradients w.r.t. y_hard gradients w.r.t. y
|
199 |
+
y_hard = (y_hard - y).detach() + y
|
200 |
+
return y_hard.view(-1, logits.shape[-1])
|
201 |
+
epsilons = gumbel_softmax(self.logits,1e-3,k)
|
202 |
+
|
203 |
return epsilons
|
204 |
|
205 |
def log_abs_det_jacobian(self, x, y):
|