CantusSVS-hf / modules /fastspeech /acoustic_encoder.py
liampond
Clean deploy snapshot
c42fe7e
raw
history blame
4.44 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
from modules.commons.common_layers import (
NormalInitEmbedding as Embedding,
XavierUniformInitLinear as Linear,
)
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur
from utils.hparams import hparams
from utils.text_encoder import PAD_INDEX
class FastSpeech2Acoustic(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX)
self.dur_embed = Linear(1, hparams['hidden_size'])
self.encoder = FastSpeech2Encoder(
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False)
)
self.pitch_embed = Linear(1, hparams['hidden_size'])
self.variance_embed_list = []
self.use_energy_embed = hparams.get('use_energy_embed', False)
self.use_breathiness_embed = hparams.get('use_breathiness_embed', False)
self.use_voicing_embed = hparams.get('use_voicing_embed', False)
self.use_tension_embed = hparams.get('use_tension_embed', False)
if self.use_energy_embed:
self.variance_embed_list.append('energy')
if self.use_breathiness_embed:
self.variance_embed_list.append('breathiness')
if self.use_voicing_embed:
self.variance_embed_list.append('voicing')
if self.use_tension_embed:
self.variance_embed_list.append('tension')
self.use_variance_embeds = len(self.variance_embed_list) > 0
if self.use_variance_embeds:
self.variance_embeds = nn.ModuleDict({
v_name: Linear(1, hparams['hidden_size'])
for v_name in self.variance_embed_list
})
self.use_key_shift_embed = hparams.get('use_key_shift_embed', False)
if self.use_key_shift_embed:
self.key_shift_embed = Linear(1, hparams['hidden_size'])
self.use_speed_embed = hparams.get('use_speed_embed', False)
if self.use_speed_embed:
self.speed_embed = Linear(1, hparams['hidden_size'])
self.use_spk_id = hparams['use_spk_id']
if self.use_spk_id:
self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size'])
def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances):
if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds
if self.use_key_shift_embed:
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
condition += key_shift_embed
if self.use_speed_embed:
speed_embed = self.speed_embed(speed[:, :, None])
condition += speed_embed
return condition
def forward(
self, txt_tokens, mel2ph, f0,
key_shift=None, speed=None,
spk_embed_id=None, **kwargs
):
txt_embed = self.txt_embed(txt_tokens)
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float()
dur_embed = self.dur_embed(dur[:, :, None])
encoder_out = self.encoder(txt_embed, dur_embed, txt_tokens == 0)
encoder_out = F.pad(encoder_out, [0, 0, 1, 0])
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
condition = torch.gather(encoder_out, 1, mel2ph_)
if self.use_spk_id:
spk_mix_embed = kwargs.get('spk_mix_embed')
if spk_mix_embed is not None:
spk_embed = spk_mix_embed
else:
spk_embed = self.spk_embed(spk_embed_id)[:, None, :]
condition += spk_embed
f0_mel = (1 + f0 / 700).log()
pitch_embed = self.pitch_embed(f0_mel[:, :, None])
condition += pitch_embed
condition = self.forward_variance_embedding(
condition, key_shift=key_shift, speed=speed, **kwargs
)
return condition