# coding=utf8 import numpy as np import torch from torch import Tensor from typing import Tuple, Dict from utils.hparams import hparams from utils.infer_utils import resample_align_curve class BaseSVSInfer: """ Base class for SVS inference models. Subclasses should define: 1. *build_model*: how to build the model; 2. *run_model*: how to run the model (typically, generate a mel-spectrogram and pass it to the pre-built vocoder); 3. *preprocess_input*: how to preprocess user input. 4. *infer_once* infer from raw inputs to the final outputs """ def __init__(self, device=None): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.timestep = hparams['hop_size'] / hparams['audio_sample_rate'] self.spk_map = {} self.model: torch.nn.Module = None def build_model(self, ckpt_steps=None) -> torch.nn.Module: raise NotImplementedError() def load_speaker_mix(self, param_src: dict, summary_dst: dict, mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]: """ :param param_src: param dict :param summary_dst: summary dict :param mix_mode: 'token' or 'frame' :param mix_length: total tokens or frames to mix :return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N] """ assert mix_mode == 'token' or mix_mode == 'frame' param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix' summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk' spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." } dynamic = False if spk_mix_map is None: # Get the first speaker for name in self.spk_map.keys(): spk_mix_map = {name: 1.0} break else: for name in spk_mix_map: assert name in self.spk_map, f'Speaker \'{name}\' not found.' if len(spk_mix_map) == 1: summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0] elif any([isinstance(val, str) for val in spk_mix_map.values()]): print_mix = '|'.join(spk_mix_map.keys()) summary_dst[param_key] = f'dynamic({print_mix})' dynamic = True else: print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map]) summary_dst[param_key] = f'static({print_mix})' spk_mix_id_list = [] spk_mix_value_list = [] if dynamic: for name, values in spk_mix_map.items(): spk_mix_id_list.append(self.spk_map[name]) if isinstance(values, str): # this speaker has a variable proportion if mix_mode == 'token': cur_spk_mix_value = values.split() assert len(cur_spk_mix_value) == mix_length, \ 'Speaker mix checks failed. In dynamic token-level mix, ' \ 'number of proportion values must equal number of tokens.' cur_spk_mix_value = torch.from_numpy( np.array(cur_spk_mix_value, 'float32') ).to(self.device)[None] # => [B=1, T] else: cur_spk_mix_value = torch.from_numpy(resample_align_curve( np.array(values.split(), 'float32'), original_timestep=float(param_src['spk_mix_timestep']), target_timestep=self.timestep, align_length=mix_length )).to(self.device)[None] # => [B=1, T] assert torch.all(cur_spk_mix_value >= 0.), \ f'Speaker mix checks failed.\n' \ f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.' else: # this speaker has a constant proportion assert values >= 0., f'Speaker mix checks failed.\n' \ f'Proportion of speaker \'{name}\' is negative.' cur_spk_mix_value = torch.full( (1, mix_length), fill_value=values, dtype=torch.float32, device=self.device ) spk_mix_value_list.append(cur_spk_mix_value) spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N] spk_mix_value = torch.stack(spk_mix_value_list, dim=2) # [B=1, T] => [B=1, T, N] spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True) # => [B=1, T, 1] assert torch.all(spk_mix_value_sum > 0.), \ f'Speaker mix checks failed.\n' \ f'Proportions of speaker mix on some frames sum to zero.' spk_mix_value /= spk_mix_value_sum # normalize else: for name, value in spk_mix_map.items(): spk_mix_id_list.append(self.spk_map[name]) assert value >= 0., f'Speaker mix checks failed.\n' \ f'Proportion of speaker \'{name}\' is negative.' spk_mix_value_list.append(value) spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N] spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None] # => [B=1, 1, N] spk_mix_value_sum = spk_mix_value.sum() assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \ f'Proportions of speaker mix sum to zero.' spk_mix_value /= spk_mix_value_sum # normalize return spk_mix_id, spk_mix_value def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]: raise NotImplementedError() def forward_model(self, sample: Dict[str, torch.Tensor]): raise NotImplementedError() def run_inference(self, params, **kwargs): raise NotImplementedError()