Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from modules.commons.common_layers import ( | |
NormalInitEmbedding as Embedding, | |
XavierUniformInitLinear as Linear, | |
) | |
from modules.fastspeech.tts_modules import FastSpeech2Encoder, DurationPredictor | |
from utils.hparams import hparams | |
from utils.text_encoder import PAD_INDEX | |
class FastSpeech2Variance(nn.Module): | |
def __init__(self, vocab_size): | |
super().__init__() | |
self.predict_dur = hparams['predict_dur'] | |
self.linguistic_mode = 'word' if hparams['predict_dur'] else 'phoneme' | |
self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX) | |
if self.predict_dur: | |
self.onset_embed = Embedding(2, hparams['hidden_size']) | |
self.word_dur_embed = Linear(1, hparams['hidden_size']) | |
else: | |
self.ph_dur_embed = Linear(1, hparams['hidden_size']) | |
self.encoder = FastSpeech2Encoder( | |
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], | |
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], | |
dropout=hparams['dropout'], num_heads=hparams['num_heads'], | |
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), | |
use_rope=hparams.get('use_rope', False) | |
) | |
dur_hparams = hparams['dur_prediction_args'] | |
if self.predict_dur: | |
self.midi_embed = Embedding(128, hparams['hidden_size']) | |
self.dur_predictor = DurationPredictor( | |
in_dims=hparams['hidden_size'], | |
n_chans=dur_hparams['hidden_size'], | |
n_layers=dur_hparams['num_layers'], | |
dropout_rate=dur_hparams['dropout'], | |
kernel_size=dur_hparams['kernel_size'], | |
offset=dur_hparams['log_offset'], | |
dur_loss_type=dur_hparams['loss_type'] | |
) | |
def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_embed=None, infer=True): | |
""" | |
:param txt_tokens: (train, infer) [B, T_ph] | |
:param midi: (train, infer) [B, T_ph] | |
:param ph2word: (train, infer) [B, T_ph] | |
:param ph_dur: (train, [infer]) [B, T_ph] | |
:param word_dur: (infer) [B, T_w] | |
:param spk_embed: (train) [B, T_ph, H] | |
:param infer: whether inference | |
:return: encoder_out, ph_dur_pred | |
""" | |
txt_embed = self.txt_embed(txt_tokens) | |
if self.linguistic_mode == 'word': | |
b = txt_tokens.shape[0] | |
onset = torch.diff(ph2word, dim=1, prepend=ph2word.new_zeros(b, 1)) > 0 | |
onset_embed = self.onset_embed(onset.long()) # [B, T_ph, H] | |
if word_dur is None or not infer: | |
word_dur = ph_dur.new_zeros(b, ph2word.max() + 1).scatter_add( | |
1, ph2word, ph_dur | |
)[:, 1:] # [B, T_ph] => [B, T_w] | |
word_dur = torch.gather(F.pad(word_dur, [1, 0], value=0), 1, ph2word) # [B, T_w] => [B, T_ph] | |
word_dur_embed = self.word_dur_embed(word_dur.float()[:, :, None]) | |
encoder_out = self.encoder(txt_embed, onset_embed + word_dur_embed, txt_tokens == 0) | |
else: | |
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) | |
encoder_out = self.encoder(txt_embed, ph_dur_embed, txt_tokens == 0) | |
if self.predict_dur: | |
midi_embed = self.midi_embed(midi) # => [B, T_ph, H] | |
dur_cond = encoder_out + midi_embed | |
if spk_embed is not None: | |
dur_cond += spk_embed | |
ph_dur_pred = self.dur_predictor(dur_cond, x_masks=txt_tokens == PAD_INDEX, infer=infer) | |
return encoder_out, ph_dur_pred | |
else: | |
return encoder_out, None | |
class MelodyEncoder(nn.Module): | |
def __init__(self, enc_hparams: dict): | |
super().__init__() | |
def get_hparam(key): | |
return enc_hparams.get(key, hparams.get(key)) | |
# MIDI inputs | |
hidden_size = get_hparam('hidden_size') | |
self.note_midi_embed = Linear(1, hidden_size) | |
self.note_dur_embed = Linear(1, hidden_size) | |
# ornament inputs | |
self.use_glide_embed = hparams['use_glide_embed'] | |
self.glide_embed_scale = hparams['glide_embed_scale'] | |
if self.use_glide_embed: | |
# 0: none, 1: up, 2: down | |
self.note_glide_embed = Embedding(len(hparams['glide_types']) + 1, hidden_size, padding_idx=0) | |
self.encoder = FastSpeech2Encoder( | |
hidden_size=hidden_size, num_layers=get_hparam('enc_layers'), | |
ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'), | |
dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), | |
use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'), | |
use_rope=get_hparam('use_rope') | |
) | |
self.out_proj = Linear(hidden_size, hparams['hidden_size']) | |
def forward(self, note_midi, note_rest, note_dur, glide=None): | |
""" | |
:param note_midi: float32 [B, T_n], -1: padding | |
:param note_rest: bool [B, T_n] | |
:param note_dur: int64 [B, T_n] | |
:param glide: int64 [B, T_n] | |
:return: [B, T_n, H] | |
""" | |
midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None] | |
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None]) | |
ornament_embed = 0 | |
if self.use_glide_embed: | |
ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale | |
encoder_out = self.encoder( | |
midi_embed, dur_embed + ornament_embed, | |
padding_mask=note_midi < 0 | |
) | |
encoder_out = self.out_proj(encoder_out) | |
return encoder_out | |