Spaces:
Configuration error
Configuration error
File size: 5,515 Bytes
cfdc687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import torch
import numpy as np
import soundfile as sf
def fast_cosine_dist(source_feats: torch.Tensor, matching_pool: torch.Tensor, device='cpu'):
"""
Computes the cosine distance between source features and a matching pool of features.
Like torch.cdist, but fixed dim=-1 and for cosine distance.
Args:
source_feats (torch.Tensor): Tensor of source features with shape (n_source_feats, feat_dim).
matching_pool (torch.Tensor): Tensor of matching pool features with shape (n_matching_feats, feat_dim).
device (str, optional): Device to perform the computation on. Defaults to 'cpu'.
Returns:
torch.Tensor: Tensor of cosine distances between the source features and the matching pool features.
"""
source_feats = source_feats.to(device)
matching_pool = matching_pool.to(device)
source_norms = torch.norm(source_feats, p=2, dim=-1)
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
dotprod /= 2
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
return dists
def load_wav(wav_path, sr=None):
"""
Loads a waveform from a wav file.
Args:
wav_path (str): Path to the wav file.
sr (int, optional): Target sample rate.
If `sr` is specified and the loaded audio has a different sample rate, an AssertionError is raised.
Defaults to None.
Returns:
Tuple[np.ndarray, int]: Tuple containing the loaded waveform as a NumPy array and the sample rate.
"""
wav, fs = sf.read(wav_path)
if wav.ndim != 1:
print('The wav file %s has %d channels, select the first one to proceed.' %(wav_path, wav.ndim))
wav = wav[:,0]
assert sr is None or fs == sr, f'{sr} kHz audio is required. Got {fs}'
peak = np.abs(wav).max()
if peak > 1.0:
wav /= peak
return wav, fs
class ConfigWrapper(object):
"""
Wrapper dict class to avoid annoying key dict indexing like:
`config.sample_rate` instead of `config["sample_rate"]`.
"""
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = ConfigWrapper(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def to_dict_type(self):
return {
key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type())
for key, value in dict(**self).items()
}
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
def save_checkpoint(steps, epochs, model, optimizer, scheduler, checkpoint_path, dst_train=False):
"""Save checkpoint.
Args:
checkpoint_path (str): Checkpoint path to be saved.
"""
state_dict = {
"optimizer": {
"generator": optimizer["generator"].state_dict(),
"discriminator": optimizer["discriminator"].state_dict(),
},
"scheduler": {
"generator": scheduler["generator"].state_dict(),
"discriminator": scheduler["discriminator"].state_dict(),
},
"steps": steps,
"epochs": epochs,
}
if dst_train:
state_dict["model"] = {
"generator": model["generator"].module.state_dict(),
"discriminator": model["discriminator"].module.state_dict(),
}
else:
state_dict["model"] = {
"generator": model["generator"].state_dict(),
"discriminator": model["discriminator"].state_dict(),
}
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(model, optimizer, scheduler, checkpoint_path, load_only_params=False, dst_train=False):
"""Load checkpoint.
Args:
checkpoint_path (str): Checkpoint path to be loaded.
load_only_params (bool): Whether to load only model parameters.
"""
state_dict = torch.load(checkpoint_path, map_location="cpu")
if dst_train:
model["generator"].module.load_state_dict(
state_dict["model"]["generator"]
)
model["discriminator"].module.load_state_dict(
state_dict["model"]["discriminator"]
)
else:
model["generator"].load_state_dict(state_dict["model"]["generator"])
model["discriminator"].load_state_dict(
state_dict["model"]["discriminator"]
)
optimizer["generator"].load_state_dict(
state_dict["optimizer"]["generator"]
)
optimizer["discriminator"].load_state_dict(
state_dict["optimizer"]["discriminator"]
)
scheduler["generator"].load_state_dict(
state_dict["scheduler"]["generator"]
)
scheduler["discriminator"].load_state_dict(
state_dict["scheduler"]["discriminator"]
)
steps = state_dict["steps"]
epochs = state_dict["epochs"]
return steps, epochs
|