CantusSVS-hf / training /acoustic_task.py
liampond
Clean deploy snapshot
c42fe7e
raw
history blame
10.1 kB
import matplotlib
import torch
import torch.distributions
import torch.optim
import torch.utils.data
import utils
import utils.infer_utils
from basics.base_dataset import BaseDataset
from basics.base_task import BaseTask
from basics.base_vocoder import BaseVocoder
from modules.aux_decoder import build_aux_loss
from modules.losses import DiffusionLoss, RectifiedFlowLoss
from modules.toplevel import DiffSingerAcoustic, ShallowDiffusionOutput
from modules.vocoders.registry import get_vocoder_cls
from utils.hparams import hparams
from utils.plot import spec_to_figure
matplotlib.use('Agg')
class AcousticDataset(BaseDataset):
def __init__(self, prefix, preload=False):
super(AcousticDataset, self).__init__(prefix, hparams['dataset_size_key'], preload)
self.required_variances = {} # key: variance name, value: padding value
if hparams['use_energy_embed']:
self.required_variances['energy'] = 0.0
if hparams['use_breathiness_embed']:
self.required_variances['breathiness'] = 0.0
if hparams['use_voicing_embed']:
self.required_variances['voicing'] = 0.0
if hparams['use_tension_embed']:
self.required_variances['tension'] = 0.0
self.need_key_shift = hparams['use_key_shift_embed']
self.need_speed = hparams['use_speed_embed']
self.need_spk_id = hparams['use_spk_id']
def collater(self, samples):
batch = super().collater(samples)
if batch['size'] == 0:
return batch
tokens = utils.collate_nd([s['tokens'] for s in samples], 0)
f0 = utils.collate_nd([s['f0'] for s in samples], 0.0)
mel2ph = utils.collate_nd([s['mel2ph'] for s in samples], 0)
mel = utils.collate_nd([s['mel'] for s in samples], 0.0)
batch.update({
'tokens': tokens,
'mel2ph': mel2ph,
'mel': mel,
'f0': f0,
})
for v_name, v_pad in self.required_variances.items():
batch[v_name] = utils.collate_nd([s[v_name] for s in samples], v_pad)
if self.need_key_shift:
batch['key_shift'] = torch.FloatTensor([s['key_shift'] for s in samples])[:, None]
if self.need_speed:
batch['speed'] = torch.FloatTensor([s['speed'] for s in samples])[:, None]
if self.need_spk_id:
spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
batch['spk_ids'] = spk_ids
return batch
class AcousticTask(BaseTask):
def __init__(self):
super().__init__()
self.dataset_cls = AcousticDataset
self.diffusion_type = hparams['diffusion_type']
assert self.diffusion_type in ['ddpm', 'reflow'], f"Unknown diffusion type: {self.diffusion_type}"
self.use_shallow_diffusion = hparams['use_shallow_diffusion']
if self.use_shallow_diffusion:
self.shallow_args = hparams['shallow_diffusion_args']
self.train_aux_decoder = self.shallow_args['train_aux_decoder']
self.train_diffusion = self.shallow_args['train_diffusion']
self.use_vocoder = hparams['infer'] or hparams['val_with_vocoder']
if self.use_vocoder:
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
self.logged_gt_wav = set()
self.required_variances = []
if hparams['use_energy_embed']:
self.required_variances.append('energy')
if hparams['use_breathiness_embed']:
self.required_variances.append('breathiness')
if hparams['use_voicing_embed']:
self.required_variances.append('voicing')
if hparams['use_tension_embed']:
self.required_variances.append('tension')
super()._finish_init()
def _build_model(self):
return DiffSingerAcoustic(
vocab_size=len(self.phone_encoder),
out_dims=hparams['audio_num_mel_bins']
)
# noinspection PyAttributeOutsideInit
def build_losses_and_metrics(self):
if self.use_shallow_diffusion:
self.aux_mel_loss = build_aux_loss(self.shallow_args['aux_decoder_arch'])
self.lambda_aux_mel_loss = hparams['lambda_aux_mel_loss']
self.register_validation_loss('aux_mel_loss')
if self.diffusion_type == 'ddpm':
self.mel_loss = DiffusionLoss(loss_type=hparams['main_loss_type'])
elif self.diffusion_type == 'reflow':
self.mel_loss = RectifiedFlowLoss(
loss_type=hparams['main_loss_type'], log_norm=hparams['main_loss_log_norm']
)
else:
raise ValueError(f"Unknown diffusion type: {self.diffusion_type}")
self.register_validation_loss('mel_loss')
def run_model(self, sample, infer=False):
txt_tokens = sample['tokens'] # [B, T_ph]
target = sample['mel'] # [B, T_s, M]
mel2ph = sample['mel2ph'] # [B, T_s]
f0 = sample['f0']
variances = {
v_name: sample[v_name]
for v_name in self.required_variances
}
key_shift = sample.get('key_shift')
speed = sample.get('speed')
if hparams['use_spk_id']:
spk_embed_id = sample['spk_ids']
else:
spk_embed_id = None
output: ShallowDiffusionOutput = self.model(
txt_tokens, mel2ph=mel2ph, f0=f0, **variances,
key_shift=key_shift, speed=speed, spk_embed_id=spk_embed_id,
gt_mel=target, infer=infer
)
if infer:
return output
else:
losses = {}
if output.aux_out is not None:
aux_out = output.aux_out
norm_gt = self.model.aux_decoder.norm_spec(target)
aux_mel_loss = self.lambda_aux_mel_loss * self.aux_mel_loss(aux_out, norm_gt)
losses['aux_mel_loss'] = aux_mel_loss
non_padding = (mel2ph > 0).unsqueeze(-1).float()
if output.diff_out is not None:
if self.diffusion_type == 'ddpm':
x_recon, x_noise = output.diff_out
mel_loss = self.mel_loss(x_recon, x_noise, non_padding=non_padding)
elif self.diffusion_type == 'reflow':
v_pred, v_gt, t = output.diff_out
mel_loss = self.mel_loss(v_pred, v_gt, t=t, non_padding=non_padding)
else:
raise ValueError(f"Unknown diffusion type: {self.diffusion_type}")
losses['mel_loss'] = mel_loss
return losses
def on_train_start(self):
if self.use_vocoder and self.vocoder.get_device() != self.device:
self.vocoder.to_device(self.device)
def _on_validation_start(self):
if self.use_vocoder and self.vocoder.get_device() != self.device:
self.vocoder.to_device(self.device)
def _validation_step(self, sample, batch_idx):
losses = self.run_model(sample, infer=False)
if sample['size'] > 0 and min(sample['indices']) < hparams['num_valid_plots']:
mel_out: ShallowDiffusionOutput = self.run_model(sample, infer=True)
for i in range(len(sample['indices'])):
data_idx = sample['indices'][i].item()
if data_idx < hparams['num_valid_plots']:
if self.use_vocoder:
self.plot_wav(
data_idx, sample['mel'][i],
mel_out.aux_out[i] if mel_out.aux_out is not None else None,
mel_out.diff_out[i],
sample['f0'][i]
)
if mel_out.aux_out is not None:
self.plot_mel(data_idx, sample['mel'][i], mel_out.aux_out[i], 'auxmel')
if mel_out.diff_out is not None:
self.plot_mel(data_idx, sample['mel'][i], mel_out.diff_out[i], 'diffmel')
return losses, sample['size']
############
# validation plots
############
def plot_wav(self, data_idx, gt_mel, aux_mel, diff_mel, f0):
f0_len = self.valid_dataset.metadata['f0'][data_idx]
mel_len = self.valid_dataset.metadata['mel'][data_idx]
gt_mel = gt_mel[:mel_len].unsqueeze(0)
if aux_mel is not None:
aux_mel = aux_mel[:mel_len].unsqueeze(0)
if diff_mel is not None:
diff_mel = diff_mel[:mel_len].unsqueeze(0)
f0 = f0[:f0_len].unsqueeze(0)
if data_idx not in self.logged_gt_wav:
gt_wav = self.vocoder.spec2wav_torch(gt_mel, f0=f0)
self.logger.all_rank_experiment.add_audio(
f'gt_{data_idx}', gt_wav,
sample_rate=hparams['audio_sample_rate'],
global_step=self.global_step
)
self.logged_gt_wav.add(data_idx)
if aux_mel is not None:
aux_wav = self.vocoder.spec2wav_torch(aux_mel, f0=f0)
self.logger.all_rank_experiment.add_audio(
f'aux_{data_idx}', aux_wav,
sample_rate=hparams['audio_sample_rate'],
global_step=self.global_step
)
if diff_mel is not None:
diff_wav = self.vocoder.spec2wav_torch(diff_mel, f0=f0)
self.logger.all_rank_experiment.add_audio(
f'diff_{data_idx}', diff_wav,
sample_rate=hparams['audio_sample_rate'],
global_step=self.global_step
)
def plot_mel(self, data_idx, gt_spec, out_spec, name_prefix='mel'):
vmin = hparams['mel_vmin']
vmax = hparams['mel_vmax']
mel_len = self.valid_dataset.metadata['mel'][data_idx]
spec_cat = torch.cat([(out_spec - gt_spec).abs() + vmin, gt_spec, out_spec], -1)
title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}"
self.logger.all_rank_experiment.add_figure(f'{name_prefix}_{data_idx}', spec_to_figure(
spec_cat[:mel_len], vmin, vmax, title_text
), global_step=self.global_step)