|
import os |
|
import time |
|
import shutil |
|
from huggingface_hub import snapshot_download |
|
import streamlit as st |
|
|
|
|
|
|
|
from indextts.infer import IndexTTS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHECKPOINT_DIR = "checkpoints" |
|
OUTPUT_DIR = "outputs" |
|
PROMPTS_DIR = "prompts" |
|
|
|
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
os.makedirs(PROMPTS_DIR, exist_ok=True) |
|
|
|
MODEL_REPO = "IndexTeam/IndexTTS-1.5" |
|
CFG_FILENAME = "config.yaml" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_tts_model(): |
|
""" |
|
Downloads the model snapshot and initializes the IndexTTS model. |
|
Cached using st.cache_resource to load only once. |
|
""" |
|
st.write("⏳ Loading model... This may take a moment.") |
|
|
|
|
|
snapshot_download( |
|
repo_id=MODEL_REPO, |
|
local_dir=CHECKPOINT_DIR, |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
|
|
|
|
tts = IndexTTS( |
|
model_dir=CHECKPOINT_DIR, |
|
cfg_path=os.path.join(CHECKPOINT_DIR, CFG_FILENAME) |
|
) |
|
|
|
tts.load_normalizer() |
|
st.write("✅ Model loaded!") |
|
return tts |
|
|
|
|
|
|
|
|
|
tts = load_tts_model() |
|
|
|
|
|
|
|
|
|
|
|
def run_inference(reference_audio_path: str, text: str) -> str: |
|
""" |
|
Run TTS inference using the uploaded reference audio and the target text. |
|
Returns the path to the generated .wav file. |
|
""" |
|
if not os.path.exists(reference_audio_path): |
|
raise FileNotFoundError(f"Reference audio not found at {reference_audio_path}") |
|
|
|
|
|
timestamp = int(time.time()) |
|
output_filename = f"generated_{timestamp}.wav" |
|
output_path = os.path.join(OUTPUT_DIR, output_filename) |
|
|
|
|
|
|
|
tts.infer(reference_audio_path, text, output_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="IndexTTS Demo", layout="wide") |
|
|
|
st.markdown( |
|
""" |
|
<h1 style="text-align: center;">IndexTTS: Zero-Shot Controllable & Efficient TTS</h1> |
|
<p style="text-align: center;"> |
|
<a href="https://arxiv.org/abs/2502.05512" target="_blank"> |
|
View the paper on arXiv (2502.05512) |
|
</a> |
|
</p> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
st.sidebar.header("Settings") |
|
with st.sidebar.expander("🗂️ Output Directories"): |
|
st.write(f"- Checkpoints: `{CHECKPOINT_DIR}`") |
|
st.write(f"- Generated audio: `{OUTPUT_DIR}`") |
|
st.write(f"- Uploaded prompts: `{PROMPTS_DIR}`") |
|
st.info("These directories are located within your Space's persistent storage.") |
|
|
|
|
|
st.header("1. Upload Reference Audio") |
|
ref_audio_file = st.file_uploader( |
|
label="Upload a reference audio (wav or mp3)", |
|
type=["wav", "mp3"], |
|
help="This audio will condition the voice characteristics.", |
|
key="ref_audio_uploader" |
|
) |
|
|
|
ref_path = None |
|
|
|
if ref_audio_file: |
|
|
|
|
|
ref_filename = ref_audio_file.name |
|
ref_path = os.path.join(PROMPTS_DIR, ref_filename) |
|
|
|
|
|
with open(ref_path, "wb") as f: |
|
|
|
f.write(ref_audio_file.getbuffer()) |
|
|
|
st.success(f"Saved reference audio: `{ref_filename}`") |
|
st.audio(ref_path, format="audio/wav") |
|
|
|
|
|
st.header("2. Enter Text to Synthesize") |
|
text_input = st.text_area( |
|
label="Enter the text you want to convert to speech", |
|
placeholder="Type your sentence here...", |
|
key="text_input_area" |
|
) |
|
|
|
|
|
generate_button = st.button("Generate Speech", key="generate_tts_button") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if generate_button: |
|
if not ref_path or not os.path.exists(ref_path): |
|
st.error("Please upload a reference audio first.") |
|
elif not text_input or not text_input.strip(): |
|
st.error("Please enter some text to synthesize.") |
|
else: |
|
|
|
with st.spinner("🚀 Generating speech..."): |
|
try: |
|
|
|
output_wav_path = run_inference(ref_path, text_input) |
|
|
|
|
|
if os.path.exists(output_wav_path): |
|
st.success("🎉 Done! Here’s your generated audio:") |
|
|
|
st.audio(output_wav_path, format="audio/wav") |
|
else: |
|
st.error("Generation failed: Output file was not created.") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during inference: {e}") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
st.markdown("Demo powered by [IndexTTS](https://arxiv.org/abs/2502.05512) and built with Streamlit.") |