File size: 463 Bytes
f0c48c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
"""

import torch
from kernels import get_kernel


_flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func


@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    outputs, lse = _flash_attn_func(q, k, v)
    return outputs

@flash_attn_func.register_fake
def _(q, k, v, **kwargs):
    return torch.empty_like(q).contiguous()