Dionyssos's picture
determinis
bc7f42e
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
torch.backends.cuda.enable_mem_efficient_sdp(True)
def create_sin_embedding(positions,
dim,
max_period=10000
):
# assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(torch.float)
adim = torch.arange(half_dim, device=positions.device,
dtype=torch.float).view(1, 1, -1)
max_period_tensor = torch.full([],
max_period,
device=positions.device,
dtype=torch.float) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
# OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
class StreamingMultiheadAttention(nn.Module):
def __init__(self,
embed_dim,
num_heads,
cross_attention=False,
):
super().__init__()
self.cross_attention = cross_attention
# if not self.cross_attention then it has kvcachingn
self.k_history = None
# cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
self.v_history = None
self.num_heads = num_heads
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
dtype=torch.float))
def forward(self,
query,
key=None,
value=None):
layout = "b h t d"
if self.cross_attention:
# Different queries, keys, values > split in_proj_weight
dim = self.in_proj_weight.shape[0] // 3
q = nn.functional.linear(query, self.in_proj_weight[:dim])
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
q, k, v = [
rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
else:
# 1st projected makes k,v (instantaneous)
# Here else is self_attention for audio with itself (above is cross attention txt)
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
# here we have different floating values from official
projected = nn.functional.linear(query, self.in_proj_weight, None)
# print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
bound_layout = "b h p t d"
packed = rearrange(
projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
q, k, v = packed.unbind(dim=2)
if self.k_history is not None:
# IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
# thus it will try to continue with incompatible k/v dims!
self.k_history = torch.cat([self.k_history, k], 2)
self.v_history = torch.cat([self.v_history, v], 2)
else:
self.k_history = k
self.v_history = v
# Assign Completed k / v to k / v
k = self.k_history
v = self.v_history
# -> kv CACHE ONLY APPLIES if not self.cross_attention
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
x = self.out_proj(x)
return x
class StreamingTransformerLayer(nn.Module):
def __init__(self,
d_model,
num_heads,
dim_feedforward):
super().__init__()
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
num_heads=num_heads)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
num_heads=num_heads,
cross_attention=True)
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
def forward(self,
x,
cross_attention_src=None):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attention(query=self.norm_cross(x),
key=cross_attention_src,
value=cross_attention_src) # txtcondition
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
return x
class StreamingTransformer(nn.Module):
def __init__(self,
d_model=1536,
num_heads=24,
num_layers=48,
dim_feedforward=6144):
super().__init__()
self.layers = nn.ModuleList(
[
StreamingTransformerLayer(d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward) for _ in range(num_layers)
]
)
def forward(self,
x,
cache_position=None,
cross_attention_src=None):
x = x + create_sin_embedding(
torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)
for lay in self.layers:
x = lay(x,
cross_attention_src=cross_attention_src)
return x
def _flush(self,
n_preserve=None):
for lay in self.layers:
if n_preserve is not None:
# cache position is difficult to choose to also preserve kv from end
lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
else:
lay.self_attn.k_history = None
lay.self_attn.v_history = None