|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from einops import rearrange |
|
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
|
|
|
def create_sin_embedding(positions, |
|
dim, |
|
max_period=10000 |
|
): |
|
|
|
half_dim = dim // 2 |
|
positions = positions.to(torch.float) |
|
adim = torch.arange(half_dim, device=positions.device, |
|
dtype=torch.float).view(1, 1, -1) |
|
max_period_tensor = torch.full([], |
|
max_period, |
|
device=positions.device, |
|
dtype=torch.float) |
|
phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
|
|
|
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
|
|
|
|
|
class StreamingMultiheadAttention(nn.Module): |
|
|
|
def __init__(self, |
|
embed_dim, |
|
num_heads, |
|
cross_attention=False, |
|
): |
|
|
|
super().__init__() |
|
|
|
self.cross_attention = cross_attention |
|
|
|
self.k_history = None |
|
|
|
self.v_history = None |
|
self.num_heads = num_heads |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim), |
|
dtype=torch.float)) |
|
|
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None): |
|
layout = "b h t d" |
|
if self.cross_attention: |
|
|
|
|
|
|
|
dim = self.in_proj_weight.shape[0] // 3 |
|
|
|
q = nn.functional.linear(query, self.in_proj_weight[:dim]) |
|
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim]) |
|
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:]) |
|
|
|
q, k, v = [ |
|
rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
projected = nn.functional.linear(query, self.in_proj_weight, None) |
|
|
|
bound_layout = "b h p t d" |
|
packed = rearrange( |
|
projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) |
|
q, k, v = packed.unbind(dim=2) |
|
if self.k_history is not None: |
|
|
|
|
|
self.k_history = torch.cat([self.k_history, k], 2) |
|
self.v_history = torch.cat([self.v_history, v], 2) |
|
else: |
|
self.k_history = k |
|
self.v_history = v |
|
|
|
|
|
|
|
k = self.k_history |
|
v = self.v_history |
|
|
|
|
|
|
|
x = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0) |
|
|
|
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class StreamingTransformerLayer(nn.Module): |
|
|
|
def __init__(self, |
|
d_model, |
|
num_heads, |
|
dim_feedforward): |
|
|
|
super().__init__() |
|
|
|
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model, |
|
num_heads=num_heads) |
|
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) |
|
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model, |
|
num_heads=num_heads, |
|
cross_attention=True) |
|
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5) |
|
self.norm1 = nn.LayerNorm(d_model, eps=1e-5) |
|
self.norm2 = nn.LayerNorm(d_model, eps=1e-5) |
|
|
|
def forward(self, |
|
x, |
|
cross_attention_src=None): |
|
x = x + self.self_attn(self.norm1(x)) |
|
x = x + self.cross_attention(query=self.norm_cross(x), |
|
key=cross_attention_src, |
|
value=cross_attention_src) |
|
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) |
|
return x |
|
|
|
|
|
class StreamingTransformer(nn.Module): |
|
|
|
def __init__(self, |
|
d_model=1536, |
|
num_heads=24, |
|
num_layers=48, |
|
dim_feedforward=6144): |
|
super().__init__() |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
StreamingTransformerLayer(d_model=d_model, |
|
num_heads=num_heads, |
|
dim_feedforward=dim_feedforward) for _ in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, |
|
x, |
|
cache_position=None, |
|
cross_attention_src=None): |
|
|
|
x = x + create_sin_embedding( |
|
torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536) |
|
|
|
for lay in self.layers: |
|
x = lay(x, |
|
cross_attention_src=cross_attention_src) |
|
return x |
|
|
|
def _flush(self, |
|
n_preserve=None): |
|
|
|
for lay in self.layers: |
|
if n_preserve is not None: |
|
|
|
lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :] |
|
lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :] |
|
else: |
|
lay.self_attn.k_history = None |
|
lay.self_attn.v_history = None |
|
|