|
|
|
"""Minimal implementation of SSD. |
|
|
|
This is the same as Listing 1 from the paper. |
|
""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
|
|
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined |
|
|
|
|
|
def segsum_unstable(x): |
|
"""Naive segment sum calculation.""" |
|
T = x.size(-1) |
|
x_cumsum = torch.cumsum(x, dim=-1) |
|
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] |
|
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) |
|
x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
|
return x_segsum |
|
|
|
|
|
def segsum(x): |
|
"""More stable segment sum calculation.""" |
|
T = x.size(-1) |
|
x = repeat(x, "... d -> ... d e", e=T) |
|
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) |
|
x = x.masked_fill(~mask, 0) |
|
x_segsum = torch.cumsum(x, dim=-2) |
|
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) |
|
x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
|
return x_segsum |
|
|
|
|
|
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): |
|
""" |
|
Arguments: |
|
X: (batch, length, n_heads, d_head) |
|
A: (batch, length, n_heads) |
|
B: (batch, length, n_heads, d_state) |
|
C: (batch, length, n_heads, d_state) |
|
Return: |
|
Y: (batch, length, n_heads, d_head) |
|
""" |
|
assert X.dtype == A.dtype == B.dtype == C.dtype |
|
assert X.shape[1] % block_len == 0 |
|
|
|
|
|
X, A, B, C = [ |
|
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C) |
|
] |
|
|
|
A = rearrange(A, "b c l h -> b h c l") |
|
A_cumsum = torch.cumsum(A, dim=-1) |
|
|
|
|
|
L = torch.exp(segsum(A)) |
|
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) |
|
|
|
|
|
|
|
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) |
|
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) |
|
|
|
|
|
|
|
if initial_states is None: |
|
initial_states = torch.zeros_like(states[:, :1]) |
|
states = torch.cat([initial_states, states], dim=1) |
|
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) |
|
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) |
|
states, final_state = new_states[:, :-1], new_states[:, -1] |
|
|
|
|
|
|
|
state_decay_out = torch.exp(A_cumsum) |
|
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) |
|
|
|
|
|
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") |
|
return Y, final_state |
|
|
|
|
|
|
|
def test_correctness(): |
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64 |
|
nheads = dim // headdim |
|
ngroups = 1 |
|
dstate = 64 |
|
dtype = torch.float32 |
|
device = "cuda" |
|
|
|
x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device) |
|
dt = F.softplus( |
|
torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4 |
|
).requires_grad_() |
|
A = ( |
|
-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device)) |
|
).requires_grad_() |
|
B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) |
|
C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) |
|
D = torch.randn(nheads, dtype=dtype, device=device) |
|
|
|
|
|
y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None) |
|
y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size) |
|
|