Spaces:
Sleeping
Sleeping
import copy | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from deployment.modules.diffusion import ( | |
GaussianDiffusionONNX, PitchDiffusionONNX, MultiVarianceDiffusionONNX | |
) | |
from deployment.modules.rectified_flow import ( | |
RectifiedFlowONNX, PitchRectifiedFlowONNX, MultiVarianceRectifiedFlowONNX | |
) | |
from deployment.modules.fastspeech2 import FastSpeech2AcousticONNX, FastSpeech2VarianceONNX | |
from modules.toplevel import DiffSingerAcoustic, DiffSingerVariance | |
from utils.hparams import hparams | |
class DiffSingerAcousticONNX(DiffSingerAcoustic): | |
def __init__(self, vocab_size, out_dims): | |
super().__init__(vocab_size, out_dims) | |
del self.fs2 | |
del self.diffusion | |
self.fs2 = FastSpeech2AcousticONNX( | |
vocab_size=vocab_size | |
) | |
if self.diffusion_type == 'ddpm': | |
self.diffusion = GaussianDiffusionONNX( | |
out_dims=out_dims, | |
num_feats=1, | |
timesteps=hparams['timesteps'], | |
k_step=hparams['K_step'], | |
backbone_type=self.backbone_type, | |
backbone_args=self.backbone_args, | |
spec_min=hparams['spec_min'], | |
spec_max=hparams['spec_max'] | |
) | |
elif self.diffusion_type == 'reflow': | |
self.diffusion = RectifiedFlowONNX( | |
out_dims=out_dims, | |
num_feats=1, | |
t_start=hparams['T_start'], | |
time_scale_factor=hparams['time_scale_factor'], | |
backbone_type=self.backbone_type, | |
backbone_args=self.backbone_args, | |
spec_min=hparams['spec_min'], | |
spec_max=hparams['spec_max'] | |
) | |
else: | |
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") | |
self.mel_base = hparams.get('mel_base', '10') | |
def ensure_mel_base(self, mel): | |
if self.mel_base != 'e': | |
# log10 mel to log mel | |
mel = mel * 2.30259 | |
return mel | |
def forward_fs2_aux( | |
self, | |
tokens: Tensor, | |
durations: Tensor, | |
f0: Tensor, | |
variances: dict, | |
gender: Tensor = None, | |
velocity: Tensor = None, | |
spk_embed: Tensor = None | |
): | |
condition = self.fs2( | |
tokens, durations, f0, variances=variances, | |
gender=gender, velocity=velocity, spk_embed=spk_embed | |
) | |
if self.use_shallow_diffusion: | |
aux_mel_pred = self.aux_decoder(condition, infer=True) | |
return condition, aux_mel_pred | |
else: | |
return condition | |
def forward_shallow_diffusion( | |
self, condition: Tensor, x_start: Tensor, | |
depth, steps: int | |
) -> Tensor: | |
mel_pred = self.diffusion(condition, x_start=x_start, depth=depth, steps=steps) | |
return self.ensure_mel_base(mel_pred) | |
def forward_diffusion(self, condition: Tensor, steps: int): | |
mel_pred = self.diffusion(condition, steps=steps) | |
return self.ensure_mel_base(mel_pred) | |
def forward_shallow_reflow( | |
self, condition: Tensor, x_end: Tensor, | |
depth, steps: int | |
): | |
mel_pred = self.diffusion(condition, x_end=x_end, depth=depth, steps=steps) | |
return self.ensure_mel_base(mel_pred) | |
def forward_reflow(self, condition: Tensor, steps: int): | |
mel_pred = self.diffusion(condition, steps=steps) | |
return self.ensure_mel_base(mel_pred) | |
def view_as_fs2_aux(self) -> nn.Module: | |
model = copy.deepcopy(self) | |
del model.diffusion | |
model.forward = model.forward_fs2_aux | |
return model | |
def view_as_diffusion(self) -> nn.Module: | |
model = copy.deepcopy(self) | |
del model.fs2 | |
if self.use_shallow_diffusion: | |
del model.aux_decoder | |
model.forward = model.forward_shallow_diffusion | |
else: | |
model.forward = model.forward_diffusion | |
return model | |
def view_as_reflow(self) -> nn.Module: | |
model = copy.deepcopy(self) | |
del model.fs2 | |
if self.use_shallow_diffusion: | |
del model.aux_decoder | |
model.forward = model.forward_shallow_reflow | |
else: | |
model.forward = model.forward_reflow | |
return model | |
class DiffSingerVarianceONNX(DiffSingerVariance): | |
def __init__(self, vocab_size): | |
super().__init__(vocab_size=vocab_size) | |
del self.fs2 | |
self.fs2 = FastSpeech2VarianceONNX( | |
vocab_size=vocab_size | |
) | |
self.hidden_size = hparams['hidden_size'] | |
if self.predict_pitch: | |
del self.pitch_predictor | |
self.smooth: nn.Conv1d = None | |
pitch_hparams = hparams['pitch_prediction_args'] | |
if self.diffusion_type == 'ddpm': | |
self.pitch_predictor = PitchDiffusionONNX( | |
vmin=pitch_hparams['pitd_norm_min'], | |
vmax=pitch_hparams['pitd_norm_max'], | |
cmin=pitch_hparams['pitd_clip_min'], | |
cmax=pitch_hparams['pitd_clip_max'], | |
repeat_bins=pitch_hparams['repeat_bins'], | |
timesteps=hparams['timesteps'], | |
k_step=hparams['K_step'], | |
backbone_type=self.pitch_backbone_type, | |
backbone_args=self.pitch_backbone_args | |
) | |
elif self.diffusion_type == 'reflow': | |
self.pitch_predictor = PitchRectifiedFlowONNX( | |
vmin=pitch_hparams['pitd_norm_min'], | |
vmax=pitch_hparams['pitd_norm_max'], | |
cmin=pitch_hparams['pitd_clip_min'], | |
cmax=pitch_hparams['pitd_clip_max'], | |
repeat_bins=pitch_hparams['repeat_bins'], | |
time_scale_factor=hparams['time_scale_factor'], | |
backbone_type=self.pitch_backbone_type, | |
backbone_args=self.pitch_backbone_args | |
) | |
else: | |
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") | |
if self.predict_variances: | |
del self.variance_predictor | |
if self.diffusion_type == 'ddpm': | |
self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusionONNX) | |
elif self.diffusion_type == 'reflow': | |
self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlowONNX) | |
else: | |
raise NotImplementedError(self.diffusion_type) | |
def build_smooth_op(self, device): | |
smooth_kernel_size = round(hparams['midi_smooth_width'] * hparams['audio_sample_rate'] / hparams['hop_size']) | |
smooth = nn.Conv1d( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=smooth_kernel_size, | |
bias=False, | |
padding='same', | |
padding_mode='replicate' | |
).eval() | |
smooth_kernel = torch.sin(torch.from_numpy( | |
np.linspace(0, 1, smooth_kernel_size).astype(np.float32) * np.pi | |
)) | |
smooth_kernel /= smooth_kernel.sum() | |
smooth.weight.data = smooth_kernel[None, None] | |
self.smooth = smooth.to(device) | |
def embed_frozen_spk(self, encoder_out): | |
if hparams['use_spk_id'] and hasattr(self, 'frozen_spk_embed'): | |
encoder_out += self.frozen_spk_embed | |
return encoder_out | |
def forward_linguistic_encoder_word(self, tokens, word_div, word_dur): | |
encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur) | |
encoder_out = self.embed_frozen_spk(encoder_out) | |
return encoder_out, x_masks | |
def forward_linguistic_encoder_phoneme(self, tokens, ph_dur): | |
encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur) | |
encoder_out = self.embed_frozen_spk(encoder_out) | |
return encoder_out, x_masks | |
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): | |
return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed) | |
def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): | |
mel2x = self.lr(x_dur) | |
if x_dim is not None: | |
x_src = F.pad(x_src, [0, 0, 1, 0]) | |
mel2x = mel2x[..., None].repeat([1, 1, x_dim]) | |
else: | |
x_src = F.pad(x_src, [1, 0]) | |
x_cond = torch.gather(x_src, 1, mel2x) | |
return x_cond | |
def forward_pitch_preprocess( | |
self, encoder_out, ph_dur, | |
note_midi=None, note_rest=None, note_dur=None, note_glide=None, | |
pitch=None, expr=None, retake=None, spk_embed=None | |
): | |
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) | |
if self.use_melody_encoder: | |
if self.melody_encoder.use_glide_embed and note_glide is None: | |
note_glide = torch.LongTensor([[0]]).to(encoder_out.device) | |
melody_encoder_out = self.melody_encoder( | |
note_midi, note_rest, note_dur, | |
glide=note_glide | |
) | |
melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size) | |
condition += melody_encoder_out | |
if expr is None: | |
retake_embed = self.pitch_retake_embed(retake.long()) | |
else: | |
retake_true_embed = self.pitch_retake_embed( | |
torch.ones(1, 1, dtype=torch.long, device=encoder_out.device) | |
) # [B=1, T=1] => [B=1, T=1, H] | |
retake_false_embed = self.pitch_retake_embed( | |
torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device) | |
) # [B=1, T=1] => [B=1, T=1, H] | |
expr = (expr * retake)[:, :, None] # [B, T, 1] | |
retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed | |
pitch_cond = condition + retake_embed | |
frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) | |
base_pitch = self.smooth(frame_midi_pitch) | |
if self.use_melody_encoder: | |
delta_pitch = (pitch - base_pitch) * ~retake | |
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) | |
else: | |
base_pitch = base_pitch * retake + pitch * ~retake | |
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) | |
if hparams['use_spk_id'] and spk_embed is not None: | |
pitch_cond += spk_embed | |
return pitch_cond, base_pitch | |
def forward_pitch_reflow( | |
self, pitch_cond, steps: int = 10 | |
): | |
x_pred = self.pitch_predictor(pitch_cond, steps=steps) | |
return x_pred | |
def forward_pitch_postprocess(self, x_pred, base_pitch): | |
pitch_pred = self.pitch_predictor.clamp_spec(x_pred) + base_pitch | |
return pitch_pred | |
def forward_variance_preprocess( | |
self, encoder_out, ph_dur, pitch, | |
variances: dict = None, retake=None, spk_embed=None | |
): | |
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) | |
variance_cond = condition + self.pitch_embed(pitch[:, :, None]) | |
non_retake_masks = [ | |
v_retake.float() # [B, T, 1] | |
for v_retake in (~retake).split(1, dim=2) | |
] | |
variance_embeds = [ | |
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks | |
for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) | |
] | |
variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) | |
if hparams['use_spk_id'] and spk_embed is not None: | |
variance_cond += spk_embed | |
return variance_cond | |
def forward_variance_reflow(self, variance_cond, steps: int = 10): | |
xs_pred = self.variance_predictor(variance_cond, steps=steps) | |
return xs_pred | |
def forward_variance_postprocess(self, xs_pred): | |
if self.variance_predictor.num_feats == 1: | |
xs_pred = [xs_pred] | |
else: | |
xs_pred = xs_pred.unbind(dim=1) | |
variance_pred = self.variance_predictor.clamp_spec(xs_pred) | |
return tuple(variance_pred) | |
def view_as_linguistic_encoder(self): | |
model = copy.deepcopy(self) | |
if self.predict_pitch: | |
del model.pitch_predictor | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
if self.predict_variances: | |
del model.variance_predictor | |
model.fs2 = model.fs2.view_as_encoder() | |
if self.predict_dur: | |
model.forward = model.forward_linguistic_encoder_word | |
else: | |
model.forward = model.forward_linguistic_encoder_phoneme | |
return model | |
def view_as_dur_predictor(self): | |
assert self.predict_dur | |
model = copy.deepcopy(self) | |
if self.predict_pitch: | |
del model.pitch_predictor | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
if self.predict_variances: | |
del model.variance_predictor | |
model.fs2 = model.fs2.view_as_dur_predictor() | |
model.forward = model.forward_dur_predictor | |
return model | |
def view_as_pitch_preprocess(self): | |
model = copy.deepcopy(self) | |
del model.fs2 | |
if self.predict_pitch: | |
del model.pitch_predictor | |
if self.predict_variances: | |
del model.variance_predictor | |
model.forward = model.forward_pitch_preprocess | |
return model | |
def view_as_pitch_predictor(self): | |
assert self.predict_pitch | |
model = copy.deepcopy(self) | |
del model.fs2 | |
del model.lr | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
if self.predict_variances: | |
del model.variance_predictor | |
model.forward = model.forward_pitch_reflow | |
return model | |
def view_as_pitch_postprocess(self): | |
model = copy.deepcopy(self) | |
del model.fs2 | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
if self.predict_variances: | |
del model.variance_predictor | |
model.forward = model.forward_pitch_postprocess | |
return model | |
def view_as_variance_preprocess(self): | |
model = copy.deepcopy(self) | |
del model.fs2 | |
if self.predict_pitch: | |
del model.pitch_predictor | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
if self.predict_variances: | |
del model.variance_predictor | |
model.forward = model.forward_variance_preprocess | |
return model | |
def view_as_variance_predictor(self): | |
assert self.predict_variances | |
model = copy.deepcopy(self) | |
del model.fs2 | |
del model.lr | |
if self.predict_pitch: | |
del model.pitch_predictor | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
model.forward = model.forward_variance_reflow | |
return model | |
def view_as_variance_postprocess(self): | |
model = copy.deepcopy(self) | |
del model.fs2 | |
if self.predict_pitch: | |
del model.pitch_predictor | |
if self.use_melody_encoder: | |
del model.melody_encoder | |
model.forward = model.forward_variance_postprocess | |
return model | |