Spaces:
Sleeping
Sleeping
from typing import Dict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
import modules.compat as compat | |
from basics.base_module import CategorizedModule | |
from modules.aux_decoder import AuxDecoderAdaptor | |
from modules.commons.common_layers import ( | |
XavierUniformInitLinear as Linear, | |
NormalInitEmbedding as Embedding | |
) | |
from modules.core import ( | |
GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion, | |
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow | |
) | |
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic | |
from modules.fastspeech.param_adaptor import ParameterAdaptorModule | |
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator | |
from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder | |
from utils.hparams import hparams | |
class ShallowDiffusionOutput: | |
def __init__(self, *, aux_out=None, diff_out=None): | |
self.aux_out = aux_out | |
self.diff_out = diff_out | |
class DiffSingerAcoustic(CategorizedModule, ParameterAdaptorModule): | |
def category(self): | |
return 'acoustic' | |
def __init__(self, vocab_size, out_dims): | |
CategorizedModule.__init__(self) | |
ParameterAdaptorModule.__init__(self) | |
self.fs2 = FastSpeech2Acoustic( | |
vocab_size=vocab_size | |
) | |
self.use_shallow_diffusion = hparams.get('use_shallow_diffusion', False) | |
self.shallow_args = hparams.get('shallow_diffusion_args', {}) | |
if self.use_shallow_diffusion: | |
self.train_aux_decoder = self.shallow_args['train_aux_decoder'] | |
self.train_diffusion = self.shallow_args['train_diffusion'] | |
self.aux_decoder_grad = self.shallow_args['aux_decoder_grad'] | |
self.aux_decoder = AuxDecoderAdaptor( | |
in_dims=hparams['hidden_size'], out_dims=out_dims, num_feats=1, | |
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], | |
aux_decoder_arch=self.shallow_args['aux_decoder_arch'], | |
aux_decoder_args=self.shallow_args['aux_decoder_args'] | |
) | |
self.diffusion_type = hparams.get('diffusion_type', 'ddpm') | |
self.backbone_type = compat.get_backbone_type(hparams) | |
self.backbone_args = compat.get_backbone_args(hparams, self.backbone_type) | |
if self.diffusion_type == 'ddpm': | |
self.diffusion = GaussianDiffusion( | |
out_dims=out_dims, | |
num_feats=1, | |
timesteps=hparams['timesteps'], | |
k_step=hparams['K_step'], | |
backbone_type=self.backbone_type, | |
backbone_args=self.backbone_args, | |
spec_min=hparams['spec_min'], | |
spec_max=hparams['spec_max'] | |
) | |
elif self.diffusion_type == 'reflow': | |
self.diffusion = RectifiedFlow( | |
out_dims=out_dims, | |
num_feats=1, | |
t_start=hparams['T_start'], | |
time_scale_factor=hparams['time_scale_factor'], | |
backbone_type=self.backbone_type, | |
backbone_args=self.backbone_args, | |
spec_min=hparams['spec_min'], | |
spec_max=hparams['spec_max'] | |
) | |
else: | |
raise NotImplementedError(self.diffusion_type) | |
def forward( | |
self, txt_tokens, mel2ph, f0, key_shift=None, speed=None, | |
spk_embed_id=None, gt_mel=None, infer=True, **kwargs | |
) -> ShallowDiffusionOutput: | |
condition = self.fs2( | |
txt_tokens, mel2ph, f0, key_shift=key_shift, speed=speed, | |
spk_embed_id=spk_embed_id, **kwargs | |
) | |
if infer: | |
if self.use_shallow_diffusion: | |
aux_mel_pred = self.aux_decoder(condition, infer=True) | |
aux_mel_pred *= ((mel2ph > 0).float()[:, :, None]) | |
if gt_mel is not None and self.shallow_args['val_gt_start']: | |
src_mel = gt_mel | |
else: | |
src_mel = aux_mel_pred | |
else: | |
aux_mel_pred = src_mel = None | |
mel_pred = self.diffusion(condition, src_spec=src_mel, infer=True) | |
mel_pred *= ((mel2ph > 0).float()[:, :, None]) | |
return ShallowDiffusionOutput(aux_out=aux_mel_pred, diff_out=mel_pred) | |
else: | |
if self.use_shallow_diffusion: | |
if self.train_aux_decoder: | |
aux_cond = condition * self.aux_decoder_grad + condition.detach() * (1 - self.aux_decoder_grad) | |
aux_out = self.aux_decoder(aux_cond, infer=False) | |
else: | |
aux_out = None | |
if self.train_diffusion: | |
diff_out = self.diffusion(condition, gt_spec=gt_mel, infer=False) | |
else: | |
diff_out = None | |
return ShallowDiffusionOutput(aux_out=aux_out, diff_out=diff_out) | |
else: | |
aux_out = None | |
diff_out = self.diffusion(condition, gt_spec=gt_mel, infer=False) | |
return ShallowDiffusionOutput(aux_out=aux_out, diff_out=diff_out) | |
class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule): | |
def category(self): | |
return 'variance' | |
def __init__(self, vocab_size): | |
CategorizedModule.__init__(self) | |
ParameterAdaptorModule.__init__(self) | |
self.predict_dur = hparams['predict_dur'] | |
self.predict_pitch = hparams['predict_pitch'] | |
self.use_spk_id = hparams['use_spk_id'] | |
if self.use_spk_id: | |
self.spk_embed = Embedding(hparams['num_spk'], hparams['hidden_size']) | |
self.fs2 = FastSpeech2Variance( | |
vocab_size=vocab_size | |
) | |
self.rr = RhythmRegulator() | |
self.lr = LengthRegulator() | |
self.diffusion_type = hparams.get('diffusion_type', 'ddpm') | |
if self.predict_pitch: | |
self.use_melody_encoder = hparams.get('use_melody_encoder', False) | |
if self.use_melody_encoder: | |
self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) | |
self.delta_pitch_embed = Linear(1, hparams['hidden_size']) | |
else: | |
self.base_pitch_embed = Linear(1, hparams['hidden_size']) | |
self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) | |
pitch_hparams = hparams['pitch_prediction_args'] | |
self.pitch_backbone_type = compat.get_backbone_type(hparams, nested_config=pitch_hparams) | |
self.pitch_backbone_args = compat.get_backbone_args(pitch_hparams, backbone_type=self.pitch_backbone_type) | |
if self.diffusion_type == 'ddpm': | |
self.pitch_predictor = PitchDiffusion( | |
vmin=pitch_hparams['pitd_norm_min'], | |
vmax=pitch_hparams['pitd_norm_max'], | |
cmin=pitch_hparams['pitd_clip_min'], | |
cmax=pitch_hparams['pitd_clip_max'], | |
repeat_bins=pitch_hparams['repeat_bins'], | |
timesteps=hparams['timesteps'], | |
k_step=hparams['K_step'], | |
backbone_type=self.pitch_backbone_type, | |
backbone_args=self.pitch_backbone_args | |
) | |
elif self.diffusion_type == 'reflow': | |
self.pitch_predictor = PitchRectifiedFlow( | |
vmin=pitch_hparams['pitd_norm_min'], | |
vmax=pitch_hparams['pitd_norm_max'], | |
cmin=pitch_hparams['pitd_clip_min'], | |
cmax=pitch_hparams['pitd_clip_max'], | |
repeat_bins=pitch_hparams['repeat_bins'], | |
time_scale_factor=hparams['time_scale_factor'], | |
backbone_type=self.pitch_backbone_type, | |
backbone_args=self.pitch_backbone_args | |
) | |
else: | |
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") | |
if self.predict_variances: | |
self.pitch_embed = Linear(1, hparams['hidden_size']) | |
self.variance_embeds = nn.ModuleDict({ | |
v_name: Linear(1, hparams['hidden_size']) | |
for v_name in self.variance_prediction_list | |
}) | |
if self.diffusion_type == 'ddpm': | |
self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusion) | |
elif self.diffusion_type == 'reflow': | |
self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlow) | |
else: | |
raise NotImplementedError(self.diffusion_type) | |
def forward( | |
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None, | |
note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None, | |
base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None, | |
variance_retake: Dict[str, Tensor] = None, | |
spk_id=None, infer=True, **kwargs | |
): | |
if self.use_spk_id: | |
ph_spk_mix_embed = kwargs.get('ph_spk_mix_embed') | |
spk_mix_embed = kwargs.get('spk_mix_embed') | |
if ph_spk_mix_embed is not None and spk_mix_embed is not None: | |
ph_spk_embed = ph_spk_mix_embed | |
spk_embed = spk_mix_embed | |
else: | |
ph_spk_embed = spk_embed = self.spk_embed(spk_id)[:, None, :] # [B,] => [B, T=1, H] | |
else: | |
ph_spk_embed = spk_embed = None | |
encoder_out, dur_pred_out = self.fs2( | |
txt_tokens, midi=midi, ph2word=ph2word, | |
ph_dur=ph_dur, word_dur=word_dur, | |
spk_embed=ph_spk_embed, infer=infer | |
) | |
if not self.predict_pitch and not self.predict_variances: | |
return dur_pred_out, None, ({} if infer else None) | |
if mel2ph is None and word_dur is not None: # inference from file | |
dur_pred_align = self.rr(dur_pred_out, ph2word, word_dur) | |
mel2ph = self.lr(dur_pred_align) | |
mel2ph = F.pad(mel2ph, [0, base_pitch.shape[1] - mel2ph.shape[1]]) | |
encoder_out = F.pad(encoder_out, [0, 0, 1, 0]) | |
mel2ph_ = mel2ph[..., None].repeat([1, 1, hparams['hidden_size']]) | |
condition = torch.gather(encoder_out, 1, mel2ph_) | |
if self.use_spk_id: | |
condition += spk_embed | |
if self.predict_pitch: | |
if self.use_melody_encoder: | |
melody_encoder_out = self.melody_encoder( | |
note_midi, note_rest, note_dur, | |
glide=note_glide | |
) | |
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0]) | |
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']]) | |
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_) | |
pitch_cond = condition + melody_condition | |
else: | |
pitch_cond = condition.clone() # preserve the original tensor to avoid further inplace operations | |
retake_unset = pitch_retake is None | |
if retake_unset: | |
pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) | |
if pitch_expr is None: | |
pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long()) | |
else: | |
retake_true_embed = self.pitch_retake_embed( | |
torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device) | |
) # [B=1, T=1] => [B=1, T=1, H] | |
retake_false_embed = self.pitch_retake_embed( | |
torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device) | |
) # [B=1, T=1] => [B=1, T=1, H] | |
pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1] | |
pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed | |
pitch_cond += pitch_retake_embed | |
if self.use_melody_encoder: | |
if retake_unset: # generate from scratch | |
delta_pitch_in = torch.zeros_like(base_pitch) | |
else: | |
delta_pitch_in = (pitch - base_pitch) * ~pitch_retake | |
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None]) | |
else: | |
if not retake_unset: # retake | |
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake | |
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) | |
if infer: | |
pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) | |
else: | |
pitch_pred_out = self.pitch_predictor(pitch_cond, pitch - base_pitch, infer=False) | |
else: | |
pitch_pred_out = None | |
if not self.predict_variances: | |
return dur_pred_out, pitch_pred_out, ({} if infer else None) | |
if pitch is None: | |
pitch = base_pitch + pitch_pred_out | |
var_cond = condition + self.pitch_embed(pitch[:, :, None]) | |
variance_inputs = self.collect_variance_inputs(**kwargs) | |
if variance_retake is not None: | |
variance_embeds = [ | |
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] | |
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs) | |
] | |
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1) | |
variance_outputs = self.variance_predictor(var_cond, variance_inputs, infer=infer) | |
if infer: | |
variances_pred_out = self.collect_variance_outputs(variance_outputs) | |
else: | |
variances_pred_out = variance_outputs | |
return dur_pred_out, pitch_pred_out, variances_pred_out | |