kernel
drbh
feat: pass vars into fwd and include build
39b4aba
raw
history blame
703 Bytes
from typing import Optional
import torch
from ._ops import ops
def mha_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
alibi_slopes: torch.Tensor,
p_dropout: float,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
return_softmax: bool,
gen: Optional[torch.Generator],
) -> torch.Tensor:
return ops.mha_fwd(
q,
k,
v,
out,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
return out