Saiyaswanth007's picture
revert
3e9ecf3
import gradio as gr
from fastapi import FastAPI
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL
print(gr.__version__)
# Connection configuration (separate signaling server from model server)
# These will be replaced at deployment time with the correct URLs
RENDER_SIGNALING_URL = "wss://render-signal-audio.onrender.com/stream"
HF_SPACE_URL = "https://androidguy-speaker-diarization.hf.space"
def build_ui():
"""Build Gradio UI for speaker diarization"""
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
# Add configuration variables to page using custom component
gr.HTML(
f"""
<!-- Configuration parameters -->
<script>
window.RENDER_SIGNALING_URL = "{RENDER_SIGNALING_URL}";
window.HF_SPACE_URL = "{HF_SPACE_URL}";
</script>
"""
)
# Header and description
gr.Markdown("# 🎀 Live Speaker Diarization")
gr.Markdown(f"Real-time speech recognition with automatic speaker identification")
# Add transcription model info
gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}")
# Status indicator
connection_status = gr.HTML(
"""<div class="status-indicator">
<span id="status-text" style="color:#888;">Waiting to connect...</span>
<span id="status-icon" style="width:10px; height:10px; display:inline-block;
background-color:#888; border-radius:50%; margin-left:5px;"></span>
</div>"""
)
with gr.Row():
with gr.Column(scale=2):
# Conversation display with embedded JavaScript for WebRTC and audio handling
conversation_display = gr.HTML(
"""
<div class='output' id="conversation" style='padding:20px; background:#111; border-radius:10px;
min-height:400px; font-family:Arial; font-size:16px; line-height:1.5; overflow-y:auto;'>
<i>Click 'Start Listening' to begin...</i>
</div>
<script>
// Global variables
let rtcConnection;
let mediaStream;
let wsConnection;
let statusUpdateInterval;
// Check connection to HF space
async function checkHfConnection() {
try {
let response = await fetch(`${window.HF_SPACE_URL}/health`);
return response.ok;
} catch (err) {
return false;
}
}
// Start the connection and audio streaming
async function startStreaming() {
try {
// Update status
updateStatus('connecting');
// Request microphone access
mediaStream = await navigator.mediaDevices.getUserMedia({audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true
}});
// Set up WebRTC connection to Render signaling server
await setupWebRTC();
// Also connect WebSocket directly to HF Space for conversation updates
setupWebSocket();
// Start status update interval
statusUpdateInterval = setInterval(updateConnectionInfo, 5000);
// Update status
updateStatus('connected');
document.getElementById("conversation").innerHTML = "<i>Connected! Start speaking...</i>";
} catch (err) {
console.error('Error starting stream:', err);
updateStatus('error', err.message);
}
}
// Set up WebRTC connection to Render signaling server
async function setupWebRTC() {
try {
if (rtcConnection) {
rtcConnection.close();
}
// Use FastRTC's connection approach
const pc = new RTCPeerConnection({
iceServers: [{ urls: 'stun:stun.l.google.com:19302' }]
});
// Add audio track
mediaStream.getAudioTracks().forEach(track => {
pc.addTrack(track, mediaStream);
});
// Connect to FastRTC signaling via WebSocket
const signalWs = new WebSocket(window.RENDER_SIGNALING_URL.replace('wss://', 'wss://'));
// Handle signaling messages
signalWs.onmessage = async (event) => {
const message = JSON.parse(event.data);
if (message.type === 'offer') {
await pc.setRemoteDescription(new RTCSessionDescription(message));
const answer = await pc.createAnswer();
await pc.setLocalDescription(answer);
signalWs.send(JSON.stringify(pc.localDescription));
} else if (message.type === 'candidate') {
if (message.candidate) {
await pc.addIceCandidate(new RTCIceCandidate(message));
}
}
};
// Send ICE candidates
pc.onicecandidate = (event) => {
if (event.candidate) {
signalWs.send(JSON.stringify({
type: 'candidate',
candidate: event.candidate
}));
}
};
// Keep connection reference
rtcConnection = pc;
// Wait for connection to be established
await new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject(new Error("WebRTC connection timeout")), 10000);
pc.onconnectionstatechange = () => {
if (pc.connectionState === 'connected') {
clearTimeout(timeout);
resolve();
} else if (pc.connectionState === 'failed' || pc.connectionState === 'disconnected') {
clearTimeout(timeout);
reject(new Error("WebRTC connection failed"));
}
};
});
updateStatus('connected');
} catch (err) {
console.error('WebRTC setup error:', err);
updateStatus('error', 'WebRTC setup failed: ' + err.message);
}
}
// Set up WebSocket connection to HF Space for conversation updates
function setupWebSocket() {
const wsUrl = window.RENDER_SIGNALING_URL.replace('stream', 'ws_relay');
wsConnection = new WebSocket(wsUrl);
wsConnection.onopen = () => {
console.log('WebSocket connection established');
};
wsConnection.onmessage = (event) => {
try {
// Parse the JSON message
const message = JSON.parse(event.data);
// Process different message types
switch(message.type) {
case 'transcription':
// Handle transcription data
if (message && message.data && typeof message.data === 'object') {
document.getElementById("conversation").innerHTML = message.data.conversation_html ||
JSON.stringify(message.data);
}
break;
case 'processing_result':
// Handle individual audio chunk processing result
console.log('Processing result:', message.data);
// Update status info if needed
if (message.data && message.data.status === "processed") {
const statusElem = document.getElementById('status-text');
if (statusElem) {
const speakerId = message.data.speaker_id !== undefined ?
`Speaker ${message.data.speaker_id + 1}` : '';
if (speakerId) {
statusElem.textContent = `Connected - ${speakerId} active`;
}
}
} else if (message.data && message.data.status === "error") {
updateStatus('error', message.data.message || 'Processing error');
}
break;
case 'connection':
console.log('Connection status:', message.status);
updateStatus(message.status === 'connected' ? 'connected' : 'warning');
break;
case 'connection_established':
console.log('Connection established:', message);
updateStatus('connected');
// If initial conversation is provided, display it
if (message.conversation) {
document.getElementById("conversation").innerHTML = message.conversation;
}
break;
case 'conversation_update':
if (message.conversation_html) {
document.getElementById("conversation").innerHTML = message.conversation_html;
}
break;
case 'conversation_cleared':
document.getElementById("conversation").innerHTML =
"<i>Conversation cleared. Start speaking again...</i>";
break;
case 'error':
console.error('Error message from server:', message.message);
updateStatus('warning', message.message);
break;
default:
// If it's just HTML content without proper JSON structure (legacy format)
document.getElementById("conversation").innerHTML = event.data;
}
// Auto-scroll to bottom
const container = document.getElementById("conversation");
container.scrollTop = container.scrollHeight;
} catch (e) {
// Fallback for non-JSON messages (legacy format)
document.getElementById("conversation").innerHTML = event.data;
// Auto-scroll to bottom
const container = document.getElementById("conversation");
container.scrollTop = container.scrollHeight;
}
};
wsConnection.onerror = (error) => {
console.error('WebSocket error:', error);
updateStatus('warning', 'WebSocket error');
};
wsConnection.onclose = () => {
console.log('WebSocket connection closed');
// Try to reconnect after a delay
setTimeout(setupWebSocket, 3000);
};
}
// Update connection info in the UI
async function updateConnectionInfo() {
try {
const hfConnected = await checkHfConnection();
if (!hfConnected) {
updateStatus('warning', 'HF Space connection issue');
} else if (rtcConnection?.connectionState === 'connected' ||
rtcConnection?.iceConnectionState === 'connected') {
updateStatus('connected');
} else {
updateStatus('warning', 'Connection unstable');
}
} catch (err) {
console.error('Error updating connection info:', err);
}
}
// Update status indicator
function updateStatus(status, message = '') {
const statusText = document.getElementById('status-text');
const statusIcon = document.getElementById('status-icon');
switch(status) {
case 'connected':
statusText.textContent = 'Connected';
statusIcon.style.backgroundColor = '#4CAF50';
break;
case 'connecting':
statusText.textContent = 'Connecting...';
statusIcon.style.backgroundColor = '#FFC107';
break;
case 'disconnected':
statusText.textContent = 'Disconnected';
statusIcon.style.backgroundColor = '#9E9E9E';
break;
case 'error':
statusText.textContent = 'Error: ' + message;
statusIcon.style.backgroundColor = '#F44336';
break;
case 'warning':
statusText.textContent = 'Warning: ' + message;
statusIcon.style.backgroundColor = '#FF9800';
break;
default:
statusText.textContent = 'Unknown';
statusIcon.style.backgroundColor = '#9E9E9E';
}
}
// Stop streaming and clean up
function stopStreaming() {
// Close WebRTC connection
if (rtcConnection) {
rtcConnection.close();
rtcConnection = null;
}
// Close WebSocket
if (wsConnection) {
wsConnection.close();
wsConnection = null;
}
// Stop all tracks in media stream
if (mediaStream) {
mediaStream.getTracks().forEach(track => track.stop());
mediaStream = null;
}
// Clear interval
if (statusUpdateInterval) {
clearInterval(statusUpdateInterval);
statusUpdateInterval = null;
}
// Update status
updateStatus('disconnected');
}
// Set up event listeners when the DOM is loaded
document.addEventListener('DOMContentLoaded', () => {
updateStatus('disconnected');
});
</script>
""",
label="Live Conversation"
)
# Control buttons
with gr.Row():
start_btn = gr.Button("▢️ Start Listening", variant="primary", size="lg")
stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", size="lg")
# Status display
status_output = gr.Markdown(
"""
## System Status
Waiting to connect...
*Click Start Listening to begin*
""",
label="Status Information"
)
with gr.Column(scale=1):
# Settings
gr.Markdown("## βš™οΈ Settings")
threshold_slider = gr.Slider(
minimum=0.3,
maximum=0.9,
step=0.05,
value=DEFAULT_CHANGE_THRESHOLD,
label="Speaker Change Sensitivity",
info="Lower = more sensitive (more speaker changes)"
)
max_speakers_slider = gr.Slider(
minimum=2,
maximum=ABSOLUTE_MAX_SPEAKERS,
step=1,
value=DEFAULT_MAX_SPEAKERS,
label="Maximum Speakers"
)
update_btn = gr.Button("Update Settings", variant="secondary")
# Instructions
gr.Markdown("""
## πŸ“‹ Instructions
1. **Start Listening** - allows browser to access microphone
2. **Speak** - system will transcribe and identify speakers
3. **Stop** when finished
4. **Clear** to reset conversation
## 🎨 Speaker Colors
- πŸ”΄ Speaker 1 (Red)
- 🟒 Speaker 2 (Teal)
- πŸ”΅ Speaker 3 (Blue)
- 🟑 Speaker 4 (Green)
- ⭐ Speaker 5 (Yellow)
- 🟣 Speaker 6 (Plum)
- 🟀 Speaker 7 (Mint)
- 🟠 Speaker 8 (Gold)
""")
# JavaScript to connect buttons to the script functions
gr.HTML("""
<script>
// Wait for Gradio to fully load
document.addEventListener('DOMContentLoaded', () => {
// Wait a bit for Gradio buttons to be created
setTimeout(() => {
// Get the buttons
const startBtn = document.querySelector('button[aria-label="Start Listening"]');
const stopBtn = document.querySelector('button[aria-label="Stop"]');
const clearBtn = document.querySelector('button[aria-label="Clear"]');
if (startBtn) startBtn.onclick = () => startStreaming();
if (stopBtn) stopBtn.onclick = () => stopStreaming();
if (clearBtn) clearBtn.onclick = () => {
// Make API call to clear conversation
fetch(`${window.HF_SPACE_URL}/clear`, {
method: 'POST'
}).then(resp => resp.json())
.then(data => {
document.getElementById("conversation").innerHTML =
"<i>Conversation cleared. Start speaking again...</i>";
});
}
// Set up settings update
const updateBtn = document.querySelector('button[aria-label="Update Settings"]');
if (updateBtn) updateBtn.onclick = () => {
const threshold = document.querySelector('input[aria-label="Speaker Change Sensitivity"]').value;
const maxSpeakers = document.querySelector('input[aria-label="Maximum Speakers"]').value;
fetch(`${window.HF_SPACE_URL}/settings?threshold=${threshold}&max_speakers=${maxSpeakers}`, {
method: 'POST'
}).then(resp => resp.json())
.then(data => {
const statusOutput = document.querySelector('.prose');
if (statusOutput) {
statusOutput.innerHTML = `
<h2>System Status</h2>
<p>Settings updated:</p>
<ul>
<li>Threshold: ${threshold}</li>
<li>Max Speakers: ${maxSpeakers}</li>
</ul>
<p>Transcription Models:</p>
<ul>
<li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li>
<li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li>
</ul>
`;
}
});
}
}, 1000);
});
</script>
""")
# Set up periodic status updates
def get_status():
"""API call to get system status - called periodically"""
import requests
try:
resp = requests.get(f"{HF_SPACE_URL}/status")
if resp.status_code == 200:
return resp.json().get('status', 'No status information')
return "Error getting status"
except Exception as e:
return f"Connection error: {str(e)}"
status_timer = gr.Timer(5)
status_timer.tick(fn=get_status, outputs=status_output)
return demo
# Create Gradio interface
demo = build_ui()
def mount_ui(app: FastAPI):
"""Mount Gradio app to FastAPI"""
app.mount("/ui", demo.app)
# For standalone testing
if __name__ == "__main__":
demo.launch()