|
|
|
|
|
import sys |
|
import warnings |
|
import os |
|
import re |
|
import ast |
|
import glob |
|
import shutil |
|
from pathlib import Path |
|
from packaging.version import parse, Version |
|
import platform |
|
|
|
from setuptools import setup, find_packages |
|
import subprocess |
|
|
|
import urllib.request |
|
import urllib.error |
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel |
|
|
|
import torch |
|
from torch.utils.cpp_extension import ( |
|
BuildExtension, |
|
CppExtension, |
|
CUDAExtension, |
|
CUDA_HOME, |
|
ROCM_HOME, |
|
IS_HIP_EXTENSION, |
|
) |
|
|
|
|
|
with open("README.md", "r", encoding="utf-8") as fh: |
|
long_description = fh.read() |
|
|
|
|
|
|
|
this_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") |
|
|
|
if BUILD_TARGET == "auto": |
|
if IS_HIP_EXTENSION: |
|
IS_ROCM = True |
|
else: |
|
IS_ROCM = False |
|
else: |
|
if BUILD_TARGET == "cuda": |
|
IS_ROCM = False |
|
elif BUILD_TARGET == "rocm": |
|
IS_ROCM = True |
|
|
|
PACKAGE_NAME = "flash_attn" |
|
|
|
BASE_WHEEL_URL = ( |
|
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" |
|
) |
|
|
|
|
|
|
|
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" |
|
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" |
|
|
|
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" |
|
|
|
|
|
def get_platform(): |
|
""" |
|
Returns the platform name as used in wheel filenames. |
|
""" |
|
if sys.platform.startswith("linux"): |
|
return f'linux_{platform.uname().machine}' |
|
elif sys.platform == "darwin": |
|
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) |
|
return f"macosx_{mac_version}_x86_64" |
|
elif sys.platform == "win32": |
|
return "win_amd64" |
|
else: |
|
raise ValueError("Unsupported platform: {}".format(sys.platform)) |
|
|
|
|
|
def get_cuda_bare_metal_version(cuda_dir): |
|
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) |
|
output = raw_output.split() |
|
release_idx = output.index("release") + 1 |
|
bare_metal_version = parse(output[release_idx].split(",")[0]) |
|
|
|
return raw_output, bare_metal_version |
|
|
|
|
|
def check_if_cuda_home_none(global_option: str) -> None: |
|
if CUDA_HOME is not None: |
|
return |
|
|
|
|
|
warnings.warn( |
|
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " |
|
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " |
|
"only images whose names contain 'devel' will provide nvcc." |
|
) |
|
|
|
|
|
def check_if_rocm_home_none(global_option: str) -> None: |
|
if ROCM_HOME is not None: |
|
return |
|
|
|
|
|
warnings.warn( |
|
f"{global_option} was requested, but hipcc was not found." |
|
) |
|
|
|
|
|
def append_nvcc_threads(nvcc_extra_args): |
|
nvcc_threads = os.getenv("NVCC_THREADS") or "4" |
|
return nvcc_extra_args + ["--threads", nvcc_threads] |
|
|
|
|
|
def rename_cpp_to_cu(cpp_files): |
|
for entry in cpp_files: |
|
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") |
|
|
|
|
|
def validate_and_update_archs(archs): |
|
|
|
allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"] |
|
|
|
|
|
assert all( |
|
arch in allowed_archs for arch in archs |
|
), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" |
|
|
|
|
|
cmdclass = {} |
|
ext_modules = [] |
|
|
|
|
|
|
|
if IS_ROCM: |
|
subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) |
|
else: |
|
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) |
|
|
|
if not SKIP_CUDA_BUILD and not IS_ROCM: |
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
|
TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
|
TORCH_MINOR = int(torch.__version__.split(".")[1]) |
|
|
|
|
|
|
|
generator_flag = [] |
|
torch_dir = torch.__path__[0] |
|
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): |
|
generator_flag = ["-DOLD_GENERATOR_PATH"] |
|
|
|
check_if_cuda_home_none("flash_attn") |
|
|
|
cc_flag = [] |
|
if CUDA_HOME is not None: |
|
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) |
|
if bare_metal_version < Version("11.6"): |
|
raise RuntimeError( |
|
"FlashAttention is only supported on CUDA 11.6 and above. " |
|
"Note: make sure nvcc has a supported version by running nvcc -V." |
|
) |
|
|
|
|
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_80,code=sm_80") |
|
if CUDA_HOME is not None: |
|
if bare_metal_version >= Version("11.8"): |
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_90,code=sm_90") |
|
|
|
|
|
|
|
|
|
if FORCE_CXX11_ABI: |
|
torch._C._GLIBCXX_USE_CXX11_ABI = True |
|
ext_modules.append( |
|
CUDAExtension( |
|
name="flash_attn_2_cuda", |
|
sources=[ |
|
"csrc/flash_attn/flash_api.cpp", |
|
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", |
|
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", |
|
], |
|
extra_compile_args={ |
|
"cxx": ["-O3", "-std=c++17"] + generator_flag, |
|
"nvcc": append_nvcc_threads( |
|
[ |
|
"-O3", |
|
"-std=c++17", |
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
"-U__CUDA_NO_HALF2_OPERATORS__", |
|
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
|
"--expt-relaxed-constexpr", |
|
"--expt-extended-lambda", |
|
"--use_fast_math", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
+ generator_flag |
|
+ cc_flag |
|
), |
|
}, |
|
include_dirs=[ |
|
Path(this_dir) / "csrc" / "flash_attn", |
|
Path(this_dir) / "csrc" / "flash_attn" / "src", |
|
Path(this_dir) / "csrc" / "cutlass" / "include", |
|
], |
|
) |
|
) |
|
elif not SKIP_CUDA_BUILD and IS_ROCM: |
|
ck_dir = "csrc/composable_kernel" |
|
|
|
|
|
if not os.path.exists("./build"): |
|
os.makedirs("build") |
|
|
|
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") |
|
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") |
|
|
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
|
TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
|
TORCH_MINOR = int(torch.__version__.split(".")[1]) |
|
|
|
|
|
|
|
generator_flag = [] |
|
torch_dir = torch.__path__[0] |
|
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): |
|
generator_flag = ["-DOLD_GENERATOR_PATH"] |
|
|
|
check_if_rocm_home_none("flash_attn") |
|
cc_flag = [] |
|
|
|
archs = os.getenv("GPU_ARCHS", "native").split(";") |
|
validate_and_update_archs(archs) |
|
|
|
cc_flag = [f"--offload-arch={arch}" for arch in archs] |
|
|
|
|
|
|
|
|
|
if FORCE_CXX11_ABI: |
|
torch._C._GLIBCXX_USE_CXX11_ABI = True |
|
|
|
sources = ["csrc/flash_attn_ck/flash_api.cpp", |
|
"csrc/flash_attn_ck/mha_bwd.cpp", |
|
"csrc/flash_attn_ck/mha_fwd.cpp", |
|
"csrc/flash_attn_ck/mha_varlen_bwd.cpp", |
|
"csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( |
|
f"build/fmha_*wd*.cpp" |
|
) |
|
|
|
rename_cpp_to_cu(sources) |
|
|
|
renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", |
|
"csrc/flash_attn_ck/mha_bwd.cu", |
|
"csrc/flash_attn_ck/mha_fwd.cu", |
|
"csrc/flash_attn_ck/mha_varlen_bwd.cu", |
|
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") |
|
extra_compile_args = { |
|
"cxx": ["-O3", "-std=c++17"] + generator_flag, |
|
"nvcc": |
|
[ |
|
"-O3","-std=c++17", |
|
"-mllvm", "-enable-post-misched=0", |
|
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1", |
|
"-fgpu-flush-denormals-to-zero", |
|
"-DCK_ENABLE_BF16", |
|
"-DCK_ENABLE_BF8", |
|
"-DCK_ENABLE_FP16", |
|
"-DCK_ENABLE_FP32", |
|
"-DCK_ENABLE_FP64", |
|
"-DCK_ENABLE_FP8", |
|
"-DCK_ENABLE_INT8", |
|
"-DCK_USE_XDL", |
|
"-DUSE_PROF_API=1", |
|
"-D__HIP_PLATFORM_HCC__=1", |
|
|
|
] |
|
+ generator_flag |
|
+ cc_flag |
|
, |
|
} |
|
|
|
include_dirs = [ |
|
Path(this_dir) / "csrc" / "composable_kernel" / "include", |
|
Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", |
|
Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", |
|
] |
|
|
|
ext_modules.append( |
|
CUDAExtension( |
|
name="flash_attn_2_cuda", |
|
sources=renamed_sources, |
|
extra_compile_args=extra_compile_args, |
|
include_dirs=include_dirs, |
|
) |
|
) |
|
|
|
|
|
def get_package_version(): |
|
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: |
|
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) |
|
public_version = ast.literal_eval(version_match.group(1)) |
|
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") |
|
if local_version: |
|
return f"{public_version}+{local_version}" |
|
else: |
|
return str(public_version) |
|
|
|
|
|
def get_wheel_url(): |
|
torch_version_raw = parse(torch.__version__) |
|
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" |
|
platform_name = get_platform() |
|
flash_version = get_package_version() |
|
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" |
|
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() |
|
|
|
if IS_ROCM: |
|
torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) |
|
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" |
|
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
|
else: |
|
|
|
|
|
|
|
torch_cuda_version = parse(torch.version.cuda) |
|
|
|
|
|
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") |
|
|
|
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" |
|
|
|
|
|
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
|
|
|
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) |
|
|
|
return wheel_url, wheel_filename |
|
|
|
|
|
class CachedWheelsCommand(_bdist_wheel): |
|
""" |
|
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot |
|
find an existing wheel (which is currently the case for all flash attention installs). We use |
|
the environment parameters to detect whether there is already a pre-built version of a compatible |
|
wheel available and short-circuits the standard full build pipeline. |
|
""" |
|
|
|
def run(self): |
|
if FORCE_BUILD: |
|
return super().run() |
|
|
|
wheel_url, wheel_filename = get_wheel_url() |
|
print("Guessing wheel URL: ", wheel_url) |
|
try: |
|
urllib.request.urlretrieve(wheel_url, wheel_filename) |
|
|
|
|
|
|
|
|
|
if not os.path.exists(self.dist_dir): |
|
os.makedirs(self.dist_dir) |
|
|
|
impl_tag, abi_tag, plat_tag = self.get_tag() |
|
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" |
|
|
|
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") |
|
print("Raw wheel path", wheel_path) |
|
os.rename(wheel_filename, wheel_path) |
|
except (urllib.error.HTTPError, urllib.error.URLError): |
|
print("Precompiled wheel not found. Building from source...") |
|
|
|
super().run() |
|
|
|
|
|
class NinjaBuildExtension(BuildExtension): |
|
def __init__(self, *args, **kwargs) -> None: |
|
|
|
if not os.environ.get("MAX_JOBS"): |
|
import psutil |
|
|
|
|
|
max_num_jobs_cores = max(1, os.cpu_count() // 2) |
|
|
|
|
|
free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) |
|
max_num_jobs_memory = int(free_memory_gb / 9) |
|
|
|
|
|
max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) |
|
os.environ["MAX_JOBS"] = str(max_jobs) |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
setup( |
|
name=PACKAGE_NAME, |
|
version=get_package_version(), |
|
packages=find_packages( |
|
exclude=( |
|
"build", |
|
"csrc", |
|
"include", |
|
"tests", |
|
"dist", |
|
"docs", |
|
"benchmarks", |
|
"flash_attn.egg-info", |
|
) |
|
), |
|
author="Tri Dao", |
|
author_email="tri@tridao.me", |
|
description="Flash Attention: Fast and Memory-Efficient Exact Attention", |
|
long_description=long_description, |
|
long_description_content_type="text/markdown", |
|
url="https://github.com/Dao-AILab/flash-attention", |
|
classifiers=[ |
|
"Programming Language :: Python :: 3", |
|
"License :: OSI Approved :: BSD License", |
|
"Operating System :: Unix", |
|
], |
|
ext_modules=ext_modules, |
|
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension} |
|
if ext_modules |
|
else { |
|
"bdist_wheel": CachedWheelsCommand, |
|
}, |
|
python_requires=">=3.8", |
|
install_requires=[ |
|
"torch", |
|
"einops", |
|
], |
|
setup_requires=[ |
|
"packaging", |
|
"psutil", |
|
"ninja", |
|
], |
|
) |