Spaces:
Sleeping
Sleeping
import copy | |
import json | |
import tqdm | |
import pathlib | |
from collections import OrderedDict | |
import librosa | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from scipy import interpolate | |
from typing import List, Tuple | |
from basics.base_svs_infer import BaseSVSInfer | |
from modules.fastspeech.tts_modules import ( | |
LengthRegulator, RhythmRegulator, | |
mel2ph_to_dur | |
) | |
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST | |
from modules.toplevel import DiffSingerVariance | |
from utils import load_ckpt | |
from utils.hparams import hparams | |
from utils.infer_utils import resample_align_curve | |
from utils.phoneme_utils import build_phoneme_list | |
from utils.pitch_utils import interp_f0 | |
from utils.text_encoder import TokenTextEncoder | |
class DiffSingerVarianceInfer(BaseSVSInfer): | |
def __init__( | |
self, device=None, ckpt_steps=None, | |
predictions: set = None | |
): | |
super().__init__(device=device) | |
self.ph_encoder = TokenTextEncoder(vocab_list=build_phoneme_list()) | |
if hparams['use_spk_id']: | |
with open(pathlib.Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f: | |
self.spk_map = json.load(f) | |
assert isinstance(self.spk_map, dict) and len(self.spk_map) > 0, 'Invalid or empty speaker map!' | |
assert len(self.spk_map) == len(set(self.spk_map.values())), 'Duplicate speaker id in speaker map!' | |
self.model: DiffSingerVariance = self.build_model(ckpt_steps=ckpt_steps) | |
self.lr = LengthRegulator() | |
self.rr = RhythmRegulator() | |
smooth_kernel_size = round(hparams['midi_smooth_width'] / self.timestep) | |
self.smooth = nn.Conv1d( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=smooth_kernel_size, | |
bias=False, | |
padding='same', | |
padding_mode='replicate' | |
).eval().to(self.device) | |
smooth_kernel = torch.sin(torch.from_numpy( | |
np.linspace(0, 1, smooth_kernel_size).astype(np.float32) * np.pi | |
).to(self.device)) | |
smooth_kernel /= smooth_kernel.sum() | |
self.smooth.weight.data = smooth_kernel[None, None] | |
glide_types = hparams.get('glide_types', []) | |
assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' | |
self.glide_map = { | |
'none': 0, | |
**{ | |
typename: idx + 1 | |
for idx, typename in enumerate(glide_types) | |
} | |
} | |
self.auto_completion_mode = len(predictions) == 0 | |
self.global_predict_dur = 'dur' in predictions and hparams['predict_dur'] | |
self.global_predict_pitch = 'pitch' in predictions and hparams['predict_pitch'] | |
self.variance_prediction_set = predictions.intersection(VARIANCE_CHECKLIST) | |
self.global_predict_variances = len(self.variance_prediction_set) > 0 | |
def build_model(self, ckpt_steps=None): | |
model = DiffSingerVariance( | |
vocab_size=len(self.ph_encoder) | |
).eval().to(self.device) | |
load_ckpt(model, hparams['work_dir'], ckpt_steps=ckpt_steps, | |
prefix_in_ckpt='model', strict=True, device=self.device) | |
return model | |
def preprocess_input( | |
self, param, idx=0, | |
load_dur: bool = False, | |
load_pitch: bool = False | |
): | |
""" | |
:param param: one segment in the .ds file | |
:param idx: index of the segment | |
:param load_dur: whether ph_dur is loaded | |
:param load_pitch: whether pitch is loaded | |
:return: batch of the model inputs | |
""" | |
batch = {} | |
summary = OrderedDict() | |
txt_tokens = torch.LongTensor([self.ph_encoder.encode(param['ph_seq'].split())]).to(self.device) # [B=1, T_ph] | |
T_ph = txt_tokens.shape[1] | |
batch['tokens'] = txt_tokens | |
ph_num = torch.from_numpy(np.array([param['ph_num'].split()], np.int64)).to(self.device) # [B=1, T_w] | |
ph2word = self.lr(ph_num) # => [B=1, T_ph] | |
T_w = int(ph2word.max()) | |
batch['ph2word'] = ph2word | |
note_midi = np.array( | |
[(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in param['note_seq'].split()], | |
dtype=np.float32 | |
) | |
note_rest = note_midi < 0 | |
if np.all(note_rest): | |
# All rests, fill with constants | |
note_midi = np.full_like(note_midi, fill_value=60.) | |
else: | |
# Interpolate rest values | |
interp_func = interpolate.interp1d( | |
np.where(~note_rest)[0], note_midi[~note_rest], | |
kind='nearest', fill_value='extrapolate' | |
) | |
note_midi[note_rest] = interp_func(np.where(note_rest)[0]) | |
note_midi = torch.from_numpy(note_midi).to(self.device)[None] # [B=1, T_n] | |
note_rest = torch.from_numpy(note_rest).to(self.device)[None] # [B=1, T_n] | |
T_n = note_midi.shape[1] | |
note_dur_sec = torch.from_numpy(np.array([param['note_dur'].split()], np.float32)).to(self.device) # [B=1, T_n] | |
note_acc = torch.round(torch.cumsum(note_dur_sec, dim=1) / self.timestep + 0.5).long() | |
note_dur = torch.diff(note_acc, dim=1, prepend=note_acc.new_zeros(1, 1)) | |
mel2note = self.lr(note_dur) # [B=1, T_s] | |
T_s = mel2note.shape[1] | |
summary['words'] = T_w | |
summary['notes'] = T_n | |
summary['tokens'] = T_ph | |
summary['frames'] = T_s | |
summary['seconds'] = '%.2f' % (T_s * self.timestep) | |
if hparams['use_spk_id']: | |
ph_spk_mix_id, ph_spk_mix_value = self.load_speaker_mix( | |
param_src=param, summary_dst=summary, mix_mode='token', mix_length=T_ph | |
) | |
spk_mix_id, spk_mix_value = self.load_speaker_mix( | |
param_src=param, summary_dst=summary, mix_mode='frame', mix_length=T_s | |
) | |
batch['ph_spk_mix_id'] = ph_spk_mix_id | |
batch['ph_spk_mix_value'] = ph_spk_mix_value | |
batch['spk_mix_id'] = spk_mix_id | |
batch['spk_mix_value'] = spk_mix_value | |
if load_dur: | |
# Get mel2ph if ph_dur is needed | |
ph_dur_sec = torch.from_numpy( | |
np.array([param['ph_dur'].split()], np.float32) | |
).to(self.device) # [B=1, T_ph] | |
ph_acc = torch.round(torch.cumsum(ph_dur_sec, dim=1) / self.timestep + 0.5).long() | |
ph_dur = torch.diff(ph_acc, dim=1, prepend=ph_acc.new_zeros(1, 1)) | |
mel2ph = self.lr(ph_dur, txt_tokens == 0) | |
if mel2ph.shape[1] != T_s: # Align phones with notes | |
mel2ph = F.pad(mel2ph, [0, T_s - mel2ph.shape[1]], value=mel2ph[0, -1]) | |
ph_dur = mel2ph_to_dur(mel2ph, T_ph) | |
# Get word_dur from ph_dur and ph_num | |
word_dur = note_dur.new_zeros(1, T_w + 1).scatter_add( | |
1, ph2word, ph_dur | |
)[:, 1:] # => [B=1, T_w] | |
else: | |
ph_dur = None | |
mel2ph = None | |
# Get word_dur from note_dur and note_slur | |
is_slur = torch.BoolTensor([[int(s) for s in param['note_slur'].split()]]).to(self.device) # [B=1, T_n] | |
note2word = torch.cumsum(~is_slur, dim=1) # [B=1, T_n] | |
word_dur = note_dur.new_zeros(1, T_w + 1).scatter_add( | |
1, note2word, note_dur | |
)[:, 1:] # => [B=1, T_w] | |
batch['ph_dur'] = ph_dur | |
batch['mel2ph'] = mel2ph | |
mel2word = self.lr(word_dur) # [B=1, T_s] | |
if mel2word.shape[1] != T_s: # Align words with notes | |
mel2word = F.pad(mel2word, [0, T_s - mel2word.shape[1]], value=mel2word[0, -1]) | |
word_dur = mel2ph_to_dur(mel2word, T_w) | |
batch['word_dur'] = word_dur | |
batch['note_midi'] = note_midi | |
batch['note_dur'] = note_dur | |
batch['note_rest'] = note_rest | |
if hparams.get('use_glide_embed', False) and param.get('note_glide') is not None: | |
batch['note_glide'] = torch.LongTensor( | |
[[self.glide_map.get(x, 0) for x in param['note_glide'].split()]] | |
).to(self.device) | |
else: | |
batch['note_glide'] = torch.zeros(1, T_n, dtype=torch.long, device=self.device) | |
batch['mel2note'] = mel2note | |
# Calculate and smoothen the frame-level MIDI pitch, which is a step function curve | |
frame_midi_pitch = torch.gather( | |
F.pad(note_midi, [1, 0]), 1, mel2note | |
) # => frame-level MIDI pitch, [B=1, T_s] | |
base_pitch = self.smooth(frame_midi_pitch) | |
batch['base_pitch'] = base_pitch | |
if ph_dur is not None: | |
# Phone durations are available, calculate phoneme-level MIDI. | |
mel2pdur = torch.gather(F.pad(ph_dur, [1, 0], value=1), 1, mel2ph) # frame-level phone duration | |
ph_midi = frame_midi_pitch.new_zeros(1, T_ph + 1).scatter_add( | |
1, mel2ph, frame_midi_pitch / mel2pdur | |
)[:, 1:] | |
else: | |
# Phone durations are not available, calculate word-level MIDI instead. | |
mel2wdur = torch.gather(F.pad(word_dur, [1, 0], value=1), 1, mel2word) | |
w_midi = frame_midi_pitch.new_zeros(1, T_w + 1).scatter_add( | |
1, mel2word, frame_midi_pitch / mel2wdur | |
)[:, 1:] | |
# Convert word-level MIDI to phoneme-level MIDI | |
ph_midi = torch.gather(F.pad(w_midi, [1, 0]), 1, ph2word) | |
ph_midi = ph_midi.round().long() | |
batch['midi'] = ph_midi | |
if load_pitch: | |
f0 = resample_align_curve( | |
np.array(param['f0_seq'].split(), np.float32), | |
original_timestep=float(param['f0_timestep']), | |
target_timestep=self.timestep, | |
align_length=T_s | |
) | |
batch['pitch'] = torch.from_numpy( | |
librosa.hz_to_midi(interp_f0(f0)[0]).astype(np.float32) | |
).to(self.device)[None] | |
if self.model.predict_dur: | |
if load_dur: | |
summary['ph_dur'] = 'manual' | |
elif self.auto_completion_mode or self.global_predict_dur: | |
summary['ph_dur'] = 'auto' | |
else: | |
summary['ph_dur'] = 'ignored' | |
if self.model.predict_pitch: | |
if load_pitch: | |
summary['pitch'] = 'manual' | |
elif self.auto_completion_mode or self.global_predict_pitch: | |
summary['pitch'] = 'auto' | |
# Load expressiveness | |
expr = param.get('expr', 1.) | |
if isinstance(expr, (int, float, bool)): | |
summary['expr'] = f'static({expr:.3f})' | |
batch['expr'] = torch.FloatTensor([expr]).to(self.device)[:, None] # [B=1, T=1] | |
else: | |
summary['expr'] = 'dynamic' | |
expr = resample_align_curve( | |
np.array(expr.split(), np.float32), | |
original_timestep=float(param['expr_timestep']), | |
target_timestep=self.timestep, | |
align_length=T_s | |
) | |
batch['expr'] = torch.from_numpy(expr.astype(np.float32)).to(self.device)[None] | |
else: | |
summary['pitch'] = 'ignored' | |
if self.model.predict_variances: | |
for v_name in self.model.variance_prediction_list: | |
if self.auto_completion_mode and param.get(v_name) is None or v_name in self.variance_prediction_set: | |
summary[v_name] = 'auto' | |
else: | |
summary[v_name] = 'ignored' | |
print(f'[{idx}]\t' + ', '.join(f'{k}: {v}' for k, v in summary.items())) | |
return batch | |
def forward_model(self, sample): | |
txt_tokens = sample['tokens'] | |
midi = sample['midi'] | |
ph2word = sample['ph2word'] | |
word_dur = sample['word_dur'] | |
ph_dur = sample['ph_dur'] | |
mel2ph = sample['mel2ph'] | |
note_midi = sample['note_midi'] | |
note_rest = sample['note_rest'] | |
note_dur = sample['note_dur'] | |
note_glide = sample['note_glide'] | |
mel2note = sample['mel2note'] | |
base_pitch = sample['base_pitch'] | |
expr = sample.get('expr') | |
pitch = sample.get('pitch') | |
if hparams['use_spk_id']: | |
ph_spk_mix_id = sample['ph_spk_mix_id'] | |
ph_spk_mix_value = sample['ph_spk_mix_value'] | |
spk_mix_id = sample['spk_mix_id'] | |
spk_mix_value = sample['spk_mix_value'] | |
ph_spk_mix_embed = torch.sum( | |
self.model.spk_embed(ph_spk_mix_id) * ph_spk_mix_value.unsqueeze(3), # => [B, T_ph, N, H] | |
dim=2, keepdim=False | |
) # => [B, T_ph, H] | |
spk_mix_embed = torch.sum( | |
self.model.spk_embed(spk_mix_id) * spk_mix_value.unsqueeze(3), # => [B, T_s, N, H] | |
dim=2, keepdim=False | |
) # [B, T_s, H] | |
else: | |
ph_spk_mix_embed = spk_mix_embed = None | |
dur_pred, pitch_pred, variance_pred = self.model( | |
txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, | |
note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, | |
base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, | |
ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, | |
infer=True | |
) | |
if dur_pred is not None: | |
dur_pred = self.rr(dur_pred, ph2word, word_dur) | |
if pitch_pred is not None: | |
pitch_pred = base_pitch + pitch_pred | |
return dur_pred, pitch_pred, variance_pred | |
def infer_once(self, param): | |
batch = self.preprocess_input(param) | |
dur_pred, pitch_pred, variance_pred = self.forward_model(batch) | |
if dur_pred is not None: | |
dur_pred = dur_pred[0].cpu().numpy() | |
if pitch_pred is not None: | |
pitch_pred = pitch_pred[0].cpu().numpy() | |
f0_pred = librosa.midi_to_hz(pitch_pred) | |
else: | |
f0_pred = None | |
variance_pred = { | |
k: v[0].cpu().numpy() | |
for k, v in variance_pred.items() | |
} | |
return dur_pred, f0_pred, variance_pred | |
def run_inference( | |
self, params, | |
out_dir: pathlib.Path = None, | |
title: str = None, | |
num_runs: int = 1, | |
seed: int = -1 | |
): | |
batches = [] | |
predictor_flags: List[Tuple[bool, bool, bool]] = [] | |
for i, param in enumerate(params): | |
param: dict | |
if self.auto_completion_mode: | |
flag = ( | |
self.model.fs2.predict_dur and param.get('ph_dur') is None, | |
self.model.predict_pitch and param.get('f0_seq') is None, | |
self.model.predict_variances and any( | |
param.get(v_name) is None for v_name in self.model.variance_prediction_list | |
) | |
) | |
else: | |
predict_variances = self.model.predict_variances and self.global_predict_variances | |
predict_pitch = self.model.predict_pitch and ( | |
self.global_predict_pitch or (param.get('f0_seq') is None and predict_variances) | |
) | |
predict_dur = self.model.predict_dur and ( | |
self.global_predict_dur or (param.get('ph_dur') is None and (predict_pitch or predict_variances)) | |
) | |
flag = (predict_dur, predict_pitch, predict_variances) | |
predictor_flags.append(flag) | |
batches.append(self.preprocess_input( | |
param, idx=i, | |
load_dur=not flag[0] and (flag[1] or flag[2]), | |
load_pitch=not flag[1] and flag[2] | |
)) | |
out_dir.mkdir(parents=True, exist_ok=True) | |
for i in range(num_runs): | |
results = [] | |
for param, flag, batch in tqdm.tqdm( | |
zip(params, predictor_flags, batches), desc='infer segments', total=len(params) | |
): | |
if 'seed' in param: | |
torch.manual_seed(param["seed"] & 0xffff_ffff) | |
torch.cuda.manual_seed_all(param["seed"] & 0xffff_ffff) | |
elif seed >= 0: | |
torch.manual_seed(seed & 0xffff_ffff) | |
torch.cuda.manual_seed_all(seed & 0xffff_ffff) | |
param_copy = copy.deepcopy(param) | |
flag_saved = ( | |
self.model.fs2.predict_dur, | |
self.model.predict_pitch, | |
self.model.predict_variances | |
) | |
( | |
self.model.fs2.predict_dur, | |
self.model.predict_pitch, | |
self.model.predict_variances | |
) = flag | |
dur_pred, pitch_pred, variance_pred = self.forward_model(batch) | |
( | |
self.model.fs2.predict_dur, | |
self.model.predict_pitch, | |
self.model.predict_variances | |
) = flag_saved | |
if dur_pred is not None and (self.auto_completion_mode or self.global_predict_dur): | |
dur_pred = dur_pred[0].cpu().numpy() | |
param_copy['ph_dur'] = ' '.join(str(round(dur, 6)) for dur in (dur_pred * self.timestep).tolist()) | |
if pitch_pred is not None and (self.auto_completion_mode or self.global_predict_pitch): | |
pitch_pred = pitch_pred[0].cpu().numpy() | |
f0_pred = librosa.midi_to_hz(pitch_pred) | |
param_copy['f0_seq'] = ' '.join([str(round(freq, 1)) for freq in f0_pred.tolist()]) | |
param_copy['f0_timestep'] = str(self.timestep) | |
variance_pred = { | |
k: v[0].cpu().numpy() | |
for k, v in variance_pred.items() | |
if (self.auto_completion_mode and param.get(k) is None) or k in self.variance_prediction_set | |
} | |
for v_name, v_pred in variance_pred.items(): | |
param_copy[v_name] = ' '.join([str(round(v, 4)) for v in v_pred.tolist()]) | |
param_copy[f'{v_name}_timestep'] = str(self.timestep) | |
# Restore ph_spk_mix and spk_mix | |
if 'ph_spk_mix' in param_copy and 'spk_mix' in param_copy: | |
if 'ph_spk_mix_backup' in param_copy: | |
if param_copy['ph_spk_mix_backup'] is None: | |
del param_copy['ph_spk_mix'] | |
else: | |
param_copy['ph_spk_mix'] = param_copy['ph_spk_mix_backup'] | |
del param['ph_spk_mix_backup'] | |
if 'spk_mix_backup' in param_copy: | |
if param_copy['ph_spk_mix_backup'] is None: | |
del param_copy['spk_mix'] | |
else: | |
param_copy['spk_mix'] = param_copy['spk_mix_backup'] | |
del param['spk_mix_backup'] | |
results.append(param_copy) | |
if num_runs > 1: | |
filename = f'{title}-{str(i).zfill(3)}.ds' | |
else: | |
filename = f'{title}.ds' | |
save_path = out_dir / filename | |
with open(save_path, 'w', encoding='utf8') as f: | |
print(f'| save params: {save_path}') | |
json.dump(results, f, ensure_ascii=False, indent=2) | |