Spaces:
Running
Running
Commit
·
7efb86f
1
Parent(s):
091b3ba
debug threads
Browse files
app.py
CHANGED
@@ -52,6 +52,16 @@ hubert_model = HubertExtractor(
|
|
52 |
# storage_dir= # target storage dir hubert model
|
53 |
)
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def get_manipulation(
|
57 |
example,
|
@@ -98,14 +108,13 @@ def get_creak_label(example):
|
|
98 |
return mean_creak * 100
|
99 |
|
100 |
|
101 |
-
def load_speaker_labels(example
|
102 |
audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :]
|
103 |
num_samples = torch.tensor([audio_data.shape[-1]])
|
104 |
|
105 |
if torch.cuda.is_available():
|
106 |
audio_data = audio_data.cuda()
|
107 |
num_samples = num_samples.cuda()
|
108 |
-
providers = ["CPUExecutionProvider"]
|
109 |
|
110 |
with torch.no_grad():
|
111 |
features, seq_len = hubert_model(
|
@@ -116,9 +125,7 @@ def load_speaker_labels(example, reg_stor_dir=Path('./models/pvq_extractor/')):
|
|
116 |
features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)
|
117 |
pvqd_predictions = {}
|
118 |
for pvq in pvq_labels:
|
119 |
-
|
120 |
-
onnx = fid.read()
|
121 |
-
sess = InferenceSession(onnx, providers=providers)
|
122 |
pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
|
123 |
pvqd_predictions[pvq] = pred.tolist()[0]
|
124 |
|
@@ -149,6 +156,15 @@ def load_audio_files(example):
|
|
149 |
return example
|
150 |
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt):
|
153 |
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated
|
154 |
|
@@ -163,25 +179,28 @@ def update_manipulation(manipulation_idx, example_id, transcription, manipulatio
|
|
163 |
}
|
164 |
|
165 |
if cached_example_id != example_id:
|
|
|
166 |
cached_loaded_example = load_audio_files(example)
|
167 |
cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth")
|
168 |
cached_labels = load_speaker_labels(example)
|
169 |
cached_example_id = example_id
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
185 |
return (24_000, cached_unmanipulated), (24_000, wav_manipulated)
|
186 |
|
187 |
|
|
|
52 |
# storage_dir= # target storage dir hubert model
|
53 |
)
|
54 |
|
55 |
+
# load pvq models
|
56 |
+
reg_stor_dir = Path('./models/pvq_extractor/')
|
57 |
+
onnx_sessions = {}
|
58 |
+
for pvq in pvq_labels:
|
59 |
+
onnx_path = reg_stor_dir / f"{pvq}.onnx"
|
60 |
+
onnx_sessions[pvq] = InferenceSession(
|
61 |
+
str(onnx_path),
|
62 |
+
providers=["CPUExecutionProvider"]
|
63 |
+
)
|
64 |
+
|
65 |
|
66 |
def get_manipulation(
|
67 |
example,
|
|
|
108 |
return mean_creak * 100
|
109 |
|
110 |
|
111 |
+
def load_speaker_labels(example):
|
112 |
audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :]
|
113 |
num_samples = torch.tensor([audio_data.shape[-1]])
|
114 |
|
115 |
if torch.cuda.is_available():
|
116 |
audio_data = audio_data.cuda()
|
117 |
num_samples = num_samples.cuda()
|
|
|
118 |
|
119 |
with torch.no_grad():
|
120 |
features, seq_len = hubert_model(
|
|
|
125 |
features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)
|
126 |
pvqd_predictions = {}
|
127 |
for pvq in pvq_labels:
|
128 |
+
sess = onnx_sessions[pvq]
|
|
|
|
|
129 |
pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
|
130 |
pvqd_predictions[pvq] = pred.tolist()[0]
|
131 |
|
|
|
156 |
return example
|
157 |
|
158 |
|
159 |
+
def delete_cache():
|
160 |
+
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, cached_unmanipulated
|
161 |
+
del cached_example_id
|
162 |
+
del cached_loaded_example
|
163 |
+
del cached_labels
|
164 |
+
del cached_d_vector
|
165 |
+
del cached_unmanipulated
|
166 |
+
|
167 |
+
|
168 |
def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt):
|
169 |
global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated
|
170 |
|
|
|
179 |
}
|
180 |
|
181 |
if cached_example_id != example_id:
|
182 |
+
delete_cache()
|
183 |
cached_loaded_example = load_audio_files(example)
|
184 |
cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth")
|
185 |
cached_labels = load_speaker_labels(example)
|
186 |
cached_example_id = example_id
|
187 |
+
with torch.no_grad():
|
188 |
+
cached_unmanipulated = tts_model.synthesize_from_example({
|
189 |
+
'text': transcription,
|
190 |
+
'd_vector': cached_d_vector.detach().numpy(),
|
191 |
+
})
|
192 |
+
|
193 |
+
with torch.no_grad():
|
194 |
+
wav_manipulated = get_manipulation(
|
195 |
+
example=example,
|
196 |
+
d_vector=cached_d_vector,
|
197 |
+
labels=cached_labels[None, :],
|
198 |
+
flow=normalizing_flow,
|
199 |
+
tts_model=tts_model,
|
200 |
+
manipulation_idx=manipulation_idx,
|
201 |
+
manipulation_fkt=manipulation_fkt,
|
202 |
+
config_norm_flow=config_norm_flow,
|
203 |
+
)
|
204 |
return (24_000, cached_unmanipulated), (24_000, wav_manipulated)
|
205 |
|
206 |
|