import numpy as np from pathlib import Path import paderbox as pb import torch from onnxruntime import InferenceSession from pvq_manipulation.models.vits import Vits_NT from pvq_manipulation.models.ffjord import FFJORD from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER import librosa from pvq_manipulation.helper.vad import EnergyVAD import gradio as gr from pvq_manipulation.helper.creapy_wrapper import process_file from creapy.utils import config import os torch.set_num_threads(os.cpu_count() or 1) device = 'cuda' if torch.cuda.is_available() else 'cpu' pvq_labels = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch'] dataset_dict = pb.io.load_yaml('./Dataset/dataset.yaml') cached_example_id = None cached_loaded_example = None cached_labels = None cached_d_vector = None cached_unmanipulated = None cached_transcription = None # path to stats stats_path = Path('./Dataset/Embeddings/') # load normalizing flow storage_dir_normalizing_flow = Path("./models/norm_flow") config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / "config.json") normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint="model.pt", device=device) # load tts model storage_dir_tts = Path("./models/tts_model/") tts_model = Vits_NT.load_model(storage_dir_tts, "model.pt") config._CONFIG_DIR = "./pvq_manipulation/helper/creapy_config.yaml" config._USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml" config.USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml" # load hubert features model hubert_model = HubertExtractor( layer=SID_LARGE_LAYER, model_name="HUBERT_LARGE", backend="torchaudio", device=device, # storage_dir= # target storage dir hubert model ) # load pvq models reg_stor_dir = Path('./models/pvq_extractor/') onnx_sessions = {} for pvq in pvq_labels: onnx_path = reg_stor_dir / f"{pvq}.onnx" onnx_sessions[pvq] = InferenceSession( str(onnx_path), providers=["CPUExecutionProvider"] ) def get_manipulation( example, labels, flow, tts_model, d_vector, config_norm_flow, manipulation_idx=0, manipulation_fkt=1, ): labels_manipulated = labels.clone() labels_manipulated[:, manipulation_idx] += manipulation_fkt if config_norm_flow['flag_remove_mean']: global_mean = pb.io.load(stats_path / "mean.json") global_mean = torch.tensor(global_mean, dtype=torch.float32) speaker_embedding_norm = (d_vector - global_mean) global_std = pb.io.load(stats_path / "std.json") global_std = torch.tensor(global_std, dtype=torch.float32) speaker_embedding_norm = speaker_embedding_norm / global_std else: speaker_embedding_norm = d_vector output_forward = flow.forward((speaker_embedding_norm.float(), labels))[0] sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0] if config_norm_flow['flag_remove_mean']: sampled_class_manipulated = (sampled_class_manipulated * global_std + global_mean) wav = tts_model.synthesize_from_example({ 'text': example['transcription'], 'd_vector': d_vector.detach().numpy(), 'd_vector_man': sampled_class_manipulated.detach().numpy(), 'd_vector_storage_root': example['d_vector_storage_root'], }) return wav def get_creak_label(example): audio_data = example['loaded_audio_data']['16_000'] test, y_pred, included_indices = process_file(audio_data) mean_creak = np.mean(y_pred[included_indices]) return mean_creak * 100 def load_speaker_labels(example): audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :] num_samples = torch.tensor([audio_data.shape[-1]]) if torch.cuda.is_available(): audio_data = audio_data.cuda() num_samples = num_samples.cuda() with torch.no_grad(): features, seq_len = hubert_model( audio_data, 16_000, sequence_lengths=num_samples, ) features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1) pvqd_predictions = {} for pvq in pvq_labels: sess = onnx_sessions[pvq] pred = sess.run(None, {"X": features[None]})[0].squeeze(1) pvqd_predictions[pvq] = pred.tolist()[0] pvqd_predictions['Creak_mean'] = get_creak_label(example) labels = [pvqd_predictions[key] / 100 for key in pvq_labels + ["Creak_mean"]] return torch.tensor(labels, device=device).float() def load_audio_files(example): observation_loaded, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True) example['loaded_audio_data'] = {} observation = librosa.resample(observation_loaded, orig_sr=sr, target_sr=16_000) vad = EnergyVAD(sample_rate=16_000) if observation.ndim == 1: observation = observation[None, :] observation = vad({'audio_data': observation})['audio_data'] example['loaded_audio_data']['16_000'] = observation observation = librosa.resample(observation, orig_sr=sr, target_sr=24_000) vad = EnergyVAD(sample_rate=24_000) if observation.ndim == 1: observation = observation[None, :] observation = vad({'audio_data': observation})['audio_data'] example['loaded_audio_data']['24_000'] = observation return example def delete_cache(): global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, cached_unmanipulated del cached_example_id del cached_loaded_example del cached_labels del cached_d_vector del cached_unmanipulated def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt): global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated, cached_transcription speaker_id = dataset_dict['dataset'][example_id]['speaker_id'] example = { 'audio_path': {'observation': f"./Dataset/Audio_files/{example_id}.wav"}, 'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth", 'speaker_id': speaker_id, 'example_id': example_id, 'transcription': transcription } if cached_example_id != example_id: delete_cache() cached_loaded_example = load_audio_files(example) cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth") cached_labels = load_speaker_labels(example) cached_example_id = example_id with torch.no_grad(): cached_unmanipulated = tts_model.synthesize_from_example({ 'text': transcription, 'd_vector': cached_d_vector.detach().numpy(), }) cached_transcription = transcription if cached_loaded_example != example or transcription != cached_transcription: with torch.no_grad(): cached_unmanipulated = tts_model.synthesize_from_example({ 'text': transcription, 'd_vector': cached_d_vector.detach().numpy(), }) cached_transcription = transcription with torch.no_grad(): wav_manipulated = get_manipulation( example=example, d_vector=cached_d_vector, labels=cached_labels[None, :], flow=normalizing_flow, tts_model=tts_model, manipulation_idx=manipulation_idx, manipulation_fkt=manipulation_fkt, config_norm_flow=config_norm_flow, ) return (24_000, cached_unmanipulated), (24_000, wav_manipulated) demo = gr.Interface( title="Perceptual Voice Quality (PVQ) Manipulation", fn=update_manipulation, inputs=[ gr.Dropdown( label="PVQ Feature", choices=[('Weight', 0), ('Resonance', 1), ('Breathiness', 2), ('Roughness', 3), ('Creak', 7)], value=2, type="value" ), gr.Dropdown( label="Speaker", choices=[(str(idx), example_id) for idx, example_id in enumerate(dataset_dict['dataset'].keys())], value="1422_149735_000006_000000", type="value" ), gr.Textbox( label="Text Input", value="Department of Communications Engineering Paderborn University.", placeholder='Type something' ), gr.Slider(label="Manipulation Intensity", minimum=-1.0, maximum=2.0, value=1.0, step=0.1), ], outputs=[gr.Audio(label="original synthesized utterance"), gr.Audio(label="manipulated synthesized utterance")], ) if __name__ == "__main__": demo.launch(share=True)