File size: 9,169 Bytes
1e82508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import tempfile
import gradio as gr
import torch
import torchaudio
import spaces
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_audio
import numpy as np
import uvicorn
from typing import Optional
import uuid
from pydub import AudioSegment

# Create output directory if it doesn't exist
os.makedirs("outputs", exist_ok=True)

# Check for CUDA availability (this will show CPU due to Zero-GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initial device check: {device}")

# Create a tensor to verify Zero-GPU is working
zero = torch.Tensor([0])
if torch.cuda.is_available():
    zero = zero.cuda()
    print(f"Zero tensor device: {zero.device}")

# Initialize FastAPI
app = FastAPI(title="Tortoise TTS API")

# Initialize TTS (will be loaded on demand with Zero-GPU)
tts = None

# Available preset voice options
PRESET_VOICES = ["random", "angie", "daniel", "deniro", "emma", "freeman", 
                "geralt", "halle", "jlaw", "lj", "mol", "myself", "pat", 
                "snakes", "tim_reynolds", "tom", "train_atkins", "train_daws", 
                "train_dotrice", "train_dreams", "train_empire", "train_grace", 
                "train_kennard", "train_lescault", "train_mouse", "weaver", "william"]

def process_audio_file(audio_file_path):
    """Process uploaded audio file to ensure it meets Tortoise requirements"""
    # Load audio file
    audio = AudioSegment.from_file(audio_file_path)
    
    # Convert to WAV format if it's not already
    if not audio_file_path.lower().endswith('.wav'):
        temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
        audio.export(temp_wav.name, format="wav")
        audio_file_path = temp_wav.name
    
    # Resample to 22.05kHz which is what Tortoise expects
    y, sr = torchaudio.load(audio_file_path)
    if sr != 22050:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=22050)
        y = resampler(y)
        temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
        torchaudio.save(temp_file.name, y, 22050)
        audio_file_path = temp_file.name
    
    return audio_file_path

@spaces.GPU
def generate_tts_with_voice(text, voice_sample_path=None, preset_voice=None):
    """Generate TTS audio using Tortoise with either a custom voice or preset"""
    global tts
    
    try:
        # Now that we're inside the @spaces.GPU decorated function, CUDA should be available
        print(f"GPU function device: {zero.device}")
        
        # Initialize TTS model if not already initialized
        if tts is None:
            tts = TextToSpeech(use_deepspeed=True if torch.cuda.is_available() else False)
            print("TTS model initialized")
        
        voice_samples = None
        
        if voice_sample_path:
            # Process the voice sample
            voice_sample_path = process_audio_file(voice_sample_path)
            voice_samples, _ = load_audio(voice_sample_path, 22050)
            voice_samples = [voice_samples]
            preset_voice = None
        elif preset_voice and preset_voice != "random":
            voice_samples = None
        else:  # random voice
            voice_samples = None
            preset_voice = "random"
        
        # Generate the speech
        output_id = str(uuid.uuid4())[:8]
        output_path = f"outputs/tts_output_{output_id}.wav"
        
        gen = tts.tts_with_preset(
            text,
            voice_samples=voice_samples,
            preset=preset_voice
        )
        
        # Save the generated audio
        torchaudio.save(output_path, gen.squeeze(0).cpu(), 24000)
        
        return output_path, "Success: TTS generation completed."
    except Exception as e:
        return None, f"Error: {str(e)}"

@spaces.GPU
def tts_interface(text, audio_file, preset_voice, record_audio):
    """Interface function for Gradio with GPU acceleration"""
    print(f"Processing with device: {zero.device}")
    
    voice_sample_path = None
    
    # Determine which voice input to use
    if record_audio is not None:
        # Use recorded audio
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
        temp_file.close()
        record_audio = (record_audio[0], 22050)  # Ensure sample rate is 22050
        torchaudio.save(temp_file.name, torch.tensor(record_audio[0]).unsqueeze(0), record_audio[1])
        voice_sample_path = temp_file.name
    elif audio_file is not None:
        # Use uploaded audio file
        voice_sample_path = audio_file
    
    # If no custom voice is provided, use the preset
    if voice_sample_path is None and preset_voice == "":
        preset_voice = "random"
    
    # Generate TTS
    output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
    
    if output_path:
        return output_path, message
    else:
        return None, message

# FastAPI endpoints
@app.post("/api/tts_with_voice_file/")
@spaces.GPU
async def tts_with_voice_file(
    text: str = Form(...),
    voice_file: Optional[UploadFile] = File(None),
    preset_voice: Optional[str] = Form("random")
):
    """API endpoint for TTS with an uploaded voice file"""
    try:
        print(f"Processing with device: {zero.device}")
        
        voice_sample_path = None
        if voice_file:
            # Save uploaded file temporarily
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(voice_file.filename)[1])
            temp_file.write(await voice_file.read())
            temp_file.close()
            voice_sample_path = temp_file.name
        
        output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
        
        if output_path:
            return FileResponse(output_path, media_type="audio/wav", filename="tts_output.wav")
        else:
            return {"status": "error", "message": message}
    except Exception as e:
        return {"status": "error", "message": f"Failed to process: {str(e)}"}

@app.post("/api/tts_with_preset/")
@spaces.GPU
async def tts_with_preset(
    text: str = Form(...),
    preset_voice: str = Form("random")
):
    """API endpoint for TTS with a preset voice"""
    try:
        print(f"Processing with device: {zero.device}")
        
        output_path, message = generate_tts_with_voice(text, preset_voice=preset_voice)
        
        if output_path:
            return FileResponse(output_path, media_type="audio/wav", filename="tts_output.wav")
        else:
            return {"status": "error", "message": message}
    except Exception as e:
        return {"status": "error", "message": f"Failed to process: {str(e)}"}

# Create Gradio interface
with gr.Blocks(title="Tortoise TTS with Voice Cloning") as demo:
    gr.Markdown("# Tortoise Text-to-Speech with Voice Cloning")
    gr.Markdown("Enter text and either upload a voice sample, record your voice, or select a preset voice.")
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Text to speak",
                placeholder="Enter the text you want to convert to speech...",
                lines=5
            )
            preset_voice = gr.Dropdown(
                choices=[""] + PRESET_VOICES,
                label="Preset Voice (optional)",
                value=""
            )
            
        with gr.Column():
            gr.Markdown("### Voice Input Options")
            with gr.Tab("Upload Voice"):
                audio_file = gr.Audio(
                    label="Upload Voice Sample (optional)",
                    type="filepath"
                )
            with gr.Tab("Record Voice"):
                record_audio = gr.Audio(
                    label="Record Your Voice (optional)",
                    source="microphone"
                )
    
    generate_button = gr.Button("Generate Speech")
    
    with gr.Row():
        output_audio = gr.Audio(label="Generated Speech")
        output_message = gr.Textbox(label="Status")
    
    generate_button.click(
        fn=tts_interface,
        inputs=[text_input, audio_file, preset_voice, record_audio],
        outputs=[output_audio, output_message]
    )
    
    gr.Markdown("### API Endpoints")
    gr.Markdown("""
    This app also provides API endpoints:
    
    1. **Voice File TTS** - `/api/tts_with_voice_file/`
       - POST request with:
         - `text`: Text to convert to speech (required)
         - `voice_file`: Audio file for voice cloning (optional)
         - `preset_voice`: Name of preset voice (optional, defaults to "random")
    
    2. **Preset Voice TTS** - `/api/tts_with_preset/`
       - POST request with:
         - `text`: Text to convert to speech (required)
         - `preset_voice`: Name of preset voice (required)
    
    Both endpoints return a WAV file with the generated speech.
    """)

# Mount the Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)