Spaces:
Runtime error
Runtime error
forgot device type
Browse files- wan/modules/model.py +4 -4
wan/modules/model.py
CHANGED
@@ -294,7 +294,7 @@ class WanAttentionBlock(nn.Module):
|
|
294 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
295 |
"""
|
296 |
assert e.dtype == torch.float32
|
297 |
-
with amp.autocast(dtype=torch.float32):
|
298 |
e = (self.modulation + e).chunk(6, dim=1)
|
299 |
assert e[0].dtype == torch.float32
|
300 |
|
@@ -309,7 +309,7 @@ class WanAttentionBlock(nn.Module):
|
|
309 |
def cross_attn_ffn(x, context, context_lens, e):
|
310 |
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
311 |
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
312 |
-
with amp.autocast(dtype=torch.float32):
|
313 |
x = x + y * e[5]
|
314 |
return x
|
315 |
|
@@ -341,7 +341,7 @@ class Head(nn.Module):
|
|
341 |
e(Tensor): Shape [B, C]
|
342 |
"""
|
343 |
assert e.dtype == torch.float32
|
344 |
-
with amp.autocast(dtype=torch.float32):
|
345 |
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
346 |
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
347 |
return x
|
@@ -542,7 +542,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
542 |
])
|
543 |
|
544 |
# time embeddings
|
545 |
-
with amp.autocast(dtype=torch.float32):
|
546 |
e = self.time_embedding(
|
547 |
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
548 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
|
|
294 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
295 |
"""
|
296 |
assert e.dtype == torch.float32
|
297 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
298 |
e = (self.modulation + e).chunk(6, dim=1)
|
299 |
assert e[0].dtype == torch.float32
|
300 |
|
|
|
309 |
def cross_attn_ffn(x, context, context_lens, e):
|
310 |
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
311 |
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
312 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
313 |
x = x + y * e[5]
|
314 |
return x
|
315 |
|
|
|
341 |
e(Tensor): Shape [B, C]
|
342 |
"""
|
343 |
assert e.dtype == torch.float32
|
344 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
345 |
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
346 |
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
347 |
return x
|
|
|
542 |
])
|
543 |
|
544 |
# time embeddings
|
545 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
546 |
e = self.time_embedding(
|
547 |
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
548 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|