CantusSVS-hf / augmentation /spec_stretch.py
liampond
Clean deploy snapshot
c42fe7e
raw
history blame
4.46 kB
from copy import deepcopy
import librosa
import numpy as np
import torch
from basics.base_augmentation import BaseAugmentation, require_same_keys
from basics.base_pe import BasePE
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
from modules.fastspeech.tts_modules import LengthRegulator
from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch
from utils.hparams import hparams
from utils.infer_utils import resample_align_curve
class SpectrogramStretchAugmentation(BaseAugmentation):
"""
This class contains methods for frequency-domain and time-domain stretching augmentation.
"""
def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None):
super().__init__(data_dirs, augmentation_args)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.lr = LengthRegulator().to(self.device)
self.pe = pe
@require_same_keys
def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
aug_item = deepcopy(item)
waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
mel = get_mel_torch(
waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
fmin=hparams['fmin'], fmax=hparams['fmax'],
keyshift=key_shift, speed=speed, device=self.device
)
aug_item['mel'] = mel
if speed != 1. or hparams['use_speed_embed']:
aug_item['length'] = mel.shape[0]
aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed
aug_item['seconds'] /= aug_item['speed']
aug_item['ph_dur'] /= aug_item['speed']
aug_item['mel2ph'] = get_mel2ph_torch(
self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device
).cpu().numpy()
f0, _ = self.pe.get_pitch(
waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
speed=speed, interp_uv=True
)
aug_item['f0'] = f0.astype(np.float32)
# NOTE: variance curves are directly resampled according to speed,
# despite how frequency-domain features change after the augmentation.
# For acoustic models, this can bring more (but not much) difficulty
# to learn how variance curves affect the mel spectrograms, since
# they must realize how the augmentation causes the mismatch.
#
# This is a simple way to combine augmentation and variances. However,
# dealing variance curves like this will decrease the accuracy of
# variance controls. In most situations, not being ~100% accurate
# will not ruin the user experience. For example, it does not matter
# if the energy does not exactly equal the RMS; it is just fine
# as long as higher energy can bring higher loudness and strength.
# The neural networks itself cannot be 100% accurate, though.
#
# There are yet other choices to simulate variance curves:
# 1. Re-extract the features from resampled waveforms;
# 2. Re-extract the features from re-constructed waveforms using
# the transformed mel spectrograms through the vocoder.
# But there are actually no perfect ways to make them all accurate
# and stable.
for v_name in VARIANCE_CHECKLIST:
if v_name in item:
aug_item[v_name] = resample_align_curve(
aug_item[v_name],
original_timestep=self.timestep,
target_timestep=self.timestep * aug_item['speed'],
align_length=aug_item['length']
)
if key_shift != 0. or hparams['use_key_shift_embed']:
if replace_spk_id is None:
aug_item['key_shift'] = key_shift
else:
aug_item['spk_id'] = replace_spk_id
aug_item['f0'] *= 2 ** (key_shift / 12)
return aug_item