Spaces:
Sleeping
Sleeping
File size: 5,807 Bytes
c42fe7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.nn as nn
from torch.nn import functional as F
from modules.commons.common_layers import (
NormalInitEmbedding as Embedding,
XavierUniformInitLinear as Linear,
)
from modules.fastspeech.tts_modules import FastSpeech2Encoder, DurationPredictor
from utils.hparams import hparams
from utils.text_encoder import PAD_INDEX
class FastSpeech2Variance(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.predict_dur = hparams['predict_dur']
self.linguistic_mode = 'word' if hparams['predict_dur'] else 'phoneme'
self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX)
if self.predict_dur:
self.onset_embed = Embedding(2, hparams['hidden_size'])
self.word_dur_embed = Linear(1, hparams['hidden_size'])
else:
self.ph_dur_embed = Linear(1, hparams['hidden_size'])
self.encoder = FastSpeech2Encoder(
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False)
)
dur_hparams = hparams['dur_prediction_args']
if self.predict_dur:
self.midi_embed = Embedding(128, hparams['hidden_size'])
self.dur_predictor = DurationPredictor(
in_dims=hparams['hidden_size'],
n_chans=dur_hparams['hidden_size'],
n_layers=dur_hparams['num_layers'],
dropout_rate=dur_hparams['dropout'],
kernel_size=dur_hparams['kernel_size'],
offset=dur_hparams['log_offset'],
dur_loss_type=dur_hparams['loss_type']
)
def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_embed=None, infer=True):
"""
:param txt_tokens: (train, infer) [B, T_ph]
:param midi: (train, infer) [B, T_ph]
:param ph2word: (train, infer) [B, T_ph]
:param ph_dur: (train, [infer]) [B, T_ph]
:param word_dur: (infer) [B, T_w]
:param spk_embed: (train) [B, T_ph, H]
:param infer: whether inference
:return: encoder_out, ph_dur_pred
"""
txt_embed = self.txt_embed(txt_tokens)
if self.linguistic_mode == 'word':
b = txt_tokens.shape[0]
onset = torch.diff(ph2word, dim=1, prepend=ph2word.new_zeros(b, 1)) > 0
onset_embed = self.onset_embed(onset.long()) # [B, T_ph, H]
if word_dur is None or not infer:
word_dur = ph_dur.new_zeros(b, ph2word.max() + 1).scatter_add(
1, ph2word, ph_dur
)[:, 1:] # [B, T_ph] => [B, T_w]
word_dur = torch.gather(F.pad(word_dur, [1, 0], value=0), 1, ph2word) # [B, T_w] => [B, T_ph]
word_dur_embed = self.word_dur_embed(word_dur.float()[:, :, None])
encoder_out = self.encoder(txt_embed, onset_embed + word_dur_embed, txt_tokens == 0)
else:
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
encoder_out = self.encoder(txt_embed, ph_dur_embed, txt_tokens == 0)
if self.predict_dur:
midi_embed = self.midi_embed(midi) # => [B, T_ph, H]
dur_cond = encoder_out + midi_embed
if spk_embed is not None:
dur_cond += spk_embed
ph_dur_pred = self.dur_predictor(dur_cond, x_masks=txt_tokens == PAD_INDEX, infer=infer)
return encoder_out, ph_dur_pred
else:
return encoder_out, None
class MelodyEncoder(nn.Module):
def __init__(self, enc_hparams: dict):
super().__init__()
def get_hparam(key):
return enc_hparams.get(key, hparams.get(key))
# MIDI inputs
hidden_size = get_hparam('hidden_size')
self.note_midi_embed = Linear(1, hidden_size)
self.note_dur_embed = Linear(1, hidden_size)
# ornament inputs
self.use_glide_embed = hparams['use_glide_embed']
self.glide_embed_scale = hparams['glide_embed_scale']
if self.use_glide_embed:
# 0: none, 1: up, 2: down
self.note_glide_embed = Embedding(len(hparams['glide_types']) + 1, hidden_size, padding_idx=0)
self.encoder = FastSpeech2Encoder(
hidden_size=hidden_size, num_layers=get_hparam('enc_layers'),
ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'),
dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'),
use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'),
use_rope=get_hparam('use_rope')
)
self.out_proj = Linear(hidden_size, hparams['hidden_size'])
def forward(self, note_midi, note_rest, note_dur, glide=None):
"""
:param note_midi: float32 [B, T_n], -1: padding
:param note_rest: bool [B, T_n]
:param note_dur: int64 [B, T_n]
:param glide: int64 [B, T_n]
:return: [B, T_n, H]
"""
midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None]
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None])
ornament_embed = 0
if self.use_glide_embed:
ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale
encoder_out = self.encoder(
midi_embed, dur_embed + ornament_embed,
padding_mask=note_midi < 0
)
encoder_out = self.out_proj(encoder_out)
return encoder_out
|