Tomtom84 commited on
Commit
4189fe1
·
verified ·
1 Parent(s): 66f2c2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import spaces
2
  from snac import SNAC
3
  import torch
 
4
  import gradio as gr
5
  import os
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -235,7 +236,7 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
235
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
236
  outputs=audio_output,
237
  fn=generate_speech,
238
- cache_examples=True,
239
  )
240
 
241
  # Set up event handlers
@@ -251,6 +252,34 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
251
  outputs=[text_input, audio_output]
252
  )
253
 
254
- # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  if __name__ == "__main__":
256
- demo.queue().launch(share=False, ssr_mode=False)
 
1
  import spaces
2
  from snac import SNAC
3
  import torch
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
5
  import gradio as gr
6
  import os
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
236
  inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
237
  outputs=audio_output,
238
  fn=generate_speech,
239
+ cache_examples=False,
240
  )
241
 
242
  # Set up event handlers
 
252
  outputs=[text_input, audio_output]
253
  )
254
 
255
+ # Create FastAPI app and mount Gradio
256
+ app = FastAPI()
257
+ app.mount("/", demo)
258
+
259
+ # WebSocket TTS endpoint\@app.websocket("/ws/tts")
260
+ async def websocket_tts(websocket: WebSocket):
261
+ await websocket.accept()
262
+ try:
263
+ while True:
264
+ msg = await websocket.receive_text()
265
+ data = json.loads(msg)
266
+ text = data.get("text", "")
267
+ voice = data.get("voice", VOICES[0])
268
+ # Generate audio for the chunk
269
+ _, audio = generate_speech(text, voice, 0.7, 0.95, 1.1, 1200)
270
+ # Stream audio in 0.1s chunks
271
+ chunk_size = 2400 # 24000 Hz -> 2400 samples = 0.1s
272
+ for i in range(0, len(audio), chunk_size):
273
+ chunk = audio[i:i+chunk_size]
274
+ await websocket.send_bytes(chunk.tobytes())
275
+ await websocket.send_text("__END__")
276
+ except WebSocketDisconnect:
277
+ print("Client disconnected from /ws/tts")
278
+
279
+ # Launch if run directly
280
+ def main():
281
+ import uvicorn
282
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)
283
+
284
  if __name__ == "__main__":
285
+ main()