File size: 11,975 Bytes
88afac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
import math
import time
from functools import partial

import torch
import torch.nn.functional as F
from torch import Tensor


def load_what_you_can(checkpoint: dict, model: torch.nn.Module):
    """
    This method takes a checkpoint and loads as many weights from it as possible:

    If they are the same shape, there's nothing to do

    Will load the smallest shape otherwise.
    """
    import torch

    model_state_dict = model.state_dict()
    checkpoint_state_dict = checkpoint

    for name, param in checkpoint_state_dict.items():
        if name not in model_state_dict:
            print(f"Ignoring parameter '{name}' because it is not found in the model")
            continue

        model_state = model_state_dict[name]
        mshape = model_state.shape
        pshape = param.shape

        if pshape == mshape:
            model_state.copy_(param)
            continue

        if len(pshape) != len(mshape):
            # Completely different shapes so probably unwise to merge
            continue

        min_shape = [
            min(param.shape[i], model_state.shape[i]) for i in range(len(param.shape))
        ]
        print(name, "model:", mshape, "chkpt:", pshape, "loading:", min_shape)
        idxs = torch.meshgrid(*[torch.arange(s) for s in min_shape])
        model_state[tuple(idxs)].copy_(param[tuple(idxs)])

    return model.load_state_dict(model_state_dict)


def multimap(
    items: list, func: callable, workers=4, desc=None, thread=False, chunk_size=128
) -> list:
    """
    Quick and dirty multiprocessing that will return the result of func if it returns None
    """
    from tqdm.contrib.concurrent import process_map, thread_map

    m = thread_map if thread else process_map
    length = None
    try:
        length = len(items)
    except Exception as e:
        print(e, "getting length")

    results = m(
        func,
        items,
        leave=False,
        desc=desc,
        max_workers=workers,
        total=length,
        chunksize=chunk_size,
    )
    return list(filter(lambda x: x is not None, results))


def round_up(num: float, factor: int):
    return factor * math.ceil(num / factor)


def left_padding_mask(lengths, max_len, device=None, dtype=None):
    masks = []
    if not max_len:
        max_len = max(lengths)
    for l in lengths:
        mask = torch.empty(l, l, device=device, dtype=dtype).fill_(-torch.inf).triu_(1)
        diff = max_len - l
        mask = F.pad(mask, (diff, 0, diff, 0), value=-torch.inf)
        masks.append(mask)

    masks = torch.stack(masks)
    return masks[:, None]


def seed_all(seed: int):
    import random

    import numpy as np
    import torch

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def split_bucket_path(url: str) -> tuple[str, str]:
    url = url.replace("s3://", "")
    url = url.replace("sj://", "")
    url = url.replace("r2://", "")
    bucket = url.split("/")[0]
    path = "/".join(url.split("/")[1:])
    return bucket, path


def prob_mask_like(shape, prob: float, device):
    import torch

    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob


def round_up_to_multiple(n: int, multiple: int) -> int:
    if n % multiple != 0:
        n += multiple - (n % multiple)

    return n


def warmup_then_cosine_decay(
    step: int, *, warmup_steps: int, steps: int, min_lr: float, max_lr: float
):
    eps = 1e-9
    cooldown_steps = warmup_steps
    if step < warmup_steps:
        return min_lr + step * (max_lr - min_lr) / (warmup_steps)
    elif step > steps:
        return min_lr
    elif step < steps - cooldown_steps:
        decay_ratio = (step - warmup_steps) / (steps - warmup_steps - cooldown_steps)
        # assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
        return min_lr + coeff * (max_lr - min_lr)
    else:
        # decay from min_lr to 0
        return min_lr * (steps - step) / cooldown_steps + eps


def decay_to_zero(step: int, *, decay_steps: int, steps: int, max_lr: float):
    if step > steps:
        return 0.0
    else:
        gradient = -max_lr / decay_steps

        return max_lr + gradient * step


def cross_entropy_loss(logits, mask, targets):
    import torch
    import torch.nn.functional as F

    B, Q, T, _ = logits.size()
    assert logits.shape[:-1] == targets.shape
    assert mask.shape == targets.shape
    loss = torch.zeros([], device=targets.device)
    codebook_losses = []
    for q in range(Q):
        logits_q = (
            logits[:, q, ...].contiguous().view(-1, logits.size(-1))
        )  # [B x T, card]
        targets_q = targets[:, q, ...].contiguous().view(-1)  # [B x T]
        mask_q = mask[:, q, ...].contiguous().view(-1)  # [B x T]
        ce_targets = targets_q[mask_q]
        ce_logits = logits_q[mask_q]
        q_ce = F.cross_entropy(ce_logits, ce_targets)
        loss += q_ce
        codebook_losses.append(q_ce.detach())
    # average cross entropy across codebooks
    loss = loss / Q
    return loss, codebook_losses


def build_optimizer(
    module, *, weight_decay: float, lr: float, betas: tuple[float, float]
):
    import torch

    param_dict = {pn: p for pn, p in module.named_parameters() if p.requires_grad}

    # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
    # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
    optim_groups = [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": nodecay_params, "weight_decay": 0.0},
    ]
    # num_decay_params = sum(p.numel() for p in decay_params)
    # num_nodecay_params = sum(p.numel() for p in nodecay_params)
    # print(
    #     f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
    # )
    # print(
    #     f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
    # )
    optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)

    return optimizer


def pad_or_cut_right(t: Tensor, padlen: int, value=0) -> Tensor:
    current_len = t.shape[-1]

    if current_len == padlen:
        return t

    if current_len < padlen:
        # Need to pad
        pad_size = (0, padlen - current_len)
        return F.pad(t, pad_size, value=value)
    # Need to cut
    return t[:padlen]


def pad_or_cut_left(t: Tensor, value: int) -> Tensor:
    dims = t.ndim
    current_len = t.shape[0]

    if current_len == value:
        return t

    if current_len < value:
        # Need to pad
        pad_size = (0,) * (2 * (dims - 1)) + (value - current_len, 0)
        return F.pad(t, pad_size)
    # Need to cut
    return t[-value:]


def dl_pt(orig: str):
    from os.path import exists

    import torch

    from vui.storage import s3, split_bucket_path

    if not orig.endswith(".pt"):
        orig = orig + ".pt"

    load = partial(torch.load, weights_only=True)
    if exists(orig):
        return load(orig)
    url = "/data/" + orig

    if exists(url):
        return load(url)
    url = "s3://fluxions/" + orig

    bucket, key = split_bucket_path(url)
    response = s3.get_object(Bucket=bucket, Key=key)
    return load(response["Body"])


def dl_ogg(url: str, start=0, end=-1, sr=None):
    import re
    from os.path import exists

    import soundfile as sf
    import torch

    search_sr = re.search(r"(\d+)/", url)
    if search_sr:
        sr = int(search_sr.group(1))

    local_file = exists(url)

    if exists("/data/audio/" + url):
        local_file = True
        url = "/data/audio/" + url

    if not local_file:
        from vui.storage import s3

        url = "s3://fluxions/" + url
        b, p = split_bucket_path(url)
        url = s3.get_object(Bucket=b, Key=p)["Body"]

    if sr is None:
        if local_file:
            sr = sf.info(url).samplerate
        else:
            sr = sf.info(url.read()).samplerate

    start_frame = int(start * sr)
    num_frames = int(end * sr) - start_frame
    wav, _ = sf.read(url, frames=num_frames, start=start_frame, always_2d=True)
    wav = torch.from_numpy(wav).float()
    wav = wav.T.mean(0, keepdim=True)
    return wav, sr


class timer:
    def __init__(self, name=""):
        self.name = name

    def __enter__(self):
        self.t = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        elapsed = time.perf_counter() - self.t
        print(f"{self.name} {elapsed:.4f}")


@torch.inference_mode()
def decode_audio_from_indices(model, indices, chunk_size=64):
    """
    Decodes audio from indices in batches to avoid memory issues.

    Args:
        model: Codec
        indices: Tensor of shape (1, n_quantizers, sequence_length)
        chunk_size: Number of samples to process at once

    Returns:
        Tensor of reconstructed audio
    """
    device = model.device
    indices = indices.to(device)
    _, _, seq_len = indices.shape
    chunks = seq_len // chunk_size + (1 if seq_len % chunk_size != 0 else 0)

    audio_chunks = []
    for i in range(chunks):
        start_idx = i * chunk_size
        end_idx = min(start_idx + chunk_size, seq_len)
        chunk_indices = indices[:, :, start_idx:end_idx]
        chunk_audio = model.from_indices(chunk_indices)
        audio_chunks.append(chunk_audio.cpu())

    full_audio = torch.cat(audio_chunks, dim=-1)
    return full_audio.flatten()


def normalize_loudness(waveform, sample_rate: int, lufs: float = -12.0):
    """
    Normalize the loudness of an audio tensor using torchaudio.transforms.Loudness.

    Args:
    audio_tensor (torch.Tensor): Input audio tensor of shape (channels, samples)
    sample_rate (int): Sampling rate of the audio
    target_loudness (float): Target loudness in LUFS (default: -16.0 LUFS)

    Returns:
    torch.Tensor: Loudness-normalized audio tensor
    """
    import torchaudio

    # Ensure the input tensor is 2D (add channel dimension if it's 1D)
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)

    # Create a Loudness transform
    loudness_transform = torchaudio.transforms.Loudness(sample_rate)

    # Measure the current loudness
    current_loudness = loudness_transform(waveform)

    # Calculate the required gain
    gain_db = lufs - current_loudness

    # Convert gain from dB to linear scale
    gain_linear = torch.pow(10, gain_db / 20)

    # Apply the gain to normalize loudness
    normalized_audio = waveform * gain_linear

    return normalized_audio


def get_basename_without_extension(file_path):
    from pathlib import Path

    p = Path(file_path)
    return p.stem


def ollama(prompt, MODEL=None):
    import os

    import requests

    OLLAMA_HOST = "http://localhost:11434"
    API = f"{OLLAMA_HOST}/api/generate"

    if MODEL is None:
        MODEL = os.environ.get("OLLAMA_MODEL", "gemma:1b")

    payload = {
        "model": MODEL,
        "prompt": prompt,
        "stream": False,
        "options": {"temperature": 0.9, "top_p": 0.9, "max_tokens": 1000},
    }

    try:
        response = requests.post(API, json=payload)
        response.raise_for_status()  # Raise exception for HTTP errors
        result = response.json()
        return result.get("response", "")
    except requests.exceptions.RequestException as e:
        print(f"Error calling Ollama API: {e}")
        return ""


def decompile_state_dict(state_dict):
    state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    # state_dict = convert_old_weight_norm_to_new(state_dict)
    return {k.replace("module.", ""): v for k, v in state_dict.items()}