File size: 676 Bytes
23d26f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
from functools import partial
from typing import Callable
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
def decorator(*args, **kwargs):
if cuda_amp_deprecated:
kwargs["device_type"] = "cuda"
return dec(*args, **kwargs)
return decorator
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
deprecated = True
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
else:
deprecated = False
from torch.cuda.amp import custom_fwd, custom_bwd
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
|