Update app.py
Browse files
app.py
CHANGED
@@ -30,6 +30,8 @@ from f5_tts.infer.utils_infer import (
|
|
30 |
infer_process,
|
31 |
)
|
32 |
|
|
|
|
|
33 |
try:
|
34 |
import spaces
|
35 |
USING_SPACES = True
|
@@ -58,7 +60,10 @@ def load_f5tts(ckpt_path=None):
|
|
58 |
"text_dim": 512,
|
59 |
"conv_layers": 4
|
60 |
}
|
61 |
-
|
|
|
|
|
|
|
62 |
|
63 |
F5TTS_ema_model = load_f5tts()
|
64 |
|
@@ -67,24 +72,33 @@ chat_tokenizer_state = None
|
|
67 |
|
68 |
@gpu_decorator
|
69 |
def generate_response(messages, model, tokenizer):
|
70 |
-
"""Generate a response using the provided model and tokenizer."""
|
71 |
text = tokenizer.apply_chat_template(
|
72 |
messages,
|
73 |
tokenize=False,
|
74 |
add_generation_prompt=True,
|
75 |
)
|
76 |
|
|
|
77 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
if not generated_ids:
|
86 |
raise ValueError("No generated IDs returned by the model.")
|
87 |
|
|
|
88 |
generated_ids = [
|
89 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
90 |
]
|
@@ -92,6 +106,7 @@ def generate_response(messages, model, tokenizer):
|
|
92 |
if not generated_ids or not generated_ids[0]:
|
93 |
raise ValueError("Generated IDs are empty after processing.")
|
94 |
|
|
|
95 |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
96 |
|
97 |
def extract_metadata_and_cover(ebook_path):
|
@@ -218,7 +233,7 @@ def show_converted_audiobooks():
|
|
218 |
return [os.path.join(output_dir, f) for f in files]
|
219 |
|
220 |
@gpu_decorator
|
221 |
-
def infer(ref_audio_orig, ref_text, gen_text, cross_fade_duration=0.
|
222 |
"""Perform inference to generate audio from text."""
|
223 |
try:
|
224 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
@@ -229,17 +244,19 @@ def infer(ref_audio_orig, ref_text, gen_text, cross_fade_duration=0.15, speed=1,
|
|
229 |
raise ValueError("Generated text is empty. Please provide valid text content.")
|
230 |
|
231 |
try:
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
243 |
except Exception as e:
|
244 |
raise RuntimeError(f"Error during inference process: {e}")
|
245 |
|
@@ -284,7 +301,8 @@ def basic_tts(ref_audio_input, ref_text_input, gen_file_input, cross_fade_durati
|
|
284 |
progress(0.8, desc="Stitching audio files")
|
285 |
sample_rate, wave = audio_out
|
286 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
|
287 |
-
|
|
|
288 |
tmp_wav_path = tmp_wav.name
|
289 |
|
290 |
progress(0.9, desc="Converting to MP3")
|
@@ -292,12 +310,21 @@ def basic_tts(ref_audio_input, ref_text_input, gen_file_input, cross_fade_durati
|
|
292 |
tmp_mp3_path = os.path.join("Working_files", "Book", f"{sanitized_title}.mp3")
|
293 |
ensure_directory(os.path.dirname(tmp_mp3_path))
|
294 |
|
|
|
295 |
audio = AudioSegment.from_wav(tmp_wav_path)
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
if cover_image:
|
299 |
embed_cover_into_mp3(tmp_mp3_path, cover_image)
|
300 |
|
|
|
301 |
os.remove(tmp_wav_path)
|
302 |
if cover_image and os.path.exists(cover_image):
|
303 |
os.remove(cover_image)
|
@@ -353,7 +380,7 @@ def create_gradio_app():
|
|
353 |
label="Cross-Fade Duration (Between Generated Audio Chunks)",
|
354 |
minimum=0.0,
|
355 |
maximum=1.0,
|
356 |
-
value=0.
|
357 |
step=0.01,
|
358 |
)
|
359 |
|
@@ -396,7 +423,7 @@ def main(port, host, share, api):
|
|
396 |
app.queue().launch(
|
397 |
server_name="0.0.0.0",
|
398 |
server_port=port or 7860,
|
399 |
-
share=
|
400 |
show_api=api,
|
401 |
debug=True
|
402 |
)
|
|
|
30 |
infer_process,
|
31 |
)
|
32 |
|
33 |
+
import torch # Added missing import
|
34 |
+
|
35 |
try:
|
36 |
import spaces
|
37 |
USING_SPACES = True
|
|
|
60 |
"text_dim": 512,
|
61 |
"conv_layers": 4
|
62 |
}
|
63 |
+
model = load_model(DiT, model_cfg, ckpt_path)
|
64 |
+
model.eval() # Ensure the model is in evaluation mode
|
65 |
+
model.to('cuda') # Move model to GPU
|
66 |
+
return model
|
67 |
|
68 |
F5TTS_ema_model = load_f5tts()
|
69 |
|
|
|
72 |
|
73 |
@gpu_decorator
|
74 |
def generate_response(messages, model, tokenizer):
|
75 |
+
"""Generate a response using the provided model and tokenizer with full precision."""
|
76 |
text = tokenizer.apply_chat_template(
|
77 |
messages,
|
78 |
tokenize=False,
|
79 |
add_generation_prompt=True,
|
80 |
)
|
81 |
|
82 |
+
# Tokenizer and model input preparation
|
83 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
84 |
+
|
85 |
+
# Use full precision for higher audio quality
|
86 |
+
with torch.no_grad():
|
87 |
+
# Ensure full precision by disabling autocast if necessary
|
88 |
+
# Assuming infer_process handles precision internally
|
89 |
+
generated_ids = model.generate(
|
90 |
+
input_ids=model_inputs.input_ids,
|
91 |
+
max_new_tokens=1024,
|
92 |
+
temperature=0.5,
|
93 |
+
top_p=0.9,
|
94 |
+
do_sample=True, # Enable sampling for more natural responses
|
95 |
+
repetition_penalty=1.2, # Prevent repetition
|
96 |
+
)
|
97 |
|
98 |
if not generated_ids:
|
99 |
raise ValueError("No generated IDs returned by the model.")
|
100 |
|
101 |
+
# Post-processing the generated IDs
|
102 |
generated_ids = [
|
103 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
104 |
]
|
|
|
106 |
if not generated_ids or not generated_ids[0]:
|
107 |
raise ValueError("Generated IDs are empty after processing.")
|
108 |
|
109 |
+
# Decode and return the response
|
110 |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
111 |
|
112 |
def extract_metadata_and_cover(ebook_path):
|
|
|
233 |
return [os.path.join(output_dir, f) for f in files]
|
234 |
|
235 |
@gpu_decorator
|
236 |
+
def infer(ref_audio_orig, ref_text, gen_text, cross_fade_duration=0.0, speed=1, show_info=gr.Info, progress=gr.Progress()):
|
237 |
"""Perform inference to generate audio from text."""
|
238 |
try:
|
239 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
|
|
244 |
raise ValueError("Generated text is empty. Please provide valid text content.")
|
245 |
|
246 |
try:
|
247 |
+
# Ensure inference is in full precision
|
248 |
+
with torch.no_grad():
|
249 |
+
final_wave, final_sample_rate, _ = infer_process(
|
250 |
+
ref_audio,
|
251 |
+
ref_text,
|
252 |
+
gen_text,
|
253 |
+
F5TTS_ema_model,
|
254 |
+
vocoder,
|
255 |
+
cross_fade_duration=cross_fade_duration,
|
256 |
+
speed=speed,
|
257 |
+
show_info=show_info,
|
258 |
+
progress=progress, # Pass progress here
|
259 |
+
)
|
260 |
except Exception as e:
|
261 |
raise RuntimeError(f"Error during inference process: {e}")
|
262 |
|
|
|
301 |
progress(0.8, desc="Stitching audio files")
|
302 |
sample_rate, wave = audio_out
|
303 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
|
304 |
+
# Save WAV with higher bit depth and sample rate if possible
|
305 |
+
sf.write(tmp_wav.name, wave, sample_rate, subtype='PCM_24')
|
306 |
tmp_wav_path = tmp_wav.name
|
307 |
|
308 |
progress(0.9, desc="Converting to MP3")
|
|
|
310 |
tmp_mp3_path = os.path.join("Working_files", "Book", f"{sanitized_title}.mp3")
|
311 |
ensure_directory(os.path.dirname(tmp_mp3_path))
|
312 |
|
313 |
+
# Load WAV with Pydub
|
314 |
audio = AudioSegment.from_wav(tmp_wav_path)
|
315 |
+
|
316 |
+
# Export to MP3 with higher bitrate and quality settings
|
317 |
+
audio.export(
|
318 |
+
tmp_mp3_path,
|
319 |
+
format="mp3",
|
320 |
+
bitrate="320k",
|
321 |
+
parameters=["-q:a", "0"] # Highest quality for VBR
|
322 |
+
)
|
323 |
|
324 |
if cover_image:
|
325 |
embed_cover_into_mp3(tmp_mp3_path, cover_image)
|
326 |
|
327 |
+
# Clean up temporary files
|
328 |
os.remove(tmp_wav_path)
|
329 |
if cover_image and os.path.exists(cover_image):
|
330 |
os.remove(cover_image)
|
|
|
380 |
label="Cross-Fade Duration (Between Generated Audio Chunks)",
|
381 |
minimum=0.0,
|
382 |
maximum=1.0,
|
383 |
+
value=0.0,
|
384 |
step=0.01,
|
385 |
)
|
386 |
|
|
|
423 |
app.queue().launch(
|
424 |
server_name="0.0.0.0",
|
425 |
server_port=port or 7860,
|
426 |
+
share=share,
|
427 |
show_api=api,
|
428 |
debug=True
|
429 |
)
|