fffiloni commited on
Commit
b4d03cb
Β·
verified Β·
1 Parent(s): f4d3f52

forgot device type

Browse files
Files changed (1) hide show
  1. wan/modules/model.py +4 -4
wan/modules/model.py CHANGED
@@ -294,7 +294,7 @@ class WanAttentionBlock(nn.Module):
294
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
295
  """
296
  assert e.dtype == torch.float32
297
- with amp.autocast(dtype=torch.float32):
298
  e = (self.modulation + e).chunk(6, dim=1)
299
  assert e[0].dtype == torch.float32
300
 
@@ -309,7 +309,7 @@ class WanAttentionBlock(nn.Module):
309
  def cross_attn_ffn(x, context, context_lens, e):
310
  x = x + self.cross_attn(self.norm3(x), context, context_lens)
311
  y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
312
- with amp.autocast(dtype=torch.float32):
313
  x = x + y * e[5]
314
  return x
315
 
@@ -341,7 +341,7 @@ class Head(nn.Module):
341
  e(Tensor): Shape [B, C]
342
  """
343
  assert e.dtype == torch.float32
344
- with amp.autocast(dtype=torch.float32):
345
  e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
346
  x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
347
  return x
@@ -542,7 +542,7 @@ class WanModel(ModelMixin, ConfigMixin):
542
  ])
543
 
544
  # time embeddings
545
- with amp.autocast(dtype=torch.float32):
546
  e = self.time_embedding(
547
  sinusoidal_embedding_1d(self.freq_dim, t).float())
548
  e0 = self.time_projection(e).unflatten(1, (6, self.dim))
 
294
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
295
  """
296
  assert e.dtype == torch.float32
297
+ with amp.autocast("cuda", dtype=torch.float32):
298
  e = (self.modulation + e).chunk(6, dim=1)
299
  assert e[0].dtype == torch.float32
300
 
 
309
  def cross_attn_ffn(x, context, context_lens, e):
310
  x = x + self.cross_attn(self.norm3(x), context, context_lens)
311
  y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
312
+ with amp.autocast("cuda", dtype=torch.float32):
313
  x = x + y * e[5]
314
  return x
315
 
 
341
  e(Tensor): Shape [B, C]
342
  """
343
  assert e.dtype == torch.float32
344
+ with amp.autocast("cuda", dtype=torch.float32):
345
  e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
346
  x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
347
  return x
 
542
  ])
543
 
544
  # time embeddings
545
+ with amp.autocast("cuda", dtype=torch.float32):
546
  e = self.time_embedding(
547
  sinusoidal_embedding_1d(self.freq_dim, t).float())
548
  e0 = self.time_projection(e).unflatten(1, (6, self.dim))