alexnasa commited on
Commit
ce045a7
·
verified ·
1 Parent(s): bd73c78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -220,6 +220,27 @@ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p,
220
  audio_path, transcribe_state, transcript, smart_transcript,
221
  mode, prompt_end_time, edit_start_time, edit_end_time,
222
  split_text, selected_sentence, previous_audio_tensors):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  if voicecraft_model is None:
224
  raise gr.Error("VoiceCraft model not loaded")
225
  if smart_transcript and (transcribe_state is None):
@@ -617,4 +638,4 @@ if __name__ == "__main__":
617
  MODELS_PATH = args.models_path
618
 
619
  app = get_app()
620
- app.queue().launch(share=args.share, server_port=args.port)
 
220
  audio_path, transcribe_state, transcript, smart_transcript,
221
  mode, prompt_end_time, edit_start_time, edit_end_time,
222
  split_text, selected_sentence, previous_audio_tensors):
223
+
224
+ voicecraft_model_name = "830M_TTSEnhanced"
225
+
226
+ voicecraft_name = f"{voicecraft_model_name}.pth"
227
+ model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
228
+ phn2num = model.args.phn2num
229
+ config = model.args
230
+ model.to(device)
231
+
232
+ encodec_fn = f"{MODELS_PATH}/encodec_4cb2048_giga.th"
233
+ if not os.path.exists(encodec_fn):
234
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th -O " + encodec_fn)
235
+
236
+ voicecraft_model = {
237
+ "config": config,
238
+ "phn2num": phn2num,
239
+ "model": model,
240
+ "text_tokenizer": TextTokenizer(backend="espeak"),
241
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
242
+ }
243
+
244
  if voicecraft_model is None:
245
  raise gr.Error("VoiceCraft model not loaded")
246
  if smart_transcript and (transcribe_state is None):
 
638
  MODELS_PATH = args.models_path
639
 
640
  app = get_app()
641
+ app.queue().launch(share=args.share, server_port=args.port)