Spaces:
Sleeping
Sleeping
import csv | |
import json | |
import os | |
import pathlib | |
import librosa | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from scipy import interpolate | |
from basics.base_binarizer import BaseBinarizer, BinarizationError | |
from basics.base_pe import BasePE | |
from modules.fastspeech.tts_modules import LengthRegulator | |
from modules.pe import initialize_pe | |
from utils.binarizer_utils import ( | |
SinusoidalSmoothingConv1d, | |
get_mel2ph_torch, | |
get_energy_librosa, | |
get_breathiness, | |
get_voicing, | |
get_tension_base_harmonic, | |
) | |
from utils.decomposed_waveform import DecomposedWaveform | |
from utils.hparams import hparams | |
from utils.infer_utils import resample_align_curve | |
from utils.pitch_utils import interp_f0 | |
from utils.plot import distribution_to_figure | |
os.environ["OMP_NUM_THREADS"] = "1" | |
VARIANCE_ITEM_ATTRIBUTES = [ | |
'spk_id', # index number of dataset/speaker, int64 | |
'tokens', # index numbers of phonemes, int64[T_ph,] | |
'ph_dur', # durations of phonemes, in number of frames, int64[T_ph,] | |
'midi', # phoneme-level mean MIDI pitch, int64[T_ph,] | |
'ph2word', # similar to mel2ph format, representing number of phones within each note, int64[T_ph,] | |
'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,] | |
'note_midi', # note-level MIDI pitch, float32[T_n,] | |
'note_rest', # flags for rest notes, bool[T_n,] | |
'note_dur', # durations of notes, in number of frames, int64[T_n,] | |
'note_glide', # flags for glides, 0 = none, 1 = up, 2 = down, int64[T_n,] | |
'mel2note', # mel2ph format representing number of frames within each note, int64[T_s,] | |
'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,] | |
'pitch', # actual pitch in semitones, float32[T_s,] | |
'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,] | |
'energy', # frame-level RMS (dB), float32[T_s,] | |
'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,] | |
'voicing', # frame-level RMS of harmonic parts (dB), float32[T_s,] | |
'tension', # frame-level tension (logit), float32[T_s,] | |
] | |
DS_INDEX_SEP = '#' | |
# These operators are used as global variables due to a PyTorch shared memory bug on Windows platforms. | |
# See https://github.com/pytorch/pytorch/issues/100358 | |
pitch_extractor: BasePE = None | |
midi_smooth: SinusoidalSmoothingConv1d = None | |
energy_smooth: SinusoidalSmoothingConv1d = None | |
breathiness_smooth: SinusoidalSmoothingConv1d = None | |
voicing_smooth: SinusoidalSmoothingConv1d = None | |
tension_smooth: SinusoidalSmoothingConv1d = None | |
class VarianceBinarizer(BaseBinarizer): | |
def __init__(self): | |
super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES) | |
self.use_glide_embed = hparams['use_glide_embed'] | |
glide_types = hparams['glide_types'] | |
assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' | |
self.glide_map = { | |
'none': 0, | |
**{ | |
typename: idx + 1 | |
for idx, typename in enumerate(glide_types) | |
} | |
} | |
predict_energy = hparams['predict_energy'] | |
predict_breathiness = hparams['predict_breathiness'] | |
predict_voicing = hparams['predict_voicing'] | |
predict_tension = hparams['predict_tension'] | |
self.predict_variances = predict_energy or predict_breathiness or predict_voicing or predict_tension | |
self.lr = LengthRegulator().to(self.device) | |
self.prefer_ds = self.binarization_args['prefer_ds'] | |
self.cached_ds = {} | |
def load_attr_from_ds(self, ds_id, name, attr, idx=0): | |
item_name = f'{ds_id}:{name}' | |
item_name_with_idx = f'{item_name}{DS_INDEX_SEP}{idx}' | |
if item_name_with_idx in self.cached_ds: | |
ds = self.cached_ds[item_name_with_idx][0] | |
elif item_name in self.cached_ds: | |
ds = self.cached_ds[item_name][idx] | |
else: | |
ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{DS_INDEX_SEP}{idx}.ds' | |
if ds_path.exists(): | |
cache_key = item_name_with_idx | |
else: | |
ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' | |
cache_key = item_name | |
if not ds_path.exists(): | |
return None | |
with open(ds_path, 'r', encoding='utf8') as f: | |
ds = json.load(f) | |
if not isinstance(ds, list): | |
ds = [ds] | |
self.cached_ds[cache_key] = ds | |
ds = ds[idx] | |
return ds.get(attr) | |
def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): | |
meta_data_dict = {} | |
with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f: | |
for utterance_label in csv.DictReader(f): | |
utterance_label: dict | |
item_name = utterance_label['name'] | |
item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0 | |
def require(attr, optional=False): | |
if self.prefer_ds: | |
value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) | |
else: | |
value = None | |
if value is None: | |
value = utterance_label.get(attr) | |
if value is None and not optional: | |
raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.') | |
return value | |
temp_dict = { | |
'ds_idx': item_idx, | |
'spk_id': spk_id, | |
'spk_name': self.speakers[ds_id], | |
'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), | |
'ph_seq': require('ph_seq').split(), | |
'ph_dur': [float(x) for x in require('ph_dur').split()] | |
} | |
assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ | |
f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' | |
assert all(ph_dur >= 0 for ph_dur in temp_dict['ph_dur']), \ | |
f'Negative ph_dur found in \'{item_name}\'.' | |
if hparams['predict_dur']: | |
temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()] | |
assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \ | |
f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.' | |
if hparams['predict_pitch']: | |
temp_dict['note_seq'] = require('note_seq').split() | |
temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()] | |
assert all(note_dur >= 0 for note_dur in temp_dict['note_dur']), \ | |
f'Negative note_dur found in \'{item_name}\'.' | |
assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \ | |
f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' | |
assert any([note != 'rest' for note in temp_dict['note_seq']]), \ | |
f'All notes are rest in \'{item_name}\'.' | |
if hparams['use_glide_embed']: | |
note_glide = require('note_glide', optional=True) | |
if note_glide is None: | |
note_glide = ['none' for _ in temp_dict['note_seq']] | |
else: | |
note_glide = note_glide.split() | |
assert len(note_glide) == len(temp_dict['note_seq']), \ | |
f'Lengths of note_seq and note_glide mismatch in \'{item_name}\'.' | |
assert all(g in self.glide_map for g in note_glide), \ | |
f'Invalid glide type found in \'{item_name}\'.' | |
temp_dict['note_glide'] = note_glide | |
meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict | |
self.items.update(meta_data_dict) | |
def check_coverage(self): | |
super().check_coverage() | |
if not hparams['predict_pitch']: | |
return | |
# MIDI pitch distribution summary | |
midi_map = {} | |
for item_name in self.items: | |
for midi in self.items[item_name]['note_seq']: | |
if midi == 'rest': | |
continue | |
midi = librosa.note_to_midi(midi, round_midi=True) | |
if midi in midi_map: | |
midi_map[midi] += 1 | |
else: | |
midi_map[midi] = 1 | |
print('===== MIDI Pitch Distribution Summary =====') | |
for i, key in enumerate(sorted(midi_map.keys())): | |
if i == len(midi_map) - 1: | |
end = '\n' | |
elif i % 10 == 9: | |
end = ',\n' | |
else: | |
end = ', ' | |
print(f'\'{librosa.midi_to_note(key, unicode=False)}\': {midi_map[key]}', end=end) | |
# Draw graph. | |
midis = sorted(midi_map.keys()) | |
notes = [librosa.midi_to_note(m, unicode=False) for m in range(midis[0], midis[-1] + 1)] | |
plt = distribution_to_figure( | |
title='MIDI Pitch Distribution Summary', | |
x_label='MIDI Key', y_label='Number of occurrences', | |
items=notes, values=[midi_map.get(m, 0) for m in range(midis[0], midis[-1] + 1)] | |
) | |
filename = self.binary_data_dir / 'midi_distribution.jpg' | |
plt.savefig(fname=filename, | |
bbox_inches='tight', | |
pad_inches=0.25) | |
print(f'| save summary to \'{filename}\'') | |
if self.use_glide_embed: | |
# Glide type distribution summary | |
glide_count = { | |
g: 0 | |
for g in self.glide_map | |
} | |
for item_name in self.items: | |
for glide in self.items[item_name]['note_glide']: | |
if glide == 'none' or glide not in self.glide_map: | |
glide_count['none'] += 1 | |
else: | |
glide_count[glide] += 1 | |
print('===== Glide Type Distribution Summary =====') | |
for i, key in enumerate(sorted(glide_count.keys(), key=lambda k: self.glide_map[k])): | |
if i == len(glide_count) - 1: | |
end = '\n' | |
elif i % 10 == 9: | |
end = ',\n' | |
else: | |
end = ', ' | |
print(f'\'{key}\': {glide_count[key]}', end=end) | |
if any(n == 0 for _, n in glide_count.items()): | |
raise BinarizationError( | |
f'Missing glide types in dataset: ' | |
f'{sorted([g for g, n in glide_count.items() if n == 0], key=lambda k: self.glide_map[k])}' | |
) | |
def process_item(self, item_name, meta_data, binarization_args): | |
ds_id, name = item_name.split(':', maxsplit=1) | |
name = name.rsplit(DS_INDEX_SEP, maxsplit=1)[0] | |
ds_id = int(ds_id) | |
ds_seg_idx = meta_data['ds_idx'] | |
seconds = sum(meta_data['ph_dur']) | |
length = round(seconds / self.timestep) | |
T_ph = len(meta_data['ph_seq']) | |
processed_input = { | |
'name': item_name, | |
'wav_fn': meta_data['wav_fn'], | |
'spk_id': meta_data['spk_id'], | |
'spk_name': meta_data['spk_name'], | |
'seconds': seconds, | |
'length': length, | |
'tokens': np.array(self.phone_encoder.encode(meta_data['ph_seq']), dtype=np.int64) | |
} | |
ph_dur_sec = torch.FloatTensor(meta_data['ph_dur']).to(self.device) | |
ph_acc = torch.round(torch.cumsum(ph_dur_sec, dim=0) / self.timestep + 0.5).long() | |
ph_dur = torch.diff(ph_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) | |
processed_input['ph_dur'] = ph_dur.cpu().numpy() | |
mel2ph = get_mel2ph_torch( | |
self.lr, ph_dur_sec, length, self.timestep, device=self.device | |
) | |
if hparams['predict_pitch'] or self.predict_variances: | |
processed_input['mel2ph'] = mel2ph.cpu().numpy() | |
# Below: extract actual f0, convert to pitch and calculate delta pitch | |
if pathlib.Path(meta_data['wav_fn']).exists(): | |
waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) | |
elif not self.prefer_ds: | |
raise FileNotFoundError(meta_data['wav_fn']) | |
else: | |
waveform = None | |
global pitch_extractor | |
if pitch_extractor is None: | |
pitch_extractor = initialize_pe() | |
f0 = uv = None | |
if self.prefer_ds: | |
f0_seq = self.load_attr_from_ds(ds_id, name, 'f0_seq', idx=ds_seg_idx) | |
if f0_seq is not None: | |
f0 = resample_align_curve( | |
np.array(f0_seq.split(), np.float32), | |
original_timestep=float(self.load_attr_from_ds(ds_id, name, 'f0_timestep', idx=ds_seg_idx)), | |
target_timestep=self.timestep, | |
align_length=length | |
) | |
uv = f0 == 0 | |
f0, _ = interp_f0(f0, uv) | |
if f0 is None: | |
f0, uv = pitch_extractor.get_pitch( | |
waveform, samplerate=hparams['audio_sample_rate'], length=length, | |
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'], | |
interp_uv=True | |
) | |
if uv.all(): # All unvoiced | |
print(f'Skipped \'{item_name}\': empty gt f0') | |
return None | |
pitch = torch.from_numpy(librosa.hz_to_midi(f0.astype(np.float32))).to(self.device) | |
if hparams['predict_dur']: | |
ph_num = torch.LongTensor(meta_data['ph_num']).to(self.device) | |
ph2word = self.lr(ph_num[None])[0] | |
processed_input['ph2word'] = ph2word.cpu().numpy() | |
mel2dur = torch.gather(F.pad(ph_dur, [1, 0], value=1), 0, mel2ph) # frame-level phone duration | |
ph_midi = pitch.new_zeros(T_ph + 1).scatter_add( | |
0, mel2ph, pitch / mel2dur | |
)[1:] | |
processed_input['midi'] = ph_midi.round().long().clamp(min=0, max=127).cpu().numpy() | |
if hparams['predict_pitch']: | |
# Below: get note sequence and interpolate rest notes | |
note_midi = np.array( | |
[(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']], | |
dtype=np.float32 | |
) | |
note_rest = note_midi < 0 | |
interp_func = interpolate.interp1d( | |
np.where(~note_rest)[0], note_midi[~note_rest], | |
kind='nearest', fill_value='extrapolate' | |
) | |
note_midi[note_rest] = interp_func(np.where(note_rest)[0]) | |
processed_input['note_midi'] = note_midi | |
processed_input['note_rest'] = note_rest | |
note_midi = torch.from_numpy(note_midi).to(self.device) | |
note_dur_sec = torch.FloatTensor(meta_data['note_dur']).to(self.device) | |
note_acc = torch.round(torch.cumsum(note_dur_sec, dim=0) / self.timestep + 0.5).long() | |
note_dur = torch.diff(note_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) | |
processed_input['note_dur'] = note_dur.cpu().numpy() | |
mel2note = get_mel2ph_torch( | |
self.lr, note_dur_sec, mel2ph.shape[0], self.timestep, device=self.device | |
) | |
processed_input['mel2note'] = mel2note.cpu().numpy() | |
# Below: get ornament attributes | |
if hparams['use_glide_embed']: | |
processed_input['note_glide'] = np.array([ | |
self.glide_map.get(x, 0) for x in meta_data['note_glide'] | |
], dtype=np.int64) | |
# Below: | |
# 1. Get the frame-level MIDI pitch, which is a step function curve | |
# 2. smoothen the pitch step curve as the base pitch curve | |
frame_midi_pitch = torch.gather(F.pad(note_midi, [1, 0], value=0), 0, mel2note) | |
global midi_smooth | |
if midi_smooth is None: | |
midi_smooth = SinusoidalSmoothingConv1d( | |
round(hparams['midi_smooth_width'] / self.timestep) | |
).eval().to(self.device) | |
smoothed_midi_pitch = midi_smooth(frame_midi_pitch[None])[0] | |
processed_input['base_pitch'] = smoothed_midi_pitch.cpu().numpy() | |
if hparams['predict_pitch'] or self.predict_variances: | |
processed_input['pitch'] = pitch.cpu().numpy() | |
processed_input['uv'] = uv | |
# Below: extract energy | |
if hparams['predict_energy']: | |
energy = None | |
energy_from_wav = False | |
if self.prefer_ds: | |
energy_seq = self.load_attr_from_ds(ds_id, name, 'energy', idx=ds_seg_idx) | |
if energy_seq is not None: | |
energy = resample_align_curve( | |
np.array(energy_seq.split(), np.float32), | |
original_timestep=float(self.load_attr_from_ds( | |
ds_id, name, 'energy_timestep', idx=ds_seg_idx | |
)), | |
target_timestep=self.timestep, | |
align_length=length | |
) | |
if energy is None: | |
energy = get_energy_librosa( | |
waveform, length, | |
hop_size=hparams['hop_size'], win_size=hparams['win_size'] | |
).astype(np.float32) | |
energy_from_wav = True | |
if energy_from_wav: | |
global energy_smooth | |
if energy_smooth is None: | |
energy_smooth = SinusoidalSmoothingConv1d( | |
round(hparams['energy_smooth_width'] / self.timestep) | |
).eval().to(self.device) | |
energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0].cpu().numpy() | |
processed_input['energy'] = energy | |
# create a DecomposedWaveform object for further feature extraction | |
dec_waveform = DecomposedWaveform( | |
waveform, samplerate=hparams['audio_sample_rate'], f0=f0 * ~uv, | |
hop_size=hparams['hop_size'], fft_size=hparams['fft_size'], win_size=hparams['win_size'], | |
algorithm=hparams['hnsep'] | |
) if waveform is not None else None | |
# Below: extract breathiness | |
if hparams['predict_breathiness']: | |
breathiness = None | |
breathiness_from_wav = False | |
if self.prefer_ds: | |
breathiness_seq = self.load_attr_from_ds(ds_id, name, 'breathiness', idx=ds_seg_idx) | |
if breathiness_seq is not None: | |
breathiness = resample_align_curve( | |
np.array(breathiness_seq.split(), np.float32), | |
original_timestep=float(self.load_attr_from_ds( | |
ds_id, name, 'breathiness_timestep', idx=ds_seg_idx | |
)), | |
target_timestep=self.timestep, | |
align_length=length | |
) | |
if breathiness is None: | |
breathiness = get_breathiness( | |
dec_waveform, None, None, length=length | |
) | |
breathiness_from_wav = True | |
if breathiness_from_wav: | |
global breathiness_smooth | |
if breathiness_smooth is None: | |
breathiness_smooth = SinusoidalSmoothingConv1d( | |
round(hparams['breathiness_smooth_width'] / self.timestep) | |
).eval().to(self.device) | |
breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0].cpu().numpy() | |
processed_input['breathiness'] = breathiness | |
# Below: extract voicing | |
if hparams['predict_voicing']: | |
voicing = None | |
voicing_from_wav = False | |
if self.prefer_ds: | |
voicing_seq = self.load_attr_from_ds(ds_id, name, 'voicing', idx=ds_seg_idx) | |
if voicing_seq is not None: | |
voicing = resample_align_curve( | |
np.array(voicing_seq.split(), np.float32), | |
original_timestep=float(self.load_attr_from_ds( | |
ds_id, name, 'voicing_timestep', idx=ds_seg_idx | |
)), | |
target_timestep=self.timestep, | |
align_length=length | |
) | |
if voicing is None: | |
voicing = get_voicing( | |
dec_waveform, None, None, length=length | |
) | |
voicing_from_wav = True | |
if voicing_from_wav: | |
global voicing_smooth | |
if voicing_smooth is None: | |
voicing_smooth = SinusoidalSmoothingConv1d( | |
round(hparams['voicing_smooth_width'] / self.timestep) | |
).eval().to(self.device) | |
voicing = voicing_smooth(torch.from_numpy(voicing).to(self.device)[None])[0].cpu().numpy() | |
processed_input['voicing'] = voicing | |
# Below: extract tension | |
if hparams['predict_tension']: | |
tension = None | |
tension_from_wav = False | |
if self.prefer_ds: | |
tension_seq = self.load_attr_from_ds(ds_id, name, 'tension', idx=ds_seg_idx) | |
if tension_seq is not None: | |
tension = resample_align_curve( | |
np.array(tension_seq.split(), np.float32), | |
original_timestep=float(self.load_attr_from_ds( | |
ds_id, name, 'tension_timestep', idx=ds_seg_idx | |
)), | |
target_timestep=self.timestep, | |
align_length=length | |
) | |
if tension is None: | |
tension = get_tension_base_harmonic( | |
dec_waveform, None, None, length=length, domain='logit' | |
) | |
tension_from_wav = True | |
if tension_from_wav: | |
global tension_smooth | |
if tension_smooth is None: | |
tension_smooth = SinusoidalSmoothingConv1d( | |
round(hparams['tension_smooth_width'] / self.timestep) | |
).eval().to(self.device) | |
tension = tension_smooth(torch.from_numpy(tension).to(self.device)[None])[0].cpu().numpy() | |
processed_input['tension'] = tension | |
return processed_input | |
def arrange_data_augmentation(self, data_iterator): | |
return {} | |