Spaces:
Running
Running
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from shared import RealtimeSpeakerDiarization | |
import numpy as np | |
import uvicorn | |
import logging | |
import asyncio | |
import json | |
import time | |
from typing import Set, Dict, Any | |
import traceback | |
# Check for RealtimeSTT and install if needed | |
try: | |
from RealtimeSTT import AudioToTextRecorder | |
except ImportError: | |
import subprocess | |
import sys | |
print("Installing RealtimeSTT dependency...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "RealtimeSTT"]) | |
from RealtimeSTT import AudioToTextRecorder | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI(title="Real-time Speaker Diarization API", version="1.0.0") | |
# Add CORS middleware for browser compatibility | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global state management | |
diart = None | |
active_connections: Set[WebSocket] = set() | |
connection_stats: Dict[str, Any] = { | |
"total_connections": 0, | |
"current_connections": 0, | |
"last_audio_received": None, | |
"total_audio_chunks": 0 | |
} | |
class ConnectionManager: | |
"""Manages WebSocket connections and broadcasting""" | |
def __init__(self): | |
self.active_connections: Set[WebSocket] = set() | |
self.connection_metadata: Dict[WebSocket, Dict] = {} | |
async def connect(self, websocket: WebSocket, client_id: str = None): | |
"""Add a new WebSocket connection""" | |
await websocket.accept() | |
self.active_connections.add(websocket) | |
self.connection_metadata[websocket] = { | |
"client_id": client_id or f"client_{int(time.time())}", | |
"connected_at": time.time(), | |
"messages_sent": 0 | |
} | |
connection_stats["current_connections"] = len(self.active_connections) | |
connection_stats["total_connections"] += 1 | |
# Start recording if this is the first connection and system is ready | |
if len(self.active_connections) == 1 and diart and not diart.is_running: | |
logger.info("First connection established, starting recording") | |
diart.start_recording() | |
logger.info(f"WebSocket connected: {self.connection_metadata[websocket]['client_id']}. " | |
f"Total connections: {len(self.active_connections)}") | |
def disconnect(self, websocket: WebSocket): | |
"""Remove a WebSocket connection""" | |
if websocket in self.active_connections: | |
client_info = self.connection_metadata.get(websocket, {}) | |
client_id = client_info.get("client_id", "unknown") | |
self.active_connections.discard(websocket) | |
self.connection_metadata.pop(websocket, None) | |
connection_stats["current_connections"] = len(self.active_connections) | |
# If no more connections, stop recording to save resources | |
if len(self.active_connections) == 0 and diart and diart.is_running: | |
logger.info("No active connections, stopping recording") | |
diart.stop_recording() | |
logger.info(f"WebSocket disconnected: {client_id}. " | |
f"Remaining connections: {len(self.active_connections)}") | |
async def broadcast(self, message: str): | |
"""Broadcast message to all active connections""" | |
if not self.active_connections: | |
return | |
disconnected = set() | |
for websocket in self.active_connections.copy(): | |
try: | |
await websocket.send_text(message) | |
if websocket in self.connection_metadata: | |
self.connection_metadata[websocket]["messages_sent"] += 1 | |
except Exception as e: | |
logger.warning(f"Failed to send message to client: {e}") | |
disconnected.add(websocket) | |
# Clean up disconnected clients | |
for ws in disconnected: | |
self.disconnect(ws) | |
def get_stats(self): | |
"""Get connection statistics""" | |
return { | |
"active_connections": len(self.active_connections), | |
"connection_metadata": { | |
ws_id: meta for ws_id, (ws, meta) in | |
enumerate(self.connection_metadata.items()) | |
} | |
} | |
# Initialize connection manager | |
manager = ConnectionManager() | |
async def initialize_diarization_system(): | |
"""Initialize the diarization system with proper error handling""" | |
global diart | |
try: | |
logger.info("Initializing diarization system...") | |
diart = RealtimeSpeakerDiarization() | |
success = diart.initialize_models() | |
if success: | |
logger.info("Models initialized successfully") | |
# Don't start recording yet - wait for an actual connection | |
# diart.start_recording() | |
logger.info("System ready for connections") | |
return True | |
else: | |
logger.error("Failed to initialize models") | |
return False | |
except Exception as e: | |
logger.error(f"Error initializing diarization system: {e}") | |
logger.error(traceback.format_exc()) | |
return False | |
async def send_conversation_updates(): | |
"""Periodically send conversation updates to all connected clients""" | |
update_interval = 0.5 # 500ms update intervals | |
last_conversation_hash = None | |
while True: | |
try: | |
if diart and diart.is_running and manager.active_connections: | |
# Get current conversation | |
conversation_html = diart.get_formatted_conversation() | |
# Only send if conversation has changed (to reduce bandwidth) | |
conversation_hash = hash(conversation_html) | |
if conversation_hash != last_conversation_hash: | |
# Create structured message | |
update_message = json.dumps({ | |
"type": "conversation_update", | |
"timestamp": time.time(), | |
"conversation_html": conversation_html, | |
"status": diart.get_status_info() if hasattr(diart, 'get_status_info') else {} | |
}) | |
await manager.broadcast(update_message) | |
last_conversation_hash = conversation_hash | |
except Exception as e: | |
logger.error(f"Error in conversation update: {e}") | |
await asyncio.sleep(update_interval) | |
async def startup_event(): | |
"""Initialize system on startup""" | |
logger.info("Starting Real-time Speaker Diarization Service") | |
# Initialize diarization system | |
success = await initialize_diarization_system() | |
if not success: | |
logger.error("Failed to initialize diarization system!") | |
# Start background update task | |
asyncio.create_task(send_conversation_updates()) | |
logger.info("Background tasks started") | |
async def shutdown_event(): | |
"""Clean up on shutdown""" | |
logger.info("Shutting down...") | |
if diart: | |
try: | |
diart.stop_recording() | |
logger.info("Recording stopped") | |
# Shutdown RealtimeSTT properly if available | |
if hasattr(diart, 'recorder') and diart.recorder: | |
try: | |
diart.recorder.shutdown() | |
logger.info("Transcription model shut down") | |
except Exception as e: | |
logger.error(f"Error shutting down transcription model: {e}") | |
except Exception as e: | |
logger.error(f"Error stopping recording: {e}") | |
async def root(): | |
"""Root endpoint with service information""" | |
return { | |
"service": "Real-time Speaker Diarization API", | |
"version": "1.0.0", | |
"status": "running" if diart and diart.is_running else "initializing", | |
"endpoints": { | |
"websocket": "/ws_inference", | |
"health": "/health", | |
"conversation": "/conversation", | |
"status": "/status" | |
} | |
} | |
async def health_check(): | |
"""Comprehensive health check endpoint""" | |
system_healthy = diart and diart.is_running | |
return { | |
"status": "healthy" if system_healthy else "unhealthy", | |
"system_running": system_healthy, | |
"active_connections": len(manager.active_connections), | |
"connection_stats": connection_stats, | |
"diarization_status": diart.get_status_info() if diart and hasattr(diart, 'get_status_info') else {} | |
} | |
async def ws_inference(websocket: WebSocket): | |
"""WebSocket endpoint for real-time audio processing""" | |
client_id = f"client_{int(time.time())}" | |
try: | |
await manager.connect(websocket, client_id) | |
# Send initial connection confirmation | |
initial_message = json.dumps({ | |
"type": "connection_established", | |
"client_id": client_id, | |
"system_status": "ready" if diart and diart.is_running else "initializing", | |
"conversation": diart.get_formatted_conversation() if diart else "" | |
}) | |
await websocket.send_text(initial_message) | |
# Process incoming audio data | |
async for data in websocket.iter_bytes(): | |
try: | |
if data and diart and diart.is_running: | |
# Update statistics | |
connection_stats["last_audio_received"] = time.time() | |
connection_stats["total_audio_chunks"] += 1 | |
# Process audio chunk | |
result = diart.process_audio_chunk(data) | |
# Send processing result back to client | |
if result: | |
# Ensure all numeric values are JSON serializable | |
for key in result: | |
if isinstance(result[key], np.number): | |
result[key] = result[key].item() | |
result_message = json.dumps({ | |
"type": "processing_result", | |
"timestamp": time.time(), | |
"data": result | |
}) | |
await websocket.send_text(result_message) | |
# Log processing result (optional) | |
if connection_stats["total_audio_chunks"] % 100 == 0: # Log every 100 chunks | |
logger.debug(f"Processed {connection_stats['total_audio_chunks']} audio chunks") | |
elif not diart: | |
logger.warning("Received audio data but diarization system not initialized") | |
error_message = json.dumps({ | |
"type": "error", | |
"message": "Diarization system not initialized", | |
"timestamp": time.time() | |
}) | |
await websocket.send_text(error_message) | |
except Exception as e: | |
logger.error(f"Error processing audio chunk: {e}") | |
# Send error message to client | |
error_message = json.dumps({ | |
"type": "error", | |
"message": "Error processing audio", | |
"details": str(e), | |
"timestamp": time.time() | |
}) | |
await websocket.send_text(error_message) | |
except WebSocketDisconnect: | |
logger.info(f"WebSocket {client_id} disconnected normally") | |
except Exception as e: | |
logger.error(f"WebSocket {client_id} error: {e}") | |
finally: | |
manager.disconnect(websocket) | |
async def get_conversation(): | |
"""Get the current conversation as HTML""" | |
if not diart: | |
raise HTTPException(status_code=503, detail="Diarization system not initialized") | |
try: | |
conversation = diart.get_formatted_conversation() | |
return { | |
"conversation": conversation, | |
"timestamp": time.time(), | |
"system_status": diart.get_status_info() if hasattr(diart, 'get_status_info') else {} | |
} | |
except Exception as e: | |
logger.error(f"Error getting conversation: {e}") | |
raise HTTPException(status_code=500, detail="Error retrieving conversation") | |
async def get_status(): | |
"""Get comprehensive system status information""" | |
if not diart: | |
return {"status": "system_not_initialized"} | |
try: | |
base_status = diart.get_status_info() if hasattr(diart, 'get_status_info') else {} | |
return { | |
**base_status, | |
"connection_stats": connection_stats, | |
"active_connections": len(manager.active_connections), | |
"system_uptime": time.time() - connection_stats.get("system_start_time", time.time()) | |
} | |
except Exception as e: | |
logger.error(f"Error getting status: {e}") | |
return {"status": "error", "message": str(e)} | |
async def update_settings(threshold: float = None, max_speakers: int = None): | |
"""Update speaker detection settings""" | |
if not diart: | |
raise HTTPException(status_code=503, detail="Diarization system not initialized") | |
try: | |
# Validate parameters | |
if threshold is not None and (threshold < 0 or threshold > 1): | |
raise HTTPException(status_code=400, detail="Threshold must be between 0 and 1") | |
if max_speakers is not None and (max_speakers < 1 or max_speakers > 20): | |
raise HTTPException(status_code=400, detail="Max speakers must be between 1 and 20") | |
result = diart.update_settings(threshold, max_speakers) | |
return { | |
"result": result, | |
"updated_settings": { | |
"threshold": threshold, | |
"max_speakers": max_speakers | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error updating settings: {e}") | |
raise HTTPException(status_code=500, detail="Error updating settings") | |
async def clear_conversation(): | |
"""Clear the conversation history""" | |
if not diart: | |
raise HTTPException(status_code=503, detail="Diarization system not initialized") | |
try: | |
result = diart.clear_conversation() | |
# Notify all connected clients about the clear | |
clear_message = json.dumps({ | |
"type": "conversation_cleared", | |
"timestamp": time.time() | |
}) | |
await manager.broadcast(clear_message) | |
return {"result": result, "message": "Conversation cleared successfully"} | |
except Exception as e: | |
logger.error(f"Error clearing conversation: {e}") | |
raise HTTPException(status_code=500, detail="Error clearing conversation") | |
async def get_connection_stats(): | |
"""Get detailed connection statistics""" | |
return { | |
"connection_stats": connection_stats, | |
"manager_stats": manager.get_stats(), | |
"system_info": { | |
"diarization_running": diart.is_running if diart else False, | |
"total_active_connections": len(manager.active_connections) | |
} | |
} | |
# Mount UI if available | |
try: | |
import ui | |
ui.mount_ui(app) | |
logger.info("Gradio UI mounted successfully") | |
except ImportError: | |
logger.warning("UI module not found, running in API-only mode") | |
except Exception as e: | |
logger.error(f"Error mounting UI: {e}") | |
# Initialize system start time | |
connection_stats["system_start_time"] = time.time() | |
if __name__ == "__main__": | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=7860, | |
log_level="info", | |
access_log=True | |
) |