Spaces:
Runtime error
Runtime error
Change deprecated cuda amp calls
Browse files- wan/modules/model.py +3 -3
wan/modules/model.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
import math
|
3 |
|
4 |
import torch
|
5 |
-
import torch.
|
6 |
import torch.nn as nn
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
from diffusers.models.modeling_utils import ModelMixin
|
@@ -28,7 +28,7 @@ def sinusoidal_embedding_1d(dim, position):
|
|
28 |
return x
|
29 |
|
30 |
|
31 |
-
@amp.autocast(enabled=False)
|
32 |
def rope_params(max_seq_len, dim, theta=10000):
|
33 |
assert dim % 2 == 0
|
34 |
freqs = torch.outer(
|
@@ -39,7 +39,7 @@ def rope_params(max_seq_len, dim, theta=10000):
|
|
39 |
return freqs
|
40 |
|
41 |
|
42 |
-
@amp.autocast(enabled=False)
|
43 |
def rope_apply(x, grid_sizes, freqs):
|
44 |
n, c = x.size(2), x.size(3) // 2
|
45 |
|
|
|
2 |
import math
|
3 |
|
4 |
import torch
|
5 |
+
import torch.amp as amp
|
6 |
import torch.nn as nn
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
28 |
return x
|
29 |
|
30 |
|
31 |
+
@amp.autocast("cuda", enabled=False)
|
32 |
def rope_params(max_seq_len, dim, theta=10000):
|
33 |
assert dim % 2 == 0
|
34 |
freqs = torch.outer(
|
|
|
39 |
return freqs
|
40 |
|
41 |
|
42 |
+
@amp.autocast("cuda", enabled=False)
|
43 |
def rope_apply(x, grid_sizes, freqs):
|
44 |
n, c = x.size(2), x.size(3) // 2
|
45 |
|