Yeefei commited on
Commit
7462b5c
·
verified ·
1 Parent(s): 87034d1

Upload layers.py

Browse files
Files changed (1) hide show
  1. pgm/layers.py +2 -2
pgm/layers.py CHANGED
@@ -173,8 +173,8 @@ class ArgMaxGumbelMax(Transform):
173
 
174
  def sample_gumbel(shape, eps=1e-20):
175
  U = torch.rand(shape)
176
-
177
- U = U.to(torch.device('cuda:1'))
178
  return -torch.log(-torch.log(U + eps) + eps)
179
  def gumbel_softmax_sample(logits, temperature):
180
  y = logits + sample_gumbel(logits.shape)
 
173
 
174
  def sample_gumbel(shape, eps=1e-20):
175
  U = torch.rand(shape)
176
+ U = U.to(torch.device('cpu'))
177
+ # U = U.to(torch.device('cuda:1'))
178
  return -torch.log(-torch.log(U + eps) + eps)
179
  def gumbel_softmax_sample(logits, temperature):
180
  y = logits + sample_gumbel(logits.shape)