Spaces:
Sleeping
Sleeping
File size: 6,329 Bytes
c42fe7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# coding=utf8
import numpy as np
import torch
from torch import Tensor
from typing import Tuple, Dict
from utils.hparams import hparams
from utils.infer_utils import resample_align_curve
class BaseSVSInfer:
"""
Base class for SVS inference models.
Subclasses should define:
1. *build_model*:
how to build the model;
2. *run_model*:
how to run the model (typically, generate a mel-spectrogram and
pass it to the pre-built vocoder);
3. *preprocess_input*:
how to preprocess user input.
4. *infer_once*
infer from raw inputs to the final outputs
"""
def __init__(self, device=None):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
self.spk_map = {}
self.model: torch.nn.Module = None
def build_model(self, ckpt_steps=None) -> torch.nn.Module:
raise NotImplementedError()
def load_speaker_mix(self, param_src: dict, summary_dst: dict,
mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]:
"""
:param param_src: param dict
:param summary_dst: summary dict
:param mix_mode: 'token' or 'frame'
:param mix_length: total tokens or frames to mix
:return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N]
"""
assert mix_mode == 'token' or mix_mode == 'frame'
param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix'
summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk'
spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." }
dynamic = False
if spk_mix_map is None:
# Get the first speaker
for name in self.spk_map.keys():
spk_mix_map = {name: 1.0}
break
else:
for name in spk_mix_map:
assert name in self.spk_map, f'Speaker \'{name}\' not found.'
if len(spk_mix_map) == 1:
summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0]
elif any([isinstance(val, str) for val in spk_mix_map.values()]):
print_mix = '|'.join(spk_mix_map.keys())
summary_dst[param_key] = f'dynamic({print_mix})'
dynamic = True
else:
print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map])
summary_dst[param_key] = f'static({print_mix})'
spk_mix_id_list = []
spk_mix_value_list = []
if dynamic:
for name, values in spk_mix_map.items():
spk_mix_id_list.append(self.spk_map[name])
if isinstance(values, str):
# this speaker has a variable proportion
if mix_mode == 'token':
cur_spk_mix_value = values.split()
assert len(cur_spk_mix_value) == mix_length, \
'Speaker mix checks failed. In dynamic token-level mix, ' \
'number of proportion values must equal number of tokens.'
cur_spk_mix_value = torch.from_numpy(
np.array(cur_spk_mix_value, 'float32')
).to(self.device)[None] # => [B=1, T]
else:
cur_spk_mix_value = torch.from_numpy(resample_align_curve(
np.array(values.split(), 'float32'),
original_timestep=float(param_src['spk_mix_timestep']),
target_timestep=self.timestep,
align_length=mix_length
)).to(self.device)[None] # => [B=1, T]
assert torch.all(cur_spk_mix_value >= 0.), \
f'Speaker mix checks failed.\n' \
f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.'
else:
# this speaker has a constant proportion
assert values >= 0., f'Speaker mix checks failed.\n' \
f'Proportion of speaker \'{name}\' is negative.'
cur_spk_mix_value = torch.full(
(1, mix_length), fill_value=values,
dtype=torch.float32, device=self.device
)
spk_mix_value_list.append(cur_spk_mix_value)
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
spk_mix_value = torch.stack(spk_mix_value_list, dim=2) # [B=1, T] => [B=1, T, N]
spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True) # => [B=1, T, 1]
assert torch.all(spk_mix_value_sum > 0.), \
f'Speaker mix checks failed.\n' \
f'Proportions of speaker mix on some frames sum to zero.'
spk_mix_value /= spk_mix_value_sum # normalize
else:
for name, value in spk_mix_map.items():
spk_mix_id_list.append(self.spk_map[name])
assert value >= 0., f'Speaker mix checks failed.\n' \
f'Proportion of speaker \'{name}\' is negative.'
spk_mix_value_list.append(value)
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None] # => [B=1, 1, N]
spk_mix_value_sum = spk_mix_value.sum()
assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
f'Proportions of speaker mix sum to zero.'
spk_mix_value /= spk_mix_value_sum # normalize
return spk_mix_id, spk_mix_value
def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]:
raise NotImplementedError()
def forward_model(self, sample: Dict[str, torch.Tensor]):
raise NotImplementedError()
def run_inference(self, params, **kwargs):
raise NotImplementedError()
|