kernel
danieldk's picture
danieldk HF Staff
Import mamba-ssm kernels
23d26f4
raw
history blame contribute delete
676 Bytes
import torch
from functools import partial
from typing import Callable
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
def decorator(*args, **kwargs):
if cuda_amp_deprecated:
kwargs["device_type"] = "cuda"
return dec(*args, **kwargs)
return decorator
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
deprecated = True
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
else:
deprecated = False
from torch.cuda.amp import custom_fwd, custom_bwd
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)