Yeefei commited on
Commit
cee1e90
·
verified ·
1 Parent(s): f8e446a

Upload layers.py

Browse files
Files changed (1) hide show
  1. pgm/layers.py +51 -21
pgm/layers.py CHANGED
@@ -149,27 +149,57 @@ class ArgMaxGumbelMax(Transform):
149
  """Infer the gumbels noises given k and logits."""
150
  assert self.logits != None, "Logits not defined."
151
 
152
- uniforms = torch.rand(
153
- self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
154
- )
155
- gumbels = -((-(uniforms.log())).log())
156
- # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
157
- # (batch_size, num_classes) mask to select kth class
158
- # print(f'k : {k.size()}')
159
- mask = F.one_hot(
160
- k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
161
- )
162
- # print(f'mask: {mask.size()}, {mask.dtype}')
163
- # (batch_size, 1) select topgumbel for truncation of other classes
164
- topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
165
- mask * self.logits
166
- ).sum(dim=-1, keepdim=True)
167
- mask = 1 - mask # invert mask to select other != k classes
168
- g = gumbels + self.logits
169
- # (batch_size, num_classes)
170
- epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
171
- mask * self.logits
172
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  return epsilons
174
 
175
  def log_abs_det_jacobian(self, x, y):
 
149
  """Infer the gumbels noises given k and logits."""
150
  assert self.logits != None, "Logits not defined."
151
 
152
+ # uniforms = torch.rand(
153
+ # self.logits.shape, dtype=self.logits.dtype, device=self.logits.device
154
+ # )
155
+ # gumbels = -((-(uniforms.log())).log())
156
+ # # print(f'gumbels: {gumbels.size()}, {gumbels.dtype}')
157
+ # # (batch_size, num_classes) mask to select kth class
158
+ # # print(f'k : {k.size()}')
159
+ # mask = F.one_hot(
160
+ # k.squeeze(-1).to(torch.int64), num_classes=self.logits.shape[-1]
161
+ # )
162
+ # # print(f'mask: {mask.size()}, {mask.dtype}')
163
+ # # (batch_size, 1) select topgumbel for truncation of other classes
164
+ # topgumbel = (mask * gumbels).sum(dim=-1, keepdim=True) - (
165
+ # mask * self.logits
166
+ # ).sum(dim=-1, keepdim=True)
167
+ # mask = 1 - mask # invert mask to select other != k classes
168
+ # g = gumbels + self.logits
169
+ # # (batch_size, num_classes)
170
+ # epsilons = -torch.log(mask * torch.exp(-g) + torch.exp(-topgumbel)) - (
171
+ # mask * self.logits
172
+ # )
173
+
174
+ def sample_gumbel(shape, eps=1e-20):
175
+ U = torch.rand(shape)
176
+
177
+ U = U.cuda()
178
+ return -torch.log(-torch.log(U + eps) + eps)
179
+ def gumbel_softmax_sample(logits, temperature):
180
+ y = logits + sample_gumbel(logits.shape)
181
+ return F.softmax(y / temperature, dim=-1)
182
+ def gumbel_softmax(logits, temperature,k, hard=False):
183
+ """
184
+ ST-gumple-softmax
185
+ input: [*, n_class]
186
+ return: flatten --> [*, n_class] an one-hot vector
187
+ """
188
+ y = gumbel_softmax_sample(logits, temperature)
189
+
190
+ if not hard:
191
+ return y.view(-1, logits.shape[-1])
192
+
193
+ shape = y.size()
194
+ _, ind = k.max(dim=-1)
195
+ y_hard = torch.zeros_like(y).view(-1, shape[-1])
196
+ y_hard.scatter_(1, ind.view(-1, 1), 1)
197
+ y_hard = y_hard.view(*shape)
198
+ # Set gradients w.r.t. y_hard gradients w.r.t. y
199
+ y_hard = (y_hard - y).detach() + y
200
+ return y_hard.view(-1, logits.shape[-1])
201
+ epsilons = gumbel_softmax(self.logits,1e-3,k)
202
+
203
  return epsilons
204
 
205
  def log_abs_det_jacobian(self, x, y):