fffiloni commited on
Commit
f4d3f52
Β·
verified Β·
1 Parent(s): a2395b4

Change deprecated cuda amp calls

Browse files
Files changed (1) hide show
  1. wan/modules/model.py +3 -3
wan/modules/model.py CHANGED
@@ -2,7 +2,7 @@
2
  import math
3
 
4
  import torch
5
- import torch.cuda.amp as amp
6
  import torch.nn as nn
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
@@ -28,7 +28,7 @@ def sinusoidal_embedding_1d(dim, position):
28
  return x
29
 
30
 
31
- @amp.autocast(enabled=False)
32
  def rope_params(max_seq_len, dim, theta=10000):
33
  assert dim % 2 == 0
34
  freqs = torch.outer(
@@ -39,7 +39,7 @@ def rope_params(max_seq_len, dim, theta=10000):
39
  return freqs
40
 
41
 
42
- @amp.autocast(enabled=False)
43
  def rope_apply(x, grid_sizes, freqs):
44
  n, c = x.size(2), x.size(3) // 2
45
 
 
2
  import math
3
 
4
  import torch
5
+ import torch.amp as amp
6
  import torch.nn as nn
7
  from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
 
28
  return x
29
 
30
 
31
+ @amp.autocast("cuda", enabled=False)
32
  def rope_params(max_seq_len, dim, theta=10000):
33
  assert dim % 2 == 0
34
  freqs = torch.outer(
 
39
  return freqs
40
 
41
 
42
+ @amp.autocast("cuda", enabled=False)
43
  def rope_apply(x, grid_sizes, freqs):
44
  n, c = x.size(2), x.size(3) // 2
45