File size: 8,720 Bytes
cbdb41f
 
 
 
 
 
 
 
 
 
2352e73
4732065
78ead36
2352e73
091b3ba
b198353
 
4732065
 
2352e73
4732065
 
 
 
 
 
 
d161181
4732065
 
 
2352e73
cbdb41f
 
4732065
cbdb41f
 
4732065
 
 
 
78ead36
 
 
 
 
cbdb41f
 
 
 
 
d161181
cbdb41f
 
f07fe35
7efb86f
 
 
 
 
 
 
 
 
 
cbdb41f
 
4732065
cbdb41f
4732065
cbdb41f
4732065
 
cbdb41f
 
 
 
4732065
 
 
 
 
 
 
 
 
 
 
 
 
cbdb41f
 
4732065
 
 
cbdb41f
4732065
cbdb41f
 
4732065
 
cbdb41f
 
 
4732065
 
 
 
 
 
 
7efb86f
4732065
 
 
 
 
 
cbdb41f
 
 
4732065
 
cbdb41f
 
 
 
4732065
7efb86f
cbdb41f
 
 
4732065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbdb41f
 
7efb86f
 
 
 
 
 
 
 
 
4732065
d161181
4732065
cbdb41f
4732065
 
 
 
 
 
 
cbdb41f
4732065
7efb86f
4732065
 
 
 
7efb86f
 
 
 
 
d161181
 
 
 
 
 
 
 
7efb86f
 
 
 
 
 
 
 
 
 
 
 
4732065
 
 
2352e73
7255251
 
cbdb41f
4732065
 
 
 
 
 
d161181
 
 
 
4732065
 
d161181
 
4732065
 
d161181
2352e73
d161181
2352e73
 
cbdb41f
2352e73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import numpy as np 
from pathlib import Path
import paderbox as pb
import torch
from onnxruntime import InferenceSession
from pvq_manipulation.models.vits import Vits_NT
from pvq_manipulation.models.ffjord import FFJORD
from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER
import librosa
from pvq_manipulation.helper.vad import EnergyVAD
import gradio as gr
from pvq_manipulation.helper.creapy_wrapper import process_file
from creapy.utils import config

import os
torch.set_num_threads(os.cpu_count() or 1)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
pvq_labels = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']

dataset_dict = pb.io.load_yaml('./Dataset/dataset.yaml')

cached_example_id = None
cached_loaded_example = None
cached_labels = None
cached_d_vector = None
cached_unmanipulated = None
cached_transcription = None

# path to stats
stats_path = Path('./Dataset/Embeddings/')

# load normalizing flow
storage_dir_normalizing_flow = Path("./models/norm_flow")
config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / "config.json")
normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint="model.pt", device=device)

# load tts model
storage_dir_tts = Path("./models/tts_model/")
tts_model = Vits_NT.load_model(storage_dir_tts, "model.pt")


config._CONFIG_DIR = "./pvq_manipulation/helper/creapy_config.yaml"
config._USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml"
config.USER_CONFIG_DIR = "./pvq_manipulation/helper/user_config.yaml"

# load hubert features model
hubert_model = HubertExtractor(
    layer=SID_LARGE_LAYER,
    model_name="HUBERT_LARGE",
    backend="torchaudio",
    device=device,
    # storage_dir= # target storage dir hubert model
)

# load pvq models
reg_stor_dir = Path('./models/pvq_extractor/')
onnx_sessions = {}
for pvq in pvq_labels:
    onnx_path = reg_stor_dir / f"{pvq}.onnx"
    onnx_sessions[pvq] = InferenceSession(
        str(onnx_path),
        providers=["CPUExecutionProvider"]
    )


def get_manipulation(
    example,
    labels,
    flow,
    tts_model,
    d_vector,
    config_norm_flow,
    manipulation_idx=0,
    manipulation_fkt=1,
):
    labels_manipulated = labels.clone()
    labels_manipulated[:, manipulation_idx] += manipulation_fkt

    if config_norm_flow['flag_remove_mean']:
        global_mean = pb.io.load(stats_path / "mean.json")
        global_mean = torch.tensor(global_mean, dtype=torch.float32)
        speaker_embedding_norm = (d_vector - global_mean)
        global_std = pb.io.load(stats_path / "std.json")
        global_std = torch.tensor(global_std, dtype=torch.float32)
        speaker_embedding_norm = speaker_embedding_norm / global_std
    else:
        speaker_embedding_norm = d_vector

    output_forward = flow.forward((speaker_embedding_norm.float(), labels))[0]
    sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]

    if config_norm_flow['flag_remove_mean']:
        sampled_class_manipulated = (sampled_class_manipulated * global_std + global_mean)

    wav = tts_model.synthesize_from_example({
        'text': example['transcription'],
        'd_vector': d_vector.detach().numpy(),
        'd_vector_man': sampled_class_manipulated.detach().numpy(),
        'd_vector_storage_root': example['d_vector_storage_root'],
    })
    return wav


def get_creak_label(example):
    audio_data = example['loaded_audio_data']['16_000']
    test, y_pred, included_indices = process_file(audio_data)
    mean_creak = np.mean(y_pred[included_indices])
    return mean_creak * 100


def load_speaker_labels(example):
    audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :]
    num_samples = torch.tensor([audio_data.shape[-1]])

    if torch.cuda.is_available():
        audio_data = audio_data.cuda()
        num_samples = num_samples.cuda()

    with torch.no_grad():
        features, seq_len = hubert_model(
            audio_data,
            16_000,
            sequence_lengths=num_samples,
        )
        features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)
        pvqd_predictions = {}
        for pvq in pvq_labels:
            sess = onnx_sessions[pvq]
            pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
            pvqd_predictions[pvq] = pred.tolist()[0]

    pvqd_predictions['Creak_mean'] = get_creak_label(example)
    labels = [pvqd_predictions[key] / 100 for key in pvq_labels + ["Creak_mean"]]
    return torch.tensor(labels, device=device).float()


def load_audio_files(example):
    observation_loaded, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)

    example['loaded_audio_data'] = {}
    observation = librosa.resample(observation_loaded, orig_sr=sr, target_sr=16_000)

    vad = EnergyVAD(sample_rate=16_000)
    if observation.ndim == 1:
        observation = observation[None, :]

    observation = vad({'audio_data': observation})['audio_data']
    example['loaded_audio_data']['16_000'] = observation

    observation = librosa.resample(observation, orig_sr=sr, target_sr=24_000)
    vad = EnergyVAD(sample_rate=24_000)
    if observation.ndim == 1:
        observation = observation[None, :]
    observation = vad({'audio_data': observation})['audio_data']
    example['loaded_audio_data']['24_000'] = observation
    return example


def delete_cache():
    global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, cached_unmanipulated
    del cached_example_id
    del cached_loaded_example
    del cached_labels
    del cached_d_vector
    del cached_unmanipulated


def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt):
    global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated, cached_transcription
    speaker_id = dataset_dict['dataset'][example_id]['speaker_id']

    example = {
        'audio_path': {'observation': f"./Dataset/Audio_files/{example_id}.wav"},
        'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth",
        'speaker_id': speaker_id,
        'example_id': example_id,
        'transcription': transcription
    }

    if cached_example_id != example_id:
        delete_cache()
        cached_loaded_example = load_audio_files(example)
        cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth")
        cached_labels = load_speaker_labels(example)
        cached_example_id = example_id
        with torch.no_grad():
            cached_unmanipulated = tts_model.synthesize_from_example({
                'text': transcription,
                'd_vector': cached_d_vector.detach().numpy(),
            })
        cached_transcription = transcription
    if cached_loaded_example != example or transcription != cached_transcription:
        with torch.no_grad():
            cached_unmanipulated = tts_model.synthesize_from_example({
                'text': transcription,
                'd_vector': cached_d_vector.detach().numpy(),
            })
        cached_transcription = transcription

    with torch.no_grad():
        wav_manipulated = get_manipulation(
            example=example,
            d_vector=cached_d_vector,
            labels=cached_labels[None, :],
            flow=normalizing_flow,
            tts_model=tts_model,
            manipulation_idx=manipulation_idx,
            manipulation_fkt=manipulation_fkt,
            config_norm_flow=config_norm_flow,
        )
    return (24_000, cached_unmanipulated), (24_000, wav_manipulated)


demo = gr.Interface(
    title="Perceptual Voice Quality (PVQ) Manipulation",
    fn=update_manipulation,
    inputs=[
        gr.Dropdown(
            label="PVQ Feature",
            choices=[('Weight', 0), ('Resonance', 1), ('Breathiness', 2), ('Roughness', 3), ('Creak', 7)],
            value=2, type="value"
        ),
        gr.Dropdown(
            label="Speaker",
            choices=[(str(idx), example_id) for idx, example_id in enumerate(dataset_dict['dataset'].keys())],
            value="1422_149735_000006_000000",
            type="value"
        ),
        gr.Textbox(
            label="Text Input",
            value="Department of Communications Engineering Paderborn University.",
            placeholder='Type something'
        ),
        gr.Slider(label="Manipulation Intensity", minimum=-1.0, maximum=2.0, value=1.0, step=0.1),
    ],
    outputs=[gr.Audio(label="original synthesized utterance"), gr.Audio(label="manipulated synthesized utterance")],
)
if __name__ == "__main__":
    demo.launch(share=True)