Upload layers.py
Browse files- pgm/layers.py +2 -2
pgm/layers.py
CHANGED
@@ -173,8 +173,8 @@ class ArgMaxGumbelMax(Transform):
|
|
173 |
|
174 |
def sample_gumbel(shape, eps=1e-20):
|
175 |
U = torch.rand(shape)
|
176 |
-
|
177 |
-
U = U.to(torch.device('cuda:1'))
|
178 |
return -torch.log(-torch.log(U + eps) + eps)
|
179 |
def gumbel_softmax_sample(logits, temperature):
|
180 |
y = logits + sample_gumbel(logits.shape)
|
|
|
173 |
|
174 |
def sample_gumbel(shape, eps=1e-20):
|
175 |
U = torch.rand(shape)
|
176 |
+
U = U.to(torch.device('cpu'))
|
177 |
+
# U = U.to(torch.device('cuda:1'))
|
178 |
return -torch.log(-torch.log(U + eps) + eps)
|
179 |
def gumbel_softmax_sample(logits, temperature):
|
180 |
y = logits + sample_gumbel(logits.shape)
|