Spaces:
Sleeping
Sleeping
import matplotlib | |
import torch | |
import torch.distributions | |
import torch.optim | |
import torch.utils.data | |
import utils | |
import utils.infer_utils | |
from basics.base_dataset import BaseDataset | |
from basics.base_task import BaseTask | |
from modules.losses import DurationLoss, DiffusionLoss, RectifiedFlowLoss | |
from modules.metrics.curve import RawCurveAccuracy | |
from modules.metrics.duration import RhythmCorrectness, PhonemeDurationAccuracy | |
from modules.toplevel import DiffSingerVariance | |
from utils.hparams import hparams | |
from utils.plot import dur_to_figure, pitch_note_to_figure, curve_to_figure | |
matplotlib.use('Agg') | |
class VarianceDataset(BaseDataset): | |
def __init__(self, prefix, preload=False): | |
super(VarianceDataset, self).__init__(prefix, hparams['dataset_size_key'], preload) | |
need_energy = hparams['predict_energy'] | |
need_breathiness = hparams['predict_breathiness'] | |
need_voicing = hparams['predict_voicing'] | |
need_tension = hparams['predict_tension'] | |
self.predict_variances = need_energy or need_breathiness or need_voicing or need_tension | |
def collater(self, samples): | |
batch = super().collater(samples) | |
if batch['size'] == 0: | |
return batch | |
tokens = utils.collate_nd([s['tokens'] for s in samples], 0) | |
ph_dur = utils.collate_nd([s['ph_dur'] for s in samples], 0) | |
batch.update({ | |
'tokens': tokens, | |
'ph_dur': ph_dur | |
}) | |
if hparams['use_spk_id']: | |
batch['spk_ids'] = torch.LongTensor([s['spk_id'] for s in samples]) | |
if hparams['predict_dur']: | |
batch['ph2word'] = utils.collate_nd([s['ph2word'] for s in samples], 0) | |
batch['midi'] = utils.collate_nd([s['midi'] for s in samples], 0) | |
if hparams['predict_pitch']: | |
batch['note_midi'] = utils.collate_nd([s['note_midi'] for s in samples], -1) | |
batch['note_rest'] = utils.collate_nd([s['note_rest'] for s in samples], True) | |
batch['note_dur'] = utils.collate_nd([s['note_dur'] for s in samples], 0) | |
if hparams['use_glide_embed']: | |
batch['note_glide'] = utils.collate_nd([s['note_glide'] for s in samples], 0) | |
batch['mel2note'] = utils.collate_nd([s['mel2note'] for s in samples], 0) | |
batch['base_pitch'] = utils.collate_nd([s['base_pitch'] for s in samples], 0) | |
if hparams['predict_pitch'] or self.predict_variances: | |
batch['mel2ph'] = utils.collate_nd([s['mel2ph'] for s in samples], 0) | |
batch['pitch'] = utils.collate_nd([s['pitch'] for s in samples], 0) | |
batch['uv'] = utils.collate_nd([s['uv'] for s in samples], True) | |
if hparams['predict_energy']: | |
batch['energy'] = utils.collate_nd([s['energy'] for s in samples], 0) | |
if hparams['predict_breathiness']: | |
batch['breathiness'] = utils.collate_nd([s['breathiness'] for s in samples], 0) | |
if hparams['predict_voicing']: | |
batch['voicing'] = utils.collate_nd([s['voicing'] for s in samples], 0) | |
if hparams['predict_tension']: | |
batch['tension'] = utils.collate_nd([s['tension'] for s in samples], 0) | |
return batch | |
def random_retake_masks(b, t, device): | |
# 1/4 segments are True in average | |
B_masks = torch.randint(low=0, high=4, size=(b, 1), dtype=torch.long, device=device) == 0 | |
# 1/3 frames are True in average | |
T_masks = utils.random_continuous_masks(b, t, dim=1, device=device) | |
# 1/4 segments and 1/2 frames are True in average (1/4 + 3/4 * 1/3 = 1/2) | |
return B_masks | T_masks | |
class VarianceTask(BaseTask): | |
def __init__(self): | |
super().__init__() | |
self.dataset_cls = VarianceDataset | |
self.diffusion_type = hparams['diffusion_type'] | |
self.use_spk_id = hparams['use_spk_id'] | |
self.predict_dur = hparams['predict_dur'] | |
if self.predict_dur: | |
self.lambda_dur_loss = hparams['lambda_dur_loss'] | |
self.predict_pitch = hparams['predict_pitch'] | |
if self.predict_pitch: | |
self.lambda_pitch_loss = hparams['lambda_pitch_loss'] | |
predict_energy = hparams['predict_energy'] | |
predict_breathiness = hparams['predict_breathiness'] | |
predict_voicing = hparams['predict_voicing'] | |
predict_tension = hparams['predict_tension'] | |
self.variance_prediction_list = [] | |
if predict_energy: | |
self.variance_prediction_list.append('energy') | |
if predict_breathiness: | |
self.variance_prediction_list.append('breathiness') | |
if predict_voicing: | |
self.variance_prediction_list.append('voicing') | |
if predict_tension: | |
self.variance_prediction_list.append('tension') | |
self.predict_variances = len(self.variance_prediction_list) > 0 | |
self.lambda_var_loss = hparams['lambda_var_loss'] | |
super()._finish_init() | |
def _build_model(self): | |
return DiffSingerVariance( | |
vocab_size=len(self.phone_encoder), | |
) | |
# noinspection PyAttributeOutsideInit | |
def build_losses_and_metrics(self): | |
if self.predict_dur: | |
dur_hparams = hparams['dur_prediction_args'] | |
self.dur_loss = DurationLoss( | |
offset=dur_hparams['log_offset'], | |
loss_type=dur_hparams['loss_type'], | |
lambda_pdur=dur_hparams['lambda_pdur_loss'], | |
lambda_wdur=dur_hparams['lambda_wdur_loss'], | |
lambda_sdur=dur_hparams['lambda_sdur_loss'] | |
) | |
self.register_validation_loss('dur_loss') | |
self.register_validation_metric('rhythm_corr', RhythmCorrectness(tolerance=0.05)) | |
self.register_validation_metric('ph_dur_acc', PhonemeDurationAccuracy(tolerance=0.2)) | |
if self.predict_pitch: | |
if self.diffusion_type == 'ddpm': | |
self.pitch_loss = DiffusionLoss(loss_type=hparams['main_loss_type']) | |
elif self.diffusion_type == 'reflow': | |
self.pitch_loss = RectifiedFlowLoss( | |
loss_type=hparams['main_loss_type'], log_norm=hparams['main_loss_log_norm'] | |
) | |
else: | |
raise ValueError(f'Unknown diffusion type: {self.diffusion_type}') | |
self.register_validation_loss('pitch_loss') | |
self.register_validation_metric('pitch_acc', RawCurveAccuracy(tolerance=0.5)) | |
if self.predict_variances: | |
if self.diffusion_type == 'ddpm': | |
self.var_loss = DiffusionLoss(loss_type=hparams['main_loss_type']) | |
elif self.diffusion_type == 'reflow': | |
self.var_loss = RectifiedFlowLoss( | |
loss_type=hparams['main_loss_type'], log_norm=hparams['main_loss_log_norm'] | |
) | |
else: | |
raise ValueError(f'Unknown diffusion type: {self.diffusion_type}') | |
self.register_validation_loss('var_loss') | |
def run_model(self, sample, infer=False): | |
spk_ids = sample['spk_ids'] if self.use_spk_id else None # [B,] | |
txt_tokens = sample['tokens'] # [B, T_ph] | |
ph_dur = sample['ph_dur'] # [B, T_ph] | |
ph2word = sample.get('ph2word') # [B, T_ph] | |
midi = sample.get('midi') # [B, T_ph] | |
mel2ph = sample.get('mel2ph') # [B, T_s] | |
note_midi = sample.get('note_midi') # [B, T_n] | |
note_rest = sample.get('note_rest') # [B, T_n] | |
note_dur = sample.get('note_dur') # [B, T_n] | |
note_glide = sample.get('note_glide') # [B, T_n] | |
mel2note = sample.get('mel2note') # [B, T_s] | |
base_pitch = sample.get('base_pitch') # [B, T_s] | |
pitch = sample.get('pitch') # [B, T_s] | |
energy = sample.get('energy') # [B, T_s] | |
breathiness = sample.get('breathiness') # [B, T_s] | |
voicing = sample.get('voicing') # [B, T_s] | |
tension = sample.get('tension') # [B, T_s] | |
pitch_retake = variance_retake = None | |
if (self.predict_pitch or self.predict_variances) and not infer: | |
# randomly select continuous retaking regions | |
b = sample['size'] | |
t = mel2ph.shape[1] | |
device = mel2ph.device | |
if self.predict_pitch: | |
pitch_retake = random_retake_masks(b, t, device) | |
if self.predict_variances: | |
variance_retake = { | |
v_name: random_retake_masks(b, t, device) | |
for v_name in self.variance_prediction_list | |
} | |
output = self.model( | |
txt_tokens, midi=midi, ph2word=ph2word, | |
ph_dur=ph_dur, mel2ph=mel2ph, | |
note_midi=note_midi, note_rest=note_rest, | |
note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, | |
base_pitch=base_pitch, pitch=pitch, | |
energy=energy, breathiness=breathiness, voicing=voicing, tension=tension, | |
pitch_retake=pitch_retake, variance_retake=variance_retake, | |
spk_id=spk_ids, infer=infer | |
) | |
dur_pred, pitch_pred, variances_pred = output | |
if infer: | |
if dur_pred is not None: | |
dur_pred = dur_pred.round().long() | |
return dur_pred, pitch_pred, variances_pred # Tensor, Tensor, Dict[str, Tensor] | |
else: | |
losses = {} | |
if dur_pred is not None: | |
losses['dur_loss'] = self.lambda_dur_loss * self.dur_loss(dur_pred, ph_dur, ph2word=ph2word) | |
non_padding = (mel2ph > 0).unsqueeze(-1) if mel2ph is not None else None | |
if pitch_pred is not None: | |
if self.diffusion_type == 'ddpm': | |
pitch_x_recon, pitch_noise = pitch_pred | |
pitch_loss = self.pitch_loss( | |
pitch_x_recon, pitch_noise, non_padding=non_padding | |
) | |
elif self.diffusion_type == 'reflow': | |
pitch_v_pred, pitch_v_gt, t = pitch_pred | |
pitch_loss = self.pitch_loss( | |
pitch_v_pred, pitch_v_gt, t=t, non_padding=non_padding | |
) | |
else: | |
raise ValueError(f"Unknown diffusion type: {self.diffusion_type}") | |
losses['pitch_loss'] = self.lambda_pitch_loss * pitch_loss | |
if variances_pred is not None: | |
if self.diffusion_type == 'ddpm': | |
var_x_recon, var_noise = variances_pred | |
var_loss = self.var_loss( | |
var_x_recon, var_noise, non_padding=non_padding | |
) | |
elif self.diffusion_type == 'reflow': | |
var_v_pred, var_v_gt, t = variances_pred | |
var_loss = self.var_loss( | |
var_v_pred, var_v_gt, t=t, non_padding=non_padding | |
) | |
else: | |
raise ValueError(f"Unknown diffusion type: {self.diffusion_type}") | |
losses['var_loss'] = self.lambda_var_loss * var_loss | |
return losses | |
def _validation_step(self, sample, batch_idx): | |
losses = self.run_model(sample, infer=False) | |
if min(sample['indices']) < hparams['num_valid_plots']: | |
def sample_get(key, idx, abs_idx): | |
return sample[key][idx][:self.valid_dataset.metadata[key][abs_idx]].unsqueeze(0) | |
dur_preds, pitch_preds, variances_preds = self.run_model(sample, infer=True) | |
for i in range(len(sample['indices'])): | |
data_idx = sample['indices'][i] | |
if data_idx < hparams['num_valid_plots']: | |
if dur_preds is not None: | |
dur_len = self.valid_dataset.metadata['ph_dur'][data_idx] | |
tokens = sample_get('tokens', i, data_idx) | |
gt_dur = sample_get('ph_dur', i, data_idx) | |
pred_dur = dur_preds[i][:dur_len].unsqueeze(0) | |
ph2word = sample_get('ph2word', i, data_idx) | |
mask = tokens != 0 | |
self.valid_metrics['rhythm_corr'].update( | |
pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask | |
) | |
self.valid_metrics['ph_dur_acc'].update( | |
pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask | |
) | |
self.plot_dur(data_idx, gt_dur, pred_dur, tokens) | |
if pitch_preds is not None: | |
pitch_len = self.valid_dataset.metadata['pitch'][data_idx] | |
pred_pitch = sample_get('base_pitch', i, data_idx) + pitch_preds[i][:pitch_len].unsqueeze(0) | |
gt_pitch = sample_get('pitch', i, data_idx) | |
mask = (sample_get('mel2ph', i, data_idx) > 0) & ~sample_get('uv', i, data_idx) | |
self.valid_metrics['pitch_acc'].update(pred=pred_pitch, target=gt_pitch, mask=mask) | |
self.plot_pitch( | |
data_idx, | |
gt_pitch=gt_pitch, | |
pred_pitch=pred_pitch, | |
note_midi=sample_get('note_midi', i, data_idx), | |
note_dur=sample_get('note_dur', i, data_idx), | |
note_rest=sample_get('note_rest', i, data_idx) | |
) | |
for name in self.variance_prediction_list: | |
variance_len = self.valid_dataset.metadata[name][data_idx] | |
gt_variances = sample[name][i][:variance_len].unsqueeze(0) | |
pred_variances = variances_preds[name][i][:variance_len].unsqueeze(0) | |
self.plot_curve( | |
data_idx, | |
gt_curve=gt_variances, | |
pred_curve=pred_variances, | |
curve_name=name | |
) | |
return losses, sample['size'] | |
############ | |
# validation plots | |
############ | |
def plot_dur(self, data_idx, gt_dur, pred_dur, txt=None): | |
gt_dur = gt_dur[0].cpu().numpy() | |
pred_dur = pred_dur[0].cpu().numpy() | |
txt = self.phone_encoder.decode(txt[0].cpu().numpy()).split() | |
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" | |
self.logger.all_rank_experiment.add_figure(f'dur_{data_idx}', dur_to_figure( | |
gt_dur, pred_dur, txt, title_text | |
), self.global_step) | |
def plot_pitch(self, data_idx, gt_pitch, pred_pitch, note_midi, note_dur, note_rest): | |
gt_pitch = gt_pitch[0].cpu().numpy() | |
pred_pitch = pred_pitch[0].cpu().numpy() | |
note_midi = note_midi[0].cpu().numpy() | |
note_dur = note_dur[0].cpu().numpy() | |
note_rest = note_rest[0].cpu().numpy() | |
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" | |
self.logger.all_rank_experiment.add_figure(f'pitch_{data_idx}', pitch_note_to_figure( | |
gt_pitch, pred_pitch, note_midi, note_dur, note_rest, title_text | |
), self.global_step) | |
def plot_curve(self, data_idx, gt_curve, pred_curve, base_curve=None, grid=None, curve_name='curve'): | |
gt_curve = gt_curve[0].cpu().numpy() | |
pred_curve = pred_curve[0].cpu().numpy() | |
if base_curve is not None: | |
base_curve = base_curve[0].cpu().numpy() | |
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" | |
self.logger.all_rank_experiment.add_figure(f'{curve_name}_{data_idx}', curve_to_figure( | |
gt_curve, pred_curve, base_curve, grid=grid, title=title_text | |
), self.global_step) | |