CantusSVS-hf / training /variance_task.py
liampond
Clean deploy snapshot
c42fe7e
raw
history blame
15.9 kB
import matplotlib
import torch
import torch.distributions
import torch.optim
import torch.utils.data
import utils
import utils.infer_utils
from basics.base_dataset import BaseDataset
from basics.base_task import BaseTask
from modules.losses import DurationLoss, DiffusionLoss, RectifiedFlowLoss
from modules.metrics.curve import RawCurveAccuracy
from modules.metrics.duration import RhythmCorrectness, PhonemeDurationAccuracy
from modules.toplevel import DiffSingerVariance
from utils.hparams import hparams
from utils.plot import dur_to_figure, pitch_note_to_figure, curve_to_figure
matplotlib.use('Agg')
class VarianceDataset(BaseDataset):
def __init__(self, prefix, preload=False):
super(VarianceDataset, self).__init__(prefix, hparams['dataset_size_key'], preload)
need_energy = hparams['predict_energy']
need_breathiness = hparams['predict_breathiness']
need_voicing = hparams['predict_voicing']
need_tension = hparams['predict_tension']
self.predict_variances = need_energy or need_breathiness or need_voicing or need_tension
def collater(self, samples):
batch = super().collater(samples)
if batch['size'] == 0:
return batch
tokens = utils.collate_nd([s['tokens'] for s in samples], 0)
ph_dur = utils.collate_nd([s['ph_dur'] for s in samples], 0)
batch.update({
'tokens': tokens,
'ph_dur': ph_dur
})
if hparams['use_spk_id']:
batch['spk_ids'] = torch.LongTensor([s['spk_id'] for s in samples])
if hparams['predict_dur']:
batch['ph2word'] = utils.collate_nd([s['ph2word'] for s in samples], 0)
batch['midi'] = utils.collate_nd([s['midi'] for s in samples], 0)
if hparams['predict_pitch']:
batch['note_midi'] = utils.collate_nd([s['note_midi'] for s in samples], -1)
batch['note_rest'] = utils.collate_nd([s['note_rest'] for s in samples], True)
batch['note_dur'] = utils.collate_nd([s['note_dur'] for s in samples], 0)
if hparams['use_glide_embed']:
batch['note_glide'] = utils.collate_nd([s['note_glide'] for s in samples], 0)
batch['mel2note'] = utils.collate_nd([s['mel2note'] for s in samples], 0)
batch['base_pitch'] = utils.collate_nd([s['base_pitch'] for s in samples], 0)
if hparams['predict_pitch'] or self.predict_variances:
batch['mel2ph'] = utils.collate_nd([s['mel2ph'] for s in samples], 0)
batch['pitch'] = utils.collate_nd([s['pitch'] for s in samples], 0)
batch['uv'] = utils.collate_nd([s['uv'] for s in samples], True)
if hparams['predict_energy']:
batch['energy'] = utils.collate_nd([s['energy'] for s in samples], 0)
if hparams['predict_breathiness']:
batch['breathiness'] = utils.collate_nd([s['breathiness'] for s in samples], 0)
if hparams['predict_voicing']:
batch['voicing'] = utils.collate_nd([s['voicing'] for s in samples], 0)
if hparams['predict_tension']:
batch['tension'] = utils.collate_nd([s['tension'] for s in samples], 0)
return batch
def random_retake_masks(b, t, device):
# 1/4 segments are True in average
B_masks = torch.randint(low=0, high=4, size=(b, 1), dtype=torch.long, device=device) == 0
# 1/3 frames are True in average
T_masks = utils.random_continuous_masks(b, t, dim=1, device=device)
# 1/4 segments and 1/2 frames are True in average (1/4 + 3/4 * 1/3 = 1/2)
return B_masks | T_masks
class VarianceTask(BaseTask):
def __init__(self):
super().__init__()
self.dataset_cls = VarianceDataset
self.diffusion_type = hparams['diffusion_type']
self.use_spk_id = hparams['use_spk_id']
self.predict_dur = hparams['predict_dur']
if self.predict_dur:
self.lambda_dur_loss = hparams['lambda_dur_loss']
self.predict_pitch = hparams['predict_pitch']
if self.predict_pitch:
self.lambda_pitch_loss = hparams['lambda_pitch_loss']
predict_energy = hparams['predict_energy']
predict_breathiness = hparams['predict_breathiness']
predict_voicing = hparams['predict_voicing']
predict_tension = hparams['predict_tension']
self.variance_prediction_list = []
if predict_energy:
self.variance_prediction_list.append('energy')
if predict_breathiness:
self.variance_prediction_list.append('breathiness')
if predict_voicing:
self.variance_prediction_list.append('voicing')
if predict_tension:
self.variance_prediction_list.append('tension')
self.predict_variances = len(self.variance_prediction_list) > 0
self.lambda_var_loss = hparams['lambda_var_loss']
super()._finish_init()
def _build_model(self):
return DiffSingerVariance(
vocab_size=len(self.phone_encoder),
)
# noinspection PyAttributeOutsideInit
def build_losses_and_metrics(self):
if self.predict_dur:
dur_hparams = hparams['dur_prediction_args']
self.dur_loss = DurationLoss(
offset=dur_hparams['log_offset'],
loss_type=dur_hparams['loss_type'],
lambda_pdur=dur_hparams['lambda_pdur_loss'],
lambda_wdur=dur_hparams['lambda_wdur_loss'],
lambda_sdur=dur_hparams['lambda_sdur_loss']
)
self.register_validation_loss('dur_loss')
self.register_validation_metric('rhythm_corr', RhythmCorrectness(tolerance=0.05))
self.register_validation_metric('ph_dur_acc', PhonemeDurationAccuracy(tolerance=0.2))
if self.predict_pitch:
if self.diffusion_type == 'ddpm':
self.pitch_loss = DiffusionLoss(loss_type=hparams['main_loss_type'])
elif self.diffusion_type == 'reflow':
self.pitch_loss = RectifiedFlowLoss(
loss_type=hparams['main_loss_type'], log_norm=hparams['main_loss_log_norm']
)
else:
raise ValueError(f'Unknown diffusion type: {self.diffusion_type}')
self.register_validation_loss('pitch_loss')
self.register_validation_metric('pitch_acc', RawCurveAccuracy(tolerance=0.5))
if self.predict_variances:
if self.diffusion_type == 'ddpm':
self.var_loss = DiffusionLoss(loss_type=hparams['main_loss_type'])
elif self.diffusion_type == 'reflow':
self.var_loss = RectifiedFlowLoss(
loss_type=hparams['main_loss_type'], log_norm=hparams['main_loss_log_norm']
)
else:
raise ValueError(f'Unknown diffusion type: {self.diffusion_type}')
self.register_validation_loss('var_loss')
def run_model(self, sample, infer=False):
spk_ids = sample['spk_ids'] if self.use_spk_id else None # [B,]
txt_tokens = sample['tokens'] # [B, T_ph]
ph_dur = sample['ph_dur'] # [B, T_ph]
ph2word = sample.get('ph2word') # [B, T_ph]
midi = sample.get('midi') # [B, T_ph]
mel2ph = sample.get('mel2ph') # [B, T_s]
note_midi = sample.get('note_midi') # [B, T_n]
note_rest = sample.get('note_rest') # [B, T_n]
note_dur = sample.get('note_dur') # [B, T_n]
note_glide = sample.get('note_glide') # [B, T_n]
mel2note = sample.get('mel2note') # [B, T_s]
base_pitch = sample.get('base_pitch') # [B, T_s]
pitch = sample.get('pitch') # [B, T_s]
energy = sample.get('energy') # [B, T_s]
breathiness = sample.get('breathiness') # [B, T_s]
voicing = sample.get('voicing') # [B, T_s]
tension = sample.get('tension') # [B, T_s]
pitch_retake = variance_retake = None
if (self.predict_pitch or self.predict_variances) and not infer:
# randomly select continuous retaking regions
b = sample['size']
t = mel2ph.shape[1]
device = mel2ph.device
if self.predict_pitch:
pitch_retake = random_retake_masks(b, t, device)
if self.predict_variances:
variance_retake = {
v_name: random_retake_masks(b, t, device)
for v_name in self.variance_prediction_list
}
output = self.model(
txt_tokens, midi=midi, ph2word=ph2word,
ph_dur=ph_dur, mel2ph=mel2ph,
note_midi=note_midi, note_rest=note_rest,
note_dur=note_dur, note_glide=note_glide, mel2note=mel2note,
base_pitch=base_pitch, pitch=pitch,
energy=energy, breathiness=breathiness, voicing=voicing, tension=tension,
pitch_retake=pitch_retake, variance_retake=variance_retake,
spk_id=spk_ids, infer=infer
)
dur_pred, pitch_pred, variances_pred = output
if infer:
if dur_pred is not None:
dur_pred = dur_pred.round().long()
return dur_pred, pitch_pred, variances_pred # Tensor, Tensor, Dict[str, Tensor]
else:
losses = {}
if dur_pred is not None:
losses['dur_loss'] = self.lambda_dur_loss * self.dur_loss(dur_pred, ph_dur, ph2word=ph2word)
non_padding = (mel2ph > 0).unsqueeze(-1) if mel2ph is not None else None
if pitch_pred is not None:
if self.diffusion_type == 'ddpm':
pitch_x_recon, pitch_noise = pitch_pred
pitch_loss = self.pitch_loss(
pitch_x_recon, pitch_noise, non_padding=non_padding
)
elif self.diffusion_type == 'reflow':
pitch_v_pred, pitch_v_gt, t = pitch_pred
pitch_loss = self.pitch_loss(
pitch_v_pred, pitch_v_gt, t=t, non_padding=non_padding
)
else:
raise ValueError(f"Unknown diffusion type: {self.diffusion_type}")
losses['pitch_loss'] = self.lambda_pitch_loss * pitch_loss
if variances_pred is not None:
if self.diffusion_type == 'ddpm':
var_x_recon, var_noise = variances_pred
var_loss = self.var_loss(
var_x_recon, var_noise, non_padding=non_padding
)
elif self.diffusion_type == 'reflow':
var_v_pred, var_v_gt, t = variances_pred
var_loss = self.var_loss(
var_v_pred, var_v_gt, t=t, non_padding=non_padding
)
else:
raise ValueError(f"Unknown diffusion type: {self.diffusion_type}")
losses['var_loss'] = self.lambda_var_loss * var_loss
return losses
def _validation_step(self, sample, batch_idx):
losses = self.run_model(sample, infer=False)
if min(sample['indices']) < hparams['num_valid_plots']:
def sample_get(key, idx, abs_idx):
return sample[key][idx][:self.valid_dataset.metadata[key][abs_idx]].unsqueeze(0)
dur_preds, pitch_preds, variances_preds = self.run_model(sample, infer=True)
for i in range(len(sample['indices'])):
data_idx = sample['indices'][i]
if data_idx < hparams['num_valid_plots']:
if dur_preds is not None:
dur_len = self.valid_dataset.metadata['ph_dur'][data_idx]
tokens = sample_get('tokens', i, data_idx)
gt_dur = sample_get('ph_dur', i, data_idx)
pred_dur = dur_preds[i][:dur_len].unsqueeze(0)
ph2word = sample_get('ph2word', i, data_idx)
mask = tokens != 0
self.valid_metrics['rhythm_corr'].update(
pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask
)
self.valid_metrics['ph_dur_acc'].update(
pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask
)
self.plot_dur(data_idx, gt_dur, pred_dur, tokens)
if pitch_preds is not None:
pitch_len = self.valid_dataset.metadata['pitch'][data_idx]
pred_pitch = sample_get('base_pitch', i, data_idx) + pitch_preds[i][:pitch_len].unsqueeze(0)
gt_pitch = sample_get('pitch', i, data_idx)
mask = (sample_get('mel2ph', i, data_idx) > 0) & ~sample_get('uv', i, data_idx)
self.valid_metrics['pitch_acc'].update(pred=pred_pitch, target=gt_pitch, mask=mask)
self.plot_pitch(
data_idx,
gt_pitch=gt_pitch,
pred_pitch=pred_pitch,
note_midi=sample_get('note_midi', i, data_idx),
note_dur=sample_get('note_dur', i, data_idx),
note_rest=sample_get('note_rest', i, data_idx)
)
for name in self.variance_prediction_list:
variance_len = self.valid_dataset.metadata[name][data_idx]
gt_variances = sample[name][i][:variance_len].unsqueeze(0)
pred_variances = variances_preds[name][i][:variance_len].unsqueeze(0)
self.plot_curve(
data_idx,
gt_curve=gt_variances,
pred_curve=pred_variances,
curve_name=name
)
return losses, sample['size']
############
# validation plots
############
def plot_dur(self, data_idx, gt_dur, pred_dur, txt=None):
gt_dur = gt_dur[0].cpu().numpy()
pred_dur = pred_dur[0].cpu().numpy()
txt = self.phone_encoder.decode(txt[0].cpu().numpy()).split()
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}"
self.logger.all_rank_experiment.add_figure(f'dur_{data_idx}', dur_to_figure(
gt_dur, pred_dur, txt, title_text
), self.global_step)
def plot_pitch(self, data_idx, gt_pitch, pred_pitch, note_midi, note_dur, note_rest):
gt_pitch = gt_pitch[0].cpu().numpy()
pred_pitch = pred_pitch[0].cpu().numpy()
note_midi = note_midi[0].cpu().numpy()
note_dur = note_dur[0].cpu().numpy()
note_rest = note_rest[0].cpu().numpy()
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}"
self.logger.all_rank_experiment.add_figure(f'pitch_{data_idx}', pitch_note_to_figure(
gt_pitch, pred_pitch, note_midi, note_dur, note_rest, title_text
), self.global_step)
def plot_curve(self, data_idx, gt_curve, pred_curve, base_curve=None, grid=None, curve_name='curve'):
gt_curve = gt_curve[0].cpu().numpy()
pred_curve = pred_curve[0].cpu().numpy()
if base_curve is not None:
base_curve = base_curve[0].cpu().numpy()
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}"
self.logger.all_rank_experiment.add_figure(f'{curve_name}_{data_idx}', curve_to_figure(
gt_curve, pred_curve, base_curve, grid=grid, title=title_text
), self.global_step)