File size: 21,200 Bytes
88afac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
import math
from contextlib import nullcontext
from functools import partial, wraps
from os import path
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, unpack
from einops.layers.torch import Rearrange
from pydantic import BaseModel
from torch import Tensor, int32
from torch.amp import autocast
from torch.nn import Module
from torch.nn.utils.parametrizations import weight_norm

from vui.utils import decompile_state_dict


def exists(v):
    return v is not None


def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None


def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)

    return inner


def pack_one(t, pattern):
    return pack([t], pattern)


def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]


def round_ste(z: Tensor) -> Tensor:
    """Round with straight through gradients."""
    zhat = z.round()
    return z + (zhat - z).detach()


class FSQ(Module):
    def __init__(
        self,
        levels: List[int],
        dim: int | None = None,
        num_codebooks: int = 1,
        keep_num_codebooks_dim: bool | None = None,
        allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
        channel_first: bool = True,
        projection_has_bias: bool = True,
        return_indices=True,
        force_quantization_f32: bool = True,
    ):
        super().__init__()

        _levels = torch.tensor(levels, dtype=int32)
        self.register_buffer("_levels", _levels, persistent=False)

        _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
        self.register_buffer("_basis", _basis, persistent=False)

        codebook_dim = len(levels)
        self.codebook_dim = codebook_dim

        effective_codebook_dim = codebook_dim * num_codebooks
        self.num_codebooks = num_codebooks
        self.effective_codebook_dim = effective_codebook_dim

        keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
        assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
        self.keep_num_codebooks_dim = keep_num_codebooks_dim

        self.dim = default(dim, len(_levels) * num_codebooks)

        self.channel_first = channel_first

        has_projections = self.dim != effective_codebook_dim
        self.project_in = (
            nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
            if has_projections
            else nn.Identity()
        )
        self.project_out = (
            nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
            if has_projections
            else nn.Identity()
        )

        self.has_projections = has_projections

        self.return_indices = return_indices
        if return_indices:
            self.codebook_size = self._levels.prod().item()
            implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
            self.register_buffer(
                "implicit_codebook", implicit_codebook, persistent=False
            )

        self.allowed_dtypes = allowed_dtypes
        self.force_quantization_f32 = force_quantization_f32

    def bound(self, z, eps: float = 1e-3):
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels - 1) * (1 + eps) / 2
        offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
        shift = (offset / half_l).atanh()
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z):
        """Quantizes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))
        half_width = self._levels // 2  # Renormalize to [-1, 1].
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized):
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat):
        half_width = self._levels // 2
        return (zhat - half_width) / half_width

    def _indices_to_codes(self, indices):
        level_indices = self.indices_to_level_indices(indices)
        codes = self._scale_and_shift_inverse(level_indices)
        return codes

    def codes_to_indices(self, zhat):
        """Converts a `code` to an index in the codebook."""
        assert zhat.shape[-1] == self.codebook_dim
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis).sum(dim=-1).to(int32)

    def indices_to_level_indices(self, indices):
        """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
        indices = rearrange(indices, "... -> ... 1")
        codes_non_centered = (indices // self._basis) % self._levels
        return codes_non_centered

    def indices_to_codes(self, indices):
        """Inverse of `codes_to_indices`."""
        assert exists(indices)

        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))

        codes = self._indices_to_codes(indices)

        if self.keep_num_codebooks_dim:
            codes = rearrange(codes, "... c d -> ... (c d)")

        codes = self.project_out(codes)

        if is_img_or_video or self.channel_first:
            codes = rearrange(codes, "b ... d -> b d ...")

        return codes

    def forward(self, z: Tensor):
        """
        einstein notation
        b - batch
        n - sequence (or flattened spatial dimensions)
        d - feature dimension
        c - number of codebook dim
        """
        device_type = z.device.type

        with torch.autocast(device_type=device_type, enabled=False):
            if self.channel_first:
                z = rearrange(z, "b d ... -> b ... d")
                z, ps = pack_one(z, "b * d")

            assert (
                z.shape[-1] == self.dim
            ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"

            z = self.project_in(z)

            z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)

            # whether to force quantization step to be full precision or not

            force_f32 = self.force_quantization_f32
            quantization_context = (
                partial(autocast, device_type=device_type, enabled=False)
                if force_f32
                else nullcontext
            )

            with quantization_context():
                orig_dtype = z.dtype

                if force_f32 and orig_dtype not in self.allowed_dtypes:
                    z = z.float()

                codes = self.quantize(z)

                # returning indices could be optional

                indices = None

                if self.return_indices:
                    indices = self.codes_to_indices(codes)

                codes = rearrange(codes, "b n c d -> b n (c d)")

                codes = codes.type(orig_dtype)

            # project out

            out = self.project_out(codes)

            # reconstitute image or video dimensions

            if self.channel_first:
                out = unpack_one(out, ps, "b * d")
                out = rearrange(out, "b ... d -> b d ...")

                indices = maybe(unpack_one)(indices, ps, "b * c")

            if not self.keep_num_codebooks_dim and self.return_indices:
                indices = maybe(rearrange)(indices, "... 1 -> ...")

            # return quantized output and indices

            return out, indices


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
    shape = x.shape
    x = x.reshape(shape[0], shape[1], -1)
    x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
    x = x.reshape(shape)
    return x


class Snake1d(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, channels, 1))

    def forward(self, x):
        return snake(x, self.alpha)


def init_weights(m):
    if isinstance(m, nn.Conv1d):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)


class ResidualUnit(nn.Module):
    def __init__(self, dim: int = 16, dilation: int = 1):
        super().__init__()
        pad = ((7 - 1) * dilation) // 2
        self.block = nn.Sequential(
            Snake1d(dim),
            WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
            Snake1d(dim),
            WNConv1d(dim, dim, kernel_size=1),
        )

    def forward(self, x):
        y = self.block(x)
        pad = (x.shape[-1] - y.shape[-1]) // 2
        if pad > 0:
            x = x[..., pad:-pad]
        return x + y


class EncoderBlock(nn.Module):
    def __init__(self, dim: int = 16, stride: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            ResidualUnit(dim // 2, dilation=1),
            ResidualUnit(dim // 2, dilation=3),
            ResidualUnit(dim // 2, dilation=9),
            Snake1d(dim // 2),
            WNConv1d(
                dim // 2,
                dim,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2),
            ),
        )

    def forward(self, x):
        return self.block(x)


class Encoder(nn.Module):
    def __init__(
        self,
        d_model: int = 64,
        strides: list = [2, 4, 8, 8],
        d_latent: int = 64,
    ):
        super().__init__()
        # Create first convolution
        self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]

        # Create EncoderBlocks that double channels as they downsample by `stride`
        for stride in strides:
            d_model *= 2
            self.block += [EncoderBlock(d_model, stride=stride)]

        # Create last convolution
        self.block += [
            Snake1d(d_model),
            WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
        ]

        # Wrap black into nn.Sequential
        self.block = nn.Sequential(*self.block)
        self.enc_dim = d_model

    def forward(self, x):
        return self.block(x)


class DecoderBlock(nn.Module):
    def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            Snake1d(input_dim),
            WNConvTranspose1d(
                input_dim,
                output_dim,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2),
            ),
            ResidualUnit(output_dim, dilation=1),
            ResidualUnit(output_dim, dilation=3),
            ResidualUnit(output_dim, dilation=9),
        )

    def forward(self, x):
        return self.block(x)


class Decoder(nn.Module):
    def __init__(
        self,
        input_channel: int,
        channels: int,
        rates: list[int],
        d_out: int = 1,
    ):
        super().__init__()

        # Add first conv layer
        layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]

        # Add upsampling + MRF blocks
        for i, stride in enumerate(rates):
            input_dim = channels // 2**i
            output_dim = channels // 2 ** (i + 1)
            layers += [DecoderBlock(input_dim, output_dim, stride)]

        # Add final conv layer
        layers += [
            Snake1d(output_dim),
            WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*layers)

    # @torch.compile(dynamic=True)
    def forward(self, z: Tensor):
        return self.model(z)


class FiniteScalarQuantize(nn.Module):
    def __init__(
        self, latent_dim: int, levels: list[int], *, stride: int = 1, mlp: bool = False
    ):
        super().__init__()

        self.stride = stride

        codebook_dim = len(levels)

        self.in_proj = WNConv1d(latent_dim, codebook_dim, kernel_size=1)
        self.quantize = FSQ(levels=levels, channel_first=True)
        self.out_proj = WNConv1d(codebook_dim, latent_dim, kernel_size=1)

        if mlp:
            self.mlp = nn.Sequential(
                Rearrange("B C T -> B T C"),
                nn.Linear(latent_dim, 4 * latent_dim),
                nn.GELU(),
                nn.Linear(4 * latent_dim, latent_dim),
                Rearrange("B T C -> B C T"),
            )
        else:
            self.mlp = None

    def from_indices(self, indices: Tensor):
        B, T = indices.size()
        z_q = self.quantize.indices_to_codes(indices)
        z_q = self.out_proj(z_q)
        return z_q

    def forward(self, z: Tensor, *args):
        if self.stride > 1:
            z = F.avg_pool1d(z, self.stride, stride=self.stride)

        z_e = self.in_proj(z)  # z_e : (B x D x T)

        # we're channels first
        # scale = scale.unsqueeze(-1)

        # z_e = z_e / scale
        z_q, indices = self.quantize(z_e)
        # z_q = z_q * scale

        z_q = self.out_proj(z_q)

        if self.stride > 1:
            z_e = z_e.repeat_interleave(self.stride, dim=-1)
            z_q = z_q.repeat_interleave(self.stride, dim=-1)
            indices = indices.repeat_interleave(self.stride, dim=-1)

        if self.mlp is not None:
            z_q = self.mlp(z_q)

        return z_q, indices, z_e


class ResidualFiniteScalarQuantize(nn.Module):
    def __init__(
        self,
        *,
        latent_dim: int,
        n_quantizers: int,
        levels: list[int],
        strides: list[int] | None = None,
        quantizer_dropout: float = 0.0,
        mlp: bool = False,
    ):
        super().__init__()

        self.n_quantizers = n_quantizers
        self.quantizer_dropout = quantizer_dropout

        strides = [1] * n_quantizers if strides is None else strides

        assert (
            len(strides) == n_quantizers
        ), "Strides must be provided for each codebook"

        scales = []
        quantizers = []
        levels_tensor = torch.tensor(levels, dtype=torch.float32)

        for i in range(n_quantizers):
            scales.append((levels_tensor - 1) ** -i)
            quantizers.append(
                FiniteScalarQuantize(
                    latent_dim=latent_dim, levels=levels, stride=strides[i], mlp=mlp
                )
            )

        self.quantizers = nn.ModuleList(quantizers)

        self.register_buffer("scales", torch.stack(scales), persistent=False)

        codebooks = [
            quantizer.quantize.implicit_codebook for quantizer in self.quantizers
        ]
        self.codebooks = torch.stack(codebooks, dim=0)

    def from_indices(self, indices: Tensor):
        B, Q, T = indices.size()

        z_q = 0.0

        for i, quantizer in enumerate(self.quantizers):
            z_q_i = quantizer.from_indices(indices[:, i])
            z_q = z_q + z_q_i

        return z_q

    def forward(self, z: Tensor, n_quantizers: int | None = None):
        """Quantized the input tensor using a fixed set of `n` codebooks and returns
        the corresponding codebook vectors
        Parameters
        ----------
        z : Tensor[B x D x T]
        n_quantizers : int, optional
            No. of quantizers to use
            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
            Note: if `self.quantizer_dropout` is True, this argument is ignored
                when in training mode, and a random number of quantizers is used.
        Returns
        -------
        dict
            A dictionary with the following keys:

            "z" : Tensor[B x D x T]
                Quantized continuous representation of input
            "codes" : Tensor[B x N x T]
                Codebook indices for each codebook
                (quantized discrete representation of input)
            "latents" : Tensor[B x N*D x T]
                Projected latents (continuous representation of input before quantization)
        """
        B = z.shape[0]
        z_q = 0
        residual = z

        indices = []
        latents = []

        if n_quantizers is None:
            n_quantizers = self.n_quantizers

        if self.training:
            n_quantizers = torch.ones((B,)) * self.n_quantizers + 1
            dropout = torch.randint(1, self.n_quantizers + 1, (B,))
            n_dropout = int(B * self.quantizer_dropout)
            n_quantizers[:n_dropout] = dropout[:n_dropout]
            n_quantizers = n_quantizers.to(z.device)

        for i, quantizer in enumerate(self.quantizers):
            if not self.training and i >= n_quantizers:
                break

            z_q_i, indices_i, z_e_i = quantizer(residual)

            residual = residual - z_q_i.detach()

            mask = torch.full((B,), fill_value=i, device=z.device) < n_quantizers
            z_q = z_q + z_q_i * mask[:, None, None]

            indices.append(indices_i)
            latents.append(z_e_i)

        indices = torch.stack(indices, dim=1)
        latents = torch.cat(latents, dim=1)

        return z_q, indices, latents


class FluacConfig(BaseModel):
    sample_rate: int = 44100

    codebook_size: int | None = None

    encoder_dim: int = 64
    encoder_rates: list[int] = [2, 4, 8, 8]

    quantizer_strides: list[int] | None = None  # SNAC style strides
    n_quantizers: int = 1
    fsq_levels: list[int] | None = [8, 5, 5, 5]  # 1000
    decoder_dim: int = 1536
    decoder_rates: list[int] = [8, 8, 4, 2]

    @property
    def hop_length(self) -> int:
        return math.prod(self.encoder_rates)

    @property
    def latent_dim(self) -> int:
        return self.encoder_dim * (2 ** len(self.encoder_rates))

    @property
    def effective_codebook_size(self) -> int:
        return math.prod(self.fsq_levels)


class Fluac(nn.Module):
    Q9_22KHZ = "fluac-22hz-22khz.pt"

    def __init__(self, config: FluacConfig):
        super().__init__()

        self.config = config

        self.encoder = Encoder(
            config.encoder_dim, config.encoder_rates, config.latent_dim
        )

        self.quantizer = ResidualFiniteScalarQuantize(
            latent_dim=config.latent_dim,
            n_quantizers=config.n_quantizers,
            levels=config.fsq_levels,
            strides=config.quantizer_strides,
        )

        self.decoder = Decoder(
            config.latent_dim,
            config.decoder_dim,
            config.decoder_rates,
        )

        self.apply(init_weights)

    @staticmethod
    def from_pretrained(name: str = Q9_22KHZ):
        if path.exists(name):
            checkpoint_path = name
        else:
            from huggingface_hub import hf_hub_download

            checkpoint_path = hf_hub_download(
                "fluxions/vui",
                name,
            )

        checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
        config = checkpoint["config"]
        if "model" in config:
            model_config = FluacConfig(**config["model"])
        else:
            model_config = FluacConfig(**config)

        generator = Fluac(model_config).eval()
        ckpt = decompile_state_dict(checkpoint["generator"])
        generator.load_state_dict(ckpt)
        return generator

    def pad(self, waveform: Tensor):
        T = waveform.size(-1)
        right_pad = math.ceil(T / self.config.hop_length) * self.config.hop_length - T
        waveform = F.pad(waveform, (0, right_pad))
        return waveform

    @torch.inference_mode()
    def from_indices(self, indices: Tensor):
        z_q = self.quantizer.from_indices(indices)
        waveform = self.decoder(z_q)
        return waveform

    @torch.inference_mode()
    def encode(self, waveforms: Tensor, n_quantizers: int | None = None):
        # Ensure that waveforms is 3 dima
        waveforms = waveforms.flatten()[None][None]
        waveforms = self.pad(waveforms)
        B, C, T = waveforms.size()
        z = self.encoder(waveforms)
        z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
        return codes

    def forward(self, waveforms: Tensor, n_quantizers: int | None = None):
        B, C, T = waveforms.size()
        waveforms = self.pad(waveforms)
        z = self.encoder(waveforms)
        z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)

        recons = self.decoder(z_q)
        recons = recons[..., :T]
        return {
            "recons": recons,
            "codes": codes,
        }

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def hz(self):
        import numpy as np

        return self.config.sample_rate / np.prod(self.config.encoder_rates).item()


if __name__ == "__main__":
    codec = Fluac.from_pretrained(Fluac.Q9_22KHZ)
    print(codec.config)
    wav = torch.rand(1, 1, 22050)
    wav = codec.pad(wav)
    codes = codec.encode(wav)
    breakpoint()