Spaces:
Sleeping
Sleeping
from copy import deepcopy | |
import librosa | |
import numpy as np | |
import torch | |
from basics.base_augmentation import BaseAugmentation, require_same_keys | |
from basics.base_pe import BasePE | |
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST | |
from modules.fastspeech.tts_modules import LengthRegulator | |
from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch | |
from utils.hparams import hparams | |
from utils.infer_utils import resample_align_curve | |
class SpectrogramStretchAugmentation(BaseAugmentation): | |
""" | |
This class contains methods for frequency-domain and time-domain stretching augmentation. | |
""" | |
def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None): | |
super().__init__(data_dirs, augmentation_args) | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.lr = LengthRegulator().to(self.device) | |
self.pe = pe | |
def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict: | |
aug_item = deepcopy(item) | |
waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) | |
mel = get_mel_torch( | |
waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'], | |
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'], | |
fmin=hparams['fmin'], fmax=hparams['fmax'], | |
keyshift=key_shift, speed=speed, device=self.device | |
) | |
aug_item['mel'] = mel | |
if speed != 1. or hparams['use_speed_embed']: | |
aug_item['length'] = mel.shape[0] | |
aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed | |
aug_item['seconds'] /= aug_item['speed'] | |
aug_item['ph_dur'] /= aug_item['speed'] | |
aug_item['mel2ph'] = get_mel2ph_torch( | |
self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device | |
).cpu().numpy() | |
f0, _ = self.pe.get_pitch( | |
waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'], | |
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'], | |
speed=speed, interp_uv=True | |
) | |
aug_item['f0'] = f0.astype(np.float32) | |
# NOTE: variance curves are directly resampled according to speed, | |
# despite how frequency-domain features change after the augmentation. | |
# For acoustic models, this can bring more (but not much) difficulty | |
# to learn how variance curves affect the mel spectrograms, since | |
# they must realize how the augmentation causes the mismatch. | |
# | |
# This is a simple way to combine augmentation and variances. However, | |
# dealing variance curves like this will decrease the accuracy of | |
# variance controls. In most situations, not being ~100% accurate | |
# will not ruin the user experience. For example, it does not matter | |
# if the energy does not exactly equal the RMS; it is just fine | |
# as long as higher energy can bring higher loudness and strength. | |
# The neural networks itself cannot be 100% accurate, though. | |
# | |
# There are yet other choices to simulate variance curves: | |
# 1. Re-extract the features from resampled waveforms; | |
# 2. Re-extract the features from re-constructed waveforms using | |
# the transformed mel spectrograms through the vocoder. | |
# But there are actually no perfect ways to make them all accurate | |
# and stable. | |
for v_name in VARIANCE_CHECKLIST: | |
if v_name in item: | |
aug_item[v_name] = resample_align_curve( | |
aug_item[v_name], | |
original_timestep=self.timestep, | |
target_timestep=self.timestep * aug_item['speed'], | |
align_length=aug_item['length'] | |
) | |
if key_shift != 0. or hparams['use_key_shift_embed']: | |
if replace_spk_id is None: | |
aug_item['key_shift'] = key_shift | |
else: | |
aug_item['spk_id'] = replace_spk_id | |
aug_item['f0'] *= 2 ** (key_shift / 12) | |
return aug_item | |