Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
import websockets | |
import json | |
import logging | |
import time | |
from typing import Dict, Any, Optional | |
import threading | |
from queue import Queue | |
import base64 | |
import numpy as np | |
import os | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Environment-configurable HF Space URL (matching backend.py) | |
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://androidguy-speaker-diarization.hf.space") | |
API_WS = f"wss://{HF_SPACE_URL}/ws_inference" | |
class TranscriptionWebSocketServer: | |
"""WebSocket server that receives audio from backend and returns transcription results""" | |
def __init__(self): | |
self.connected_clients = set() | |
self.is_running = False | |
self.websocket_server = None | |
self.conversation_history = [] | |
self.processing_stats = { | |
"total_audio_chunks": 0, | |
"total_transcriptions": 0, | |
"last_audio_received": None, | |
"server_start_time": time.time(), | |
"backend_url": HF_SPACE_URL | |
} | |
async def handle_client_connection(self, websocket, path): | |
"""Handle incoming WebSocket connections from the backend""" | |
client_addr = websocket.remote_address | |
logger.info(f"Backend client connected from {client_addr}") | |
self.connected_clients.add(websocket) | |
try: | |
# Send initial connection acknowledgment | |
await websocket.send(json.dumps({ | |
"type": "connection_ack", | |
"status": "connected", | |
"timestamp": time.time(), | |
"message": "HuggingFace transcription service ready" | |
})) | |
# Handle incoming messages/audio data | |
async for message in websocket: | |
try: | |
if isinstance(message, bytes): | |
# Handle binary audio data | |
await self.process_audio_data(message, websocket) | |
else: | |
# Handle text messages (JSON) | |
await self.handle_text_message(message, websocket) | |
except Exception as e: | |
logger.error(f"Error processing message: {e}") | |
await self.send_error(websocket, f"Processing error: {str(e)}") | |
except websockets.exceptions.ConnectionClosed: | |
logger.info("Backend client disconnected") | |
except Exception as e: | |
logger.error(f"Client connection error: {e}") | |
finally: | |
self.connected_clients.discard(websocket) | |
logger.info(f"Client removed. Active connections: {len(self.connected_clients)}") | |
async def process_audio_data(self, audio_data: bytes, websocket): | |
"""Process incoming audio data and return transcription results""" | |
try: | |
self.processing_stats["total_audio_chunks"] += 1 | |
self.processing_stats["last_audio_received"] = time.time() | |
logger.debug(f"Received {len(audio_data)} bytes of audio data") | |
# Try to import and use your inference functions | |
try: | |
from inference import transcribe_audio, identify_speakers | |
# Process the audio for transcription | |
transcription_result = await transcribe_audio(audio_data) | |
if transcription_result: | |
# Process for speaker diarization if available | |
try: | |
speaker_info = await identify_speakers(audio_data) | |
transcription_result.update(speaker_info) | |
except Exception as e: | |
logger.warning(f"Speaker diarization failed: {e}") | |
transcription_result["speaker"] = "Unknown" | |
# Update conversation history | |
self.update_conversation_history(transcription_result) | |
# Send result back to backend | |
response = { | |
"type": "processing_result", | |
"timestamp": time.time(), | |
"data": transcription_result | |
} | |
await websocket.send(json.dumps(response)) | |
self.processing_stats["total_transcriptions"] += 1 | |
logger.info(f"Sent transcription result: {transcription_result.get('text', '')[:50]}...") | |
except ImportError: | |
# Fallback if inference module is not available | |
logger.warning("Inference module not found, using mock transcription") | |
# Try to use shared.py for processing if available | |
try: | |
from shared import RealtimeSpeakerDiarization | |
# Initialize if not already initialized | |
if not hasattr(self, 'diarization_system'): | |
self.diarization_system = RealtimeSpeakerDiarization() | |
await asyncio.to_thread(self.diarization_system.initialize_models) | |
await asyncio.to_thread(self.diarization_system.start_recording) | |
# Process the audio chunk | |
result = await asyncio.to_thread(self.diarization_system.process_audio_chunk, audio_data) | |
# Format result for response | |
if result and result["status"] != "error": | |
mock_result = { | |
"text": result.get("text", f"[Processing {len(audio_data)} bytes]"), | |
"speaker": f"Speaker_{result.get('speaker_id', 0) + 1}", | |
"confidence": result.get("similarity", 0.85), | |
"timestamp": time.time() | |
} | |
else: | |
# Fallback mock result | |
mock_result = { | |
"text": f"[Mock transcription - {len(audio_data)} bytes processed]", | |
"speaker": "Speaker_1", | |
"confidence": 0.85, | |
"timestamp": time.time() | |
} | |
# Update conversation history | |
self.update_conversation_history(mock_result) | |
response = { | |
"type": "processing_result", | |
"timestamp": time.time(), | |
"data": mock_result | |
} | |
await websocket.send(json.dumps(response)) | |
self.processing_stats["total_transcriptions"] += 1 | |
except Exception as e: | |
logger.warning(f"Failed to use shared module: {e}") | |
# Basic mock transcription as last resort | |
mock_result = { | |
"text": f"[Mock transcription - {len(audio_data)} bytes processed]", | |
"speaker": "Speaker_1", | |
"confidence": 0.85, | |
"timestamp": time.time() | |
} | |
self.update_conversation_history(mock_result) | |
response = { | |
"type": "processing_result", | |
"timestamp": time.time(), | |
"data": mock_result | |
} | |
await websocket.send(json.dumps(response)) | |
except Exception as e: | |
logger.error(f"Audio processing error: {e}") | |
await self.send_error(websocket, f"Audio processing failed: {str(e)}") | |
async def handle_text_message(self, message: str, websocket): | |
"""Handle text-based messages from backend""" | |
try: | |
data = json.loads(message) | |
message_type = data.get("type", "unknown") | |
logger.info(f"Received message type: {message_type}") | |
if message_type == "ping": | |
# Respond to ping with pong | |
await websocket.send(json.dumps({ | |
"type": "pong", | |
"timestamp": time.time() | |
})) | |
elif message_type == "config": | |
# Handle configuration updates | |
logger.info(f"Configuration update: {data}") | |
# Apply configuration settings if available | |
settings = data.get("settings", {}) | |
if "max_speakers" in settings: | |
max_speakers = settings.get("max_speakers") | |
logger.info(f"Setting max_speakers to {max_speakers}") | |
if "threshold" in settings: | |
threshold = settings.get("threshold") | |
logger.info(f"Setting speaker change threshold to {threshold}") | |
# Send acknowledgment | |
await websocket.send(json.dumps({ | |
"type": "config_ack", | |
"message": "Configuration received", | |
"timestamp": time.time() | |
})) | |
elif message_type == "status_request": | |
# Send status information | |
await websocket.send(json.dumps({ | |
"type": "status_response", | |
"data": self.get_processing_stats(), | |
"timestamp": time.time() | |
})) | |
else: | |
logger.warning(f"Unknown message type: {message_type}") | |
except json.JSONDecodeError: | |
logger.error(f"Invalid JSON received: {message}") | |
await self.send_error(websocket, "Invalid JSON format") | |
async def send_error(self, websocket, error_message: str): | |
"""Send error message to client""" | |
try: | |
await websocket.send(json.dumps({ | |
"type": "error", | |
"message": error_message, | |
"timestamp": time.time() | |
})) | |
except Exception as e: | |
logger.error(f"Failed to send error message: {e}") | |
def update_conversation_history(self, transcription_result: Dict[str, Any]): | |
"""Update conversation history with new transcription""" | |
history_entry = { | |
"timestamp": time.time(), | |
"text": transcription_result.get("text", ""), | |
"speaker": transcription_result.get("speaker", "Unknown"), | |
"confidence": transcription_result.get("confidence", 0.0) | |
} | |
self.conversation_history.append(history_entry) | |
# Keep only last 50 entries to prevent memory issues | |
if len(self.conversation_history) > 50: | |
self.conversation_history = self.conversation_history[-50:] | |
def get_processing_stats(self): | |
"""Get processing statistics""" | |
return { | |
"connected_clients": len(self.connected_clients), | |
"total_audio_chunks": self.processing_stats["total_audio_chunks"], | |
"total_transcriptions": self.processing_stats["total_transcriptions"], | |
"last_audio_received": self.processing_stats["last_audio_received"], | |
"server_uptime": time.time() - self.processing_stats["server_start_time"], | |
"conversation_entries": len(self.conversation_history), | |
"backend_url": self.processing_stats.get("backend_url", HF_SPACE_URL) | |
} | |
async def start_server(self, host="0.0.0.0", port=7860): | |
"""Start the WebSocket server""" | |
try: | |
# Start WebSocket server on /ws_inference endpoint | |
self.websocket_server = await websockets.serve( | |
self.handle_client_connection, | |
host, | |
port, | |
subprotocols=[], | |
path="/ws_inference" | |
) | |
self.is_running = True | |
logger.info(f"WebSocket server started on ws://{host}:{port}/ws_inference") | |
# Keep the server running | |
await self.websocket_server.wait_closed() | |
except Exception as e: | |
logger.error(f"Failed to start WebSocket server: {e}") | |
self.is_running = False | |
# Initialize the WebSocket server | |
ws_server = TranscriptionWebSocketServer() | |
def create_gradio_interface(): | |
"""Create Gradio interface for monitoring and testing""" | |
def get_server_status(): | |
"""Get current server status""" | |
stats = ws_server.get_processing_stats() | |
status_text = f""" | |
### Server Status | |
- **WebSocket Server**: {'π’ Running' if ws_server.is_running else 'π΄ Stopped'} | |
- **Connected Clients**: {stats['connected_clients']} | |
- **Server Uptime**: {stats['server_uptime']:.1f} seconds | |
### Processing Statistics | |
- **Audio Chunks Processed**: {stats['total_audio_chunks']} | |
- **Transcriptions Generated**: {stats['total_transcriptions']} | |
- **Last Audio Received**: {time.ctime(stats['last_audio_received']) if stats['last_audio_received'] else 'Never'} | |
### Conversation | |
- **History Entries**: {stats['conversation_entries']} | |
""" | |
return status_text | |
def get_recent_transcriptions(): | |
"""Get recent transcription results""" | |
if not ws_server.conversation_history: | |
return "No transcriptions yet. Waiting for audio data from backend..." | |
recent_entries = ws_server.conversation_history[-10:] # Last 10 entries | |
formatted_text = "### Recent Transcriptions\n\n" | |
for entry in recent_entries: | |
timestamp = time.strftime("%H:%M:%S", time.localtime(entry['timestamp'])) | |
speaker = entry['speaker'] | |
text = entry['text'] | |
confidence = entry['confidence'] | |
# Extract speaker number for color matching with shared.py | |
speaker_num = 0 | |
if speaker.startswith("Speaker_"): | |
try: | |
speaker_num = int(speaker.split("_")[1]) - 1 | |
except (ValueError, IndexError): | |
speaker_num = 0 | |
# Use colors from shared.py if possible | |
try: | |
from shared import SPEAKER_COLORS | |
color = SPEAKER_COLORS[speaker_num % len(SPEAKER_COLORS)] | |
except (ImportError, IndexError): | |
# Fallback colors | |
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F"] | |
color = colors[speaker_num % len(colors)] | |
formatted_text += f"<span style='color:{color};font-weight:bold;'>[{timestamp}] {speaker}</span> (confidence: {confidence:.2f})\n" | |
formatted_text += f"{text}\n\n" | |
return formatted_text | |
def clear_conversation_history(): | |
"""Clear conversation history""" | |
ws_server.conversation_history.clear() | |
return "Conversation history cleared!" | |
# Create Gradio interface | |
with gr.Blocks( | |
title="Real-time Audio Transcription Service", | |
theme=gr.themes.Soft() | |
) as demo: | |
gr.Markdown("# π€ Real-time Audio Transcription Service") | |
gr.Markdown("This HuggingFace Space receives audio from your backend and returns transcription results with speaker diarization.") | |
with gr.Tab("π Server Status"): | |
status_display = gr.Markdown(get_server_status()) | |
with gr.Row(): | |
refresh_status_btn = gr.Button("π Refresh Status", variant="primary") | |
refresh_status_btn.click( | |
fn=get_server_status, | |
outputs=status_display, | |
every=None | |
) | |
with gr.Tab("π Live Transcription"): | |
transcription_display = gr.Markdown(get_recent_transcriptions()) | |
with gr.Row(): | |
refresh_transcription_btn = gr.Button("π Refresh Transcriptions", variant="primary") | |
clear_history_btn = gr.Button("ποΈ Clear History", variant="secondary") | |
refresh_transcription_btn.click( | |
fn=get_recent_transcriptions, | |
outputs=transcription_display | |
) | |
clear_history_btn.click( | |
fn=clear_conversation_history, | |
outputs=gr.Markdown() | |
) | |
with gr.Tab("π§ Connection Info"): | |
gr.Markdown(f""" | |
### WebSocket Connection Details | |
**WebSocket Endpoint**: `wss://{HF_SPACE_URL}/ws_inference` | |
### Backend Connection | |
Your backend should connect to this WebSocket endpoint and: | |
1. **Send Audio Data**: Stream raw audio bytes to this endpoint | |
2. **Receive Results**: Get JSON responses with transcription results | |
### Expected Message Flow | |
**Backend β HuggingFace**: | |
- Raw audio bytes (binary data) | |
- Configuration messages (JSON) | |
**HuggingFace β Backend**: | |
```json | |
{{ | |
"type": "processing_result", | |
"timestamp": 1234567890.123, | |
"data": {{ | |
"text": "transcribed text here", | |
"speaker": "Speaker_1", | |
"confidence": 0.95 | |
}} | |
}} | |
``` | |
### Test Connection | |
Your backend is configured to connect to: `{ws_server.processing_stats.get('backend_url', HF_SPACE_URL)}` | |
""") | |
with gr.Tab("π API Documentation"): | |
gr.Markdown(""" | |
### WebSocket API Reference | |
#### Endpoint | |
- **URL**: `/ws_inference` | |
- **Protocol**: WebSocket | |
- **Accepts**: Binary audio data + JSON messages | |
#### Message Types | |
##### 1. Audio Processing | |
- **Input**: Raw audio bytes (binary) | |
- **Output**: Processing result (JSON) | |
##### 2. Configuration | |
- **Input**: | |
```json | |
{ | |
"type": "config", | |
"settings": { | |
"language": "en", | |
"enable_diarization": true, | |
"max_speakers": 4, | |
"threshold": 0.65 | |
} | |
} | |
``` | |
##### 3. Status Check | |
- **Input**: `{"type": "status_request"}` | |
- **Output**: Server statistics | |
##### 4. Ping/Pong | |
- **Input**: `{"type": "ping"}` | |
- **Output**: `{"type": "pong", "timestamp": 1234567890}` | |
#### Error Handling | |
All errors are returned as: | |
```json | |
{ | |
"type": "error", | |
"message": "Error description", | |
"timestamp": 1234567890.123 | |
} | |
``` | |
""") | |
return demo | |
def run_websocket_server(): | |
"""Run WebSocket server in background thread""" | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
logger.info("Starting WebSocket server thread...") | |
loop.run_until_complete(ws_server.start_server()) | |
except Exception as e: | |
logger.error(f"WebSocket server error: {e}") | |
finally: | |
loop.close() | |
# Mount UI to inference.py | |
def mount_ui(app): | |
"""Mount Gradio interface to FastAPI app""" | |
try: | |
demo = create_gradio_interface() | |
# Mount without starting server (FastAPI will handle it) | |
demo.mount_to_app(app) | |
logger.info("Gradio UI mounted to FastAPI app") | |
return True | |
except Exception as e: | |
logger.error(f"Error mounting UI: {e}") | |
return False | |
# Start WebSocket server in background | |
logger.info("Initializing WebSocket server...") | |
websocket_thread = threading.Thread(target=run_websocket_server, daemon=True) | |
websocket_thread.start() | |
# Give server time to start | |
time.sleep(2) | |
# Create and launch Gradio interface | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) |