Command_RTC / app.py
RSHVR's picture
Create app.py
1e82508 verified
raw
history blame
9.17 kB
import os
import tempfile
import gradio as gr
import torch
import torchaudio
import spaces
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_audio
import numpy as np
import uvicorn
from typing import Optional
import uuid
from pydub import AudioSegment
# Create output directory if it doesn't exist
os.makedirs("outputs", exist_ok=True)
# Check for CUDA availability (this will show CPU due to Zero-GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initial device check: {device}")
# Create a tensor to verify Zero-GPU is working
zero = torch.Tensor([0])
if torch.cuda.is_available():
zero = zero.cuda()
print(f"Zero tensor device: {zero.device}")
# Initialize FastAPI
app = FastAPI(title="Tortoise TTS API")
# Initialize TTS (will be loaded on demand with Zero-GPU)
tts = None
# Available preset voice options
PRESET_VOICES = ["random", "angie", "daniel", "deniro", "emma", "freeman",
"geralt", "halle", "jlaw", "lj", "mol", "myself", "pat",
"snakes", "tim_reynolds", "tom", "train_atkins", "train_daws",
"train_dotrice", "train_dreams", "train_empire", "train_grace",
"train_kennard", "train_lescault", "train_mouse", "weaver", "william"]
def process_audio_file(audio_file_path):
"""Process uploaded audio file to ensure it meets Tortoise requirements"""
# Load audio file
audio = AudioSegment.from_file(audio_file_path)
# Convert to WAV format if it's not already
if not audio_file_path.lower().endswith('.wav'):
temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
audio.export(temp_wav.name, format="wav")
audio_file_path = temp_wav.name
# Resample to 22.05kHz which is what Tortoise expects
y, sr = torchaudio.load(audio_file_path)
if sr != 22050:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=22050)
y = resampler(y)
temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
torchaudio.save(temp_file.name, y, 22050)
audio_file_path = temp_file.name
return audio_file_path
@spaces.GPU
def generate_tts_with_voice(text, voice_sample_path=None, preset_voice=None):
"""Generate TTS audio using Tortoise with either a custom voice or preset"""
global tts
try:
# Now that we're inside the @spaces.GPU decorated function, CUDA should be available
print(f"GPU function device: {zero.device}")
# Initialize TTS model if not already initialized
if tts is None:
tts = TextToSpeech(use_deepspeed=True if torch.cuda.is_available() else False)
print("TTS model initialized")
voice_samples = None
if voice_sample_path:
# Process the voice sample
voice_sample_path = process_audio_file(voice_sample_path)
voice_samples, _ = load_audio(voice_sample_path, 22050)
voice_samples = [voice_samples]
preset_voice = None
elif preset_voice and preset_voice != "random":
voice_samples = None
else: # random voice
voice_samples = None
preset_voice = "random"
# Generate the speech
output_id = str(uuid.uuid4())[:8]
output_path = f"outputs/tts_output_{output_id}.wav"
gen = tts.tts_with_preset(
text,
voice_samples=voice_samples,
preset=preset_voice
)
# Save the generated audio
torchaudio.save(output_path, gen.squeeze(0).cpu(), 24000)
return output_path, "Success: TTS generation completed."
except Exception as e:
return None, f"Error: {str(e)}"
@spaces.GPU
def tts_interface(text, audio_file, preset_voice, record_audio):
"""Interface function for Gradio with GPU acceleration"""
print(f"Processing with device: {zero.device}")
voice_sample_path = None
# Determine which voice input to use
if record_audio is not None:
# Use recorded audio
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
temp_file.close()
record_audio = (record_audio[0], 22050) # Ensure sample rate is 22050
torchaudio.save(temp_file.name, torch.tensor(record_audio[0]).unsqueeze(0), record_audio[1])
voice_sample_path = temp_file.name
elif audio_file is not None:
# Use uploaded audio file
voice_sample_path = audio_file
# If no custom voice is provided, use the preset
if voice_sample_path is None and preset_voice == "":
preset_voice = "random"
# Generate TTS
output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
if output_path:
return output_path, message
else:
return None, message
# FastAPI endpoints
@app.post("/api/tts_with_voice_file/")
@spaces.GPU
async def tts_with_voice_file(
text: str = Form(...),
voice_file: Optional[UploadFile] = File(None),
preset_voice: Optional[str] = Form("random")
):
"""API endpoint for TTS with an uploaded voice file"""
try:
print(f"Processing with device: {zero.device}")
voice_sample_path = None
if voice_file:
# Save uploaded file temporarily
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(voice_file.filename)[1])
temp_file.write(await voice_file.read())
temp_file.close()
voice_sample_path = temp_file.name
output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
if output_path:
return FileResponse(output_path, media_type="audio/wav", filename="tts_output.wav")
else:
return {"status": "error", "message": message}
except Exception as e:
return {"status": "error", "message": f"Failed to process: {str(e)}"}
@app.post("/api/tts_with_preset/")
@spaces.GPU
async def tts_with_preset(
text: str = Form(...),
preset_voice: str = Form("random")
):
"""API endpoint for TTS with a preset voice"""
try:
print(f"Processing with device: {zero.device}")
output_path, message = generate_tts_with_voice(text, preset_voice=preset_voice)
if output_path:
return FileResponse(output_path, media_type="audio/wav", filename="tts_output.wav")
else:
return {"status": "error", "message": message}
except Exception as e:
return {"status": "error", "message": f"Failed to process: {str(e)}"}
# Create Gradio interface
with gr.Blocks(title="Tortoise TTS with Voice Cloning") as demo:
gr.Markdown("# Tortoise Text-to-Speech with Voice Cloning")
gr.Markdown("Enter text and either upload a voice sample, record your voice, or select a preset voice.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text to speak",
placeholder="Enter the text you want to convert to speech...",
lines=5
)
preset_voice = gr.Dropdown(
choices=[""] + PRESET_VOICES,
label="Preset Voice (optional)",
value=""
)
with gr.Column():
gr.Markdown("### Voice Input Options")
with gr.Tab("Upload Voice"):
audio_file = gr.Audio(
label="Upload Voice Sample (optional)",
type="filepath"
)
with gr.Tab("Record Voice"):
record_audio = gr.Audio(
label="Record Your Voice (optional)",
source="microphone"
)
generate_button = gr.Button("Generate Speech")
with gr.Row():
output_audio = gr.Audio(label="Generated Speech")
output_message = gr.Textbox(label="Status")
generate_button.click(
fn=tts_interface,
inputs=[text_input, audio_file, preset_voice, record_audio],
outputs=[output_audio, output_message]
)
gr.Markdown("### API Endpoints")
gr.Markdown("""
This app also provides API endpoints:
1. **Voice File TTS** - `/api/tts_with_voice_file/`
- POST request with:
- `text`: Text to convert to speech (required)
- `voice_file`: Audio file for voice cloning (optional)
- `preset_voice`: Name of preset voice (optional, defaults to "random")
2. **Preset Voice TTS** - `/api/tts_with_preset/`
- POST request with:
- `text`: Text to convert to speech (required)
- `preset_voice`: Name of preset voice (required)
Both endpoints return a WAV file with the generated speech.
""")
# Mount the Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)