kernel
drbh commited on
Commit
0518250
Β·
1 Parent(s): 876ac68

fix: bump readme

Browse files
Files changed (1) hide show
  1. README.md +104 -42
README.md CHANGED
@@ -6,18 +6,21 @@ tags:
6
 
7
  <!-- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/flash-attn) -->
8
 
9
- > [!WARNING]
10
- > The latest build b58ed97 may contain an accuracy issue, which is currently being addressed. Please use with caution, and be aware that corrected outputs will be available soon.
11
-
12
  # Flash Attention
13
 
14
  Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.
15
 
16
  Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
17
 
 
 
18
  ```python
19
  # /// script
20
- # dependencies = ["numpy", "torch", "kernels"]
 
 
 
 
21
  # ///
22
  import torch
23
  from kernels import get_kernel
@@ -27,23 +30,87 @@ torch.manual_seed(42)
27
  flash_attn = get_kernel("kernels-community/flash-attn")
28
  device = torch.device("cuda")
29
 
30
- # Show available functions
31
  print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
32
 
33
- # 1. Standard attention
34
- print("\n1. Standard attention:")
35
  B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
36
  q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
37
- out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0]
38
- print(f"Output: {out.shape}")
39
 
40
- # 2. Variable length sequences
41
- print("\n2. Variable length sequences:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
43
  k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
44
- # For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k
45
- cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32)
46
  cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
 
 
 
47
  out_var = flash_attn.mha_varlen_fwd(
48
  q=q_var,
49
  k=k_var,
@@ -52,42 +119,37 @@ out_var = flash_attn.mha_varlen_fwd(
52
  cu_seqlens_k=cu_k,
53
  max_seqlen_q=4,
54
  max_seqlen_k=5,
 
55
  )[0]
56
- print(f"Output: {out_var.shape}")
57
-
58
- # 3. KV-cache for autoregressive generation
59
- print("\n3. KV-cache:")
60
- cache_len, new_len = 10, 2
61
- kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16)
62
- q_new = k_new = v_new = torch.randn(
63
- B, new_len, H, D, device=device, dtype=torch.float16
64
- )
65
- seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32)
66
- out_kv = flash_attn.mha_fwd_kvcache(
67
- q=q_new,
68
- kcache=kcache,
69
- vcache=vcache,
70
- k=k_new,
71
- v=v_new,
72
- seqlens_k=seqlens,
73
- is_causal=True,
74
- )[0]
75
- print(f"Output: {out_kv.shape}")
76
  ```
77
 
78
- expected output
 
 
 
 
79
 
80
  ```txt
81
- Fetching 3 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:00<00:00, 16384.00it/s]
 
82
  Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']
83
 
84
  1. Standard attention:
85
- Output: torch.Size([2, 5, 4, 8])
86
-
87
- 2. Variable length sequences:
88
- Output: torch.Size([10, 4, 8])
89
-
90
- 3. KV-cache:
91
- Output: torch.Size([2, 2, 4, 8])
 
 
 
 
 
 
92
  ```
93
 
 
6
 
7
  <!-- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/flash-attn) -->
8
 
 
 
 
9
  # Flash Attention
10
 
11
  Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.
12
 
13
  Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
14
 
15
+
16
+ [`scripts/readme_example.py`](scripts/readme_example.py) provides a simple example of how to use the Flash Attention kernel in PyTorch. It demonstrates standard attention, causal attention, and variable-length sequences.
17
  ```python
18
  # /// script
19
+ # dependencies = [
20
+ # "numpy",
21
+ # "torch",
22
+ # "kernels"
23
+ # ]
24
  # ///
25
  import torch
26
  from kernels import get_kernel
 
30
  flash_attn = get_kernel("kernels-community/flash-attn")
31
  device = torch.device("cuda")
32
 
 
33
  print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
34
 
35
+ # Create test tensors
 
36
  B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
37
  q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
 
 
38
 
39
+ # Reference implementation using PyTorch SDPA
40
+ def reference_attention(query, key, value, causal=False):
41
+ query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
42
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
43
+ out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
44
+ return out.transpose(1, 2).contiguous()
45
+
46
+ # 1. Standard attention
47
+ print("\n1. Standard attention:")
48
+ out_ref = reference_attention(q, k, v)
49
+ out_flash = flash_attn.mha_fwd(
50
+ q=q,
51
+ k=k,
52
+ v=v,
53
+ is_causal=False,
54
+ softmax_scale=1.0 / (D ** 0.5), # scale factor
55
+ )[0]
56
+ print(f"Reference output: {out_ref.shape}")
57
+ print(f"Flash output: {out_flash.shape}")
58
+ print(f"Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}")
59
+
60
+ # 2. Causal attention (for autoregressive models)
61
+ print("\n2. Causal attention:")
62
+
63
+ out_ref_causal = reference_attention(q, k, v, causal=True)
64
+ out_causal = flash_attn.mha_fwd(
65
+ q=q,
66
+ k=k,
67
+ v=v,
68
+ is_causal=True,
69
+ softmax_scale=1.0 / (D ** 0.5), # scale factor
70
+ )[0]
71
+ print(f"Reference causal output: {out_ref_causal.shape}")
72
+ print(f"Flash causal output: {out_causal.shape}")
73
+ print(f"Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}")
74
+
75
+ def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
76
+ batch_size = cu_seqlens_q.shape[0] - 1
77
+ # Return output in packed format
78
+ total_tokens_q = q.shape[0]
79
+ out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
80
+
81
+ for b in range(batch_size):
82
+ start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
83
+ start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
84
+
85
+ # Extract slices for this batch
86
+ q_slice = q[start_q:end_q] # Shape: (seq_len_q, H, D)
87
+ k_slice = k[start_k:end_k] # Shape: (seq_len_k, H, D)
88
+ v_slice = v[start_k:end_k] # Shape: (seq_len_k, H, D)
89
+
90
+ # Add batch dimension for reference_attention
91
+ q_slice = q_slice.unsqueeze(0) # Shape: (1, seq_len_q, H, D)
92
+ k_slice = k_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
93
+ v_slice = v_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
94
+
95
+ # Compute attention and remove batch dimension
96
+ attn_out = reference_attention(q_slice, k_slice, v_slice, causal=causal)
97
+ attn_out = attn_out.squeeze(0) # Shape: (seq_len_q, H, D)
98
+
99
+ # Place result in output tensor (packed format)
100
+ out[start_q:end_q] = attn_out
101
+
102
+ return out
103
+
104
+ # 3. Variable length sequences (packed format)
105
+ print("\n3. Variable length sequences:")
106
+ # Pack sequences of lengths [3,4,3] for q and [4,5,3] for k into single tensors
107
  q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
108
  k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
109
+ cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32) # cumulative sequence lengths
 
110
  cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
111
+
112
+ out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
113
+ # Custom function to handle variable
114
  out_var = flash_attn.mha_varlen_fwd(
115
  q=q_var,
116
  k=k_var,
 
119
  cu_seqlens_k=cu_k,
120
  max_seqlen_q=4,
121
  max_seqlen_k=5,
122
+ softmax_scale=1.0 / (D ** 0.5), # scale factor
123
  )[0]
124
+ print(f"Variable length output: {out_var.shape}")
125
+ print(f"Reference variable length output: {out_var_ref.shape}")
126
+ print(f"Outputs close: {torch.allclose(out_var, out_var_ref, atol=1e-2, rtol=1e-3)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  ```
128
 
129
+ run it using the following command:
130
+
131
+ ```bash
132
+ uv run scripts/readme_example.py
133
+ ```
134
 
135
  ```txt
136
+ Reading inline script metadata from `flash-attn/scripts/readme_example.py`
137
+ Fetching 4 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 33354.31it/s]
138
  Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']
139
 
140
  1. Standard attention:
141
+ Reference output: torch.Size([2, 5, 4, 8])
142
+ Flash output: torch.Size([2, 5, 4, 8])
143
+ Outputs close: True
144
+
145
+ 1. Causal attention:
146
+ Reference causal output: torch.Size([2, 5, 4, 8])
147
+ Flash causal output: torch.Size([2, 5, 4, 8])
148
+ Outputs close: True
149
+
150
+ 1. Variable length sequences:
151
+ Variable length output: torch.Size([10, 4, 8])
152
+ Reference variable length output: torch.Size([10, 4, 8])
153
+ Outputs close: True
154
  ```
155