import copy import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from deployment.modules.diffusion import ( GaussianDiffusionONNX, PitchDiffusionONNX, MultiVarianceDiffusionONNX ) from deployment.modules.rectified_flow import ( RectifiedFlowONNX, PitchRectifiedFlowONNX, MultiVarianceRectifiedFlowONNX ) from deployment.modules.fastspeech2 import FastSpeech2AcousticONNX, FastSpeech2VarianceONNX from modules.toplevel import DiffSingerAcoustic, DiffSingerVariance from utils.hparams import hparams class DiffSingerAcousticONNX(DiffSingerAcoustic): def __init__(self, vocab_size, out_dims): super().__init__(vocab_size, out_dims) del self.fs2 del self.diffusion self.fs2 = FastSpeech2AcousticONNX( vocab_size=vocab_size ) if self.diffusion_type == 'ddpm': self.diffusion = GaussianDiffusionONNX( out_dims=out_dims, num_feats=1, timesteps=hparams['timesteps'], k_step=hparams['K_step'], backbone_type=self.backbone_type, backbone_args=self.backbone_args, spec_min=hparams['spec_min'], spec_max=hparams['spec_max'] ) elif self.diffusion_type == 'reflow': self.diffusion = RectifiedFlowONNX( out_dims=out_dims, num_feats=1, t_start=hparams['T_start'], time_scale_factor=hparams['time_scale_factor'], backbone_type=self.backbone_type, backbone_args=self.backbone_args, spec_min=hparams['spec_min'], spec_max=hparams['spec_max'] ) else: raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") self.mel_base = hparams.get('mel_base', '10') def ensure_mel_base(self, mel): if self.mel_base != 'e': # log10 mel to log mel mel = mel * 2.30259 return mel def forward_fs2_aux( self, tokens: Tensor, durations: Tensor, f0: Tensor, variances: dict, gender: Tensor = None, velocity: Tensor = None, spk_embed: Tensor = None ): condition = self.fs2( tokens, durations, f0, variances=variances, gender=gender, velocity=velocity, spk_embed=spk_embed ) if self.use_shallow_diffusion: aux_mel_pred = self.aux_decoder(condition, infer=True) return condition, aux_mel_pred else: return condition def forward_shallow_diffusion( self, condition: Tensor, x_start: Tensor, depth, steps: int ) -> Tensor: mel_pred = self.diffusion(condition, x_start=x_start, depth=depth, steps=steps) return self.ensure_mel_base(mel_pred) def forward_diffusion(self, condition: Tensor, steps: int): mel_pred = self.diffusion(condition, steps=steps) return self.ensure_mel_base(mel_pred) def forward_shallow_reflow( self, condition: Tensor, x_end: Tensor, depth, steps: int ): mel_pred = self.diffusion(condition, x_end=x_end, depth=depth, steps=steps) return self.ensure_mel_base(mel_pred) def forward_reflow(self, condition: Tensor, steps: int): mel_pred = self.diffusion(condition, steps=steps) return self.ensure_mel_base(mel_pred) def view_as_fs2_aux(self) -> nn.Module: model = copy.deepcopy(self) del model.diffusion model.forward = model.forward_fs2_aux return model def view_as_diffusion(self) -> nn.Module: model = copy.deepcopy(self) del model.fs2 if self.use_shallow_diffusion: del model.aux_decoder model.forward = model.forward_shallow_diffusion else: model.forward = model.forward_diffusion return model def view_as_reflow(self) -> nn.Module: model = copy.deepcopy(self) del model.fs2 if self.use_shallow_diffusion: del model.aux_decoder model.forward = model.forward_shallow_reflow else: model.forward = model.forward_reflow return model class DiffSingerVarianceONNX(DiffSingerVariance): def __init__(self, vocab_size): super().__init__(vocab_size=vocab_size) del self.fs2 self.fs2 = FastSpeech2VarianceONNX( vocab_size=vocab_size ) self.hidden_size = hparams['hidden_size'] if self.predict_pitch: del self.pitch_predictor self.smooth: nn.Conv1d = None pitch_hparams = hparams['pitch_prediction_args'] if self.diffusion_type == 'ddpm': self.pitch_predictor = PitchDiffusionONNX( vmin=pitch_hparams['pitd_norm_min'], vmax=pitch_hparams['pitd_norm_max'], cmin=pitch_hparams['pitd_clip_min'], cmax=pitch_hparams['pitd_clip_max'], repeat_bins=pitch_hparams['repeat_bins'], timesteps=hparams['timesteps'], k_step=hparams['K_step'], backbone_type=self.pitch_backbone_type, backbone_args=self.pitch_backbone_args ) elif self.diffusion_type == 'reflow': self.pitch_predictor = PitchRectifiedFlowONNX( vmin=pitch_hparams['pitd_norm_min'], vmax=pitch_hparams['pitd_norm_max'], cmin=pitch_hparams['pitd_clip_min'], cmax=pitch_hparams['pitd_clip_max'], repeat_bins=pitch_hparams['repeat_bins'], time_scale_factor=hparams['time_scale_factor'], backbone_type=self.pitch_backbone_type, backbone_args=self.pitch_backbone_args ) else: raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") if self.predict_variances: del self.variance_predictor if self.diffusion_type == 'ddpm': self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusionONNX) elif self.diffusion_type == 'reflow': self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlowONNX) else: raise NotImplementedError(self.diffusion_type) def build_smooth_op(self, device): smooth_kernel_size = round(hparams['midi_smooth_width'] * hparams['audio_sample_rate'] / hparams['hop_size']) smooth = nn.Conv1d( in_channels=1, out_channels=1, kernel_size=smooth_kernel_size, bias=False, padding='same', padding_mode='replicate' ).eval() smooth_kernel = torch.sin(torch.from_numpy( np.linspace(0, 1, smooth_kernel_size).astype(np.float32) * np.pi )) smooth_kernel /= smooth_kernel.sum() smooth.weight.data = smooth_kernel[None, None] self.smooth = smooth.to(device) def embed_frozen_spk(self, encoder_out): if hparams['use_spk_id'] and hasattr(self, 'frozen_spk_embed'): encoder_out += self.frozen_spk_embed return encoder_out def forward_linguistic_encoder_word(self, tokens, word_div, word_dur): encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur) encoder_out = self.embed_frozen_spk(encoder_out) return encoder_out, x_masks def forward_linguistic_encoder_phoneme(self, tokens, ph_dur): encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur) encoder_out = self.embed_frozen_spk(encoder_out) return encoder_out, x_masks def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed) def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): mel2x = self.lr(x_dur) if x_dim is not None: x_src = F.pad(x_src, [0, 0, 1, 0]) mel2x = mel2x[..., None].repeat([1, 1, x_dim]) else: x_src = F.pad(x_src, [1, 0]) x_cond = torch.gather(x_src, 1, mel2x) return x_cond def forward_pitch_preprocess( self, encoder_out, ph_dur, note_midi=None, note_rest=None, note_dur=None, note_glide=None, pitch=None, expr=None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) if self.use_melody_encoder: if self.melody_encoder.use_glide_embed and note_glide is None: note_glide = torch.LongTensor([[0]]).to(encoder_out.device) melody_encoder_out = self.melody_encoder( note_midi, note_rest, note_dur, glide=note_glide ) melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size) condition += melody_encoder_out if expr is None: retake_embed = self.pitch_retake_embed(retake.long()) else: retake_true_embed = self.pitch_retake_embed( torch.ones(1, 1, dtype=torch.long, device=encoder_out.device) ) # [B=1, T=1] => [B=1, T=1, H] retake_false_embed = self.pitch_retake_embed( torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device) ) # [B=1, T=1] => [B=1, T=1, H] expr = (expr * retake)[:, :, None] # [B, T, 1] retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed pitch_cond = condition + retake_embed frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) base_pitch = self.smooth(frame_midi_pitch) if self.use_melody_encoder: delta_pitch = (pitch - base_pitch) * ~retake pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) else: base_pitch = base_pitch * retake + pitch * ~retake pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if hparams['use_spk_id'] and spk_embed is not None: pitch_cond += spk_embed return pitch_cond, base_pitch def forward_pitch_reflow( self, pitch_cond, steps: int = 10 ): x_pred = self.pitch_predictor(pitch_cond, steps=steps) return x_pred def forward_pitch_postprocess(self, x_pred, base_pitch): pitch_pred = self.pitch_predictor.clamp_spec(x_pred) + base_pitch return pitch_pred def forward_variance_preprocess( self, encoder_out, ph_dur, pitch, variances: dict = None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) variance_cond = condition + self.pitch_embed(pitch[:, :, None]) non_retake_masks = [ v_retake.float() # [B, T, 1] for v_retake in (~retake).split(1, dim=2) ] variance_embeds = [ self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) ] variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) if hparams['use_spk_id'] and spk_embed is not None: variance_cond += spk_embed return variance_cond def forward_variance_reflow(self, variance_cond, steps: int = 10): xs_pred = self.variance_predictor(variance_cond, steps=steps) return xs_pred def forward_variance_postprocess(self, xs_pred): if self.variance_predictor.num_feats == 1: xs_pred = [xs_pred] else: xs_pred = xs_pred.unbind(dim=1) variance_pred = self.variance_predictor.clamp_spec(xs_pred) return tuple(variance_pred) def view_as_linguistic_encoder(self): model = copy.deepcopy(self) if self.predict_pitch: del model.pitch_predictor if self.use_melody_encoder: del model.melody_encoder if self.predict_variances: del model.variance_predictor model.fs2 = model.fs2.view_as_encoder() if self.predict_dur: model.forward = model.forward_linguistic_encoder_word else: model.forward = model.forward_linguistic_encoder_phoneme return model def view_as_dur_predictor(self): assert self.predict_dur model = copy.deepcopy(self) if self.predict_pitch: del model.pitch_predictor if self.use_melody_encoder: del model.melody_encoder if self.predict_variances: del model.variance_predictor model.fs2 = model.fs2.view_as_dur_predictor() model.forward = model.forward_dur_predictor return model def view_as_pitch_preprocess(self): model = copy.deepcopy(self) del model.fs2 if self.predict_pitch: del model.pitch_predictor if self.predict_variances: del model.variance_predictor model.forward = model.forward_pitch_preprocess return model def view_as_pitch_predictor(self): assert self.predict_pitch model = copy.deepcopy(self) del model.fs2 del model.lr if self.use_melody_encoder: del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_pitch_reflow return model def view_as_pitch_postprocess(self): model = copy.deepcopy(self) del model.fs2 if self.use_melody_encoder: del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_pitch_postprocess return model def view_as_variance_preprocess(self): model = copy.deepcopy(self) del model.fs2 if self.predict_pitch: del model.pitch_predictor if self.use_melody_encoder: del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_variance_preprocess return model def view_as_variance_predictor(self): assert self.predict_variances model = copy.deepcopy(self) del model.fs2 del model.lr if self.predict_pitch: del model.pitch_predictor if self.use_melody_encoder: del model.melody_encoder model.forward = model.forward_variance_reflow return model def view_as_variance_postprocess(self): model = copy.deepcopy(self) del model.fs2 if self.predict_pitch: del model.pitch_predictor if self.use_melody_encoder: del model.melody_encoder model.forward = model.forward_variance_postprocess return model