CantusSVS-hf / utils /__init__.py
liampond
Clean deploy snapshot
c42fe7e
from __future__ import annotations
import pathlib
import re
import time
import types
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
from basics.base_module import CategorizedModule
from utils.hparams import hparams
from utils.training_utils import get_latest_checkpoint_path
def tensors_to_scalars(metrics):
new_metrics = {}
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
v = v.item()
if type(v) is dict:
v = tensors_to_scalars(v)
new_metrics[k] = v
return new_metrics
def collate_nd(values, pad_value=0, max_len=None):
"""
Pad a list of Nd tensors on their first dimension and stack them into a (N+1)d tensor.
"""
size = ((max(v.size(0) for v in values) if max_len is None else max_len), *values[0].shape[1:])
res = torch.full((len(values), *size), fill_value=pad_value, dtype=values[0].dtype, device=values[0].device)
for i, v in enumerate(values):
res[i, :len(v), ...] = v
return res
def random_continuous_masks(*shape: int, dim: int, device: str | torch.device = 'cpu'):
start, end = torch.sort(
torch.randint(
low=0, high=shape[dim] + 1, size=(*shape[:dim], 2, *((1,) * (len(shape) - dim - 1))), device=device
).expand(*((-1,) * (dim + 1)), *shape[dim + 1:]), dim=dim
)[0].split(1, dim=dim)
idx = torch.arange(
0, shape[dim], dtype=torch.long, device=device
).reshape(*((1,) * dim), shape[dim], *((1,) * (len(shape) - dim - 1)))
masks = (idx >= start) & (idx < end)
return masks
def _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size):
if len(batch) == 0:
return 0
if len(batch) == max_batch_size:
return 1
if num_frames > max_batch_frames:
return 1
return 0
def batch_by_size(
indices, num_frames_fn, max_batch_frames=80000, max_batch_size=48,
required_batch_size_multiple=1
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_frames_fn (callable): function that returns the number of frames at
a given index
max_batch_frames (int, optional): max number of frames in each batch
(default: 80000).
max_batch_size (int, optional): max number of sentences in each
batch (default: 48).
required_batch_size_multiple: require the batch size to be multiple
of a given number
"""
bsz_mult = required_batch_size_multiple
if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
sample_len = 0
sample_lens = []
batch = []
batches = []
for i in range(len(indices)):
idx = indices[i]
num_frames = num_frames_fn(idx)
sample_lens.append(num_frames)
sample_len = max(sample_len, num_frames)
assert sample_len <= max_batch_frames, (
"sentence at index {} of size {} exceeds max_batch_samples "
"limit of {}!".format(idx, sample_len, max_batch_frames)
)
num_frames = (len(batch) + 1) * sample_len
if _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
def make_positions(tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
def softmax(x, dim):
return F.softmax(x, dim=dim, dtype=torch.float32)
def unpack_dict_to_list(samples):
samples_ = []
bsz = samples.get('outputs').size(0)
for i in range(bsz):
res = {}
for k, v in samples.items():
try:
res[k] = v[i]
except:
pass
samples_.append(res)
return samples_
def filter_kwargs(dict_to_filter, kwarg_obj):
import inspect
sig = inspect.signature(kwarg_obj)
if any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()):
# the signature contains definitions like **kwargs, so there is no need to filter
return dict_to_filter.copy()
filter_keys = [
param.name
for param in sig.parameters.values()
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.KEYWORD_ONLY
]
filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if
filter_key in dict_to_filter}
return filtered_dict
def load_ckpt(
cur_model, ckpt_base_dir, ckpt_steps=None,
prefix_in_ckpt='model', ignored_prefixes=None, key_in_ckpt='state_dict',
strict=True, device='cpu'
):
if ignored_prefixes is None:
# NOTICE: this is for compatibility with old checkpoints which have duplicate txt_embed layer in them.
ignored_prefixes = ['model.fs2.encoder.embed_tokens']
if not isinstance(ckpt_base_dir, pathlib.Path):
ckpt_base_dir = pathlib.Path(ckpt_base_dir)
if ckpt_base_dir.is_file():
checkpoint_path = [ckpt_base_dir]
elif ckpt_steps is not None:
checkpoint_path = [ckpt_base_dir / f'model_ckpt_steps_{int(ckpt_steps)}.ckpt']
else:
base_dir = ckpt_base_dir
checkpoint_path = sorted(
[
ckpt_file
for ckpt_file in base_dir.iterdir()
if ckpt_file.is_file() and re.fullmatch(r'model_ckpt_steps_\d+\.ckpt', ckpt_file.name)
],
key=lambda x: int(re.search(r'\d+', x.name).group(0))
)
assert len(checkpoint_path) > 0, f'| ckpt not found in {ckpt_base_dir}.'
checkpoint_path = checkpoint_path[-1]
ckpt_loaded = torch.load(checkpoint_path, map_location=device)
if isinstance(cur_model, CategorizedModule):
cur_model.check_category(ckpt_loaded.get('category'))
if key_in_ckpt is None:
state_dict = ckpt_loaded
else:
state_dict = ckpt_loaded[key_in_ckpt]
if prefix_in_ckpt is not None:
state_dict = OrderedDict({
k[len(prefix_in_ckpt) + 1:]: v
for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.')
if all(not k.startswith(p) for p in ignored_prefixes)
})
if not strict:
cur_model_state_dict = cur_model.state_dict()
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print('| Unmatched keys: ', key, new_param.shape, param.shape)
for key in unmatched_keys:
del state_dict[key]
cur_model.load_state_dict(state_dict, strict=strict)
shown_model_name = 'state dict'
if prefix_in_ckpt is not None:
shown_model_name = f'\'{prefix_in_ckpt}\''
elif key_in_ckpt is not None:
shown_model_name = f'\'{key_in_ckpt}\''
print(f'| load {shown_model_name} from \'{checkpoint_path}\'.')
def remove_padding(x, padding_idx=0):
if x is None:
return None
assert len(x.shape) in [1, 2]
if len(x.shape) == 2: # [T, H]
return x[np.abs(x).sum(-1) != padding_idx]
elif len(x.shape) == 1: # [T]
return x[x != padding_idx]
class Timer:
timer_map = {}
def __init__(self, name, print_time=False):
if name not in Timer.timer_map:
Timer.timer_map[name] = 0
self.name = name
self.print_time = print_time
def __enter__(self):
self.t = time.time()
def __exit__(self, exc_type, exc_val, exc_tb):
Timer.timer_map[self.name] += time.time() - self.t
if self.print_time:
print(self.name, Timer.timer_map[self.name])
def print_arch(model, model_name='model'):
print(f"| {model_name} Arch: ", model)
# num_params(model, model_name=model_name)
def num_params(model, print_out=True, model_name="model"):
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
if print_out:
print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
return parameters
def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs):
import importlib
pkg = ".".join(cls_str.split(".")[:-1])
cls_name = cls_str.split(".")[-1]
cls_type = getattr(importlib.import_module(pkg), cls_name)
if parent_cls is not None:
assert issubclass(cls_type, parent_cls), f'| {cls_type} is not subclass of {parent_cls}.'
return cls_type(*args, **filter_kwargs(kwargs, cls_type))
def build_lr_scheduler_from_config(optimizer, scheduler_args):
try:
# PyTorch 2.0+
from torch.optim.lr_scheduler import LRScheduler as LRScheduler
except ImportError:
# PyTorch 1.X
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
def helper(params):
if isinstance(params, list):
return [helper(s) for s in params]
elif isinstance(params, dict):
resolved = {k: helper(v) for k, v in params.items()}
if 'cls' in resolved:
if (
resolved["cls"] == "torch.optim.lr_scheduler.ChainedScheduler"
and scheduler_args["scheduler_cls"] == "torch.optim.lr_scheduler.SequentialLR"
):
raise ValueError(f"ChainedScheduler cannot be part of a SequentialLR.")
resolved['optimizer'] = optimizer
obj = build_object_from_class_name(
resolved['cls'],
LRScheduler,
**resolved
)
return obj
return resolved
else:
return params
resolved = helper(scheduler_args)
resolved['optimizer'] = optimizer
return build_object_from_class_name(
scheduler_args['scheduler_cls'],
LRScheduler,
**resolved
)
def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1):
optimizer = build_object_from_class_name(
optimizer_args['optimizer_cls'],
torch.optim.Optimizer,
[{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)],
**optimizer_args
)
scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args)
scheduler.optimizer._step_count = 1
for _ in range(step_count):
scheduler.step()
return scheduler.state_dict()
def remove_suffix(string: str, suffix: str):
# Just for Python 3.8 compatibility, since `str.removesuffix()` API of is available since Python 3.9
if string.endswith(suffix):
string = string[:-len(suffix)]
return string