ECHOAI / webui.py
MPCIRCLE's picture
Update webui.py
f222c6c verified
raw
history blame
7.63 kB
import os
import time
import shutil # Added shutil for potentially cleaning old files if needed, though not used in this version
from huggingface_hub import snapshot_download
import streamlit as st
# Imports from your package
# Ensure 'indextts' is correctly installed or available in your environment/requirements.txt
from indextts.infer import IndexTTS
# ------------------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------------------
# Where to store model checkpoints and outputs
# These paths are relative to the root directory of your Spaces repository
CHECKPOINT_DIR = "checkpoints"
OUTPUT_DIR = "outputs"
PROMPTS_DIR = "prompts" # Directory to save uploaded reference audio
# Ensure necessary directories exist. Hugging Face Spaces provides a writable filesystem.
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(PROMPTS_DIR, exist_ok=True)
MODEL_REPO = "IndexTeam/IndexTTS-1.5"
CFG_FILENAME = "config.yaml"
# ------------------------------------------------------------------------------
# Model loading (cached so it only runs once per resource identifier)
# ------------------------------------------------------------------------------
# @st.cache_resource is the recommended way in Streamlit to cache large objects
# like ML models that should be loaded only once.
# This is crucial for efficiency on platforms like Spaces, preventing re-loading
# the model on every user interaction/script re-run.
@st.cache_resource(show_spinner=False)
def load_tts_model():
"""
Downloads the model snapshot and initializes the IndexTTS model.
Cached using st.cache_resource to load only once.
"""
st.write("⏳ Loading model... This may take a moment.")
# Download the model snapshot if not already present
# local_dir_use_symlinks=False is often safer in containerized environments
snapshot_download(
repo_id=MODEL_REPO,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False,
)
# Initialize the TTS object
# The underlying IndexTTS library should handle using the GPU if available
# and if dependencies (like CUDA-enabled PyTorch/TensorFlow) are installed.
tts = IndexTTS(
model_dir=CHECKPOINT_DIR,
cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME)
)
# Load any normalizer or auxiliary data required by the model
tts.load_normalizer()
st.write("✅ Model loaded!")
return tts
# Load the TTS model using the cached function
# This line is executed on each script run, but the function body only runs
# the first time or if the function signature/dependencies change.
tts = load_tts_model()
# ------------------------------------------------------------------------------
# Inference function
# ------------------------------------------------------------------------------
def run_inference(reference_audio_path: str, text: str) -> str:
"""
Run TTS inference using the uploaded reference audio and the target text.
Returns the path to the generated .wav file.
"""
if not os.path.exists(reference_audio_path):
raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}")
# Generate a unique output filename
timestamp = int(time.time())
output_filename = f"generated_{timestamp}.wav"
output_path = os.path.join(OUTPUT_DIR, output_filename)
# Perform the TTS inference
# The efficiency of this step depends on the IndexTTS library and hardware
tts.infer(reference_audio_path, text, output_path)
# Optional: Clean up old files in output/prompts directories if space is limited
# This can be added if you find directories filling up on Spaces.
# E.g., a function to remove files older than X hours/days.
# For a simple demo, may not be necessary initially.
return output_path
# ------------------------------------------------------------------------------
# Streamlit UI
# ------------------------------------------------------------------------------
st.set_page_config(page_title="IndexTTS Demo", layout="wide")
st.markdown(
"""
<h1 style="text-align: center;">IndexTTS: Zero-Shot Controllable & Efficient TTS</h1>
<p style="text-align: center;">
<a href="https://arxiv.org/abs/2502.05512" target="_blank">
View the paper on arXiv (2502.05512)
</a>
</p>
""",
unsafe_allow_html=True
)
st.sidebar.header("Settings")
with st.sidebar.expander("🗂️ Output Directories"):
st.write(f"- Checkpoints: `{CHECKPOINT_DIR}`")
st.write(f"- Generated audio: `{OUTPUT_DIR}`")
st.write(f"- Uploaded prompts: `{PROMPTS_DIR}`")
st.info("These directories are located within your Space's persistent storage.")
st.header("1. Upload Reference Audio")
ref_audio_file = st.file_uploader(
label="Upload a reference audio (wav or mp3)",
type=["wav", "mp3"],
help="This audio will condition the voice characteristics.",
key="ref_audio_uploader" # Added a key for potential future state management
)
ref_path = None # Initialize ref_path
if ref_audio_file:
# Save the uploaded file to the prompts directory
# Streamlit's uploader provides file-like object
ref_filename = ref_audio_file.name
ref_path = os.path.join(PROMPTS_DIR, ref_filename)
# Use a more robust way to save the file
with open(ref_path, "wb") as f:
# Use getbuffer() for efficiency with large files
f.write(ref_audio_file.getbuffer())
st.success(f"Saved reference audio: `{ref_filename}`")
st.audio(ref_path, format="audio/wav") # Display the uploaded audio
st.header("2. Enter Text to Synthesize")
text_input = st.text_area(
label="Enter the text you want to convert to speech",
placeholder="Type your sentence here...",
key="text_input_area" # Added a key
)
# Button to trigger generation
generate_button = st.button("Generate Speech", key="generate_tts_button")
# ------------------------------------------------------------------------------
# Trigger Inference and Display Results
# ------------------------------------------------------------------------------
# This block runs only when the button is clicked AND inputs are valid
if generate_button:
if not ref_path or not os.path.exists(ref_path):
st.error("Please upload a reference audio first.")
elif not text_input or not text_input.strip():
st.error("Please enter some text to synthesize.")
else:
# Use st.spinner to indicate processing is happening
with st.spinner("🚀 Generating speech..."):
try:
# Call the inference function
output_wav_path = run_inference(ref_path, text_input)
# Check if output file was actually created
if os.path.exists(output_wav_path):
st.success("🎉 Done! Here’s your generated audio:")
# Display the generated audio
st.audio(output_wav_path, format="audio/wav")
else:
st.error("Generation failed: Output file was not created.")
except Exception as e:
st.error(f"An error occurred during inference: {e}")
# Optional: Log the full traceback for debugging on Spaces
# import traceback
# st.exception(e) # This shows traceback in the app
# Add a footer or more info
st.markdown("---")
st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.")