|
import math |
|
from contextlib import nullcontext |
|
from functools import partial, wraps |
|
from os import path |
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import pack, rearrange, unpack |
|
from einops.layers.torch import Rearrange |
|
from pydantic import BaseModel |
|
from torch import Tensor, int32 |
|
from torch.amp import autocast |
|
from torch.nn import Module |
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
|
from vui.utils import decompile_state_dict |
|
|
|
|
|
def exists(v): |
|
return v is not None |
|
|
|
|
|
def default(*args): |
|
for arg in args: |
|
if exists(arg): |
|
return arg |
|
return None |
|
|
|
|
|
def maybe(fn): |
|
@wraps(fn) |
|
def inner(x, *args, **kwargs): |
|
if not exists(x): |
|
return x |
|
return fn(x, *args, **kwargs) |
|
|
|
return inner |
|
|
|
|
|
def pack_one(t, pattern): |
|
return pack([t], pattern) |
|
|
|
|
|
def unpack_one(t, ps, pattern): |
|
return unpack(t, ps, pattern)[0] |
|
|
|
|
|
def round_ste(z: Tensor) -> Tensor: |
|
"""Round with straight through gradients.""" |
|
zhat = z.round() |
|
return z + (zhat - z).detach() |
|
|
|
|
|
class FSQ(Module): |
|
def __init__( |
|
self, |
|
levels: List[int], |
|
dim: int | None = None, |
|
num_codebooks: int = 1, |
|
keep_num_codebooks_dim: bool | None = None, |
|
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), |
|
channel_first: bool = True, |
|
projection_has_bias: bool = True, |
|
return_indices=True, |
|
force_quantization_f32: bool = True, |
|
): |
|
super().__init__() |
|
|
|
_levels = torch.tensor(levels, dtype=int32) |
|
self.register_buffer("_levels", _levels, persistent=False) |
|
|
|
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) |
|
self.register_buffer("_basis", _basis, persistent=False) |
|
|
|
codebook_dim = len(levels) |
|
self.codebook_dim = codebook_dim |
|
|
|
effective_codebook_dim = codebook_dim * num_codebooks |
|
self.num_codebooks = num_codebooks |
|
self.effective_codebook_dim = effective_codebook_dim |
|
|
|
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) |
|
assert not (num_codebooks > 1 and not keep_num_codebooks_dim) |
|
self.keep_num_codebooks_dim = keep_num_codebooks_dim |
|
|
|
self.dim = default(dim, len(_levels) * num_codebooks) |
|
|
|
self.channel_first = channel_first |
|
|
|
has_projections = self.dim != effective_codebook_dim |
|
self.project_in = ( |
|
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) |
|
if has_projections |
|
else nn.Identity() |
|
) |
|
self.project_out = ( |
|
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) |
|
if has_projections |
|
else nn.Identity() |
|
) |
|
|
|
self.has_projections = has_projections |
|
|
|
self.return_indices = return_indices |
|
if return_indices: |
|
self.codebook_size = self._levels.prod().item() |
|
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size)) |
|
self.register_buffer( |
|
"implicit_codebook", implicit_codebook, persistent=False |
|
) |
|
|
|
self.allowed_dtypes = allowed_dtypes |
|
self.force_quantization_f32 = force_quantization_f32 |
|
|
|
def bound(self, z, eps: float = 1e-3): |
|
"""Bound `z`, an array of shape (..., d).""" |
|
half_l = (self._levels - 1) * (1 + eps) / 2 |
|
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) |
|
shift = (offset / half_l).atanh() |
|
return (z + shift).tanh() * half_l - offset |
|
|
|
def quantize(self, z): |
|
"""Quantizes z, returns quantized zhat, same shape as z.""" |
|
quantized = round_ste(self.bound(z)) |
|
half_width = self._levels // 2 |
|
return quantized / half_width |
|
|
|
def _scale_and_shift(self, zhat_normalized): |
|
half_width = self._levels // 2 |
|
return (zhat_normalized * half_width) + half_width |
|
|
|
def _scale_and_shift_inverse(self, zhat): |
|
half_width = self._levels // 2 |
|
return (zhat - half_width) / half_width |
|
|
|
def _indices_to_codes(self, indices): |
|
level_indices = self.indices_to_level_indices(indices) |
|
codes = self._scale_and_shift_inverse(level_indices) |
|
return codes |
|
|
|
def codes_to_indices(self, zhat): |
|
"""Converts a `code` to an index in the codebook.""" |
|
assert zhat.shape[-1] == self.codebook_dim |
|
zhat = self._scale_and_shift(zhat) |
|
return (zhat * self._basis).sum(dim=-1).to(int32) |
|
|
|
def indices_to_level_indices(self, indices): |
|
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" |
|
indices = rearrange(indices, "... -> ... 1") |
|
codes_non_centered = (indices // self._basis) % self._levels |
|
return codes_non_centered |
|
|
|
def indices_to_codes(self, indices): |
|
"""Inverse of `codes_to_indices`.""" |
|
assert exists(indices) |
|
|
|
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) |
|
|
|
codes = self._indices_to_codes(indices) |
|
|
|
if self.keep_num_codebooks_dim: |
|
codes = rearrange(codes, "... c d -> ... (c d)") |
|
|
|
codes = self.project_out(codes) |
|
|
|
if is_img_or_video or self.channel_first: |
|
codes = rearrange(codes, "b ... d -> b d ...") |
|
|
|
return codes |
|
|
|
def forward(self, z: Tensor): |
|
""" |
|
einstein notation |
|
b - batch |
|
n - sequence (or flattened spatial dimensions) |
|
d - feature dimension |
|
c - number of codebook dim |
|
""" |
|
device_type = z.device.type |
|
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
if self.channel_first: |
|
z = rearrange(z, "b d ... -> b ... d") |
|
z, ps = pack_one(z, "b * d") |
|
|
|
assert ( |
|
z.shape[-1] == self.dim |
|
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" |
|
|
|
z = self.project_in(z) |
|
|
|
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) |
|
|
|
|
|
|
|
force_f32 = self.force_quantization_f32 |
|
quantization_context = ( |
|
partial(autocast, device_type=device_type, enabled=False) |
|
if force_f32 |
|
else nullcontext |
|
) |
|
|
|
with quantization_context(): |
|
orig_dtype = z.dtype |
|
|
|
if force_f32 and orig_dtype not in self.allowed_dtypes: |
|
z = z.float() |
|
|
|
codes = self.quantize(z) |
|
|
|
|
|
|
|
indices = None |
|
|
|
if self.return_indices: |
|
indices = self.codes_to_indices(codes) |
|
|
|
codes = rearrange(codes, "b n c d -> b n (c d)") |
|
|
|
codes = codes.type(orig_dtype) |
|
|
|
|
|
|
|
out = self.project_out(codes) |
|
|
|
|
|
|
|
if self.channel_first: |
|
out = unpack_one(out, ps, "b * d") |
|
out = rearrange(out, "b ... d -> b d ...") |
|
|
|
indices = maybe(unpack_one)(indices, ps, "b * c") |
|
|
|
if not self.keep_num_codebooks_dim and self.return_indices: |
|
indices = maybe(rearrange)(indices, "... 1 -> ...") |
|
|
|
|
|
|
|
return out, indices |
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
|
|
@torch.jit.script |
|
def snake(x, alpha): |
|
shape = x.shape |
|
x = x.reshape(shape[0], shape[1], -1) |
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
|
x = x.reshape(shape) |
|
return x |
|
|
|
|
|
class Snake1d(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
|
def forward(self, x): |
|
return snake(x, self.alpha) |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__(self, dim: int = 16, dilation: int = 1): |
|
super().__init__() |
|
pad = ((7 - 1) * dilation) // 2 |
|
self.block = nn.Sequential( |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
|
Snake1d(dim), |
|
WNConv1d(dim, dim, kernel_size=1), |
|
) |
|
|
|
def forward(self, x): |
|
y = self.block(x) |
|
pad = (x.shape[-1] - y.shape[-1]) // 2 |
|
if pad > 0: |
|
x = x[..., pad:-pad] |
|
return x + y |
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
def __init__(self, dim: int = 16, stride: int = 1): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
ResidualUnit(dim // 2, dilation=1), |
|
ResidualUnit(dim // 2, dilation=3), |
|
ResidualUnit(dim // 2, dilation=9), |
|
Snake1d(dim // 2), |
|
WNConv1d( |
|
dim // 2, |
|
dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=math.ceil(stride / 2), |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int = 64, |
|
strides: list = [2, 4, 8, 8], |
|
d_latent: int = 64, |
|
): |
|
super().__init__() |
|
|
|
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
|
|
|
|
|
for stride in strides: |
|
d_model *= 2 |
|
self.block += [EncoderBlock(d_model, stride=stride)] |
|
|
|
|
|
self.block += [ |
|
Snake1d(d_model), |
|
WNConv1d(d_model, d_latent, kernel_size=3, padding=1), |
|
] |
|
|
|
|
|
self.block = nn.Sequential(*self.block) |
|
self.enc_dim = d_model |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
Snake1d(input_dim), |
|
WNConvTranspose1d( |
|
input_dim, |
|
output_dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=math.ceil(stride / 2), |
|
), |
|
ResidualUnit(output_dim, dilation=1), |
|
ResidualUnit(output_dim, dilation=3), |
|
ResidualUnit(output_dim, dilation=9), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
input_channel: int, |
|
channels: int, |
|
rates: list[int], |
|
d_out: int = 1, |
|
): |
|
super().__init__() |
|
|
|
|
|
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] |
|
|
|
|
|
for i, stride in enumerate(rates): |
|
input_dim = channels // 2**i |
|
output_dim = channels // 2 ** (i + 1) |
|
layers += [DecoderBlock(input_dim, output_dim, stride)] |
|
|
|
|
|
layers += [ |
|
Snake1d(output_dim), |
|
WNConv1d(output_dim, d_out, kernel_size=7, padding=3), |
|
nn.Tanh(), |
|
] |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, z: Tensor): |
|
return self.model(z) |
|
|
|
|
|
class FiniteScalarQuantize(nn.Module): |
|
def __init__( |
|
self, latent_dim: int, levels: list[int], *, stride: int = 1, mlp: bool = False |
|
): |
|
super().__init__() |
|
|
|
self.stride = stride |
|
|
|
codebook_dim = len(levels) |
|
|
|
self.in_proj = WNConv1d(latent_dim, codebook_dim, kernel_size=1) |
|
self.quantize = FSQ(levels=levels, channel_first=True) |
|
self.out_proj = WNConv1d(codebook_dim, latent_dim, kernel_size=1) |
|
|
|
if mlp: |
|
self.mlp = nn.Sequential( |
|
Rearrange("B C T -> B T C"), |
|
nn.Linear(latent_dim, 4 * latent_dim), |
|
nn.GELU(), |
|
nn.Linear(4 * latent_dim, latent_dim), |
|
Rearrange("B T C -> B C T"), |
|
) |
|
else: |
|
self.mlp = None |
|
|
|
def from_indices(self, indices: Tensor): |
|
B, T = indices.size() |
|
z_q = self.quantize.indices_to_codes(indices) |
|
z_q = self.out_proj(z_q) |
|
return z_q |
|
|
|
def forward(self, z: Tensor, *args): |
|
if self.stride > 1: |
|
z = F.avg_pool1d(z, self.stride, stride=self.stride) |
|
|
|
z_e = self.in_proj(z) |
|
|
|
|
|
|
|
|
|
|
|
z_q, indices = self.quantize(z_e) |
|
|
|
|
|
z_q = self.out_proj(z_q) |
|
|
|
if self.stride > 1: |
|
z_e = z_e.repeat_interleave(self.stride, dim=-1) |
|
z_q = z_q.repeat_interleave(self.stride, dim=-1) |
|
indices = indices.repeat_interleave(self.stride, dim=-1) |
|
|
|
if self.mlp is not None: |
|
z_q = self.mlp(z_q) |
|
|
|
return z_q, indices, z_e |
|
|
|
|
|
class ResidualFiniteScalarQuantize(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
latent_dim: int, |
|
n_quantizers: int, |
|
levels: list[int], |
|
strides: list[int] | None = None, |
|
quantizer_dropout: float = 0.0, |
|
mlp: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.n_quantizers = n_quantizers |
|
self.quantizer_dropout = quantizer_dropout |
|
|
|
strides = [1] * n_quantizers if strides is None else strides |
|
|
|
assert ( |
|
len(strides) == n_quantizers |
|
), "Strides must be provided for each codebook" |
|
|
|
scales = [] |
|
quantizers = [] |
|
levels_tensor = torch.tensor(levels, dtype=torch.float32) |
|
|
|
for i in range(n_quantizers): |
|
scales.append((levels_tensor - 1) ** -i) |
|
quantizers.append( |
|
FiniteScalarQuantize( |
|
latent_dim=latent_dim, levels=levels, stride=strides[i], mlp=mlp |
|
) |
|
) |
|
|
|
self.quantizers = nn.ModuleList(quantizers) |
|
|
|
self.register_buffer("scales", torch.stack(scales), persistent=False) |
|
|
|
codebooks = [ |
|
quantizer.quantize.implicit_codebook for quantizer in self.quantizers |
|
] |
|
self.codebooks = torch.stack(codebooks, dim=0) |
|
|
|
def from_indices(self, indices: Tensor): |
|
B, Q, T = indices.size() |
|
|
|
z_q = 0.0 |
|
|
|
for i, quantizer in enumerate(self.quantizers): |
|
z_q_i = quantizer.from_indices(indices[:, i]) |
|
z_q = z_q + z_q_i |
|
|
|
return z_q |
|
|
|
def forward(self, z: Tensor, n_quantizers: int | None = None): |
|
"""Quantized the input tensor using a fixed set of `n` codebooks and returns |
|
the corresponding codebook vectors |
|
Parameters |
|
---------- |
|
z : Tensor[B x D x T] |
|
n_quantizers : int, optional |
|
No. of quantizers to use |
|
(n_quantizers < self.n_codebooks ex: for quantizer dropout) |
|
Note: if `self.quantizer_dropout` is True, this argument is ignored |
|
when in training mode, and a random number of quantizers is used. |
|
Returns |
|
------- |
|
dict |
|
A dictionary with the following keys: |
|
|
|
"z" : Tensor[B x D x T] |
|
Quantized continuous representation of input |
|
"codes" : Tensor[B x N x T] |
|
Codebook indices for each codebook |
|
(quantized discrete representation of input) |
|
"latents" : Tensor[B x N*D x T] |
|
Projected latents (continuous representation of input before quantization) |
|
""" |
|
B = z.shape[0] |
|
z_q = 0 |
|
residual = z |
|
|
|
indices = [] |
|
latents = [] |
|
|
|
if n_quantizers is None: |
|
n_quantizers = self.n_quantizers |
|
|
|
if self.training: |
|
n_quantizers = torch.ones((B,)) * self.n_quantizers + 1 |
|
dropout = torch.randint(1, self.n_quantizers + 1, (B,)) |
|
n_dropout = int(B * self.quantizer_dropout) |
|
n_quantizers[:n_dropout] = dropout[:n_dropout] |
|
n_quantizers = n_quantizers.to(z.device) |
|
|
|
for i, quantizer in enumerate(self.quantizers): |
|
if not self.training and i >= n_quantizers: |
|
break |
|
|
|
z_q_i, indices_i, z_e_i = quantizer(residual) |
|
|
|
residual = residual - z_q_i.detach() |
|
|
|
mask = torch.full((B,), fill_value=i, device=z.device) < n_quantizers |
|
z_q = z_q + z_q_i * mask[:, None, None] |
|
|
|
indices.append(indices_i) |
|
latents.append(z_e_i) |
|
|
|
indices = torch.stack(indices, dim=1) |
|
latents = torch.cat(latents, dim=1) |
|
|
|
return z_q, indices, latents |
|
|
|
|
|
class FluacConfig(BaseModel): |
|
sample_rate: int = 44100 |
|
|
|
codebook_size: int | None = None |
|
|
|
encoder_dim: int = 64 |
|
encoder_rates: list[int] = [2, 4, 8, 8] |
|
|
|
quantizer_strides: list[int] | None = None |
|
n_quantizers: int = 1 |
|
fsq_levels: list[int] | None = [8, 5, 5, 5] |
|
decoder_dim: int = 1536 |
|
decoder_rates: list[int] = [8, 8, 4, 2] |
|
|
|
@property |
|
def hop_length(self) -> int: |
|
return math.prod(self.encoder_rates) |
|
|
|
@property |
|
def latent_dim(self) -> int: |
|
return self.encoder_dim * (2 ** len(self.encoder_rates)) |
|
|
|
@property |
|
def effective_codebook_size(self) -> int: |
|
return math.prod(self.fsq_levels) |
|
|
|
|
|
class Fluac(nn.Module): |
|
Q9_22KHZ = "fluac-22hz-22khz.pt" |
|
|
|
def __init__(self, config: FluacConfig): |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
self.encoder = Encoder( |
|
config.encoder_dim, config.encoder_rates, config.latent_dim |
|
) |
|
|
|
self.quantizer = ResidualFiniteScalarQuantize( |
|
latent_dim=config.latent_dim, |
|
n_quantizers=config.n_quantizers, |
|
levels=config.fsq_levels, |
|
strides=config.quantizer_strides, |
|
) |
|
|
|
self.decoder = Decoder( |
|
config.latent_dim, |
|
config.decoder_dim, |
|
config.decoder_rates, |
|
) |
|
|
|
self.apply(init_weights) |
|
|
|
@staticmethod |
|
def from_pretrained(name: str = Q9_22KHZ): |
|
if path.exists(name): |
|
checkpoint_path = name |
|
else: |
|
from huggingface_hub import hf_hub_download |
|
|
|
checkpoint_path = hf_hub_download( |
|
"fluxions/vui", |
|
name, |
|
) |
|
|
|
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu") |
|
config = checkpoint["config"] |
|
if "model" in config: |
|
model_config = FluacConfig(**config["model"]) |
|
else: |
|
model_config = FluacConfig(**config) |
|
|
|
generator = Fluac(model_config).eval() |
|
ckpt = decompile_state_dict(checkpoint["generator"]) |
|
generator.load_state_dict(ckpt) |
|
return generator |
|
|
|
def pad(self, waveform: Tensor): |
|
T = waveform.size(-1) |
|
right_pad = math.ceil(T / self.config.hop_length) * self.config.hop_length - T |
|
waveform = F.pad(waveform, (0, right_pad)) |
|
return waveform |
|
|
|
@torch.inference_mode() |
|
def from_indices(self, indices: Tensor): |
|
z_q = self.quantizer.from_indices(indices) |
|
waveform = self.decoder(z_q) |
|
return waveform |
|
|
|
@torch.inference_mode() |
|
def encode(self, waveforms: Tensor, n_quantizers: int | None = None): |
|
|
|
waveforms = waveforms.flatten()[None][None] |
|
waveforms = self.pad(waveforms) |
|
B, C, T = waveforms.size() |
|
z = self.encoder(waveforms) |
|
z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers) |
|
return codes |
|
|
|
def forward(self, waveforms: Tensor, n_quantizers: int | None = None): |
|
B, C, T = waveforms.size() |
|
waveforms = self.pad(waveforms) |
|
z = self.encoder(waveforms) |
|
z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers) |
|
|
|
recons = self.decoder(z_q) |
|
recons = recons[..., :T] |
|
return { |
|
"recons": recons, |
|
"codes": codes, |
|
} |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
@property |
|
def hz(self): |
|
import numpy as np |
|
|
|
return self.config.sample_rate / np.prod(self.config.encoder_rates).item() |
|
|
|
|
|
if __name__ == "__main__": |
|
codec = Fluac.from_pretrained(Fluac.Q9_22KHZ) |
|
print(codec.config) |
|
wav = torch.rand(1, 1, 22050) |
|
wav = codec.pad(wav) |
|
codes = codec.encode(wav) |
|
breakpoint() |
|
|