liampond
Clean deploy snapshot
c42fe7e
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.commons.common_layers import NormalInitEmbedding as Embedding
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.variance_encoder import FastSpeech2Variance
from utils.hparams import hparams
from utils.text_encoder import PAD_INDEX
f0_bin = 256
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
def f0_to_coarse(f0):
f0_mel = 1127 * (1 + f0 / 700).log()
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
b = f0_mel_min * a - 1.
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
f0_coarse = torch.round(f0_mel).long()
return f0_coarse
class LengthRegulator(nn.Module):
# noinspection PyMethodMayBeStatic
def forward(self, dur):
token_idx = torch.arange(1, dur.shape[1] + 1, device=dur.device)[None, :, None]
dur_cumsum = torch.cumsum(dur, dim=1)
dur_cumsum_prev = F.pad(dur_cumsum, (1, -1), mode='constant', value=0)
pos_idx = torch.arange(dur.sum(dim=1).max(), device=dur.device)[None, None]
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
mel2ph = (token_idx * token_mask).sum(dim=1)
return mel2ph
class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
def __init__(self, vocab_size):
super().__init__(vocab_size=vocab_size)
# for temporary compatibility; will be completely removed in the future
self.f0_embed_type = hparams.get('f0_embed_type', 'continuous')
if self.f0_embed_type == 'discrete':
self.pitch_embed = Embedding(300, hparams['hidden_size'], PAD_INDEX)
self.lr = LengthRegulator()
if hparams['use_key_shift_embed']:
self.shift_min, self.shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
if hparams['use_speed_embed']:
self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range']
# noinspection PyMethodOverriding
def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None):
txt_embed = self.txt_embed(tokens)
durations = durations * (tokens > 0)
mel2ph = self.lr(durations)
f0 = f0 * (mel2ph > 0)
mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
dur_embed = self.dur_embed(durations.float()[:, :, None])
encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX)
encoded = F.pad(encoded, (0, 0, 1, 0))
condition = torch.gather(encoded, 1, mel2ph)
if self.f0_embed_type == 'discrete':
pitch = f0_to_coarse(f0)
pitch_embed = self.pitch_embed(pitch)
else:
f0_mel = (1 + f0 / 700).log()
pitch_embed = self.pitch_embed(f0_mel[:, :, None])
condition += pitch_embed
if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds
if hparams['use_key_shift_embed']:
if hasattr(self, 'frozen_key_shift'):
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
else:
gender = torch.clip(gender, min=-1., max=1.)
gender_mask = (gender < 0.).float()
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
condition += key_shift_embed
if hparams['use_speed_embed']:
if velocity is not None:
velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
speed_embed = self.speed_embed(velocity[:, :, None])
else:
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
condition += speed_embed
if hparams['use_spk_id']:
if hasattr(self, 'frozen_spk_embed'):
condition += self.frozen_spk_embed
else:
condition += spk_embed
return condition
class FastSpeech2VarianceONNX(FastSpeech2Variance):
def __init__(self, vocab_size):
super().__init__(vocab_size=vocab_size)
self.lr = LengthRegulator()
def forward_encoder_word(self, tokens, word_div, word_dur):
txt_embed = self.txt_embed(tokens)
ph2word = self.lr(word_div)
onset = ph2word > F.pad(ph2word, [1, -1])
onset_embed = self.onset_embed(onset.long())
ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word)
word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None])
x_masks = tokens == PAD_INDEX
return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks
def forward_encoder_phoneme(self, tokens, ph_dur):
txt_embed = self.txt_embed(tokens)
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
x_masks = tokens == PAD_INDEX
return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
midi_embed = self.midi_embed(ph_midi)
dur_cond = encoder_out + midi_embed
if hparams['use_spk_id'] and spk_embed is not None:
dur_cond += spk_embed
ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks)
return ph_dur
def view_as_encoder(self):
model = copy.deepcopy(self)
if self.predict_dur:
del model.dur_predictor
model.forward = model.forward_encoder_word
else:
model.forward = model.forward_encoder_phoneme
return model
def view_as_dur_predictor(self):
model = copy.deepcopy(self)
del model.encoder
model.forward = model.forward_dur_predictor
return model