Spaces:
Sleeping
Sleeping
import numpy as np | |
from pathlib import Path | |
import paderbox as pb | |
import torch | |
from onnxruntime import InferenceSession | |
from pvq_manipulation.models.vits import Vits_NT | |
from pvq_manipulation.models.ffjord import FFJORD | |
from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER | |
import librosa | |
from pvq_manipulation.helper.vad import EnergyVAD | |
import gradio as gr | |
from pvq_manipulation.helper.creapy_wrapper import process_file | |
from creapy.utils import config | |
import os | |
torch.set_num_threads(os.cpu_count() or 1) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
pvq_labels = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch'] | |
dataset_dict = pb.io.load_yaml('./Dataset/dataset.yaml') | |
cached_example_id = None | |
cached_loaded_example = None | |
cached_labels = None | |
cached_d_vector = None | |
cached_unmanipulated = None | |
cached_transcription = None | |
# path to stats | |
stats_path = Path('./Dataset/Embeddings/') | |
# load normalizing flow | |
storage_dir_normalizing_flow = Path("./models/norm_flow") | |
config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / "config.json") | |
normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint="model.pt", device=device) | |
# load tts model | |
storage_dir_tts = Path("./models/tts_model/") | |
tts_model = Vits_NT.load_model(storage_dir_tts, "model.pt") | |
config._CONFIG_DIR = "./pvq_manipulation/helper/creapy_config.yaml" | |
config._USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml" | |
config.USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml" | |
# load hubert features model | |
hubert_model = HubertExtractor( | |
layer=SID_LARGE_LAYER, | |
model_name="HUBERT_LARGE", | |
backend="torchaudio", | |
device=device, | |
# storage_dir= # target storage dir hubert model | |
) | |
# load pvq models | |
reg_stor_dir = Path('./models/pvq_extractor/') | |
onnx_sessions = {} | |
for pvq in pvq_labels: | |
onnx_path = reg_stor_dir / f"{pvq}.onnx" | |
onnx_sessions[pvq] = InferenceSession( | |
str(onnx_path), | |
providers=["CPUExecutionProvider"] | |
) | |
def get_manipulation( | |
example, | |
labels, | |
flow, | |
tts_model, | |
d_vector, | |
config_norm_flow, | |
manipulation_idx=0, | |
manipulation_fkt=1, | |
): | |
labels_manipulated = labels.clone() | |
labels_manipulated[:, manipulation_idx] += manipulation_fkt | |
if config_norm_flow['flag_remove_mean']: | |
global_mean = pb.io.load(stats_path / "mean.json") | |
global_mean = torch.tensor(global_mean, dtype=torch.float32) | |
speaker_embedding_norm = (d_vector - global_mean) | |
global_std = pb.io.load(stats_path / "std.json") | |
global_std = torch.tensor(global_std, dtype=torch.float32) | |
speaker_embedding_norm = speaker_embedding_norm / global_std | |
else: | |
speaker_embedding_norm = d_vector | |
output_forward = flow.forward((speaker_embedding_norm.float(), labels))[0] | |
sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0] | |
if config_norm_flow['flag_remove_mean']: | |
sampled_class_manipulated = (sampled_class_manipulated * global_std + global_mean) | |
wav = tts_model.synthesize_from_example({ | |
'text': example['transcription'], | |
'd_vector': d_vector.detach().numpy(), | |
'd_vector_man': sampled_class_manipulated.detach().numpy(), | |
'd_vector_storage_root': example['d_vector_storage_root'], | |
}) | |
return wav | |
def get_creak_label(example): | |
audio_data = example['loaded_audio_data']['16_000'] | |
test, y_pred, included_indices = process_file(audio_data) | |
mean_creak = np.mean(y_pred[included_indices]) | |
return mean_creak * 100 | |
def load_speaker_labels(example): | |
audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :] | |
num_samples = torch.tensor([audio_data.shape[-1]]) | |
if torch.cuda.is_available(): | |
audio_data = audio_data.cuda() | |
num_samples = num_samples.cuda() | |
with torch.no_grad(): | |
features, seq_len = hubert_model( | |
audio_data, | |
16_000, | |
sequence_lengths=num_samples, | |
) | |
features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1) | |
pvqd_predictions = {} | |
for pvq in pvq_labels: | |
sess = onnx_sessions[pvq] | |
pred = sess.run(None, {"X": features[None]})[0].squeeze(1) | |
pvqd_predictions[pvq] = pred.tolist()[0] | |
pvqd_predictions['Creak_mean'] = get_creak_label(example) | |
labels = [pvqd_predictions[key] / 100 for key in pvq_labels + ["Creak_mean"]] | |
return torch.tensor(labels, device=device).float() | |
def load_audio_files(example): | |
observation_loaded, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True) | |
example['loaded_audio_data'] = {} | |
observation = librosa.resample(observation_loaded, orig_sr=sr, target_sr=16_000) | |
vad = EnergyVAD(sample_rate=16_000) | |
if observation.ndim == 1: | |
observation = observation[None, :] | |
observation = vad({'audio_data': observation})['audio_data'] | |
example['loaded_audio_data']['16_000'] = observation | |
observation = librosa.resample(observation, orig_sr=sr, target_sr=24_000) | |
vad = EnergyVAD(sample_rate=24_000) | |
if observation.ndim == 1: | |
observation = observation[None, :] | |
observation = vad({'audio_data': observation})['audio_data'] | |
example['loaded_audio_data']['24_000'] = observation | |
return example | |
def delete_cache(): | |
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, cached_unmanipulated | |
del cached_example_id | |
del cached_loaded_example | |
del cached_labels | |
del cached_d_vector | |
del cached_unmanipulated | |
def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt): | |
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated, cached_transcription | |
speaker_id = dataset_dict['dataset'][example_id]['speaker_id'] | |
example = { | |
'audio_path': {'observation': f"./Dataset/Audio_files/{example_id}.wav"}, | |
'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth", | |
'speaker_id': speaker_id, | |
'example_id': example_id, | |
'transcription': transcription | |
} | |
if cached_example_id != example_id: | |
delete_cache() | |
cached_loaded_example = load_audio_files(example) | |
cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth") | |
cached_labels = load_speaker_labels(example) | |
cached_example_id = example_id | |
with torch.no_grad(): | |
cached_unmanipulated = tts_model.synthesize_from_example({ | |
'text': transcription, | |
'd_vector': cached_d_vector.detach().numpy(), | |
}) | |
cached_transcription = transcription | |
if cached_loaded_example != example or transcription != cached_transcription: | |
with torch.no_grad(): | |
cached_unmanipulated = tts_model.synthesize_from_example({ | |
'text': transcription, | |
'd_vector': cached_d_vector.detach().numpy(), | |
}) | |
cached_transcription = transcription | |
with torch.no_grad(): | |
wav_manipulated = get_manipulation( | |
example=example, | |
d_vector=cached_d_vector, | |
labels=cached_labels[None, :], | |
flow=normalizing_flow, | |
tts_model=tts_model, | |
manipulation_idx=manipulation_idx, | |
manipulation_fkt=manipulation_fkt, | |
config_norm_flow=config_norm_flow, | |
) | |
return (24_000, cached_unmanipulated), (24_000, wav_manipulated) | |
demo = gr.Interface( | |
title="Perceptual Voice Quality (PVQ) Manipulation", | |
fn=update_manipulation, | |
inputs=[ | |
gr.Dropdown( | |
label="PVQ Feature", | |
choices=[('Weight', 0), ('Resonance', 1), ('Breathiness', 2), ('Roughness', 3), ('Creak', 7)], | |
value=2, type="value" | |
), | |
gr.Dropdown( | |
label="Speaker", | |
choices=[(str(idx), example_id) for idx, example_id in enumerate(dataset_dict['dataset'].keys())], | |
value="1422_149735_000006_000000", | |
type="value" | |
), | |
gr.Textbox( | |
label="Text Input", | |
value="Department of Communications Engineering Paderborn University.", | |
placeholder='Type something' | |
), | |
gr.Slider(label="Manipulation Intensity", minimum=-1.0, maximum=2.0, value=1.0, step=0.1), | |
], | |
outputs=[gr.Audio(label="original synthesized utterance"), gr.Audio(label="manipulated synthesized utterance")], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |