jdana commited on
Commit
aaad0fc
·
verified ·
1 Parent(s): a0c5946

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -24
app.py CHANGED
@@ -30,6 +30,8 @@ from f5_tts.infer.utils_infer import (
30
  infer_process,
31
  )
32
 
 
 
33
  try:
34
  import spaces
35
  USING_SPACES = True
@@ -58,7 +60,10 @@ def load_f5tts(ckpt_path=None):
58
  "text_dim": 512,
59
  "conv_layers": 4
60
  }
61
- return load_model(DiT, model_cfg, ckpt_path)
 
 
 
62
 
63
  F5TTS_ema_model = load_f5tts()
64
 
@@ -67,24 +72,33 @@ chat_tokenizer_state = None
67
 
68
  @gpu_decorator
69
  def generate_response(messages, model, tokenizer):
70
- """Generate a response using the provided model and tokenizer."""
71
  text = tokenizer.apply_chat_template(
72
  messages,
73
  tokenize=False,
74
  add_generation_prompt=True,
75
  )
76
 
 
77
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
78
- generated_ids = model.generate(
79
- input_features=model_inputs.input_features,
80
- max_new_tokens=512,
81
- temperature=0.7,
82
- top_p=0.95,
83
- )
 
 
 
 
 
 
 
84
 
85
  if not generated_ids:
86
  raise ValueError("No generated IDs returned by the model.")
87
 
 
88
  generated_ids = [
89
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
90
  ]
@@ -92,6 +106,7 @@ def generate_response(messages, model, tokenizer):
92
  if not generated_ids or not generated_ids[0]:
93
  raise ValueError("Generated IDs are empty after processing.")
94
 
 
95
  return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
96
 
97
  def extract_metadata_and_cover(ebook_path):
@@ -218,7 +233,7 @@ def show_converted_audiobooks():
218
  return [os.path.join(output_dir, f) for f in files]
219
 
220
  @gpu_decorator
221
- def infer(ref_audio_orig, ref_text, gen_text, cross_fade_duration=0.15, speed=1, show_info=gr.Info, progress=gr.Progress()):
222
  """Perform inference to generate audio from text."""
223
  try:
224
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
@@ -229,17 +244,19 @@ def infer(ref_audio_orig, ref_text, gen_text, cross_fade_duration=0.15, speed=1,
229
  raise ValueError("Generated text is empty. Please provide valid text content.")
230
 
231
  try:
232
- final_wave, final_sample_rate, _ = infer_process(
233
- ref_audio,
234
- ref_text,
235
- gen_text,
236
- F5TTS_ema_model,
237
- vocoder,
238
- cross_fade_duration=cross_fade_duration,
239
- speed=speed,
240
- show_info=show_info,
241
- progress=progress, # Pass progress here
242
- )
 
 
243
  except Exception as e:
244
  raise RuntimeError(f"Error during inference process: {e}")
245
 
@@ -284,7 +301,8 @@ def basic_tts(ref_audio_input, ref_text_input, gen_file_input, cross_fade_durati
284
  progress(0.8, desc="Stitching audio files")
285
  sample_rate, wave = audio_out
286
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
287
- sf.write(tmp_wav.name, wave, sample_rate)
 
288
  tmp_wav_path = tmp_wav.name
289
 
290
  progress(0.9, desc="Converting to MP3")
@@ -292,12 +310,21 @@ def basic_tts(ref_audio_input, ref_text_input, gen_file_input, cross_fade_durati
292
  tmp_mp3_path = os.path.join("Working_files", "Book", f"{sanitized_title}.mp3")
293
  ensure_directory(os.path.dirname(tmp_mp3_path))
294
 
 
295
  audio = AudioSegment.from_wav(tmp_wav_path)
296
- audio.export(tmp_mp3_path, format="mp3", bitrate="256k")
 
 
 
 
 
 
 
297
 
298
  if cover_image:
299
  embed_cover_into_mp3(tmp_mp3_path, cover_image)
300
 
 
301
  os.remove(tmp_wav_path)
302
  if cover_image and os.path.exists(cover_image):
303
  os.remove(cover_image)
@@ -353,7 +380,7 @@ def create_gradio_app():
353
  label="Cross-Fade Duration (Between Generated Audio Chunks)",
354
  minimum=0.0,
355
  maximum=1.0,
356
- value=0.15,
357
  step=0.01,
358
  )
359
 
@@ -396,7 +423,7 @@ def main(port, host, share, api):
396
  app.queue().launch(
397
  server_name="0.0.0.0",
398
  server_port=port or 7860,
399
- share=True,
400
  show_api=api,
401
  debug=True
402
  )
 
30
  infer_process,
31
  )
32
 
33
+ import torch # Added missing import
34
+
35
  try:
36
  import spaces
37
  USING_SPACES = True
 
60
  "text_dim": 512,
61
  "conv_layers": 4
62
  }
63
+ model = load_model(DiT, model_cfg, ckpt_path)
64
+ model.eval() # Ensure the model is in evaluation mode
65
+ model.to('cuda') # Move model to GPU
66
+ return model
67
 
68
  F5TTS_ema_model = load_f5tts()
69
 
 
72
 
73
  @gpu_decorator
74
  def generate_response(messages, model, tokenizer):
75
+ """Generate a response using the provided model and tokenizer with full precision."""
76
  text = tokenizer.apply_chat_template(
77
  messages,
78
  tokenize=False,
79
  add_generation_prompt=True,
80
  )
81
 
82
+ # Tokenizer and model input preparation
83
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
84
+
85
+ # Use full precision for higher audio quality
86
+ with torch.no_grad():
87
+ # Ensure full precision by disabling autocast if necessary
88
+ # Assuming infer_process handles precision internally
89
+ generated_ids = model.generate(
90
+ input_ids=model_inputs.input_ids,
91
+ max_new_tokens=1024,
92
+ temperature=0.5,
93
+ top_p=0.9,
94
+ do_sample=True, # Enable sampling for more natural responses
95
+ repetition_penalty=1.2, # Prevent repetition
96
+ )
97
 
98
  if not generated_ids:
99
  raise ValueError("No generated IDs returned by the model.")
100
 
101
+ # Post-processing the generated IDs
102
  generated_ids = [
103
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
104
  ]
 
106
  if not generated_ids or not generated_ids[0]:
107
  raise ValueError("Generated IDs are empty after processing.")
108
 
109
+ # Decode and return the response
110
  return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
111
 
112
  def extract_metadata_and_cover(ebook_path):
 
233
  return [os.path.join(output_dir, f) for f in files]
234
 
235
  @gpu_decorator
236
+ def infer(ref_audio_orig, ref_text, gen_text, cross_fade_duration=0.0, speed=1, show_info=gr.Info, progress=gr.Progress()):
237
  """Perform inference to generate audio from text."""
238
  try:
239
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
 
244
  raise ValueError("Generated text is empty. Please provide valid text content.")
245
 
246
  try:
247
+ # Ensure inference is in full precision
248
+ with torch.no_grad():
249
+ final_wave, final_sample_rate, _ = infer_process(
250
+ ref_audio,
251
+ ref_text,
252
+ gen_text,
253
+ F5TTS_ema_model,
254
+ vocoder,
255
+ cross_fade_duration=cross_fade_duration,
256
+ speed=speed,
257
+ show_info=show_info,
258
+ progress=progress, # Pass progress here
259
+ )
260
  except Exception as e:
261
  raise RuntimeError(f"Error during inference process: {e}")
262
 
 
301
  progress(0.8, desc="Stitching audio files")
302
  sample_rate, wave = audio_out
303
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
304
+ # Save WAV with higher bit depth and sample rate if possible
305
+ sf.write(tmp_wav.name, wave, sample_rate, subtype='PCM_24')
306
  tmp_wav_path = tmp_wav.name
307
 
308
  progress(0.9, desc="Converting to MP3")
 
310
  tmp_mp3_path = os.path.join("Working_files", "Book", f"{sanitized_title}.mp3")
311
  ensure_directory(os.path.dirname(tmp_mp3_path))
312
 
313
+ # Load WAV with Pydub
314
  audio = AudioSegment.from_wav(tmp_wav_path)
315
+
316
+ # Export to MP3 with higher bitrate and quality settings
317
+ audio.export(
318
+ tmp_mp3_path,
319
+ format="mp3",
320
+ bitrate="320k",
321
+ parameters=["-q:a", "0"] # Highest quality for VBR
322
+ )
323
 
324
  if cover_image:
325
  embed_cover_into_mp3(tmp_mp3_path, cover_image)
326
 
327
+ # Clean up temporary files
328
  os.remove(tmp_wav_path)
329
  if cover_image and os.path.exists(cover_image):
330
  os.remove(cover_image)
 
380
  label="Cross-Fade Duration (Between Generated Audio Chunks)",
381
  minimum=0.0,
382
  maximum=1.0,
383
+ value=0.0,
384
  step=0.01,
385
  )
386
 
 
423
  app.queue().launch(
424
  server_name="0.0.0.0",
425
  server_port=port or 7860,
426
+ share=share,
427
  show_api=api,
428
  debug=True
429
  )