File size: 6,298 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import NormalInitEmbedding as Embedding
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.variance_encoder import FastSpeech2Variance
from utils.hparams import hparams
from utils.text_encoder import PAD_INDEX

f0_bin = 256
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)


def f0_to_coarse(f0):
    f0_mel = 1127 * (1 + f0 / 700).log()
    a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
    b = f0_mel_min * a - 1.
    f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
    torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
    f0_coarse = torch.round(f0_mel).long()
    return f0_coarse


class LengthRegulator(nn.Module):
    # noinspection PyMethodMayBeStatic
    def forward(self, dur):
        token_idx = torch.arange(1, dur.shape[1] + 1, device=dur.device)[None, :, None]
        dur_cumsum = torch.cumsum(dur, dim=1)
        dur_cumsum_prev = F.pad(dur_cumsum, (1, -1), mode='constant', value=0)
        pos_idx = torch.arange(dur.sum(dim=1).max(), device=dur.device)[None, None]
        token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
        mel2ph = (token_idx * token_mask).sum(dim=1)
        return mel2ph


class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
    def __init__(self, vocab_size):
        super().__init__(vocab_size=vocab_size)

        # for temporary compatibility; will be completely removed in the future
        self.f0_embed_type = hparams.get('f0_embed_type', 'continuous')
        if self.f0_embed_type == 'discrete':
            self.pitch_embed = Embedding(300, hparams['hidden_size'], PAD_INDEX)

        self.lr = LengthRegulator()
        if hparams['use_key_shift_embed']:
            self.shift_min, self.shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
        if hparams['use_speed_embed']:
            self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range']

    # noinspection PyMethodOverriding
    def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None):
        txt_embed = self.txt_embed(tokens)
        durations = durations * (tokens > 0)
        mel2ph = self.lr(durations)
        f0 = f0 * (mel2ph > 0)
        mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
        dur_embed = self.dur_embed(durations.float()[:, :, None])
        encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX)
        encoded = F.pad(encoded, (0, 0, 1, 0))
        condition = torch.gather(encoded, 1, mel2ph)

        if self.f0_embed_type == 'discrete':
            pitch = f0_to_coarse(f0)
            pitch_embed = self.pitch_embed(pitch)
        else:
            f0_mel = (1 + f0 / 700).log()
            pitch_embed = self.pitch_embed(f0_mel[:, :, None])
        condition += pitch_embed

        if self.use_variance_embeds:
            variance_embeds = torch.stack([
                self.variance_embeds[v_name](variances[v_name][:, :, None])
                for v_name in self.variance_embed_list
            ], dim=-1).sum(-1)
            condition += variance_embeds

        if hparams['use_key_shift_embed']:
            if hasattr(self, 'frozen_key_shift'):
                key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
            else:
                gender = torch.clip(gender, min=-1., max=1.)
                gender_mask = (gender < 0.).float()
                key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
                key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
            condition += key_shift_embed

        if hparams['use_speed_embed']:
            if velocity is not None:
                velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
                speed_embed = self.speed_embed(velocity[:, :, None])
            else:
                speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
            condition += speed_embed

        if hparams['use_spk_id']:
            if hasattr(self, 'frozen_spk_embed'):
                condition += self.frozen_spk_embed
            else:
                condition += spk_embed
        return condition


class FastSpeech2VarianceONNX(FastSpeech2Variance):
    def __init__(self, vocab_size):
        super().__init__(vocab_size=vocab_size)
        self.lr = LengthRegulator()

    def forward_encoder_word(self, tokens, word_div, word_dur):
        txt_embed = self.txt_embed(tokens)
        ph2word = self.lr(word_div)
        onset = ph2word > F.pad(ph2word, [1, -1])
        onset_embed = self.onset_embed(onset.long())
        ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word)
        word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None])
        x_masks = tokens == PAD_INDEX
        return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks

    def forward_encoder_phoneme(self, tokens, ph_dur):
        txt_embed = self.txt_embed(tokens)
        ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
        x_masks = tokens == PAD_INDEX
        return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks

    def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
        midi_embed = self.midi_embed(ph_midi)
        dur_cond = encoder_out + midi_embed
        if hparams['use_spk_id'] and spk_embed is not None:
            dur_cond += spk_embed
        ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks)
        return ph_dur

    def view_as_encoder(self):
        model = copy.deepcopy(self)
        if self.predict_dur:
            del model.dur_predictor
            model.forward = model.forward_encoder_word
        else:
            model.forward = model.forward_encoder_phoneme
        return model

    def view_as_dur_predictor(self):
        model = copy.deepcopy(self)
        del model.encoder
        model.forward = model.forward_dur_predictor
        return model