Spaces:
Sleeping
Sleeping
# coding=utf8 | |
import numpy as np | |
import torch | |
from torch import Tensor | |
from typing import Tuple, Dict | |
from utils.hparams import hparams | |
from utils.infer_utils import resample_align_curve | |
class BaseSVSInfer: | |
""" | |
Base class for SVS inference models. | |
Subclasses should define: | |
1. *build_model*: | |
how to build the model; | |
2. *run_model*: | |
how to run the model (typically, generate a mel-spectrogram and | |
pass it to the pre-built vocoder); | |
3. *preprocess_input*: | |
how to preprocess user input. | |
4. *infer_once* | |
infer from raw inputs to the final outputs | |
""" | |
def __init__(self, device=None): | |
if device is None: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.device = device | |
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate'] | |
self.spk_map = {} | |
self.model: torch.nn.Module = None | |
def build_model(self, ckpt_steps=None) -> torch.nn.Module: | |
raise NotImplementedError() | |
def load_speaker_mix(self, param_src: dict, summary_dst: dict, | |
mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]: | |
""" | |
:param param_src: param dict | |
:param summary_dst: summary dict | |
:param mix_mode: 'token' or 'frame' | |
:param mix_length: total tokens or frames to mix | |
:return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N] | |
""" | |
assert mix_mode == 'token' or mix_mode == 'frame' | |
param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix' | |
summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk' | |
spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." } | |
dynamic = False | |
if spk_mix_map is None: | |
# Get the first speaker | |
for name in self.spk_map.keys(): | |
spk_mix_map = {name: 1.0} | |
break | |
else: | |
for name in spk_mix_map: | |
assert name in self.spk_map, f'Speaker \'{name}\' not found.' | |
if len(spk_mix_map) == 1: | |
summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0] | |
elif any([isinstance(val, str) for val in spk_mix_map.values()]): | |
print_mix = '|'.join(spk_mix_map.keys()) | |
summary_dst[param_key] = f'dynamic({print_mix})' | |
dynamic = True | |
else: | |
print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map]) | |
summary_dst[param_key] = f'static({print_mix})' | |
spk_mix_id_list = [] | |
spk_mix_value_list = [] | |
if dynamic: | |
for name, values in spk_mix_map.items(): | |
spk_mix_id_list.append(self.spk_map[name]) | |
if isinstance(values, str): | |
# this speaker has a variable proportion | |
if mix_mode == 'token': | |
cur_spk_mix_value = values.split() | |
assert len(cur_spk_mix_value) == mix_length, \ | |
'Speaker mix checks failed. In dynamic token-level mix, ' \ | |
'number of proportion values must equal number of tokens.' | |
cur_spk_mix_value = torch.from_numpy( | |
np.array(cur_spk_mix_value, 'float32') | |
).to(self.device)[None] # => [B=1, T] | |
else: | |
cur_spk_mix_value = torch.from_numpy(resample_align_curve( | |
np.array(values.split(), 'float32'), | |
original_timestep=float(param_src['spk_mix_timestep']), | |
target_timestep=self.timestep, | |
align_length=mix_length | |
)).to(self.device)[None] # => [B=1, T] | |
assert torch.all(cur_spk_mix_value >= 0.), \ | |
f'Speaker mix checks failed.\n' \ | |
f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.' | |
else: | |
# this speaker has a constant proportion | |
assert values >= 0., f'Speaker mix checks failed.\n' \ | |
f'Proportion of speaker \'{name}\' is negative.' | |
cur_spk_mix_value = torch.full( | |
(1, mix_length), fill_value=values, | |
dtype=torch.float32, device=self.device | |
) | |
spk_mix_value_list.append(cur_spk_mix_value) | |
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N] | |
spk_mix_value = torch.stack(spk_mix_value_list, dim=2) # [B=1, T] => [B=1, T, N] | |
spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True) # => [B=1, T, 1] | |
assert torch.all(spk_mix_value_sum > 0.), \ | |
f'Speaker mix checks failed.\n' \ | |
f'Proportions of speaker mix on some frames sum to zero.' | |
spk_mix_value /= spk_mix_value_sum # normalize | |
else: | |
for name, value in spk_mix_map.items(): | |
spk_mix_id_list.append(self.spk_map[name]) | |
assert value >= 0., f'Speaker mix checks failed.\n' \ | |
f'Proportion of speaker \'{name}\' is negative.' | |
spk_mix_value_list.append(value) | |
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N] | |
spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None] # => [B=1, 1, N] | |
spk_mix_value_sum = spk_mix_value.sum() | |
assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \ | |
f'Proportions of speaker mix sum to zero.' | |
spk_mix_value /= spk_mix_value_sum # normalize | |
return spk_mix_id, spk_mix_value | |
def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]: | |
raise NotImplementedError() | |
def forward_model(self, sample: Dict[str, torch.Tensor]): | |
raise NotImplementedError() | |
def run_inference(self, params, **kwargs): | |
raise NotImplementedError() | |