import math import os import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import Tensor from transformers import AutoTokenizer from vui.fluac import Fluac from vui.utils import load_what_you_can from .config import Config from .patterns import DelayedPatternProvider from .rope import apply_rotary_emb, precompute_freqs_cis class KVCache(nn.Module): def __init__( self, batch_size: int, max_seqlen: int, n_kv_heads: int, head_dim: int, dtype: torch.dtype = torch.bfloat16, ): super().__init__() cache_shape = (batch_size, n_kv_heads, max_seqlen, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): # input_pos: (T,), k_val: (B, nh, T, d) assert input_pos.size(0) == k_val.size(-2) k_out = self.k_cache v_out = self.v_cache input_pos = input_pos.int() k_out[:, :, input_pos] = k_val.to(k_out.dtype) v_out[:, :, input_pos] = v_val.to(k_out.dtype) return k_out, v_out def repeat_kv(x: torch.Tensor, n_reps: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, n_kv_heads, T, head_dim = x.shape return ( x[:, :, :, None, :] .expand(bs, n_kv_heads, n_reps, T, head_dim) .reshape(bs, n_kv_heads * n_reps, T, head_dim) ) class MHA(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, *, block_idx: int, bias: bool = False, dropout: float = 0.0, causal: bool = False, use_rotary_emb: bool = True, ): super().__init__() head_dim = dim // n_heads self.use_rotary_emb = use_rotary_emb self.block_idx = block_idx self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim self.dropout = dropout self.causal = causal self.n_reps = n_kv_heads // n_heads qkv_dim = (n_heads + 2 * n_kv_heads) * head_dim self.Wqkv = nn.Linear(dim, qkv_dim, bias=bias) self.out_proj = nn.Linear(dim, dim, bias=bias) self.kv_cache = None def forward( self, x: Tensor, freqs_cis: Tensor | None = None, input_pos: Tensor | None = None, attn_mask: Tensor | None = None, ): B, T, d = x.size() dropout_p = self.dropout if self.training else 0.0 qkv = self.Wqkv(x).to(x.dtype) if self.n_heads == self.n_kv_heads: qkv = rearrange( qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads ) q, k, v = qkv.unbind(dim=1) # (B, h, T, d) else: q, k, v = torch.split( qkv, [ self.head_dim * self.n_heads, self.head_dim * self.n_kv_heads, self.head_dim * self.n_kv_heads, ], dim=1, ) q, k, v = map(lambda t: rearrange(t, "B T (h d) -> B h T d"), (q, k, v)) if self.use_rotary_emb: q = apply_rotary_emb(freqs_cis, q) k = apply_rotary_emb(freqs_cis, k) if self.kv_cache is not None: k, v = self.kv_cache.update(input_pos, k, v) if self.n_reps > 1: k = repeat_kv(k, self.n_reps) v = repeat_kv(v, self.n_reps) q, k, v = q.to(x.dtype), k.to(x.dtype), v.to(x.dtype) is_causal = self.causal and self.kv_cache is None out = F.scaled_dot_product_attention( q, k, v, dropout_p=dropout_p, is_causal=is_causal, attn_mask=attn_mask, ) out = self.out_proj(rearrange(out, "B h T d -> B T (h d)")) return out class MLP(nn.Module): def __init__( self, *, d_model: int, bias: bool, dropout: float, act=nn.GELU, **kwargs ): super().__init__() self.fc1 = nn.Linear(d_model, 4 * d_model, bias=bias) self.act = act() self.fc2 = nn.Linear(4 * d_model, d_model, bias=bias) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.fc2(self.act(self.fc1(x)))) class LlamaMLP(nn.Module): def __init__( self, *, d_model: int, multiple_of: int = 256, bias: bool = False, **kwargs ) -> None: super().__init__() hidden_dim = 4 * d_model hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(d_model, hidden_dim, bias=bias) self.w3 = nn.Linear(d_model, hidden_dim, bias=bias) self.w2 = nn.Linear(hidden_dim, d_model, bias=bias) def forward(self, x: Tensor) -> Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: Tensor): output = self._norm(x.float()).type_as(x) return output * self.weight class Block(nn.Module): def __init__( self, *, d_model: int, n_heads: int, n_kv_heads: int, block_idx: int, bias: bool, dropout: float, norm_eps: float = 1e-5, # use 1e-6 for rms use_rotary_emb: bool = True, ): super().__init__() self.block_idx = block_idx self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = d_model // n_heads self.attn_norm = RMSNorm(d_model, eps=norm_eps) self.attn = MHA( d_model, n_heads, n_kv_heads, block_idx=block_idx, bias=bias, dropout=dropout, causal=True, use_rotary_emb=use_rotary_emb, ) self.mlp_norm = RMSNorm(d_model, eps=norm_eps) self.mlp = LlamaMLP(d_model=d_model, bias=bias, dropout=dropout) def forward( self, x: Tensor, freqs_cis: Tensor | None = None, input_pos: Tensor | None = None, attn_mask: Tensor | None = None, ): x = x + self.attn( self.attn_norm(x), freqs_cis=freqs_cis, input_pos=input_pos, attn_mask=attn_mask, ) x = x + self.mlp(self.mlp_norm(x)) return x class Decoder(nn.Module): def __init__( self, *, n_layers: int, d_model: int, n_heads: int, n_kv_heads: int, bias: bool, dropout: float, max_seqlen: int = 4096, rope_theta: float = 10000.0, rope_theta_rescale_factor: float = 1.0, norm_eps: float = 1e-5, use_rotary_emb: bool = True, rope_dim: int | None = None, ): super().__init__() assert d_model % n_heads == 0 self.use_rotary_emb = use_rotary_emb self.max_seqlen = max_seqlen self.blocks = nn.ModuleList( [ Block( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, block_idx=block_idx, bias=bias, dropout=dropout, norm_eps=norm_eps, use_rotary_emb=use_rotary_emb, ) for block_idx in range(n_layers) ] ) self.norm = RMSNorm(d_model, eps=norm_eps) self.attn_mask = None head_dim = d_model // n_heads rope_dim = rope_dim or head_dim assert rope_dim <= head_dim # apply RoPE to a fraction of embeddings freqs_cis = precompute_freqs_cis( rope_dim, max_seqlen, theta=rope_theta, theta_rescale_factor=rope_theta_rescale_factor, ) self.register_buffer("freqs_cis", freqs_cis, persistent=False) def allocate_inference_cache( self, batch_size: int, device: str, dtype=torch.bfloat16 ): for block in self.blocks: block.attn.kv_cache = KVCache( batch_size, self.max_seqlen, block.n_kv_heads, block.head_dim, dtype ).to(device) # I don't understand why this is needed self.attn_mask = torch.tril( torch.ones( self.max_seqlen, self.max_seqlen, dtype=torch.bool, device=device ) ) def deallocate_kv_cache(self): for block in self.blocks: block.attn.kv_cache = None self.attn_mask = None def forward(self, x: Tensor, input_pos: Tensor): if self.use_rotary_emb: freqs_cis = self.freqs_cis[input_pos] else: freqs_cis = None attn_mask = ( self.attn_mask[None, None, input_pos] if self.attn_mask is not None else None ) for block in self.blocks: x = block(x, freqs_cis=freqs_cis, input_pos=input_pos, attn_mask=attn_mask) x = self.norm(x) return x class Vui(nn.Module): BASE = "vui-100m-base.pt" COHOST = "vui-cohost-100m.pt" ABRAHAM = "vui-abraham-100m.pt" def __init__(self, config: Config = Config()): super().__init__() self.codec = Fluac.from_pretrained() self.config = config cfg = config.model self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") self.use_rotary_emb = cfg.use_rotary_emb self.token_emb = nn.Embedding(self.tokenizer.vocab_size, cfg.d_model) self.pattern_provider = DelayedPatternProvider(n_q=cfg.n_quantizers) self.audio_embeddings = nn.ModuleList( [ nn.Embedding(cfg.codebook_size + 8, cfg.d_model) for _ in range(cfg.n_quantizers) ] ) n_kv_heads = cfg.n_heads max_seqlen = cfg.max_text_tokens + cfg.max_audio_tokens self.decoder = Decoder( n_layers=cfg.n_layers, d_model=cfg.d_model, n_heads=cfg.n_heads, n_kv_heads=n_kv_heads, bias=cfg.bias, dropout=cfg.dropout, max_seqlen=max_seqlen + cfg.n_quantizers, rope_dim=cfg.rope_dim, rope_theta=cfg.rope_theta, rope_theta_rescale_factor=cfg.rope_theta_rescale_factor, ) self.audio_heads = nn.ModuleList( [ nn.Linear(cfg.d_model, cfg.codebook_size + 8, bias=cfg.bias) for _ in range(cfg.n_quantizers) ] ) self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("out_proj.weight"): torch.nn.init.normal_( p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers) ) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @staticmethod def from_pretrained( checkpoint_path: str | dict = ABRAHAM, **config_kwargs, ): if isinstance(checkpoint_path, dict): checkpoint = checkpoint_path else: if not os.path.exists(checkpoint_path): from huggingface_hub import hf_hub_download checkpoint_path = hf_hub_download( "fluxions/vui", checkpoint_path, ) checkpoint = torch.load( checkpoint_path, map_location="cpu", weights_only=True ) config = {**checkpoint["config"], **config_kwargs} config = Config(**config) state_dict = checkpoint["model"] state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} state_dict = { k.replace("text_embedding.", "token_emb."): v for k, v in state_dict.items() } model = Vui(config) load_what_you_can(state_dict, model) return model @staticmethod def from_pretrained_inf( checkpoint_path: str | dict, **config_kwargs, ): return Vui.from_pretrained(checkpoint_path, **config_kwargs).eval() @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype