File size: 7,633 Bytes
8db92ed f222c6c ec8ba93 f222c6c 579fccc f222c6c 8db92ed c21ab36 f222c6c c21ab36 f222c6c c21ab36 f222c6c c21ab36 8db92ed f222c6c c21ab36 f222c6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import time
import shutil # Added shutil for potentially cleaning old files if needed, though not used in this version
from huggingface_hub import snapshot_download
import streamlit as st
# Imports from your package
# Ensure 'indextts' is correctly installed or available in your environment/requirements.txt
from indextts.infer import IndexTTS
# ------------------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------------------
# Where to store model checkpoints and outputs
# These paths are relative to the root directory of your Spaces repository
CHECKPOINT_DIR = "checkpoints"
OUTPUT_DIR = "outputs"
PROMPTS_DIR = "prompts" # Directory to save uploaded reference audio
# Ensure necessary directories exist. Hugging Face Spaces provides a writable filesystem.
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(PROMPTS_DIR, exist_ok=True)
MODEL_REPO = "IndexTeam/IndexTTS-1.5"
CFG_FILENAME = "config.yaml"
# ------------------------------------------------------------------------------
# Model loading (cached so it only runs once per resource identifier)
# ------------------------------------------------------------------------------
# @st.cache_resource is the recommended way in Streamlit to cache large objects
# like ML models that should be loaded only once.
# This is crucial for efficiency on platforms like Spaces, preventing re-loading
# the model on every user interaction/script re-run.
@st.cache_resource(show_spinner=False)
def load_tts_model():
"""
Downloads the model snapshot and initializes the IndexTTS model.
Cached using st.cache_resource to load only once.
"""
st.write("⏳ Loading model... This may take a moment.")
# Download the model snapshot if not already present
# local_dir_use_symlinks=False is often safer in containerized environments
snapshot_download(
repo_id=MODEL_REPO,
local_dir=CHECKPOINT_DIR,
local_dir_use_symlinks=False,
)
# Initialize the TTS object
# The underlying IndexTTS library should handle using the GPU if available
# and if dependencies (like CUDA-enabled PyTorch/TensorFlow) are installed.
tts = IndexTTS(
model_dir=CHECKPOINT_DIR,
cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME)
)
# Load any normalizer or auxiliary data required by the model
tts.load_normalizer()
st.write("✅ Model loaded!")
return tts
# Load the TTS model using the cached function
# This line is executed on each script run, but the function body only runs
# the first time or if the function signature/dependencies change.
tts = load_tts_model()
# ------------------------------------------------------------------------------
# Inference function
# ------------------------------------------------------------------------------
def run_inference(reference_audio_path: str, text: str) -> str:
"""
Run TTS inference using the uploaded reference audio and the target text.
Returns the path to the generated .wav file.
"""
if not os.path.exists(reference_audio_path):
raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}")
# Generate a unique output filename
timestamp = int(time.time())
output_filename = f"generated_{timestamp}.wav"
output_path = os.path.join(OUTPUT_DIR, output_filename)
# Perform the TTS inference
# The efficiency of this step depends on the IndexTTS library and hardware
tts.infer(reference_audio_path, text, output_path)
# Optional: Clean up old files in output/prompts directories if space is limited
# This can be added if you find directories filling up on Spaces.
# E.g., a function to remove files older than X hours/days.
# For a simple demo, may not be necessary initially.
return output_path
# ------------------------------------------------------------------------------
# Streamlit UI
# ------------------------------------------------------------------------------
st.set_page_config(page_title="IndexTTS Demo", layout="wide")
st.markdown(
"""
<h1 style="text-align: center;">IndexTTS: Zero-Shot Controllable & Efficient TTS</h1>
<p style="text-align: center;">
<a href="https://arxiv.org/abs/2502.05512" target="_blank">
View the paper on arXiv (2502.05512)
</a>
</p>
""",
unsafe_allow_html=True
)
st.sidebar.header("Settings")
with st.sidebar.expander("🗂️ Output Directories"):
st.write(f"- Checkpoints: `{CHECKPOINT_DIR}`")
st.write(f"- Generated audio: `{OUTPUT_DIR}`")
st.write(f"- Uploaded prompts: `{PROMPTS_DIR}`")
st.info("These directories are located within your Space's persistent storage.")
st.header("1. Upload Reference Audio")
ref_audio_file = st.file_uploader(
label="Upload a reference audio (wav or mp3)",
type=["wav", "mp3"],
help="This audio will condition the voice characteristics.",
key="ref_audio_uploader" # Added a key for potential future state management
)
ref_path = None # Initialize ref_path
if ref_audio_file:
# Save the uploaded file to the prompts directory
# Streamlit's uploader provides file-like object
ref_filename = ref_audio_file.name
ref_path = os.path.join(PROMPTS_DIR, ref_filename)
# Use a more robust way to save the file
with open(ref_path, "wb") as f:
# Use getbuffer() for efficiency with large files
f.write(ref_audio_file.getbuffer())
st.success(f"Saved reference audio: `{ref_filename}`")
st.audio(ref_path, format="audio/wav") # Display the uploaded audio
st.header("2. Enter Text to Synthesize")
text_input = st.text_area(
label="Enter the text you want to convert to speech",
placeholder="Type your sentence here...",
key="text_input_area" # Added a key
)
# Button to trigger generation
generate_button = st.button("Generate Speech", key="generate_tts_button")
# ------------------------------------------------------------------------------
# Trigger Inference and Display Results
# ------------------------------------------------------------------------------
# This block runs only when the button is clicked AND inputs are valid
if generate_button:
if not ref_path or not os.path.exists(ref_path):
st.error("Please upload a reference audio first.")
elif not text_input or not text_input.strip():
st.error("Please enter some text to synthesize.")
else:
# Use st.spinner to indicate processing is happening
with st.spinner("🚀 Generating speech..."):
try:
# Call the inference function
output_wav_path = run_inference(ref_path, text_input)
# Check if output file was actually created
if os.path.exists(output_wav_path):
st.success("🎉 Done! Here’s your generated audio:")
# Display the generated audio
st.audio(output_wav_path, format="audio/wav")
else:
st.error("Generation failed: Output file was not created.")
except Exception as e:
st.error(f"An error occurred during inference: {e}")
# Optional: Log the full traceback for debugging on Spaces
# import traceback
# st.exception(e) # This shows traceback in the app
# Add a footer or more info
st.markdown("---")
st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.") |