vui-space / vui /fluac.py
Harry Coultas Blum
tryin relative imports
a1a9d33
import math
from contextlib import nullcontext
from functools import partial, wraps
from os import path
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, unpack
from einops.layers.torch import Rearrange
from pydantic import BaseModel
from torch import Tensor, int32
from torch.amp import autocast
from torch.nn import Module
from torch.nn.utils.parametrizations import weight_norm
from vui.utils import decompile_state_dict
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def maybe(fn):
@wraps(fn)
def inner(x, *args, **kwargs):
if not exists(x):
return x
return fn(x, *args, **kwargs)
return inner
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
class FSQ(Module):
def __init__(
self,
levels: List[int],
dim: int | None = None,
num_codebooks: int = 1,
keep_num_codebooks_dim: bool | None = None,
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
channel_first: bool = True,
projection_has_bias: bool = True,
return_indices=True,
force_quantization_f32: bool = True,
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent=False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
self.register_buffer("_basis", _basis, persistent=False)
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(dim, len(_levels) * num_codebooks)
self.channel_first = channel_first
has_projections = self.dim != effective_codebook_dim
self.project_in = (
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
if has_projections
else nn.Identity()
)
self.project_out = (
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
if has_projections
else nn.Identity()
)
self.has_projections = has_projections
self.return_indices = return_indices
if return_indices:
self.codebook_size = self._levels.prod().item()
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
self.register_buffer(
"implicit_codebook", implicit_codebook, persistent=False
)
self.allowed_dtypes = allowed_dtypes
self.force_quantization_f32 = force_quantization_f32
def bound(self, z, eps: float = 1e-3):
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def quantize(self, z):
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized):
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat):
half_width = self._levels // 2
return (zhat - half_width) / half_width
def _indices_to_codes(self, indices):
level_indices = self.indices_to_level_indices(indices)
codes = self._scale_and_shift_inverse(level_indices)
return codes
def codes_to_indices(self, zhat):
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_level_indices(self, indices):
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
indices = rearrange(indices, "... -> ... 1")
codes_non_centered = (indices // self._basis) % self._levels
return codes_non_centered
def indices_to_codes(self, indices):
"""Inverse of `codes_to_indices`."""
assert exists(indices)
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
codes = self._indices_to_codes(indices)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, "... c d -> ... (c d)")
codes = self.project_out(codes)
if is_img_or_video or self.channel_first:
codes = rearrange(codes, "b ... d -> b d ...")
return codes
def forward(self, z: Tensor):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
device_type = z.device.type
with torch.autocast(device_type=device_type, enabled=False):
if self.channel_first:
z = rearrange(z, "b d ... -> b ... d")
z, ps = pack_one(z, "b * d")
assert (
z.shape[-1] == self.dim
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
z = self.project_in(z)
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
# whether to force quantization step to be full precision or not
force_f32 = self.force_quantization_f32
quantization_context = (
partial(autocast, device_type=device_type, enabled=False)
if force_f32
else nullcontext
)
with quantization_context():
orig_dtype = z.dtype
if force_f32 and orig_dtype not in self.allowed_dtypes:
z = z.float()
codes = self.quantize(z)
# returning indices could be optional
indices = None
if self.return_indices:
indices = self.codes_to_indices(codes)
codes = rearrange(codes, "b n c d -> b n (c d)")
codes = codes.type(orig_dtype)
# project out
out = self.project_out(codes)
# reconstitute image or video dimensions
if self.channel_first:
out = unpack_one(out, ps, "b * d")
out = rearrange(out, "b ... d -> b d ...")
indices = maybe(unpack_one)(indices, ps, "b * c")
if not self.keep_num_codebooks_dim and self.return_indices:
indices = maybe(rearrange)(indices, "... 1 -> ...")
# return quantized output and indices
return out, indices
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
class ResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
Snake1d(dim),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 64,
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [EncoderBlock(d_model, stride=stride)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class Decoder(nn.Module):
def __init__(
self,
input_channel: int,
channels: int,
rates: list[int],
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, stride)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
# @torch.compile(dynamic=True)
def forward(self, z: Tensor):
return self.model(z)
class FiniteScalarQuantize(nn.Module):
def __init__(
self, latent_dim: int, levels: list[int], *, stride: int = 1, mlp: bool = False
):
super().__init__()
self.stride = stride
codebook_dim = len(levels)
self.in_proj = WNConv1d(latent_dim, codebook_dim, kernel_size=1)
self.quantize = FSQ(levels=levels, channel_first=True)
self.out_proj = WNConv1d(codebook_dim, latent_dim, kernel_size=1)
if mlp:
self.mlp = nn.Sequential(
Rearrange("B C T -> B T C"),
nn.Linear(latent_dim, 4 * latent_dim),
nn.GELU(),
nn.Linear(4 * latent_dim, latent_dim),
Rearrange("B T C -> B C T"),
)
else:
self.mlp = None
def from_indices(self, indices: Tensor):
B, T = indices.size()
z_q = self.quantize.indices_to_codes(indices)
z_q = self.out_proj(z_q)
return z_q
def forward(self, z: Tensor, *args):
if self.stride > 1:
z = F.avg_pool1d(z, self.stride, stride=self.stride)
z_e = self.in_proj(z) # z_e : (B x D x T)
# we're channels first
# scale = scale.unsqueeze(-1)
# z_e = z_e / scale
z_q, indices = self.quantize(z_e)
# z_q = z_q * scale
z_q = self.out_proj(z_q)
if self.stride > 1:
z_e = z_e.repeat_interleave(self.stride, dim=-1)
z_q = z_q.repeat_interleave(self.stride, dim=-1)
indices = indices.repeat_interleave(self.stride, dim=-1)
if self.mlp is not None:
z_q = self.mlp(z_q)
return z_q, indices, z_e
class ResidualFiniteScalarQuantize(nn.Module):
def __init__(
self,
*,
latent_dim: int,
n_quantizers: int,
levels: list[int],
strides: list[int] | None = None,
quantizer_dropout: float = 0.0,
mlp: bool = False,
):
super().__init__()
self.n_quantizers = n_quantizers
self.quantizer_dropout = quantizer_dropout
strides = [1] * n_quantizers if strides is None else strides
assert (
len(strides) == n_quantizers
), "Strides must be provided for each codebook"
scales = []
quantizers = []
levels_tensor = torch.tensor(levels, dtype=torch.float32)
for i in range(n_quantizers):
scales.append((levels_tensor - 1) ** -i)
quantizers.append(
FiniteScalarQuantize(
latent_dim=latent_dim, levels=levels, stride=strides[i], mlp=mlp
)
)
self.quantizers = nn.ModuleList(quantizers)
self.register_buffer("scales", torch.stack(scales), persistent=False)
codebooks = [
quantizer.quantize.implicit_codebook for quantizer in self.quantizers
]
self.codebooks = torch.stack(codebooks, dim=0)
def from_indices(self, indices: Tensor):
B, Q, T = indices.size()
z_q = 0.0
for i, quantizer in enumerate(self.quantizers):
z_q_i = quantizer.from_indices(indices[:, i])
z_q = z_q + z_q_i
return z_q
def forward(self, z: Tensor, n_quantizers: int | None = None):
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
the corresponding codebook vectors
Parameters
----------
z : Tensor[B x D x T]
n_quantizers : int, optional
No. of quantizers to use
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
Note: if `self.quantizer_dropout` is True, this argument is ignored
when in training mode, and a random number of quantizers is used.
Returns
-------
dict
A dictionary with the following keys:
"z" : Tensor[B x D x T]
Quantized continuous representation of input
"codes" : Tensor[B x N x T]
Codebook indices for each codebook
(quantized discrete representation of input)
"latents" : Tensor[B x N*D x T]
Projected latents (continuous representation of input before quantization)
"""
B = z.shape[0]
z_q = 0
residual = z
indices = []
latents = []
if n_quantizers is None:
n_quantizers = self.n_quantizers
if self.training:
n_quantizers = torch.ones((B,)) * self.n_quantizers + 1
dropout = torch.randint(1, self.n_quantizers + 1, (B,))
n_dropout = int(B * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(z.device)
for i, quantizer in enumerate(self.quantizers):
if not self.training and i >= n_quantizers:
break
z_q_i, indices_i, z_e_i = quantizer(residual)
residual = residual - z_q_i.detach()
mask = torch.full((B,), fill_value=i, device=z.device) < n_quantizers
z_q = z_q + z_q_i * mask[:, None, None]
indices.append(indices_i)
latents.append(z_e_i)
indices = torch.stack(indices, dim=1)
latents = torch.cat(latents, dim=1)
return z_q, indices, latents
class FluacConfig(BaseModel):
sample_rate: int = 44100
codebook_size: int | None = None
encoder_dim: int = 64
encoder_rates: list[int] = [2, 4, 8, 8]
quantizer_strides: list[int] | None = None # SNAC style strides
n_quantizers: int = 1
fsq_levels: list[int] | None = [8, 5, 5, 5] # 1000
decoder_dim: int = 1536
decoder_rates: list[int] = [8, 8, 4, 2]
@property
def hop_length(self) -> int:
return math.prod(self.encoder_rates)
@property
def latent_dim(self) -> int:
return self.encoder_dim * (2 ** len(self.encoder_rates))
@property
def effective_codebook_size(self) -> int:
return math.prod(self.fsq_levels)
class Fluac(nn.Module):
Q9_22KHZ = "fluac-22hz-22khz.pt"
def __init__(self, config: FluacConfig):
super().__init__()
self.config = config
self.encoder = Encoder(
config.encoder_dim, config.encoder_rates, config.latent_dim
)
self.quantizer = ResidualFiniteScalarQuantize(
latent_dim=config.latent_dim,
n_quantizers=config.n_quantizers,
levels=config.fsq_levels,
strides=config.quantizer_strides,
)
self.decoder = Decoder(
config.latent_dim,
config.decoder_dim,
config.decoder_rates,
)
self.apply(init_weights)
@staticmethod
def from_pretrained(name: str = Q9_22KHZ):
if path.exists(name):
checkpoint_path = name
else:
from huggingface_hub import hf_hub_download
checkpoint_path = hf_hub_download(
"fluxions/vui",
name,
)
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
config = checkpoint["config"]
if "model" in config:
model_config = FluacConfig(**config["model"])
else:
model_config = FluacConfig(**config)
generator = Fluac(model_config).eval()
ckpt = decompile_state_dict(checkpoint["generator"])
generator.load_state_dict(ckpt)
return generator
def pad(self, waveform: Tensor):
T = waveform.size(-1)
right_pad = math.ceil(T / self.config.hop_length) * self.config.hop_length - T
waveform = F.pad(waveform, (0, right_pad))
return waveform
@torch.inference_mode()
def from_indices(self, indices: Tensor):
z_q = self.quantizer.from_indices(indices)
waveform = self.decoder(z_q)
return waveform
@torch.inference_mode()
def encode(self, waveforms: Tensor, n_quantizers: int | None = None):
# Ensure that waveforms is 3 dima
waveforms = waveforms.flatten()[None][None]
waveforms = self.pad(waveforms)
B, C, T = waveforms.size()
z = self.encoder(waveforms)
z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
return codes
def forward(self, waveforms: Tensor, n_quantizers: int | None = None):
B, C, T = waveforms.size()
waveforms = self.pad(waveforms)
z = self.encoder(waveforms)
z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
recons = self.decoder(z_q)
recons = recons[..., :T]
return {
"recons": recons,
"codes": codes,
}
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def hz(self):
import numpy as np
return self.config.sample_rate / np.prod(self.config.encoder_rates).item()
if __name__ == "__main__":
codec = Fluac.from_pretrained(Fluac.Q9_22KHZ)
print(codec.config)
wav = torch.rand(1, 1, 22050)
wav = codec.pad(wav)
codes = codec.encode(wav)
breakpoint()