CantusSVS-hf / basics /base_svs_infer.py
liampond
Clean deploy snapshot
c42fe7e
# coding=utf8
import numpy as np
import torch
from torch import Tensor
from typing import Tuple, Dict
from utils.hparams import hparams
from utils.infer_utils import resample_align_curve
class BaseSVSInfer:
"""
Base class for SVS inference models.
Subclasses should define:
1. *build_model*:
how to build the model;
2. *run_model*:
how to run the model (typically, generate a mel-spectrogram and
pass it to the pre-built vocoder);
3. *preprocess_input*:
how to preprocess user input.
4. *infer_once*
infer from raw inputs to the final outputs
"""
def __init__(self, device=None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
self.spk_map = {}
self.model: torch.nn.Module = None
def build_model(self, ckpt_steps=None) -> torch.nn.Module:
raise NotImplementedError()
def load_speaker_mix(self, param_src: dict, summary_dst: dict,
mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]:
"""
:param param_src: param dict
:param summary_dst: summary dict
:param mix_mode: 'token' or 'frame'
:param mix_length: total tokens or frames to mix
:return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N]
"""
assert mix_mode == 'token' or mix_mode == 'frame'
param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix'
summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk'
spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." }
dynamic = False
if spk_mix_map is None:
# Get the first speaker
for name in self.spk_map.keys():
spk_mix_map = {name: 1.0}
break
else:
for name in spk_mix_map:
assert name in self.spk_map, f'Speaker \'{name}\' not found.'
if len(spk_mix_map) == 1:
summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0]
elif any([isinstance(val, str) for val in spk_mix_map.values()]):
print_mix = '|'.join(spk_mix_map.keys())
summary_dst[param_key] = f'dynamic({print_mix})'
dynamic = True
else:
print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map])
summary_dst[param_key] = f'static({print_mix})'
spk_mix_id_list = []
spk_mix_value_list = []
if dynamic:
for name, values in spk_mix_map.items():
spk_mix_id_list.append(self.spk_map[name])
if isinstance(values, str):
# this speaker has a variable proportion
if mix_mode == 'token':
cur_spk_mix_value = values.split()
assert len(cur_spk_mix_value) == mix_length, \
'Speaker mix checks failed. In dynamic token-level mix, ' \
'number of proportion values must equal number of tokens.'
cur_spk_mix_value = torch.from_numpy(
np.array(cur_spk_mix_value, 'float32')
).to(self.device)[None] # => [B=1, T]
else:
cur_spk_mix_value = torch.from_numpy(resample_align_curve(
np.array(values.split(), 'float32'),
original_timestep=float(param_src['spk_mix_timestep']),
target_timestep=self.timestep,
align_length=mix_length
)).to(self.device)[None] # => [B=1, T]
assert torch.all(cur_spk_mix_value >= 0.), \
f'Speaker mix checks failed.\n' \
f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.'
else:
# this speaker has a constant proportion
assert values >= 0., f'Speaker mix checks failed.\n' \
f'Proportion of speaker \'{name}\' is negative.'
cur_spk_mix_value = torch.full(
(1, mix_length), fill_value=values,
dtype=torch.float32, device=self.device
)
spk_mix_value_list.append(cur_spk_mix_value)
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
spk_mix_value = torch.stack(spk_mix_value_list, dim=2) # [B=1, T] => [B=1, T, N]
spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True) # => [B=1, T, 1]
assert torch.all(spk_mix_value_sum > 0.), \
f'Speaker mix checks failed.\n' \
f'Proportions of speaker mix on some frames sum to zero.'
spk_mix_value /= spk_mix_value_sum # normalize
else:
for name, value in spk_mix_map.items():
spk_mix_id_list.append(self.spk_map[name])
assert value >= 0., f'Speaker mix checks failed.\n' \
f'Proportion of speaker \'{name}\' is negative.'
spk_mix_value_list.append(value)
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None] # => [B=1, 1, N]
spk_mix_value_sum = spk_mix_value.sum()
assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
f'Proportions of speaker mix sum to zero.'
spk_mix_value /= spk_mix_value_sum # normalize
return spk_mix_id, spk_mix_value
def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]:
raise NotImplementedError()
def forward_model(self, sample: Dict[str, torch.Tensor]):
raise NotImplementedError()
def run_inference(self, params, **kwargs):
raise NotImplementedError()