File size: 4,462 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from copy import deepcopy

import librosa
import numpy as np
import torch

from basics.base_augmentation import BaseAugmentation, require_same_keys
from basics.base_pe import BasePE
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
from modules.fastspeech.tts_modules import LengthRegulator
from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch
from utils.hparams import hparams
from utils.infer_utils import resample_align_curve


class SpectrogramStretchAugmentation(BaseAugmentation):
    """
    This class contains methods for frequency-domain and time-domain stretching augmentation.
    """

    def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None):
        super().__init__(data_dirs, augmentation_args)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.lr = LengthRegulator().to(self.device)
        self.pe = pe

    @require_same_keys
    def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
        aug_item = deepcopy(item)
        waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
        mel = get_mel_torch(
            waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
            hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
            fmin=hparams['fmin'], fmax=hparams['fmax'],
            keyshift=key_shift, speed=speed, device=self.device
        )

        aug_item['mel'] = mel

        if speed != 1. or hparams['use_speed_embed']:
            aug_item['length'] = mel.shape[0]
            aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size']  # real speed
            aug_item['seconds'] /= aug_item['speed']
            aug_item['ph_dur'] /= aug_item['speed']
            aug_item['mel2ph'] = get_mel2ph_torch(
                self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device
            ).cpu().numpy()

            f0, _ = self.pe.get_pitch(
                waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
                hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
                speed=speed, interp_uv=True
            )
            aug_item['f0'] = f0.astype(np.float32)

            # NOTE: variance curves are directly resampled according to speed,
            # despite how frequency-domain features change after the augmentation.
            # For acoustic models, this can bring more (but not much) difficulty
            # to learn how variance curves affect the mel spectrograms, since
            # they must realize how the augmentation causes the mismatch.
            #
            # This is a simple way to combine augmentation and variances. However,
            # dealing variance curves like this will decrease the accuracy of
            # variance controls. In most situations, not being ~100% accurate
            # will not ruin the user experience. For example, it does not matter
            # if the energy does not exactly equal the RMS; it is just fine
            # as long as higher energy can bring higher loudness and strength.
            # The neural networks itself cannot be 100% accurate, though.
            #
            # There are yet other choices to simulate variance curves:
            #   1. Re-extract the features from resampled waveforms;
            #   2. Re-extract the features from re-constructed waveforms using
            #      the transformed mel spectrograms through the vocoder.
            # But there are actually no perfect ways to make them all accurate
            # and stable.
            for v_name in VARIANCE_CHECKLIST:
                if v_name in item:
                    aug_item[v_name] = resample_align_curve(
                        aug_item[v_name],
                        original_timestep=self.timestep,
                        target_timestep=self.timestep * aug_item['speed'],
                        align_length=aug_item['length']
                    )

        if key_shift != 0. or hparams['use_key_shift_embed']:
            if replace_spk_id is None:
                aug_item['key_shift'] = key_shift
            else:
                aug_item['spk_id'] = replace_spk_id
            aug_item['f0'] *= 2 ** (key_shift / 12)

        return aug_item