File size: 14,265 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
    item: one piece of data
    item_name: data id
    wav_fn: wave file path
    spk: dataset name
    ph_seq: phoneme sequence
    ph_dur: phoneme durations
"""
import csv
import os
import pathlib
import random
from copy import deepcopy

import librosa
import numpy as np
import torch

from basics.base_binarizer import BaseBinarizer
from basics.base_pe import BasePE
from modules.fastspeech.tts_modules import LengthRegulator
from modules.pe import initialize_pe
from utils.binarizer_utils import (
    SinusoidalSmoothingConv1d,
    get_mel_torch,
    get_mel2ph_torch,
    get_energy_librosa,
    get_breathiness,
    get_voicing,
    get_tension_base_harmonic,
)
from utils.decomposed_waveform import DecomposedWaveform
from utils.hparams import hparams

os.environ["OMP_NUM_THREADS"] = "1"
ACOUSTIC_ITEM_ATTRIBUTES = [
    'spk_id',
    'mel',
    'tokens',
    'mel2ph',
    'f0',
    'energy',
    'breathiness',
    'voicing',
    'tension',
    'key_shift',
    'speed',
]

pitch_extractor: BasePE = None
energy_smooth: SinusoidalSmoothingConv1d = None
breathiness_smooth: SinusoidalSmoothingConv1d = None
voicing_smooth: SinusoidalSmoothingConv1d = None
tension_smooth: SinusoidalSmoothingConv1d = None


class AcousticBinarizer(BaseBinarizer):
    def __init__(self):
        super().__init__(data_attrs=ACOUSTIC_ITEM_ATTRIBUTES)
        self.lr = LengthRegulator()
        self.need_energy = hparams['use_energy_embed']
        self.need_breathiness = hparams['use_breathiness_embed']
        self.need_voicing = hparams['use_voicing_embed']
        self.need_tension = hparams['use_tension_embed']
        assert hparams['mel_base'] == 'e', (
            "Mel base must be set to \'e\' according to 2nd stage of the migration plan. "
            "See https://github.com/openvpi/DiffSinger/releases/tag/v2.3.0 for more details."
        )

    def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id):
        meta_data_dict = {}
        with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf-8') as f:
            for utterance_label in csv.DictReader(f):
                item_name = utterance_label['name']
                temp_dict = {
                    'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'),
                    'ph_seq': utterance_label['ph_seq'].split(),
                    'ph_dur': [float(x) for x in utterance_label['ph_dur'].split()],
                    'spk_id': spk_id,
                    'spk_name': self.speakers[ds_id],
                }
                assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \
                    f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.'
                assert all(ph_dur >= 0 for ph_dur in temp_dict['ph_dur']), \
                    f'Negative ph_dur found in \'{item_name}\'.'
                meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict

        self.items.update(meta_data_dict)

    @torch.no_grad()
    def process_item(self, item_name, meta_data, binarization_args):
        waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
        mel = get_mel_torch(
            waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
            hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
            fmin=hparams['fmin'], fmax=hparams['fmax'],
            device=self.device
        )
        length = mel.shape[0]
        seconds = length * hparams['hop_size'] / hparams['audio_sample_rate']
        processed_input = {
            'name': item_name,
            'wav_fn': meta_data['wav_fn'],
            'spk_id': meta_data['spk_id'],
            'spk_name': meta_data['spk_name'],
            'seconds': seconds,
            'length': length,
            'mel': mel,
            'tokens': np.array(self.phone_encoder.encode(meta_data['ph_seq']), dtype=np.int64),
            'ph_dur': np.array(meta_data['ph_dur']).astype(np.float32),
        }

        # get ground truth dur
        processed_input['mel2ph'] = get_mel2ph_torch(
            self.lr, torch.from_numpy(processed_input['ph_dur']), length, self.timestep, device=self.device
        ).cpu().numpy()

        # get ground truth f0
        global pitch_extractor
        if pitch_extractor is None:
            pitch_extractor = initialize_pe()
        gt_f0, uv = pitch_extractor.get_pitch(
            waveform, samplerate=hparams['audio_sample_rate'], length=length,
            hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
            interp_uv=True
        )
        if uv.all():  # All unvoiced
            print(f'Skipped \'{item_name}\': empty gt f0')
            return None
        processed_input['f0'] = gt_f0.astype(np.float32)

        if self.need_energy:
            # get ground truth energy
            energy = get_energy_librosa(
                waveform, length, hop_size=hparams['hop_size'], win_size=hparams['win_size']
            ).astype(np.float32)

            global energy_smooth
            if energy_smooth is None:
                energy_smooth = SinusoidalSmoothingConv1d(
                    round(hparams['energy_smooth_width'] / self.timestep)
                ).eval().to(self.device)
            energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0]

            processed_input['energy'] = energy.cpu().numpy()

        # create a DecomposedWaveform object for further feature extraction
        dec_waveform = DecomposedWaveform(
            waveform, samplerate=hparams['audio_sample_rate'], f0=gt_f0 * ~uv,
            hop_size=hparams['hop_size'], fft_size=hparams['fft_size'], win_size=hparams['win_size'],
            algorithm=hparams['hnsep']
        )

        if self.need_breathiness:
            # get ground truth breathiness
            breathiness = get_breathiness(
                dec_waveform, None, None, length=length
            )

            global breathiness_smooth
            if breathiness_smooth is None:
                breathiness_smooth = SinusoidalSmoothingConv1d(
                    round(hparams['breathiness_smooth_width'] / self.timestep)
                ).eval().to(self.device)
            breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0]

            processed_input['breathiness'] = breathiness.cpu().numpy()

        if self.need_voicing:
            # get ground truth voicing
            voicing = get_voicing(
                dec_waveform, None, None, length=length
            )

            global voicing_smooth
            if voicing_smooth is None:
                voicing_smooth = SinusoidalSmoothingConv1d(
                    round(hparams['voicing_smooth_width'] / self.timestep)
                ).eval().to(self.device)
            voicing = voicing_smooth(torch.from_numpy(voicing).to(self.device)[None])[0]

            processed_input['voicing'] = voicing.cpu().numpy()

        if self.need_tension:
            # get ground truth tension
            tension = get_tension_base_harmonic(
                dec_waveform, None, None, length=length, domain='logit'
            )

            global tension_smooth
            if tension_smooth is None:
                tension_smooth = SinusoidalSmoothingConv1d(
                    round(hparams['tension_smooth_width'] / self.timestep)
                ).eval().to(self.device)
            tension = tension_smooth(torch.from_numpy(tension).to(self.device)[None])[0]
            if tension.isnan().any():
                print('Error:', item_name)
                print(tension)
                return None

            processed_input['tension'] = tension.cpu().numpy()

        if hparams['use_key_shift_embed']:
            processed_input['key_shift'] = 0.

        if hparams['use_speed_embed']:
            processed_input['speed'] = 1.

        return processed_input

    def arrange_data_augmentation(self, data_iterator):
        aug_map = {}
        aug_list = []
        all_item_names = [item_name for item_name, _ in data_iterator]
        total_scale = 0
        aug_pe = initialize_pe()
        if self.augmentation_args['random_pitch_shifting']['enabled']:
            from augmentation.spec_stretch import SpectrogramStretchAugmentation
            aug_args = self.augmentation_args['random_pitch_shifting']
            key_shift_min, key_shift_max = aug_args['range']
            assert hparams['use_key_shift_embed'], \
                'Random pitch shifting augmentation requires use_key_shift_embed == True.'
            assert key_shift_min < 0 < key_shift_max, \
                'Random pitch shifting augmentation must have a range where min < 0 < max.'

            aug_ins = SpectrogramStretchAugmentation(self.raw_data_dirs, aug_args, pe=aug_pe)
            scale = aug_args['scale']
            aug_item_names = random.choices(all_item_names, k=int(scale * len(all_item_names)))

            for aug_item_name in aug_item_names:
                rand = random.uniform(-1, 1)
                if rand < 0:
                    key_shift = key_shift_min * abs(rand)
                else:
                    key_shift = key_shift_max * rand
                aug_task = {
                    'name': aug_item_name,
                    'func': aug_ins.process_item,
                    'kwargs': {'key_shift': key_shift}
                }
                if aug_item_name in aug_map:
                    aug_map[aug_item_name].append(aug_task)
                else:
                    aug_map[aug_item_name] = [aug_task]
                aug_list.append(aug_task)

            total_scale += scale

        if self.augmentation_args['fixed_pitch_shifting']['enabled']:
            from augmentation.spec_stretch import SpectrogramStretchAugmentation
            aug_args = self.augmentation_args['fixed_pitch_shifting']
            targets = aug_args['targets']
            scale = aug_args['scale']
            spk_id_size = max(self.spk_ids) + 1
            min_num_spk = (1 + len(targets)) * spk_id_size
            assert not self.augmentation_args['random_pitch_shifting']['enabled'], \
                'Fixed pitch shifting augmentation is not compatible with random pitch shifting.'
            assert len(targets) == len(set(targets)), \
                'Fixed pitch shifting augmentation requires having no duplicate targets.'
            assert hparams['use_spk_id'], 'Fixed pitch shifting augmentation requires use_spk_id == True.'
            assert hparams['num_spk'] >= min_num_spk, \
                f'Fixed pitch shifting augmentation requires num_spk >= (1 + len(targets)) * (max(spk_ids) + 1).'
            assert scale < 1, 'Fixed pitch shifting augmentation requires scale < 1.'

            aug_ins = SpectrogramStretchAugmentation(self.raw_data_dirs, aug_args, pe=aug_pe)
            for i, target in enumerate(targets):
                aug_item_names = random.choices(all_item_names, k=int(scale * len(all_item_names)))
                for aug_item_name in aug_item_names:
                    replace_spk_id = self.spk_ids[int(aug_item_name.split(':', maxsplit=1)[0])] + (i + 1) * spk_id_size
                    aug_task = {
                        'name': aug_item_name,
                        'func': aug_ins.process_item,
                        'kwargs': {'key_shift': target, 'replace_spk_id': replace_spk_id}
                    }
                    if aug_item_name in aug_map:
                        aug_map[aug_item_name].append(aug_task)
                    else:
                        aug_map[aug_item_name] = [aug_task]
                    aug_list.append(aug_task)

            total_scale += scale * len(targets)

        if self.augmentation_args['random_time_stretching']['enabled']:
            from augmentation.spec_stretch import SpectrogramStretchAugmentation
            aug_args = self.augmentation_args['random_time_stretching']
            speed_min, speed_max = aug_args['range']
            assert hparams['use_speed_embed'], \
                'Random time stretching augmentation requires use_speed_embed == True.'
            assert 0 < speed_min < 1 < speed_max, \
                'Random time stretching augmentation must have a range where 0 < min < 1 < max.'

            aug_ins = SpectrogramStretchAugmentation(self.raw_data_dirs, aug_args, pe=aug_pe)
            scale = aug_args['scale']
            k_from_raw = int(scale / (1 + total_scale) * len(all_item_names))
            k_from_aug = int(total_scale * scale / (1 + total_scale) * len(all_item_names))
            k_mutate = int(total_scale * scale / (1 + scale) * len(all_item_names))
            aug_types = [0] * k_from_raw + [1] * k_from_aug + [2] * k_mutate
            aug_items = random.choices(all_item_names, k=k_from_raw) + random.choices(aug_list, k=k_from_aug + k_mutate)

            for aug_type, aug_item in zip(aug_types, aug_items):
                # Uniform distribution in log domain
                speed = speed_min * (speed_max / speed_min) ** random.random()
                if aug_type == 0:
                    aug_task = {
                        'name': aug_item,
                        'func': aug_ins.process_item,
                        'kwargs': {'speed': speed}
                    }
                    if aug_item in aug_map:
                        aug_map[aug_item].append(aug_task)
                    else:
                        aug_map[aug_item] = [aug_task]
                    aug_list.append(aug_task)
                elif aug_type == 1:
                    aug_task = {
                        'name': aug_item,
                        'func': aug_item['func'],
                        'kwargs': deepcopy(aug_item['kwargs'])
                    }
                    aug_task['kwargs']['speed'] = speed
                    if aug_item['name'] in aug_map:
                        aug_map[aug_item['name']].append(aug_task)
                    else:
                        aug_map[aug_item['name']] = [aug_task]
                    aug_list.append(aug_task)
                elif aug_type == 2:
                    aug_item['kwargs']['speed'] = speed

            total_scale += scale

        return aug_map