import os import torch import numpy as np import soundfile as sf def fast_cosine_dist(source_feats: torch.Tensor, matching_pool: torch.Tensor, device='cpu'): """ Computes the cosine distance between source features and a matching pool of features. Like torch.cdist, but fixed dim=-1 and for cosine distance. Args: source_feats (torch.Tensor): Tensor of source features with shape (n_source_feats, feat_dim). matching_pool (torch.Tensor): Tensor of matching pool features with shape (n_matching_feats, feat_dim). device (str, optional): Device to perform the computation on. Defaults to 'cpu'. Returns: torch.Tensor: Tensor of cosine distances between the source features and the matching pool features. """ source_feats = source_feats.to(device) matching_pool = matching_pool.to(device) source_norms = torch.norm(source_feats, p=2, dim=-1) matching_norms = torch.norm(matching_pool, p=2, dim=-1) dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 dotprod /= 2 dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) return dists def load_wav(wav_path, sr=None): """ Loads a waveform from a wav file. Args: wav_path (str): Path to the wav file. sr (int, optional): Target sample rate. If `sr` is specified and the loaded audio has a different sample rate, an AssertionError is raised. Defaults to None. Returns: Tuple[np.ndarray, int]: Tuple containing the loaded waveform as a NumPy array and the sample rate. """ wav, fs = sf.read(wav_path) if wav.ndim != 1: print('The wav file %s has %d channels, select the first one to proceed.' %(wav_path, wav.ndim)) wav = wav[:,0] assert sr is None or fs == sr, f'{sr} kHz audio is required. Got {fs}' peak = np.abs(wav).max() if peak > 1.0: wav /= peak return wav, fs class ConfigWrapper(object): """ Wrapper dict class to avoid annoying key dict indexing like: `config.sample_rate` instead of `config["sample_rate"]`. """ def __init__(self, **kwargs): for k, v in kwargs.items(): if type(v) == dict: v = ConfigWrapper(**v) self[k] = v def keys(self): return self.__dict__.keys() def items(self): return self.__dict__.items() def values(self): return self.__dict__.values() def to_dict_type(self): return { key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type()) for key, value in dict(**self).items() } def __len__(self): return len(self.__dict__) def __getitem__(self, key): return getattr(self, key) def __setitem__(self, key, value): return setattr(self, key, value) def __contains__(self, key): return key in self.__dict__ def __repr__(self): return self.__dict__.__repr__() def save_checkpoint(steps, epochs, model, optimizer, scheduler, checkpoint_path, dst_train=False): """Save checkpoint. Args: checkpoint_path (str): Checkpoint path to be saved. """ state_dict = { "optimizer": { "generator": optimizer["generator"].state_dict(), "discriminator": optimizer["discriminator"].state_dict(), }, "scheduler": { "generator": scheduler["generator"].state_dict(), "discriminator": scheduler["discriminator"].state_dict(), }, "steps": steps, "epochs": epochs, } if dst_train: state_dict["model"] = { "generator": model["generator"].module.state_dict(), "discriminator": model["discriminator"].module.state_dict(), } else: state_dict["model"] = { "generator": model["generator"].state_dict(), "discriminator": model["discriminator"].state_dict(), } if not os.path.exists(os.path.dirname(checkpoint_path)): os.makedirs(os.path.dirname(checkpoint_path)) torch.save(state_dict, checkpoint_path) def load_checkpoint(model, optimizer, scheduler, checkpoint_path, load_only_params=False, dst_train=False): """Load checkpoint. Args: checkpoint_path (str): Checkpoint path to be loaded. load_only_params (bool): Whether to load only model parameters. """ state_dict = torch.load(checkpoint_path, map_location="cpu") if dst_train: model["generator"].module.load_state_dict( state_dict["model"]["generator"] ) model["discriminator"].module.load_state_dict( state_dict["model"]["discriminator"] ) else: model["generator"].load_state_dict(state_dict["model"]["generator"]) model["discriminator"].load_state_dict( state_dict["model"]["discriminator"] ) optimizer["generator"].load_state_dict( state_dict["optimizer"]["generator"] ) optimizer["discriminator"].load_state_dict( state_dict["optimizer"]["discriminator"] ) scheduler["generator"].load_state_dict( state_dict["scheduler"]["generator"] ) scheduler["discriminator"].load_state_dict( state_dict["scheduler"]["discriminator"] ) steps = state_dict["steps"] epochs = state_dict["epochs"] return steps, epochs