import gradio as gr import asyncio import websockets import json import logging import time from typing import Dict, Any, Optional import threading from queue import Queue import base64 # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TranscriptionInterface: """Interface for real-time transcription and speaker diarization""" def __init__(self): self.connected_clients = set() self.message_queue = Queue() self.is_running = False self.websocket_server = None self.current_transcript = "" self.conversation_history = [] async def handle_client(self, websocket, path): """Handle WebSocket client connections""" client_id = f"client_{int(time.time())}" self.connected_clients.add(websocket) logger.info(f"Client connected: {client_id}. Total clients: {len(self.connected_clients)}") try: # Send connection confirmation await websocket.send(json.dumps({ "type": "connection", "status": "connected", "timestamp": time.time(), "client_id": client_id })) async for message in websocket: try: if isinstance(message, bytes): # Handle binary audio data await self.process_audio_chunk(message, websocket) else: # Handle text messages data = json.loads(message) await self.handle_message(data, websocket) except json.JSONDecodeError: logger.warning(f"Invalid JSON received from client: {message}") except Exception as e: logger.error(f"Error processing message: {e}") except websockets.exceptions.ConnectionClosed: logger.info(f"Client {client_id} disconnected") except Exception as e: logger.error(f"Client handler error: {e}") finally: self.connected_clients.discard(websocket) logger.info(f"Client removed. Remaining clients: {len(self.connected_clients)}") async def process_audio_chunk(self, audio_data: bytes, websocket): """Process incoming audio data""" try: # Import inference functions (assuming they exist in your setup) from inference import process_audio_for_transcription # Process the audio chunk result = await process_audio_for_transcription(audio_data) if result: # Broadcast result to all clients await self.broadcast_result({ "type": "processing_result", "timestamp": time.time(), "data": result }) # Update conversation history if "transcription" in result: self.update_conversation(result) except ImportError: logger.warning("Inference module not found - audio processing disabled") except Exception as e: logger.error(f"Error processing audio chunk: {e}") await websocket.send(json.dumps({ "type": "error", "message": f"Audio processing error: {str(e)}", "timestamp": time.time() })) async def handle_message(self, data: Dict[str, Any], websocket): """Handle non-audio messages from clients""" message_type = data.get("type", "unknown") if message_type == "config": # Handle configuration updates logger.info(f"Configuration update: {data}") elif message_type == "request_history": # Send conversation history to client await websocket.send(json.dumps({ "type": "conversation_history", "data": self.conversation_history, "timestamp": time.time() })) elif message_type == "clear_history": # Clear conversation history self.conversation_history = [] self.current_transcript = "" await self.broadcast_result({ "type": "conversation_update", "action": "cleared", "timestamp": time.time() }) else: logger.warning(f"Unknown message type: {message_type}") async def broadcast_result(self, result: Dict[str, Any]): """Broadcast results to all connected clients""" if not self.connected_clients: return message = json.dumps(result) disconnected = set() for client in self.connected_clients.copy(): try: await client.send(message) except Exception as e: logger.warning(f"Failed to send to client: {e}") disconnected.add(client) # Clean up disconnected clients for client in disconnected: self.connected_clients.discard(client) def update_conversation(self, result: Dict[str, Any]): """Update conversation history with new transcription results""" if "transcription" in result: transcript_data = { "timestamp": time.time(), "text": result["transcription"], "speaker": result.get("speaker", "Unknown"), "confidence": result.get("confidence", 0.0) } self.conversation_history.append(transcript_data) # Keep only last 100 entries to prevent memory issues if len(self.conversation_history) > 100: self.conversation_history = self.conversation_history[-100:] async def start_websocket_server(self, host="0.0.0.0", port=7860): """Start the WebSocket server""" try: self.websocket_server = await websockets.serve( self.handle_client, host, port, path="/ws_inference" ) self.is_running = True logger.info(f"WebSocket server started on {host}:{port}") # Keep server running await self.websocket_server.wait_closed() except Exception as e: logger.error(f"WebSocket server error: {e}") self.is_running = False def get_status(self): """Get current status information""" return { "connected_clients": len(self.connected_clients), "is_running": self.is_running, "conversation_entries": len(self.conversation_history), "last_activity": time.time() } # Initialize the transcription interface transcription_interface = TranscriptionInterface() def create_gradio_interface(): """Create the Gradio interface""" def get_server_status(): """Get server status for display""" status = transcription_interface.get_status() return f""" **Server Status:** - WebSocket Server: {'Running' if status['is_running'] else 'Stopped'} - Connected Clients: {status['connected_clients']} - Conversation Entries: {status['conversation_entries']} - Last Activity: {time.ctime(status['last_activity'])} """ def get_conversation_history(): """Get formatted conversation history""" if not transcription_interface.conversation_history: return "No conversation history available." formatted_history = [] for entry in transcription_interface.conversation_history[-10:]: # Show last 10 entries timestamp = time.ctime(entry['timestamp']) speaker = entry.get('speaker', 'Unknown') text = entry.get('text', '') confidence = entry.get('confidence', 0.0) formatted_history.append(f"**[{timestamp}] {speaker}** (confidence: {confidence:.2f})\n{text}\n") return "\n".join(formatted_history) def clear_conversation(): """Clear conversation history""" transcription_interface.conversation_history = [] transcription_interface.current_transcript = "" return "Conversation history cleared." # Create Gradio interface with gr.Blocks(title="Real-time Audio Transcription & Speaker Diarization") as demo: gr.Markdown("# Real-time Audio Transcription & Speaker Diarization") gr.Markdown("This Hugging Face Space provides WebSocket endpoints for real-time audio processing.") with gr.Tab("Server Status"): status_display = gr.Markdown(get_server_status()) refresh_btn = gr.Button("Refresh Status") refresh_btn.click(get_server_status, outputs=status_display) with gr.Tab("Live Transcription"): gr.Markdown("### Live Conversation") conversation_display = gr.Markdown(get_conversation_history()) with gr.Row(): refresh_conv_btn = gr.Button("Refresh Conversation") clear_conv_btn = gr.Button("Clear History", variant="secondary") refresh_conv_btn.click(get_conversation_history, outputs=conversation_display) clear_conv_btn.click(clear_conversation, outputs=conversation_display) with gr.Tab("WebSocket Info"): gr.Markdown(""" ### WebSocket Endpoint Connect to this Space's WebSocket endpoint for real-time audio processing: **WebSocket URL:** `wss://your-space-name.hf.space/ws_inference` ### Message Format **Audio Data:** Send raw audio bytes directly to the WebSocket **Text Messages:** JSON format ```json { "type": "config", "settings": { "language": "en", "enable_diarization": true } } ``` ### Response Format ```json { "type": "processing_result", "timestamp": 1234567890.123, "data": { "transcription": "Hello world", "speaker": "Speaker_1", "confidence": 0.95 } } ``` """) with gr.Tab("API Documentation"): gr.Markdown(""" ### Available Endpoints - **WebSocket:** `/ws_inference` - Main endpoint for real-time audio processing - **HTTP:** `/health` - Check server health status - **HTTP:** `/stats` - Get detailed statistics ### Integration Example ```javascript const ws = new WebSocket('wss://your-space-name.hf.space/ws_inference'); ws.onopen = function() { console.log('Connected to transcription service'); }; ws.onmessage = function(event) { const data = JSON.parse(event.data); if (data.type === 'processing_result') { console.log('Transcription:', data.data.transcription); console.log('Speaker:', data.data.speaker); } }; // Send audio data ws.send(audioBuffer); ``` """) return demo def run_websocket_server(): """Run the WebSocket server in a separate thread""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(transcription_interface.start_websocket_server()) except Exception as e: logger.error(f"WebSocket server thread error: {e}") finally: loop.close() # Start WebSocket server in background thread websocket_thread = threading.Thread(target=run_websocket_server, daemon=True) websocket_thread.start() # Create and launch Gradio interface if __name__ == "__main__": demo = create_gradio_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )