File size: 4,087 Bytes
c21ab36 8db92ed c21ab36 ec8ba93 579fccc 8db92ed c21ab36 8db92ed c21ab36 33551a3 c21ab36 8db92ed c21ab36 8db92ed c21ab36 229bbd8 c21ab36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
import os
import time
import sys
import torch
from huggingface_hub import snapshot_download
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto
# Initialize internationalization
i18n = I18nAuto(language="en") # Changed to English
# GPU configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# App configuration
st.set_page_config(page_title="echoAI - IndexTTS", layout="wide")
# Create necessary directories
os.makedirs("outputs/tasks", exist_ok=True)
os.makedirs("prompts", exist_ok=True)
# Download checkpoints if not exists
if not os.path.exists("checkpoints"):
snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")
# Load TTS model with GPU support
@st.cache_resource
def load_model():
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
tts.load_normalizer()
if DEVICE == "cuda":
tts.model.to(DEVICE) # Move model to GPU if available
return tts
tts = load_model()
# Inference function with device awareness
def infer(voice_path, text, output_path=None):
if not output_path:
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
# Ensure input is on correct device
tts.infer(voice_path, text, output_path)
return output_path
# Streamlit UI
st.title("echoAI - IndexTTS")
st.markdown("""
<h4 style='text-align: center;'>
An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System
</h4>
<p style='text-align: center;'>
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
</p>
""", unsafe_allow_html=True)
# Device status indicator
st.sidebar.markdown(f"**Device:** {DEVICE.upper()}")
# Main interface
with st.container():
st.header("Audio Generation") # Translated
col1, col2 = st.columns(2)
with col1:
uploaded_audio = st.file_uploader(
"Upload reference audio", # Translated
type=["wav", "mp3", "ogg"],
accept_multiple_files=False
)
input_text = st.text_area(
"Input target text", # Translated
height=150,
placeholder="Enter text to synthesize..."
)
generate_btn = st.button("Generate Speech") # Translated
with col2:
if generate_btn and uploaded_audio and input_text:
with st.spinner("Generating audio..."):
# Save uploaded audio
audio_path = os.path.join("prompts", uploaded_audio.name)
with open(audio_path, "wb") as f:
f.write(uploaded_audio.getbuffer())
# Perform inference
try:
output_path = infer(audio_path, input_text)
st.audio(output_path, format="audio/wav")
st.success("Generation complete!")
# Download button
with open(output_path, "rb") as f:
st.download_button(
"Download Result", # Translated
f,
file_name=os.path.basename(output_path)
except Exception as e:
st.error(f"Error: {str(e)}")
elif generate_btn:
st.warning("Please upload an audio file and enter text first!") # Translated
# Sidebar with additional info
with st.sidebar:
st.header("About echoAI")
st.markdown("""
### Key Features:
- Zero-shot voice cloning
- Industrial-grade TTS
- Efficient synthesis
- Controllable output
""")
st.markdown("---")
st.markdown("""
### Usage Instructions:
1. Upload a reference audio clip
2. Enter target text
3. Click 'Generate Speech'
""")
if __name__ == "__main__":
# Cleanup old files if needed
pass |