Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import spaces
|
2 |
from snac import SNAC
|
3 |
import torch
|
|
|
4 |
import gradio as gr
|
5 |
import os
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
@@ -235,7 +236,7 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
|
|
235 |
inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
|
236 |
outputs=audio_output,
|
237 |
fn=generate_speech,
|
238 |
-
cache_examples=
|
239 |
)
|
240 |
|
241 |
# Set up event handlers
|
@@ -251,6 +252,34 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
|
|
251 |
outputs=[text_input, audio_output]
|
252 |
)
|
253 |
|
254 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
if __name__ == "__main__":
|
256 |
-
|
|
|
1 |
import spaces
|
2 |
from snac import SNAC
|
3 |
import torch
|
4 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
5 |
import gradio as gr
|
6 |
import os
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
236 |
inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
|
237 |
outputs=audio_output,
|
238 |
fn=generate_speech,
|
239 |
+
cache_examples=False,
|
240 |
)
|
241 |
|
242 |
# Set up event handlers
|
|
|
252 |
outputs=[text_input, audio_output]
|
253 |
)
|
254 |
|
255 |
+
# Create FastAPI app and mount Gradio
|
256 |
+
app = FastAPI()
|
257 |
+
app.mount("/", demo)
|
258 |
+
|
259 |
+
# WebSocket TTS endpoint\@app.websocket("/ws/tts")
|
260 |
+
async def websocket_tts(websocket: WebSocket):
|
261 |
+
await websocket.accept()
|
262 |
+
try:
|
263 |
+
while True:
|
264 |
+
msg = await websocket.receive_text()
|
265 |
+
data = json.loads(msg)
|
266 |
+
text = data.get("text", "")
|
267 |
+
voice = data.get("voice", VOICES[0])
|
268 |
+
# Generate audio for the chunk
|
269 |
+
_, audio = generate_speech(text, voice, 0.7, 0.95, 1.1, 1200)
|
270 |
+
# Stream audio in 0.1s chunks
|
271 |
+
chunk_size = 2400 # 24000 Hz -> 2400 samples = 0.1s
|
272 |
+
for i in range(0, len(audio), chunk_size):
|
273 |
+
chunk = audio[i:i+chunk_size]
|
274 |
+
await websocket.send_bytes(chunk.tobytes())
|
275 |
+
await websocket.send_text("__END__")
|
276 |
+
except WebSocketDisconnect:
|
277 |
+
print("Client disconnected from /ws/tts")
|
278 |
+
|
279 |
+
# Launch if run directly
|
280 |
+
def main():
|
281 |
+
import uvicorn
|
282 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
283 |
+
|
284 |
if __name__ == "__main__":
|
285 |
+
main()
|