Update webui.py
Browse filesbaxk to gradio ui
webui.py
CHANGED
@@ -1,194 +1,82 @@
|
|
|
|
1 |
import os
|
2 |
-
import
|
|
|
3 |
import time
|
4 |
-
|
|
|
5 |
from huggingface_hub import snapshot_download
|
6 |
-
import streamlit as st
|
7 |
-
|
8 |
-
# ----------------------- Critical Path Configuration --------------------------
|
9 |
-
current_dir = Path(__file__).parent.resolve() # Get absolute path to current file
|
10 |
-
sys.path.insert(0, str(current_dir)) # Add current directory to Python path
|
11 |
-
sys.path.insert(1, str(current_dir / "indextts")) # Add indextts package
|
12 |
-
sys.path.insert(2, str(current_dir.parent)) # Add parent directory for utils
|
13 |
-
|
14 |
-
try:
|
15 |
-
from indextts.infer import IndexTTS
|
16 |
-
except ModuleNotFoundError as e:
|
17 |
-
st.error(f"Module import error: {str(e)}")
|
18 |
-
st.stop()
|
19 |
-
|
20 |
-
# ----------------------- Rest of Your Original Code ---------------------------
|
21 |
-
CHECKPOINT_DIR = "checkpoints"
|
22 |
-
OUTPUT_DIR = "outputs"
|
23 |
-
PROMPTS_DIR = "prompts"
|
24 |
-
|
25 |
-
# Ensure necessary directories exist. Hugging Face Spaces provides a writable filesystem.
|
26 |
-
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
27 |
-
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
28 |
-
os.makedirs(PROMPTS_DIR, exist_ok=True)
|
29 |
-
|
30 |
-
MODEL_REPO = "IndexTeam/IndexTTS-1.5"
|
31 |
-
CFG_FILENAME = "config.yaml"
|
32 |
-
|
33 |
-
# ------------------------------------------------------------------------------
|
34 |
-
# Model loading (cached so it only runs once per resource identifier)
|
35 |
-
# ------------------------------------------------------------------------------
|
36 |
-
|
37 |
-
# @st.cache_resource is the recommended way in Streamlit to cache large objects
|
38 |
-
# like ML models that should be loaded only once.
|
39 |
-
# This is crucial for efficiency on platforms like Spaces, preventing re-loading
|
40 |
-
# the model on every user interaction/script re-run.
|
41 |
-
@st.cache_resource(show_spinner=False)
|
42 |
-
def load_tts_model():
|
43 |
-
"""
|
44 |
-
Downloads the model snapshot and initializes the IndexTTS model.
|
45 |
-
Cached using st.cache_resource to load only once.
|
46 |
-
"""
|
47 |
-
st.write("⏳ Loading model... This may take a moment.")
|
48 |
-
# Download the model snapshot if not already present
|
49 |
-
# local_dir_use_symlinks=False is often safer in containerized environments
|
50 |
-
snapshot_download(
|
51 |
-
repo_id=MODEL_REPO,
|
52 |
-
local_dir=CHECKPOINT_DIR,
|
53 |
-
local_dir_use_symlinks=False,
|
54 |
-
)
|
55 |
-
# Initialize the TTS object
|
56 |
-
# The underlying IndexTTS library should handle using the GPU if available
|
57 |
-
# and if dependencies (like CUDA-enabled PyTorch/TensorFlow) are installed.
|
58 |
-
tts = IndexTTS(
|
59 |
-
model_dir=CHECKPOINT_DIR,
|
60 |
-
cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME)
|
61 |
-
)
|
62 |
-
# Load any normalizer or auxiliary data required by the model
|
63 |
-
tts.load_normalizer()
|
64 |
-
st.write("✅ Model loaded!")
|
65 |
-
return tts
|
66 |
-
|
67 |
-
# Load the TTS model using the cached function
|
68 |
-
# This line is executed on each script run, but the function body only runs
|
69 |
-
# the first time or if the function signature/dependencies change.
|
70 |
-
tts = load_tts_model()
|
71 |
-
|
72 |
-
# ------------------------------------------------------------------------------
|
73 |
-
# Inference function
|
74 |
-
# ------------------------------------------------------------------------------
|
75 |
-
|
76 |
-
def run_inference(reference_audio_path: str, text: str) -> str:
|
77 |
-
"""
|
78 |
-
Run TTS inference using the uploaded reference audio and the target text.
|
79 |
-
Returns the path to the generated .wav file.
|
80 |
-
"""
|
81 |
-
if not os.path.exists(reference_audio_path):
|
82 |
-
raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}")
|
83 |
-
|
84 |
-
# Generate a unique output filename
|
85 |
-
timestamp = int(time.time())
|
86 |
-
output_filename = f"generated_{timestamp}.wav"
|
87 |
-
output_path = os.path.join(OUTPUT_DIR, output_filename)
|
88 |
-
|
89 |
-
# Perform the TTS inference
|
90 |
-
# The efficiency of this step depends on the IndexTTS library and hardware
|
91 |
-
tts.infer(reference_audio_path, text, output_path)
|
92 |
-
|
93 |
-
# Optional: Clean up old files in output/prompts directories if space is limited
|
94 |
-
# This can be added if you find directories filling up on Spaces.
|
95 |
-
# E.g., a function to remove files older than X hours/days.
|
96 |
-
# For a simple demo, may not be necessary initially.
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
return output_path
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
with
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
st.audio(ref_path, format="audio/wav") # Display the uploaded audio
|
149 |
-
|
150 |
-
|
151 |
-
st.header("2. Enter Text to Synthesize")
|
152 |
-
text_input = st.text_area(
|
153 |
-
label="Enter the text you want to convert to speech",
|
154 |
-
placeholder="Type your sentence here...",
|
155 |
-
key="text_input_area" # Added a key
|
156 |
-
)
|
157 |
-
|
158 |
-
# Button to trigger generation
|
159 |
-
generate_button = st.button("Generate Speech", key="generate_tts_button")
|
160 |
-
|
161 |
-
# ------------------------------------------------------------------------------
|
162 |
-
# Trigger Inference and Display Results
|
163 |
-
# ------------------------------------------------------------------------------
|
164 |
-
|
165 |
-
# This block runs only when the button is clicked AND inputs are valid
|
166 |
-
if generate_button:
|
167 |
-
if not ref_path or not os.path.exists(ref_path):
|
168 |
-
st.error("Please upload a reference audio first.")
|
169 |
-
elif not text_input or not text_input.strip():
|
170 |
-
st.error("Please enter some text to synthesize.")
|
171 |
-
else:
|
172 |
-
# Use st.spinner to indicate processing is happening
|
173 |
-
with st.spinner("🚀 Generating speech..."):
|
174 |
-
try:
|
175 |
-
# Call the inference function
|
176 |
-
output_wav_path = run_inference(ref_path, text_input)
|
177 |
-
|
178 |
-
# Check if output file was actually created
|
179 |
-
if os.path.exists(output_wav_path):
|
180 |
-
st.success("🎉 Done! Here’s your generated audio:")
|
181 |
-
# Display the generated audio
|
182 |
-
st.audio(output_wav_path, format="audio/wav")
|
183 |
-
else:
|
184 |
-
st.error("Generation failed: Output file was not created.")
|
185 |
-
|
186 |
-
except Exception as e:
|
187 |
-
st.error(f"An error occurred during inference: {e}")
|
188 |
-
# Optional: Log the full traceback for debugging on Spaces
|
189 |
-
# import traceback
|
190 |
-
# st.exception(e) # This shows traceback in the app
|
191 |
-
|
192 |
-
# Add a footer or more info
|
193 |
-
st.markdown("---")
|
194 |
-
st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.")
|
|
|
1 |
+
import spaces
|
2 |
import os
|
3 |
+
import shutil
|
4 |
+
import threading
|
5 |
import time
|
6 |
+
import sys
|
7 |
+
|
8 |
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
sys.path.append(current_dir)
|
12 |
+
sys.path.append(os.path.join(current_dir, "indextts"))
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
from indextts.infer import IndexTTS
|
16 |
+
from tools.i18n.i18n import I18nAuto
|
17 |
+
|
18 |
+
i18n = I18nAuto(language="zh_CN")
|
19 |
+
MODE = 'local'
|
20 |
+
snapshot_download("IndexTeam/IndexTTS-1.5",local_dir="checkpoints",)
|
21 |
+
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
|
22 |
+
|
23 |
+
os.makedirs("outputs/tasks",exist_ok=True)
|
24 |
+
os.makedirs("prompts",exist_ok=True)
|
25 |
+
|
26 |
+
@spaces.GPU
|
27 |
+
def infer(voice, text,output_path=None):
|
28 |
+
if not tts:
|
29 |
+
raise Exception("Model not loaded")
|
30 |
+
if not output_path:
|
31 |
+
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
|
32 |
+
tts.infer(voice, text, output_path)
|
33 |
return output_path
|
34 |
|
35 |
+
def gen_single(prompt, text):
|
36 |
+
output_path = infer(prompt, text)
|
37 |
+
return gr.update(value=output_path,visible=True)
|
38 |
+
|
39 |
+
def update_prompt_audio():
|
40 |
+
update_button = gr.update(interactive=True)
|
41 |
+
return update_button
|
42 |
+
|
43 |
+
|
44 |
+
with gr.Blocks() as demo:
|
45 |
+
mutex = threading.Lock()
|
46 |
+
gr.HTML('''
|
47 |
+
<h2><center>Echo AI: High-Fidelity, Controllable, and Zero-Shot Text-to-Speech for the Real World</center></h2>
|
48 |
+
|
49 |
+
<p align="center">
|
50 |
+
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
|
51 |
+
|
52 |
+
''')
|
53 |
+
with gr.Tab("audio generation"):
|
54 |
+
with gr.Row():
|
55 |
+
os.makedirs("prompts",exist_ok=True)
|
56 |
+
prompt_audio = gr.Audio(label="Please upload reference audio",key="prompt_audio",
|
57 |
+
sources=["upload","microphone"],type="filepath")
|
58 |
+
prompt_list = os.listdir("prompts")
|
59 |
+
default = ''
|
60 |
+
if prompt_list:
|
61 |
+
default = prompt_list[0]
|
62 |
+
input_text_single = gr.Textbox(label="Please enter target text",key="input_text_single")
|
63 |
+
gen_button = gr.Button("generate speech",key="gen_button",interactive=True)
|
64 |
+
output_audio = gr.Audio(label="Generate results", visible=False,key="output_audio")
|
65 |
+
|
66 |
+
prompt_audio.upload(update_prompt_audio,
|
67 |
+
inputs=[],
|
68 |
+
outputs=[gen_button])
|
69 |
+
|
70 |
+
gen_button.click(gen_single,
|
71 |
+
inputs=[prompt_audio, input_text_single],
|
72 |
+
outputs=[output_audio])
|
73 |
+
|
74 |
+
|
75 |
+
def main():
|
76 |
+
tts.load_normalizer()
|
77 |
+
demo.queue(20)
|
78 |
+
demo.launch(server_name="0.0.0.0")
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
main()
|
82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|