MPCIRCLE commited on
Commit
252addc
·
verified ·
1 Parent(s): b2ef099

Update webui.py

Browse files

baxk to gradio ui

Files changed (1) hide show
  1. webui.py +76 -188
webui.py CHANGED
@@ -1,194 +1,82 @@
 
1
  import os
2
- import sys
 
3
  import time
4
- from pathlib import Path
 
5
  from huggingface_hub import snapshot_download
6
- import streamlit as st
7
-
8
- # ----------------------- Critical Path Configuration --------------------------
9
- current_dir = Path(__file__).parent.resolve() # Get absolute path to current file
10
- sys.path.insert(0, str(current_dir)) # Add current directory to Python path
11
- sys.path.insert(1, str(current_dir / "indextts")) # Add indextts package
12
- sys.path.insert(2, str(current_dir.parent)) # Add parent directory for utils
13
-
14
- try:
15
- from indextts.infer import IndexTTS
16
- except ModuleNotFoundError as e:
17
- st.error(f"Module import error: {str(e)}")
18
- st.stop()
19
-
20
- # ----------------------- Rest of Your Original Code ---------------------------
21
- CHECKPOINT_DIR = "checkpoints"
22
- OUTPUT_DIR = "outputs"
23
- PROMPTS_DIR = "prompts"
24
-
25
- # Ensure necessary directories exist. Hugging Face Spaces provides a writable filesystem.
26
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
27
- os.makedirs(OUTPUT_DIR, exist_ok=True)
28
- os.makedirs(PROMPTS_DIR, exist_ok=True)
29
-
30
- MODEL_REPO = "IndexTeam/IndexTTS-1.5"
31
- CFG_FILENAME = "config.yaml"
32
-
33
- # ------------------------------------------------------------------------------
34
- # Model loading (cached so it only runs once per resource identifier)
35
- # ------------------------------------------------------------------------------
36
-
37
- # @st.cache_resource is the recommended way in Streamlit to cache large objects
38
- # like ML models that should be loaded only once.
39
- # This is crucial for efficiency on platforms like Spaces, preventing re-loading
40
- # the model on every user interaction/script re-run.
41
- @st.cache_resource(show_spinner=False)
42
- def load_tts_model():
43
- """
44
- Downloads the model snapshot and initializes the IndexTTS model.
45
- Cached using st.cache_resource to load only once.
46
- """
47
- st.write("⏳ Loading model... This may take a moment.")
48
- # Download the model snapshot if not already present
49
- # local_dir_use_symlinks=False is often safer in containerized environments
50
- snapshot_download(
51
- repo_id=MODEL_REPO,
52
- local_dir=CHECKPOINT_DIR,
53
- local_dir_use_symlinks=False,
54
- )
55
- # Initialize the TTS object
56
- # The underlying IndexTTS library should handle using the GPU if available
57
- # and if dependencies (like CUDA-enabled PyTorch/TensorFlow) are installed.
58
- tts = IndexTTS(
59
- model_dir=CHECKPOINT_DIR,
60
- cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME)
61
- )
62
- # Load any normalizer or auxiliary data required by the model
63
- tts.load_normalizer()
64
- st.write("✅ Model loaded!")
65
- return tts
66
-
67
- # Load the TTS model using the cached function
68
- # This line is executed on each script run, but the function body only runs
69
- # the first time or if the function signature/dependencies change.
70
- tts = load_tts_model()
71
-
72
- # ------------------------------------------------------------------------------
73
- # Inference function
74
- # ------------------------------------------------------------------------------
75
-
76
- def run_inference(reference_audio_path: str, text: str) -> str:
77
- """
78
- Run TTS inference using the uploaded reference audio and the target text.
79
- Returns the path to the generated .wav file.
80
- """
81
- if not os.path.exists(reference_audio_path):
82
- raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}")
83
-
84
- # Generate a unique output filename
85
- timestamp = int(time.time())
86
- output_filename = f"generated_{timestamp}.wav"
87
- output_path = os.path.join(OUTPUT_DIR, output_filename)
88
-
89
- # Perform the TTS inference
90
- # The efficiency of this step depends on the IndexTTS library and hardware
91
- tts.infer(reference_audio_path, text, output_path)
92
-
93
- # Optional: Clean up old files in output/prompts directories if space is limited
94
- # This can be added if you find directories filling up on Spaces.
95
- # E.g., a function to remove files older than X hours/days.
96
- # For a simple demo, may not be necessary initially.
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return output_path
99
 
100
- # ------------------------------------------------------------------------------
101
- # Streamlit UI
102
- # ------------------------------------------------------------------------------
103
-
104
- st.set_page_config(page_title="IndexTTS Demo", layout="wide")
105
-
106
- st.markdown(
107
- """
108
- <h1 style="text-align: center;">IndexTTS: Zero-Shot Controllable & Efficient TTS</h1>
109
- <p style="text-align: center;">
110
- <a href="https://arxiv.org/abs/2502.05512" target="_blank">
111
- View the paper on arXiv (2502.05512)
112
- </a>
113
- </p>
114
- """,
115
- unsafe_allow_html=True
116
- )
117
-
118
- st.sidebar.header("Settings")
119
- with st.sidebar.expander("🗂️ Output Directories"):
120
- st.write(f"- Checkpoints: `{CHECKPOINT_DIR}`")
121
- st.write(f"- Generated audio: `{OUTPUT_DIR}`")
122
- st.write(f"- Uploaded prompts: `{PROMPTS_DIR}`")
123
- st.info("These directories are located within your Space's persistent storage.")
124
-
125
-
126
- st.header("1. Upload Reference Audio")
127
- ref_audio_file = st.file_uploader(
128
- label="Upload a reference audio (wav or mp3)",
129
- type=["wav", "mp3"],
130
- help="This audio will condition the voice characteristics.",
131
- key="ref_audio_uploader" # Added a key for potential future state management
132
- )
133
-
134
- ref_path = None # Initialize ref_path
135
-
136
- if ref_audio_file:
137
- # Save the uploaded file to the prompts directory
138
- # Streamlit's uploader provides file-like object
139
- ref_filename = ref_audio_file.name
140
- ref_path = os.path.join(PROMPTS_DIR, ref_filename)
141
-
142
- # Use a more robust way to save the file
143
- with open(ref_path, "wb") as f:
144
- # Use getbuffer() for efficiency with large files
145
- f.write(ref_audio_file.getbuffer())
146
-
147
- st.success(f"Saved reference audio: `{ref_filename}`")
148
- st.audio(ref_path, format="audio/wav") # Display the uploaded audio
149
-
150
-
151
- st.header("2. Enter Text to Synthesize")
152
- text_input = st.text_area(
153
- label="Enter the text you want to convert to speech",
154
- placeholder="Type your sentence here...",
155
- key="text_input_area" # Added a key
156
- )
157
-
158
- # Button to trigger generation
159
- generate_button = st.button("Generate Speech", key="generate_tts_button")
160
-
161
- # ------------------------------------------------------------------------------
162
- # Trigger Inference and Display Results
163
- # ------------------------------------------------------------------------------
164
-
165
- # This block runs only when the button is clicked AND inputs are valid
166
- if generate_button:
167
- if not ref_path or not os.path.exists(ref_path):
168
- st.error("Please upload a reference audio first.")
169
- elif not text_input or not text_input.strip():
170
- st.error("Please enter some text to synthesize.")
171
- else:
172
- # Use st.spinner to indicate processing is happening
173
- with st.spinner("🚀 Generating speech..."):
174
- try:
175
- # Call the inference function
176
- output_wav_path = run_inference(ref_path, text_input)
177
-
178
- # Check if output file was actually created
179
- if os.path.exists(output_wav_path):
180
- st.success("🎉 Done! Here’s your generated audio:")
181
- # Display the generated audio
182
- st.audio(output_wav_path, format="audio/wav")
183
- else:
184
- st.error("Generation failed: Output file was not created.")
185
-
186
- except Exception as e:
187
- st.error(f"An error occurred during inference: {e}")
188
- # Optional: Log the full traceback for debugging on Spaces
189
- # import traceback
190
- # st.exception(e) # This shows traceback in the app
191
-
192
- # Add a footer or more info
193
- st.markdown("---")
194
- st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.")
 
1
+ import spaces
2
  import os
3
+ import shutil
4
+ import threading
5
  import time
6
+ import sys
7
+
8
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ current_dir = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.append(current_dir)
12
+ sys.path.append(os.path.join(current_dir, "indextts"))
13
+
14
+ import gradio as gr
15
+ from indextts.infer import IndexTTS
16
+ from tools.i18n.i18n import I18nAuto
17
+
18
+ i18n = I18nAuto(language="zh_CN")
19
+ MODE = 'local'
20
+ snapshot_download("IndexTeam/IndexTTS-1.5",local_dir="checkpoints",)
21
+ tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
22
+
23
+ os.makedirs("outputs/tasks",exist_ok=True)
24
+ os.makedirs("prompts",exist_ok=True)
25
+
26
+ @spaces.GPU
27
+ def infer(voice, text,output_path=None):
28
+ if not tts:
29
+ raise Exception("Model not loaded")
30
+ if not output_path:
31
+ output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
32
+ tts.infer(voice, text, output_path)
33
  return output_path
34
 
35
+ def gen_single(prompt, text):
36
+ output_path = infer(prompt, text)
37
+ return gr.update(value=output_path,visible=True)
38
+
39
+ def update_prompt_audio():
40
+ update_button = gr.update(interactive=True)
41
+ return update_button
42
+
43
+
44
+ with gr.Blocks() as demo:
45
+ mutex = threading.Lock()
46
+ gr.HTML('''
47
+ <h2><center>Echo AI: High-Fidelity, Controllable, and Zero-Shot Text-to-Speech for the Real World</center></h2>
48
+
49
+ <p align="center">
50
+ <a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
51
+
52
+ ''')
53
+ with gr.Tab("audio generation"):
54
+ with gr.Row():
55
+ os.makedirs("prompts",exist_ok=True)
56
+ prompt_audio = gr.Audio(label="Please upload reference audio",key="prompt_audio",
57
+ sources=["upload","microphone"],type="filepath")
58
+ prompt_list = os.listdir("prompts")
59
+ default = ''
60
+ if prompt_list:
61
+ default = prompt_list[0]
62
+ input_text_single = gr.Textbox(label="Please enter target text",key="input_text_single")
63
+ gen_button = gr.Button("generate speech",key="gen_button",interactive=True)
64
+ output_audio = gr.Audio(label="Generate results", visible=False,key="output_audio")
65
+
66
+ prompt_audio.upload(update_prompt_audio,
67
+ inputs=[],
68
+ outputs=[gen_button])
69
+
70
+ gen_button.click(gen_single,
71
+ inputs=[prompt_audio, input_text_single],
72
+ outputs=[output_audio])
73
+
74
+
75
+ def main():
76
+ tts.load_normalizer()
77
+ demo.queue(20)
78
+ demo.launch(server_name="0.0.0.0")
79
+
80
+ if __name__ == "__main__":
81
+ main()
82
+