CantusSVS-hf / modules /toplevel.py
liampond
Clean deploy snapshot
c42fe7e
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import modules.compat as compat
from basics.base_module import CategorizedModule
from modules.aux_decoder import AuxDecoderAdaptor
from modules.commons.common_layers import (
XavierUniformInitLinear as Linear,
NormalInitEmbedding as Embedding
)
from modules.core import (
GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion,
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
)
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.param_adaptor import ParameterAdaptorModule
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator
from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder
from utils.hparams import hparams
class ShallowDiffusionOutput:
def __init__(self, *, aux_out=None, diff_out=None):
self.aux_out = aux_out
self.diff_out = diff_out
class DiffSingerAcoustic(CategorizedModule, ParameterAdaptorModule):
@property
def category(self):
return 'acoustic'
def __init__(self, vocab_size, out_dims):
CategorizedModule.__init__(self)
ParameterAdaptorModule.__init__(self)
self.fs2 = FastSpeech2Acoustic(
vocab_size=vocab_size
)
self.use_shallow_diffusion = hparams.get('use_shallow_diffusion', False)
self.shallow_args = hparams.get('shallow_diffusion_args', {})
if self.use_shallow_diffusion:
self.train_aux_decoder = self.shallow_args['train_aux_decoder']
self.train_diffusion = self.shallow_args['train_diffusion']
self.aux_decoder_grad = self.shallow_args['aux_decoder_grad']
self.aux_decoder = AuxDecoderAdaptor(
in_dims=hparams['hidden_size'], out_dims=out_dims, num_feats=1,
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
aux_decoder_arch=self.shallow_args['aux_decoder_arch'],
aux_decoder_args=self.shallow_args['aux_decoder_args']
)
self.diffusion_type = hparams.get('diffusion_type', 'ddpm')
self.backbone_type = compat.get_backbone_type(hparams)
self.backbone_args = compat.get_backbone_args(hparams, self.backbone_type)
if self.diffusion_type == 'ddpm':
self.diffusion = GaussianDiffusion(
out_dims=out_dims,
num_feats=1,
timesteps=hparams['timesteps'],
k_step=hparams['K_step'],
backbone_type=self.backbone_type,
backbone_args=self.backbone_args,
spec_min=hparams['spec_min'],
spec_max=hparams['spec_max']
)
elif self.diffusion_type == 'reflow':
self.diffusion = RectifiedFlow(
out_dims=out_dims,
num_feats=1,
t_start=hparams['T_start'],
time_scale_factor=hparams['time_scale_factor'],
backbone_type=self.backbone_type,
backbone_args=self.backbone_args,
spec_min=hparams['spec_min'],
spec_max=hparams['spec_max']
)
else:
raise NotImplementedError(self.diffusion_type)
def forward(
self, txt_tokens, mel2ph, f0, key_shift=None, speed=None,
spk_embed_id=None, gt_mel=None, infer=True, **kwargs
) -> ShallowDiffusionOutput:
condition = self.fs2(
txt_tokens, mel2ph, f0, key_shift=key_shift, speed=speed,
spk_embed_id=spk_embed_id, **kwargs
)
if infer:
if self.use_shallow_diffusion:
aux_mel_pred = self.aux_decoder(condition, infer=True)
aux_mel_pred *= ((mel2ph > 0).float()[:, :, None])
if gt_mel is not None and self.shallow_args['val_gt_start']:
src_mel = gt_mel
else:
src_mel = aux_mel_pred
else:
aux_mel_pred = src_mel = None
mel_pred = self.diffusion(condition, src_spec=src_mel, infer=True)
mel_pred *= ((mel2ph > 0).float()[:, :, None])
return ShallowDiffusionOutput(aux_out=aux_mel_pred, diff_out=mel_pred)
else:
if self.use_shallow_diffusion:
if self.train_aux_decoder:
aux_cond = condition * self.aux_decoder_grad + condition.detach() * (1 - self.aux_decoder_grad)
aux_out = self.aux_decoder(aux_cond, infer=False)
else:
aux_out = None
if self.train_diffusion:
diff_out = self.diffusion(condition, gt_spec=gt_mel, infer=False)
else:
diff_out = None
return ShallowDiffusionOutput(aux_out=aux_out, diff_out=diff_out)
else:
aux_out = None
diff_out = self.diffusion(condition, gt_spec=gt_mel, infer=False)
return ShallowDiffusionOutput(aux_out=aux_out, diff_out=diff_out)
class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
@property
def category(self):
return 'variance'
def __init__(self, vocab_size):
CategorizedModule.__init__(self)
ParameterAdaptorModule.__init__(self)
self.predict_dur = hparams['predict_dur']
self.predict_pitch = hparams['predict_pitch']
self.use_spk_id = hparams['use_spk_id']
if self.use_spk_id:
self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size'])
self.fs2 = FastSpeech2Variance(
vocab_size=vocab_size
)
self.rr = RhythmRegulator()
self.lr = LengthRegulator()
self.diffusion_type = hparams.get('diffusion_type', 'ddpm')
if self.predict_pitch:
self.use_melody_encoder = hparams.get('use_melody_encoder', False)
if self.use_melody_encoder:
self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args'])
self.delta_pitch_embed = Linear(1, hparams['hidden_size'])
else:
self.base_pitch_embed = Linear(1, hparams['hidden_size'])
self.pitch_retake_embed = Embedding(2, hparams['hidden_size'])
pitch_hparams = hparams['pitch_prediction_args']
self.pitch_backbone_type = compat.get_backbone_type(hparams, nested_config=pitch_hparams)
self.pitch_backbone_args = compat.get_backbone_args(pitch_hparams, backbone_type=self.pitch_backbone_type)
if self.diffusion_type == 'ddpm':
self.pitch_predictor = PitchDiffusion(
vmin=pitch_hparams['pitd_norm_min'],
vmax=pitch_hparams['pitd_norm_max'],
cmin=pitch_hparams['pitd_clip_min'],
cmax=pitch_hparams['pitd_clip_max'],
repeat_bins=pitch_hparams['repeat_bins'],
timesteps=hparams['timesteps'],
k_step=hparams['K_step'],
backbone_type=self.pitch_backbone_type,
backbone_args=self.pitch_backbone_args
)
elif self.diffusion_type == 'reflow':
self.pitch_predictor = PitchRectifiedFlow(
vmin=pitch_hparams['pitd_norm_min'],
vmax=pitch_hparams['pitd_norm_max'],
cmin=pitch_hparams['pitd_clip_min'],
cmax=pitch_hparams['pitd_clip_max'],
repeat_bins=pitch_hparams['repeat_bins'],
time_scale_factor=hparams['time_scale_factor'],
backbone_type=self.pitch_backbone_type,
backbone_args=self.pitch_backbone_args
)
else:
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
if self.predict_variances:
self.pitch_embed = Linear(1, hparams['hidden_size'])
self.variance_embeds = nn.ModuleDict({
v_name: Linear(1, hparams['hidden_size'])
for v_name in self.variance_prediction_list
})
if self.diffusion_type == 'ddpm':
self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusion)
elif self.diffusion_type == 'reflow':
self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlow)
else:
raise NotImplementedError(self.diffusion_type)
def forward(
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None,
note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None,
base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None,
variance_retake: Dict[str, Tensor] = None,
spk_id=None, infer=True, **kwargs
):
if self.use_spk_id:
ph_spk_mix_embed = kwargs.get('ph_spk_mix_embed')
spk_mix_embed = kwargs.get('spk_mix_embed')
if ph_spk_mix_embed is not None and spk_mix_embed is not None:
ph_spk_embed = ph_spk_mix_embed
spk_embed = spk_mix_embed
else:
ph_spk_embed = spk_embed = self.spk_embed(spk_id)[:, None, :] # [B,] => [B, T=1, H]
else:
ph_spk_embed = spk_embed = None
encoder_out, dur_pred_out = self.fs2(
txt_tokens, midi=midi, ph2word=ph2word,
ph_dur=ph_dur, word_dur=word_dur,
spk_embed=ph_spk_embed, infer=infer
)
if not self.predict_pitch and not self.predict_variances:
return dur_pred_out, None, ({} if infer else None)
if mel2ph is None and word_dur is not None: # inference from file
dur_pred_align = self.rr(dur_pred_out, ph2word, word_dur)
mel2ph = self.lr(dur_pred_align)
mel2ph = F.pad(mel2ph, [0, base_pitch.shape[1] - mel2ph.shape[1]])
encoder_out = F.pad(encoder_out, [0, 0, 1, 0])
mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']])
condition = torch.gather(encoder_out, 1, mel2ph_)
if self.use_spk_id:
condition += spk_embed
if self.predict_pitch:
if self.use_melody_encoder:
melody_encoder_out = self.melody_encoder(
note_midi, note_rest, note_dur,
glide=note_glide
)
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)
pitch_cond = condition + melody_condition
else:
pitch_cond = condition.clone() # preserve the original tensor to avoid further inplace operations
retake_unset = pitch_retake is None
if retake_unset:
pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool)
if pitch_expr is None:
pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long())
else:
retake_true_embed = self.pitch_retake_embed(
torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device)
) # [B=1, T=1] => [B=1, T=1, H]
retake_false_embed = self.pitch_retake_embed(
torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device)
) # [B=1, T=1] => [B=1, T=1, H]
pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1]
pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed
pitch_cond += pitch_retake_embed
if self.use_melody_encoder:
if retake_unset: # generate from scratch
delta_pitch_in = torch.zeros_like(base_pitch)
else:
delta_pitch_in = (pitch - base_pitch) * ~pitch_retake
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
else:
if not retake_unset: # retake
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if infer:
pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True)
else:
pitch_pred_out = self.pitch_predictor(pitch_cond, pitch - base_pitch, infer=False)
else:
pitch_pred_out = None
if not self.predict_variances:
return dur_pred_out, pitch_pred_out, ({} if infer else None)
if pitch is None:
pitch = base_pitch + pitch_pred_out
var_cond = condition + self.pitch_embed(pitch[:, :, None])
variance_inputs = self.collect_variance_inputs(**kwargs)
if variance_retake is not None:
variance_embeds = [
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None]
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs)
]
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1)
variance_outputs = self.variance_predictor(var_cond, variance_inputs, infer=infer)
if infer:
variances_pred_out = self.collect_variance_outputs(variance_outputs)
else:
variances_pred_out = variance_outputs
return dur_pred_out, pitch_pred_out, variances_pred_out