from typing import Optional | |
import torch | |
from ._ops import ops | |
def mha_fwd( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
out: torch.Tensor, | |
alibi_slopes: torch.Tensor, | |
p_dropout: float, | |
softmax_scale: float, | |
is_causal: bool, | |
window_size_left: int, | |
window_size_right: int, | |
softcap: float, | |
return_softmax: bool, | |
gen: Optional[torch.Generator], | |
) -> torch.Tensor: | |
return ops.mha_fwd( | |
q, | |
k, | |
v, | |
out, | |
alibi_slopes, | |
p_dropout, | |
softmax_scale, | |
is_causal, | |
window_size_left, | |
window_size_right, | |
softcap, | |
return_softmax, | |
gen, | |
) | |
return out | |