import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from modules.commons.common_layers import NormalInitEmbedding as Embedding from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.variance_encoder import FastSpeech2Variance from utils.hparams import hparams from utils.text_encoder import PAD_INDEX f0_bin = 256 f0_max = 1100.0 f0_min = 50.0 f0_mel_min = 1127 * np.log(1 + f0_min / 700) f0_mel_max = 1127 * np.log(1 + f0_max / 700) def f0_to_coarse(f0): f0_mel = 1127 * (1 + f0 / 700).log() a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) b = f0_mel_min * a - 1. f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) f0_coarse = torch.round(f0_mel).long() return f0_coarse class LengthRegulator(nn.Module): # noinspection PyMethodMayBeStatic def forward(self, dur): token_idx = torch.arange(1, dur.shape[1] + 1, device=dur.device)[None, :, None] dur_cumsum = torch.cumsum(dur, dim=1) dur_cumsum_prev = F.pad(dur_cumsum, (1, -1), mode='constant', value=0) pos_idx = torch.arange(dur.sum(dim=1).max(), device=dur.device)[None, None] token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) mel2ph = (token_idx * token_mask).sum(dim=1) return mel2ph class FastSpeech2AcousticONNX(FastSpeech2Acoustic): def __init__(self, vocab_size): super().__init__(vocab_size=vocab_size) # for temporary compatibility; will be completely removed in the future self.f0_embed_type = hparams.get('f0_embed_type', 'continuous') if self.f0_embed_type == 'discrete': self.pitch_embed = Embedding(300, hparams['hidden_size'], PAD_INDEX) self.lr = LengthRegulator() if hparams['use_key_shift_embed']: self.shift_min, self.shift_max = hparams['augmentation_args']['random_pitch_shifting']['range'] if hparams['use_speed_embed']: self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range'] # noinspection PyMethodOverriding def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None): txt_embed = self.txt_embed(tokens) durations = durations * (tokens > 0) mel2ph = self.lr(durations) f0 = f0 * (mel2ph > 0) mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size'])) dur_embed = self.dur_embed(durations.float()[:, :, None]) encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX) encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) if self.f0_embed_type == 'discrete': pitch = f0_to_coarse(f0) pitch_embed = self.pitch_embed(pitch) else: f0_mel = (1 + f0 / 700).log() pitch_embed = self.pitch_embed(f0_mel[:, :, None]) condition += pitch_embed if self.use_variance_embeds: variance_embeds = torch.stack([ self.variance_embeds[v_name](variances[v_name][:, :, None]) for v_name in self.variance_embed_list ], dim=-1).sum(-1) condition += variance_embeds if hparams['use_key_shift_embed']: if hasattr(self, 'frozen_key_shift'): key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None]) else: gender = torch.clip(gender, min=-1., max=1.) gender_mask = (gender < 0.).float() key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min)) key_shift_embed = self.key_shift_embed(key_shift[:, :, None]) condition += key_shift_embed if hparams['use_speed_embed']: if velocity is not None: velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max) speed_embed = self.speed_embed(velocity[:, :, None]) else: speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None]) condition += speed_embed if hparams['use_spk_id']: if hasattr(self, 'frozen_spk_embed'): condition += self.frozen_spk_embed else: condition += spk_embed return condition class FastSpeech2VarianceONNX(FastSpeech2Variance): def __init__(self, vocab_size): super().__init__(vocab_size=vocab_size) self.lr = LengthRegulator() def forward_encoder_word(self, tokens, word_div, word_dur): txt_embed = self.txt_embed(tokens) ph2word = self.lr(word_div) onset = ph2word > F.pad(ph2word, [1, -1]) onset_embed = self.onset_embed(onset.long()) ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word) word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None]) x_masks = tokens == PAD_INDEX return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks def forward_encoder_phoneme(self, tokens, ph_dur): txt_embed = self.txt_embed(tokens) ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) x_masks = tokens == PAD_INDEX return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): midi_embed = self.midi_embed(ph_midi) dur_cond = encoder_out + midi_embed if hparams['use_spk_id'] and spk_embed is not None: dur_cond += spk_embed ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks) return ph_dur def view_as_encoder(self): model = copy.deepcopy(self) if self.predict_dur: del model.dur_predictor model.forward = model.forward_encoder_word else: model.forward = model.forward_encoder_phoneme return model def view_as_dur_predictor(self): model = copy.deepcopy(self) del model.encoder model.forward = model.forward_dur_predictor return model