File size: 6,329 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# coding=utf8
import numpy as np
import torch
from torch import Tensor
from typing import Tuple, Dict

from utils.hparams import hparams
from utils.infer_utils import resample_align_curve


class BaseSVSInfer:
    """
        Base class for SVS inference models.
        Subclasses should define:
        1. *build_model*:
            how to build the model;
        2. *run_model*:
            how to run the model (typically, generate a mel-spectrogram and
            pass it to the pre-built vocoder);
        3. *preprocess_input*:
            how to preprocess user input.
        4. *infer_once*
            infer from raw inputs to the final outputs
    """

    def __init__(self, device=None):
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = device
        self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
        self.spk_map = {}
        self.model: torch.nn.Module = None

    def build_model(self, ckpt_steps=None) -> torch.nn.Module:
        raise NotImplementedError()

    def load_speaker_mix(self, param_src: dict, summary_dst: dict,
                         mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]:
        """

        :param param_src: param dict
        :param summary_dst: summary dict
        :param mix_mode: 'token' or 'frame'
        :param mix_length: total tokens or frames to mix
        :return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N]
        """
        assert mix_mode == 'token' or mix_mode == 'frame'
        param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix'
        summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk'
        spk_mix_map = param_src.get(param_key)  # { spk_name: value } or { spk_name: "value value value ..." }
        dynamic = False
        if spk_mix_map is None:
            # Get the first speaker
            for name in self.spk_map.keys():
                spk_mix_map = {name: 1.0}
                break
        else:
            for name in spk_mix_map:
                assert name in self.spk_map, f'Speaker \'{name}\' not found.'
        if len(spk_mix_map) == 1:
            summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0]
        elif any([isinstance(val, str) for val in spk_mix_map.values()]):
            print_mix = '|'.join(spk_mix_map.keys())
            summary_dst[param_key] = f'dynamic({print_mix})'
            dynamic = True
        else:
            print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map])
            summary_dst[param_key] = f'static({print_mix})'
        spk_mix_id_list = []
        spk_mix_value_list = []
        if dynamic:
            for name, values in spk_mix_map.items():
                spk_mix_id_list.append(self.spk_map[name])
                if isinstance(values, str):
                    # this speaker has a variable proportion
                    if mix_mode == 'token':
                        cur_spk_mix_value = values.split()
                        assert len(cur_spk_mix_value) == mix_length, \
                            'Speaker mix checks failed. In dynamic token-level mix, ' \
                            'number of proportion values must equal number of tokens.'
                        cur_spk_mix_value = torch.from_numpy(
                            np.array(cur_spk_mix_value, 'float32')
                        ).to(self.device)[None]  # => [B=1, T]
                    else:
                        cur_spk_mix_value = torch.from_numpy(resample_align_curve(
                            np.array(values.split(), 'float32'),
                            original_timestep=float(param_src['spk_mix_timestep']),
                            target_timestep=self.timestep,
                            align_length=mix_length
                        )).to(self.device)[None]  # => [B=1, T]
                    assert torch.all(cur_spk_mix_value >= 0.), \
                        f'Speaker mix checks failed.\n' \
                        f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.'
                else:
                    # this speaker has a constant proportion
                    assert values >= 0., f'Speaker mix checks failed.\n' \
                                         f'Proportion of speaker \'{name}\' is negative.'
                    cur_spk_mix_value = torch.full(
                        (1, mix_length), fill_value=values,
                        dtype=torch.float32, device=self.device
                    )
                spk_mix_value_list.append(cur_spk_mix_value)
            spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None]  # => [B=1, 1, N]
            spk_mix_value = torch.stack(spk_mix_value_list, dim=2)  # [B=1, T] => [B=1, T, N]
            spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True)  # => [B=1, T, 1]
            assert torch.all(spk_mix_value_sum > 0.), \
                f'Speaker mix checks failed.\n' \
                f'Proportions of speaker mix on some frames sum to zero.'
            spk_mix_value /= spk_mix_value_sum  # normalize
        else:
            for name, value in spk_mix_map.items():
                spk_mix_id_list.append(self.spk_map[name])
                assert value >= 0., f'Speaker mix checks failed.\n' \
                                    f'Proportion of speaker \'{name}\' is negative.'
                spk_mix_value_list.append(value)
            spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None]  # => [B=1, 1, N]
            spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None]  # => [B=1, 1, N]
            spk_mix_value_sum = spk_mix_value.sum()
            assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
                                           f'Proportions of speaker mix sum to zero.'
            spk_mix_value /= spk_mix_value_sum  # normalize
        return spk_mix_id, spk_mix_value

    def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]:
        raise NotImplementedError()

    def forward_model(self, sample: Dict[str, torch.Tensor]):
        raise NotImplementedError()

    def run_inference(self, params, **kwargs):
        raise NotImplementedError()