vui-space / vui /utils.py
Harry Coultas Blum
tryin relative imports
a1a9d33
import math
import time
from functools import partial
import torch
import torch.nn.functional as F
from torch import Tensor
def load_what_you_can(checkpoint: dict, model: torch.nn.Module):
"""
This method takes a checkpoint and loads as many weights from it as possible:
If they are the same shape, there's nothing to do
Will load the smallest shape otherwise.
"""
import torch
model_state_dict = model.state_dict()
checkpoint_state_dict = checkpoint
for name, param in checkpoint_state_dict.items():
if name not in model_state_dict:
print(f"Ignoring parameter '{name}' because it is not found in the model")
continue
model_state = model_state_dict[name]
mshape = model_state.shape
pshape = param.shape
if pshape == mshape:
model_state.copy_(param)
continue
if len(pshape) != len(mshape):
# Completely different shapes so probably unwise to merge
continue
min_shape = [
min(param.shape[i], model_state.shape[i]) for i in range(len(param.shape))
]
print(name, "model:", mshape, "chkpt:", pshape, "loading:", min_shape)
idxs = torch.meshgrid(*[torch.arange(s) for s in min_shape])
model_state[tuple(idxs)].copy_(param[tuple(idxs)])
return model.load_state_dict(model_state_dict)
def multimap(
items: list, func: callable, workers=4, desc=None, thread=False, chunk_size=128
) -> list:
"""
Quick and dirty multiprocessing that will return the result of func if it returns None
"""
from tqdm.contrib.concurrent import process_map, thread_map
m = thread_map if thread else process_map
length = None
try:
length = len(items)
except Exception as e:
print(e, "getting length")
results = m(
func,
items,
leave=False,
desc=desc,
max_workers=workers,
total=length,
chunksize=chunk_size,
)
return list(filter(lambda x: x is not None, results))
def round_up(num: float, factor: int):
return factor * math.ceil(num / factor)
def left_padding_mask(lengths, max_len, device=None, dtype=None):
masks = []
if not max_len:
max_len = max(lengths)
for l in lengths:
mask = torch.empty(l, l, device=device, dtype=dtype).fill_(-torch.inf).triu_(1)
diff = max_len - l
mask = F.pad(mask, (diff, 0, diff, 0), value=-torch.inf)
masks.append(mask)
masks = torch.stack(masks)
return masks[:, None]
def seed_all(seed: int):
import random
import numpy as np
import torch
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def split_bucket_path(url: str) -> tuple[str, str]:
url = url.replace("s3://", "")
url = url.replace("sj://", "")
url = url.replace("r2://", "")
bucket = url.split("/")[0]
path = "/".join(url.split("/")[1:])
return bucket, path
def prob_mask_like(shape, prob: float, device):
import torch
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
def round_up_to_multiple(n: int, multiple: int) -> int:
if n % multiple != 0:
n += multiple - (n % multiple)
return n
def warmup_then_cosine_decay(
step: int, *, warmup_steps: int, steps: int, min_lr: float, max_lr: float
):
eps = 1e-9
cooldown_steps = warmup_steps
if step < warmup_steps:
return min_lr + step * (max_lr - min_lr) / (warmup_steps)
elif step > steps:
return min_lr
elif step < steps - cooldown_steps:
decay_ratio = (step - warmup_steps) / (steps - warmup_steps - cooldown_steps)
# assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
else:
# decay from min_lr to 0
return min_lr * (steps - step) / cooldown_steps + eps
def decay_to_zero(step: int, *, decay_steps: int, steps: int, max_lr: float):
if step > steps:
return 0.0
else:
gradient = -max_lr / decay_steps
return max_lr + gradient * step
def cross_entropy_loss(logits, mask, targets):
import torch
import torch.nn.functional as F
B, Q, T, _ = logits.size()
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
loss = torch.zeros([], device=targets.device)
codebook_losses = []
for q in range(Q):
logits_q = (
logits[:, q, ...].contiguous().view(-1, logits.size(-1))
) # [B x T, card]
targets_q = targets[:, q, ...].contiguous().view(-1) # [B x T]
mask_q = mask[:, q, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_q[mask_q]
ce_logits = logits_q[mask_q]
q_ce = F.cross_entropy(ce_logits, ce_targets)
loss += q_ce
codebook_losses.append(q_ce.detach())
# average cross entropy across codebooks
loss = loss / Q
return loss, codebook_losses
def build_optimizer(
module, *, weight_decay: float, lr: float, betas: tuple[float, float]
):
import torch
param_dict = {pn: p for pn, p in module.named_parameters() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]
# num_decay_params = sum(p.numel() for p in decay_params)
# num_nodecay_params = sum(p.numel() for p in nodecay_params)
# print(
# f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
# )
# print(
# f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
# )
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
return optimizer
def pad_or_cut_right(t: Tensor, padlen: int, value=0) -> Tensor:
current_len = t.shape[-1]
if current_len == padlen:
return t
if current_len < padlen:
# Need to pad
pad_size = (0, padlen - current_len)
return F.pad(t, pad_size, value=value)
# Need to cut
return t[:padlen]
def pad_or_cut_left(t: Tensor, value: int) -> Tensor:
dims = t.ndim
current_len = t.shape[0]
if current_len == value:
return t
if current_len < value:
# Need to pad
pad_size = (0,) * (2 * (dims - 1)) + (value - current_len, 0)
return F.pad(t, pad_size)
# Need to cut
return t[-value:]
def dl_pt(orig: str):
from os.path import exists
import torch
from vui.storage import s3, split_bucket_path
if not orig.endswith(".pt"):
orig = orig + ".pt"
load = partial(torch.load, weights_only=True)
if exists(orig):
return load(orig)
url = "/data/" + orig
if exists(url):
return load(url)
url = "s3://fluxions/" + orig
bucket, key = split_bucket_path(url)
response = s3.get_object(Bucket=bucket, Key=key)
return load(response["Body"])
def dl_ogg(url: str, start=0, end=-1, sr=None):
import re
from os.path import exists
import soundfile as sf
import torch
search_sr = re.search(r"(\d+)/", url)
if search_sr:
sr = int(search_sr.group(1))
local_file = exists(url)
if exists("/data/audio/" + url):
local_file = True
url = "/data/audio/" + url
if not local_file:
from vui.storage import s3
url = "s3://fluxions/" + url
b, p = split_bucket_path(url)
url = s3.get_object(Bucket=b, Key=p)["Body"]
if sr is None:
if local_file:
sr = sf.info(url).samplerate
else:
sr = sf.info(url.read()).samplerate
start_frame = int(start * sr)
num_frames = int(end * sr) - start_frame
wav, _ = sf.read(url, frames=num_frames, start=start_frame, always_2d=True)
wav = torch.from_numpy(wav).float()
wav = wav.T.mean(0, keepdim=True)
return wav, sr
class timer:
def __init__(self, name=""):
self.name = name
def __enter__(self):
self.t = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed = time.perf_counter() - self.t
print(f"{self.name} {elapsed:.4f}")
@torch.inference_mode()
def decode_audio_from_indices(model, indices, chunk_size=64):
"""
Decodes audio from indices in batches to avoid memory issues.
Args:
model: Codec
indices: Tensor of shape (1, n_quantizers, sequence_length)
chunk_size: Number of samples to process at once
Returns:
Tensor of reconstructed audio
"""
device = model.device
indices = indices.to(device)
_, _, seq_len = indices.shape
chunks = seq_len // chunk_size + (1 if seq_len % chunk_size != 0 else 0)
audio_chunks = []
for i in range(chunks):
start_idx = i * chunk_size
end_idx = min(start_idx + chunk_size, seq_len)
chunk_indices = indices[:, :, start_idx:end_idx]
chunk_audio = model.from_indices(chunk_indices)
audio_chunks.append(chunk_audio.cpu())
full_audio = torch.cat(audio_chunks, dim=-1)
return full_audio.flatten()
def normalize_loudness(waveform, sample_rate: int, lufs: float = -12.0):
"""
Normalize the loudness of an audio tensor using torchaudio.transforms.Loudness.
Args:
audio_tensor (torch.Tensor): Input audio tensor of shape (channels, samples)
sample_rate (int): Sampling rate of the audio
target_loudness (float): Target loudness in LUFS (default: -16.0 LUFS)
Returns:
torch.Tensor: Loudness-normalized audio tensor
"""
import torchaudio
# Ensure the input tensor is 2D (add channel dimension if it's 1D)
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Create a Loudness transform
loudness_transform = torchaudio.transforms.Loudness(sample_rate)
# Measure the current loudness
current_loudness = loudness_transform(waveform)
# Calculate the required gain
gain_db = lufs - current_loudness
# Convert gain from dB to linear scale
gain_linear = torch.pow(10, gain_db / 20)
# Apply the gain to normalize loudness
normalized_audio = waveform * gain_linear
return normalized_audio
def get_basename_without_extension(file_path):
from pathlib import Path
p = Path(file_path)
return p.stem
def ollama(prompt, MODEL=None):
import os
import requests
OLLAMA_HOST = "http://localhost:11434"
API = f"{OLLAMA_HOST}/api/generate"
if MODEL is None:
MODEL = os.environ.get("OLLAMA_MODEL", "gemma:1b")
payload = {
"model": MODEL,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.9, "top_p": 0.9, "max_tokens": 1000},
}
try:
response = requests.post(API, json=payload)
response.raise_for_status() # Raise exception for HTTP errors
result = response.json()
return result.get("response", "")
except requests.exceptions.RequestException as e:
print(f"Error calling Ollama API: {e}")
return ""
def decompile_state_dict(state_dict):
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
# state_dict = convert_old_weight_norm_to_new(state_dict)
return {k.replace("module.", ""): v for k, v in state_dict.items()}