File size: 703 Bytes
39b4aba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import Optional
import torch
from ._ops import ops
def mha_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
alibi_slopes: torch.Tensor,
p_dropout: float,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
return_softmax: bool,
gen: Optional[torch.Generator],
) -> torch.Tensor:
return ops.mha_fwd(
q,
k,
v,
out,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
return out
|