FrederikRautenberg commited on
Commit
7efb86f
·
1 Parent(s): 091b3ba

debug threads

Browse files
Files changed (1) hide show
  1. app.py +39 -20
app.py CHANGED
@@ -52,6 +52,16 @@ hubert_model = HubertExtractor(
52
  # storage_dir= # target storage dir hubert model
53
  )
54
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def get_manipulation(
57
  example,
@@ -98,14 +108,13 @@ def get_creak_label(example):
98
  return mean_creak * 100
99
 
100
 
101
- def load_speaker_labels(example, reg_stor_dir=Path('./models/pvq_extractor/')):
102
  audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :]
103
  num_samples = torch.tensor([audio_data.shape[-1]])
104
 
105
  if torch.cuda.is_available():
106
  audio_data = audio_data.cuda()
107
  num_samples = num_samples.cuda()
108
- providers = ["CPUExecutionProvider"]
109
 
110
  with torch.no_grad():
111
  features, seq_len = hubert_model(
@@ -116,9 +125,7 @@ def load_speaker_labels(example, reg_stor_dir=Path('./models/pvq_extractor/')):
116
  features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)
117
  pvqd_predictions = {}
118
  for pvq in pvq_labels:
119
- with open(reg_stor_dir / f"{pvq}.onnx", "rb") as fid:
120
- onnx = fid.read()
121
- sess = InferenceSession(onnx, providers=providers)
122
  pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
123
  pvqd_predictions[pvq] = pred.tolist()[0]
124
 
@@ -149,6 +156,15 @@ def load_audio_files(example):
149
  return example
150
 
151
 
 
 
 
 
 
 
 
 
 
152
  def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt):
153
  global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated
154
 
@@ -163,25 +179,28 @@ def update_manipulation(manipulation_idx, example_id, transcription, manipulatio
163
  }
164
 
165
  if cached_example_id != example_id:
 
166
  cached_loaded_example = load_audio_files(example)
167
  cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth")
168
  cached_labels = load_speaker_labels(example)
169
  cached_example_id = example_id
170
- cached_unmanipulated = tts_model.synthesize_from_example({
171
- 'text': transcription,
172
- 'd_vector': cached_d_vector.detach().numpy(),
173
- })
174
-
175
- wav_manipulated = get_manipulation(
176
- example=example,
177
- d_vector=cached_d_vector,
178
- labels=cached_labels[None, :],
179
- flow=normalizing_flow,
180
- tts_model=tts_model,
181
- manipulation_idx=manipulation_idx,
182
- manipulation_fkt=manipulation_fkt,
183
- config_norm_flow=config_norm_flow,
184
- )
 
 
185
  return (24_000, cached_unmanipulated), (24_000, wav_manipulated)
186
 
187
 
 
52
  # storage_dir= # target storage dir hubert model
53
  )
54
 
55
+ # load pvq models
56
+ reg_stor_dir = Path('./models/pvq_extractor/')
57
+ onnx_sessions = {}
58
+ for pvq in pvq_labels:
59
+ onnx_path = reg_stor_dir / f"{pvq}.onnx"
60
+ onnx_sessions[pvq] = InferenceSession(
61
+ str(onnx_path),
62
+ providers=["CPUExecutionProvider"]
63
+ )
64
+
65
 
66
  def get_manipulation(
67
  example,
 
108
  return mean_creak * 100
109
 
110
 
111
+ def load_speaker_labels(example):
112
  audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None, :]
113
  num_samples = torch.tensor([audio_data.shape[-1]])
114
 
115
  if torch.cuda.is_available():
116
  audio_data = audio_data.cuda()
117
  num_samples = num_samples.cuda()
 
118
 
119
  with torch.no_grad():
120
  features, seq_len = hubert_model(
 
125
  features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)
126
  pvqd_predictions = {}
127
  for pvq in pvq_labels:
128
+ sess = onnx_sessions[pvq]
 
 
129
  pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
130
  pvqd_predictions[pvq] = pred.tolist()[0]
131
 
 
156
  return example
157
 
158
 
159
+ def delete_cache():
160
+ global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, cached_unmanipulated
161
+ del cached_example_id
162
+ del cached_loaded_example
163
+ del cached_labels
164
+ del cached_d_vector
165
+ del cached_unmanipulated
166
+
167
+
168
  def update_manipulation(manipulation_idx, example_id, transcription, manipulation_fkt):
169
  global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database, cached_unmanipulated
170
 
 
179
  }
180
 
181
  if cached_example_id != example_id:
182
+ delete_cache()
183
  cached_loaded_example = load_audio_files(example)
184
  cached_d_vector = torch.load(f"./Dataset/Embeddings/{speaker_id}/{example_id}.pth")
185
  cached_labels = load_speaker_labels(example)
186
  cached_example_id = example_id
187
+ with torch.no_grad():
188
+ cached_unmanipulated = tts_model.synthesize_from_example({
189
+ 'text': transcription,
190
+ 'd_vector': cached_d_vector.detach().numpy(),
191
+ })
192
+
193
+ with torch.no_grad():
194
+ wav_manipulated = get_manipulation(
195
+ example=example,
196
+ d_vector=cached_d_vector,
197
+ labels=cached_labels[None, :],
198
+ flow=normalizing_flow,
199
+ tts_model=tts_model,
200
+ manipulation_idx=manipulation_idx,
201
+ manipulation_fkt=manipulation_fkt,
202
+ config_norm_flow=config_norm_flow,
203
+ )
204
  return (24_000, cached_unmanipulated), (24_000, wav_manipulated)
205
 
206