Spaces:
Running
Running
import gradio as gr | |
from fastapi import FastAPI | |
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL | |
print(gr.__version__) | |
# Connection configuration (separate signaling server from model server) | |
# These will be replaced at deployment time with the correct URLs | |
RENDER_SIGNALING_URL = "wss://render-signal-audio.onrender.com/stream" | |
HF_SPACE_URL = "https://androidguy-speaker-diarization.hf.space" | |
def build_ui(): | |
"""Build Gradio UI for speaker diarization""" | |
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo: | |
# Add configuration variables to page using custom component | |
gr.HTML( | |
f""" | |
<!-- Configuration parameters --> | |
<script> | |
window.RENDER_SIGNALING_URL = "{RENDER_SIGNALING_URL}"; | |
window.HF_SPACE_URL = "{HF_SPACE_URL}"; | |
</script> | |
""" | |
) | |
# Header and description | |
gr.Markdown("# π€ Live Speaker Diarization") | |
gr.Markdown(f"Real-time speech recognition with automatic speaker identification") | |
# Add transcription model info | |
gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}") | |
# Status indicator | |
connection_status = gr.HTML( | |
"""<div class="status-indicator"> | |
<span id="status-text" style="color:#888;">Waiting to connect...</span> | |
<span id="status-icon" style="width:10px; height:10px; display:inline-block; | |
background-color:#888; border-radius:50%; margin-left:5px;"></span> | |
</div>""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Conversation display with embedded JavaScript for WebRTC and audio handling | |
conversation_display = gr.HTML( | |
""" | |
<div class='output' id="conversation" style='padding:20px; background:#111; border-radius:10px; | |
min-height:400px; font-family:Arial; font-size:16px; line-height:1.5; overflow-y:auto;'> | |
<i>Click 'Start Listening' to begin...</i> | |
</div> | |
<script> | |
// Global variables | |
let rtcConnection; | |
let mediaStream; | |
let wsConnection; | |
let statusUpdateInterval; | |
// Check connection to HF space | |
async function checkHfConnection() { | |
try { | |
let response = await fetch(`${window.HF_SPACE_URL}/health`); | |
return response.ok; | |
} catch (err) { | |
return false; | |
} | |
} | |
// Start the connection and audio streaming | |
async function startStreaming() { | |
try { | |
// Update status | |
updateStatus('connecting'); | |
// Request microphone access | |
mediaStream = await navigator.mediaDevices.getUserMedia({audio: { | |
echoCancellation: true, | |
noiseSuppression: true, | |
autoGainControl: true | |
}}); | |
// Set up WebRTC connection to Render signaling server | |
await setupWebRTC(); | |
// Also connect WebSocket directly to HF Space for conversation updates | |
setupWebSocket(); | |
// Start status update interval | |
statusUpdateInterval = setInterval(updateConnectionInfo, 5000); | |
// Update status | |
updateStatus('connected'); | |
document.getElementById("conversation").innerHTML = "<i>Connected! Start speaking...</i>"; | |
} catch (err) { | |
console.error('Error starting stream:', err); | |
updateStatus('error', err.message); | |
} | |
} | |
// Set up WebRTC connection to Render signaling server | |
async function setupWebRTC() { | |
try { | |
if (rtcConnection) { | |
rtcConnection.close(); | |
} | |
// Use FastRTC's connection approach | |
const pc = new RTCPeerConnection({ | |
iceServers: [{ urls: 'stun:stun.l.google.com:19302' }] | |
}); | |
// Add audio track | |
mediaStream.getAudioTracks().forEach(track => { | |
pc.addTrack(track, mediaStream); | |
}); | |
// Connect to FastRTC signaling via WebSocket | |
const signalWs = new WebSocket(window.RENDER_SIGNALING_URL.replace('wss://', 'wss://')); | |
// Handle signaling messages | |
signalWs.onmessage = async (event) => { | |
const message = JSON.parse(event.data); | |
if (message.type === 'offer') { | |
await pc.setRemoteDescription(new RTCSessionDescription(message)); | |
const answer = await pc.createAnswer(); | |
await pc.setLocalDescription(answer); | |
signalWs.send(JSON.stringify(pc.localDescription)); | |
} else if (message.type === 'candidate') { | |
if (message.candidate) { | |
await pc.addIceCandidate(new RTCIceCandidate(message)); | |
} | |
} | |
}; | |
// Send ICE candidates | |
pc.onicecandidate = (event) => { | |
if (event.candidate) { | |
signalWs.send(JSON.stringify({ | |
type: 'candidate', | |
candidate: event.candidate | |
})); | |
} | |
}; | |
// Keep connection reference | |
rtcConnection = pc; | |
// Wait for connection to be established | |
await new Promise((resolve, reject) => { | |
const timeout = setTimeout(() => reject(new Error("WebRTC connection timeout")), 10000); | |
pc.onconnectionstatechange = () => { | |
if (pc.connectionState === 'connected') { | |
clearTimeout(timeout); | |
resolve(); | |
} else if (pc.connectionState === 'failed' || pc.connectionState === 'disconnected') { | |
clearTimeout(timeout); | |
reject(new Error("WebRTC connection failed")); | |
} | |
}; | |
}); | |
updateStatus('connected'); | |
} catch (err) { | |
console.error('WebRTC setup error:', err); | |
updateStatus('error', 'WebRTC setup failed: ' + err.message); | |
} | |
} | |
// Set up WebSocket connection to HF Space for conversation updates | |
function setupWebSocket() { | |
const wsUrl = window.RENDER_SIGNALING_URL.replace('stream', 'ws_relay'); | |
wsConnection = new WebSocket(wsUrl); | |
wsConnection.onopen = () => { | |
console.log('WebSocket connection established'); | |
}; | |
wsConnection.onmessage = (event) => { | |
try { | |
// Parse the JSON message | |
const message = JSON.parse(event.data); | |
// Process different message types | |
switch(message.type) { | |
case 'transcription': | |
// Handle transcription data | |
if (message && message.data && typeof message.data === 'object') { | |
document.getElementById("conversation").innerHTML = message.data.conversation_html || | |
JSON.stringify(message.data); | |
} | |
break; | |
case 'processing_result': | |
// Handle individual audio chunk processing result | |
console.log('Processing result:', message.data); | |
// Update status info if needed | |
if (message.data && message.data.status === "processed") { | |
const statusElem = document.getElementById('status-text'); | |
if (statusElem) { | |
const speakerId = message.data.speaker_id !== undefined ? | |
`Speaker ${message.data.speaker_id + 1}` : ''; | |
if (speakerId) { | |
statusElem.textContent = `Connected - ${speakerId} active`; | |
} | |
} | |
} else if (message.data && message.data.status === "error") { | |
updateStatus('error', message.data.message || 'Processing error'); | |
} | |
break; | |
case 'connection': | |
console.log('Connection status:', message.status); | |
updateStatus(message.status === 'connected' ? 'connected' : 'warning'); | |
break; | |
case 'connection_established': | |
console.log('Connection established:', message); | |
updateStatus('connected'); | |
// If initial conversation is provided, display it | |
if (message.conversation) { | |
document.getElementById("conversation").innerHTML = message.conversation; | |
} | |
break; | |
case 'conversation_update': | |
if (message.conversation_html) { | |
document.getElementById("conversation").innerHTML = message.conversation_html; | |
} | |
break; | |
case 'conversation_cleared': | |
document.getElementById("conversation").innerHTML = | |
"<i>Conversation cleared. Start speaking again...</i>"; | |
break; | |
case 'error': | |
console.error('Error message from server:', message.message); | |
updateStatus('warning', message.message); | |
break; | |
default: | |
// If it's just HTML content without proper JSON structure (legacy format) | |
document.getElementById("conversation").innerHTML = event.data; | |
} | |
// Auto-scroll to bottom | |
const container = document.getElementById("conversation"); | |
container.scrollTop = container.scrollHeight; | |
} catch (e) { | |
// Fallback for non-JSON messages (legacy format) | |
document.getElementById("conversation").innerHTML = event.data; | |
// Auto-scroll to bottom | |
const container = document.getElementById("conversation"); | |
container.scrollTop = container.scrollHeight; | |
} | |
}; | |
wsConnection.onerror = (error) => { | |
console.error('WebSocket error:', error); | |
updateStatus('warning', 'WebSocket error'); | |
}; | |
wsConnection.onclose = () => { | |
console.log('WebSocket connection closed'); | |
// Try to reconnect after a delay | |
setTimeout(setupWebSocket, 3000); | |
}; | |
} | |
// Update connection info in the UI | |
async function updateConnectionInfo() { | |
try { | |
const hfConnected = await checkHfConnection(); | |
if (!hfConnected) { | |
updateStatus('warning', 'HF Space connection issue'); | |
} else if (rtcConnection?.connectionState === 'connected' || | |
rtcConnection?.iceConnectionState === 'connected') { | |
updateStatus('connected'); | |
} else { | |
updateStatus('warning', 'Connection unstable'); | |
} | |
} catch (err) { | |
console.error('Error updating connection info:', err); | |
} | |
} | |
// Update status indicator | |
function updateStatus(status, message = '') { | |
const statusText = document.getElementById('status-text'); | |
const statusIcon = document.getElementById('status-icon'); | |
switch(status) { | |
case 'connected': | |
statusText.textContent = 'Connected'; | |
statusIcon.style.backgroundColor = '#4CAF50'; | |
break; | |
case 'connecting': | |
statusText.textContent = 'Connecting...'; | |
statusIcon.style.backgroundColor = '#FFC107'; | |
break; | |
case 'disconnected': | |
statusText.textContent = 'Disconnected'; | |
statusIcon.style.backgroundColor = '#9E9E9E'; | |
break; | |
case 'error': | |
statusText.textContent = 'Error: ' + message; | |
statusIcon.style.backgroundColor = '#F44336'; | |
break; | |
case 'warning': | |
statusText.textContent = 'Warning: ' + message; | |
statusIcon.style.backgroundColor = '#FF9800'; | |
break; | |
default: | |
statusText.textContent = 'Unknown'; | |
statusIcon.style.backgroundColor = '#9E9E9E'; | |
} | |
} | |
// Stop streaming and clean up | |
function stopStreaming() { | |
// Close WebRTC connection | |
if (rtcConnection) { | |
rtcConnection.close(); | |
rtcConnection = null; | |
} | |
// Close WebSocket | |
if (wsConnection) { | |
wsConnection.close(); | |
wsConnection = null; | |
} | |
// Stop all tracks in media stream | |
if (mediaStream) { | |
mediaStream.getTracks().forEach(track => track.stop()); | |
mediaStream = null; | |
} | |
// Clear interval | |
if (statusUpdateInterval) { | |
clearInterval(statusUpdateInterval); | |
statusUpdateInterval = null; | |
} | |
// Update status | |
updateStatus('disconnected'); | |
} | |
// Set up event listeners when the DOM is loaded | |
document.addEventListener('DOMContentLoaded', () => { | |
updateStatus('disconnected'); | |
}); | |
</script> | |
""", | |
label="Live Conversation" | |
) | |
# Control buttons | |
with gr.Row(): | |
start_btn = gr.Button("βΆοΈ Start Listening", variant="primary", size="lg") | |
stop_btn = gr.Button("βΉοΈ Stop", variant="stop", size="lg") | |
clear_btn = gr.Button("ποΈ Clear", variant="secondary", size="lg") | |
# Status display | |
status_output = gr.Markdown( | |
""" | |
## System Status | |
Waiting to connect... | |
*Click Start Listening to begin* | |
""", | |
label="Status Information" | |
) | |
with gr.Column(scale=1): | |
# Settings | |
gr.Markdown("## βοΈ Settings") | |
threshold_slider = gr.Slider( | |
minimum=0.3, | |
maximum=0.9, | |
step=0.05, | |
value=DEFAULT_CHANGE_THRESHOLD, | |
label="Speaker Change Sensitivity", | |
info="Lower = more sensitive (more speaker changes)" | |
) | |
max_speakers_slider = gr.Slider( | |
minimum=2, | |
maximum=ABSOLUTE_MAX_SPEAKERS, | |
step=1, | |
value=DEFAULT_MAX_SPEAKERS, | |
label="Maximum Speakers" | |
) | |
update_btn = gr.Button("Update Settings", variant="secondary") | |
# Instructions | |
gr.Markdown(""" | |
## π Instructions | |
1. **Start Listening** - allows browser to access microphone | |
2. **Speak** - system will transcribe and identify speakers | |
3. **Stop** when finished | |
4. **Clear** to reset conversation | |
## π¨ Speaker Colors | |
- π΄ Speaker 1 (Red) | |
- π’ Speaker 2 (Teal) | |
- π΅ Speaker 3 (Blue) | |
- π‘ Speaker 4 (Green) | |
- β Speaker 5 (Yellow) | |
- π£ Speaker 6 (Plum) | |
- π€ Speaker 7 (Mint) | |
- π Speaker 8 (Gold) | |
""") | |
# JavaScript to connect buttons to the script functions | |
gr.HTML(""" | |
<script> | |
// Wait for Gradio to fully load | |
document.addEventListener('DOMContentLoaded', () => { | |
// Wait a bit for Gradio buttons to be created | |
setTimeout(() => { | |
// Get the buttons | |
const startBtn = document.querySelector('button[aria-label="Start Listening"]'); | |
const stopBtn = document.querySelector('button[aria-label="Stop"]'); | |
const clearBtn = document.querySelector('button[aria-label="Clear"]'); | |
if (startBtn) startBtn.onclick = () => startStreaming(); | |
if (stopBtn) stopBtn.onclick = () => stopStreaming(); | |
if (clearBtn) clearBtn.onclick = () => { | |
// Make API call to clear conversation | |
fetch(`${window.HF_SPACE_URL}/clear`, { | |
method: 'POST' | |
}).then(resp => resp.json()) | |
.then(data => { | |
document.getElementById("conversation").innerHTML = | |
"<i>Conversation cleared. Start speaking again...</i>"; | |
}); | |
} | |
// Set up settings update | |
const updateBtn = document.querySelector('button[aria-label="Update Settings"]'); | |
if (updateBtn) updateBtn.onclick = () => { | |
const threshold = document.querySelector('input[aria-label="Speaker Change Sensitivity"]').value; | |
const maxSpeakers = document.querySelector('input[aria-label="Maximum Speakers"]').value; | |
fetch(`${window.HF_SPACE_URL}/settings?threshold=${threshold}&max_speakers=${maxSpeakers}`, { | |
method: 'POST' | |
}).then(resp => resp.json()) | |
.then(data => { | |
const statusOutput = document.querySelector('.prose'); | |
if (statusOutput) { | |
statusOutput.innerHTML = ` | |
<h2>System Status</h2> | |
<p>Settings updated:</p> | |
<ul> | |
<li>Threshold: ${threshold}</li> | |
<li>Max Speakers: ${maxSpeakers}</li> | |
</ul> | |
<p>Transcription Models:</p> | |
<ul> | |
<li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li> | |
<li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li> | |
</ul> | |
`; | |
} | |
}); | |
} | |
}, 1000); | |
}); | |
</script> | |
""") | |
# Set up periodic status updates | |
def get_status(): | |
"""API call to get system status - called periodically""" | |
import requests | |
try: | |
resp = requests.get(f"{HF_SPACE_URL}/status") | |
if resp.status_code == 200: | |
return resp.json().get('status', 'No status information') | |
return "Error getting status" | |
except Exception as e: | |
return f"Connection error: {str(e)}" | |
status_timer = gr.Timer(5) | |
status_timer.tick(fn=get_status, outputs=status_output) | |
return demo | |
# Create Gradio interface | |
demo = build_ui() | |
def mount_ui(app: FastAPI): | |
"""Mount Gradio app to FastAPI""" | |
app.mount("/ui", demo.app) | |
# For standalone testing | |
if __name__ == "__main__": | |
demo.launch() |