Spaces:
Sleeping
Sleeping
File size: 6,298 Bytes
c42fe7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.commons.common_layers import NormalInitEmbedding as Embedding
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.variance_encoder import FastSpeech2Variance
from utils.hparams import hparams
from utils.text_encoder import PAD_INDEX
f0_bin = 256
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
def f0_to_coarse(f0):
f0_mel = 1127 * (1 + f0 / 700).log()
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
b = f0_mel_min * a - 1.
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
f0_coarse = torch.round(f0_mel).long()
return f0_coarse
class LengthRegulator(nn.Module):
# noinspection PyMethodMayBeStatic
def forward(self, dur):
token_idx = torch.arange(1, dur.shape[1] + 1, device=dur.device)[None, :, None]
dur_cumsum = torch.cumsum(dur, dim=1)
dur_cumsum_prev = F.pad(dur_cumsum, (1, -1), mode='constant', value=0)
pos_idx = torch.arange(dur.sum(dim=1).max(), device=dur.device)[None, None]
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
mel2ph = (token_idx * token_mask).sum(dim=1)
return mel2ph
class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
def __init__(self, vocab_size):
super().__init__(vocab_size=vocab_size)
# for temporary compatibility; will be completely removed in the future
self.f0_embed_type = hparams.get('f0_embed_type', 'continuous')
if self.f0_embed_type == 'discrete':
self.pitch_embed = Embedding(300, hparams['hidden_size'], PAD_INDEX)
self.lr = LengthRegulator()
if hparams['use_key_shift_embed']:
self.shift_min, self.shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
if hparams['use_speed_embed']:
self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range']
# noinspection PyMethodOverriding
def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None):
txt_embed = self.txt_embed(tokens)
durations = durations * (tokens > 0)
mel2ph = self.lr(durations)
f0 = f0 * (mel2ph > 0)
mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
dur_embed = self.dur_embed(durations.float()[:, :, None])
encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX)
encoded = F.pad(encoded, (0, 0, 1, 0))
condition = torch.gather(encoded, 1, mel2ph)
if self.f0_embed_type == 'discrete':
pitch = f0_to_coarse(f0)
pitch_embed = self.pitch_embed(pitch)
else:
f0_mel = (1 + f0 / 700).log()
pitch_embed = self.pitch_embed(f0_mel[:, :, None])
condition += pitch_embed
if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds
if hparams['use_key_shift_embed']:
if hasattr(self, 'frozen_key_shift'):
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
else:
gender = torch.clip(gender, min=-1., max=1.)
gender_mask = (gender < 0.).float()
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
condition += key_shift_embed
if hparams['use_speed_embed']:
if velocity is not None:
velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
speed_embed = self.speed_embed(velocity[:, :, None])
else:
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
condition += speed_embed
if hparams['use_spk_id']:
if hasattr(self, 'frozen_spk_embed'):
condition += self.frozen_spk_embed
else:
condition += spk_embed
return condition
class FastSpeech2VarianceONNX(FastSpeech2Variance):
def __init__(self, vocab_size):
super().__init__(vocab_size=vocab_size)
self.lr = LengthRegulator()
def forward_encoder_word(self, tokens, word_div, word_dur):
txt_embed = self.txt_embed(tokens)
ph2word = self.lr(word_div)
onset = ph2word > F.pad(ph2word, [1, -1])
onset_embed = self.onset_embed(onset.long())
ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word)
word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None])
x_masks = tokens == PAD_INDEX
return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks
def forward_encoder_phoneme(self, tokens, ph_dur):
txt_embed = self.txt_embed(tokens)
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
x_masks = tokens == PAD_INDEX
return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
midi_embed = self.midi_embed(ph_midi)
dur_cond = encoder_out + midi_embed
if hparams['use_spk_id'] and spk_embed is not None:
dur_cond += spk_embed
ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks)
return ph_dur
def view_as_encoder(self):
model = copy.deepcopy(self)
if self.predict_dur:
del model.dur_predictor
model.forward = model.forward_encoder_word
else:
model.forward = model.forward_encoder_phoneme
return model
def view_as_dur_predictor(self):
model = copy.deepcopy(self)
del model.encoder
model.forward = model.forward_dur_predictor
return model
|