Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import queue | |
import torch | |
import time | |
import threading | |
import os | |
import urllib.request | |
import torchaudio | |
from scipy.spatial.distance import cosine | |
import json | |
import asyncio | |
from typing import Iterator | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Simplified configuration parameters | |
SILENCE_THRESHS = [0, 0.4] | |
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" | |
FINAL_BEAM_SIZE = 5 | |
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" | |
REALTIME_BEAM_SIZE = 5 | |
TRANSCRIPTION_LANGUAGE = "en" | |
SILERO_SENSITIVITY = 0.4 | |
WEBRTC_SENSITIVITY = 3 | |
MIN_LENGTH_OF_RECORDING = 0.7 | |
PRE_RECORDING_BUFFER_DURATION = 0.35 | |
# Speaker change detection parameters | |
DEFAULT_CHANGE_THRESHOLD = 0.7 | |
EMBEDDING_HISTORY_SIZE = 5 | |
MIN_SEGMENT_DURATION = 1.0 | |
DEFAULT_MAX_SPEAKERS = 4 | |
ABSOLUTE_MAX_SPEAKERS = 10 | |
# Global variables | |
FAST_SENTENCE_END = True | |
SAMPLE_RATE = 16000 | |
BUFFER_SIZE = 1024 | |
CHANNELS = 1 | |
CHUNK_DURATION_MS = 100 # 100ms chunks for FastRTC | |
# Speaker colors | |
SPEAKER_COLORS = [ | |
"#FFFF00", # Yellow | |
"#FF0000", # Red | |
"#00FF00", # Green | |
"#00FFFF", # Cyan | |
"#FF00FF", # Magenta | |
"#0000FF", # Blue | |
"#FF8000", # Orange | |
"#00FF80", # Spring Green | |
"#8000FF", # Purple | |
"#FFFFFF", # White | |
] | |
SPEAKER_COLOR_NAMES = [ | |
"Yellow", "Red", "Green", "Cyan", "Magenta", | |
"Blue", "Orange", "Spring Green", "Purple", "White" | |
] | |
class SpeechBrainEncoder: | |
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.model = None | |
self.embedding_dim = 192 | |
self.model_loaded = False | |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") | |
os.makedirs(self.cache_dir, exist_ok=True) | |
def _download_model(self): | |
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" | |
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" | |
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") | |
if not os.path.exists(model_path): | |
logger.info(f"Downloading ECAPA-TDNN model to {model_path}...") | |
urllib.request.urlretrieve(model_url, model_path) | |
return model_path | |
def load_model(self): | |
"""Load the ECAPA-TDNN model""" | |
try: | |
from speechbrain.pretrained import EncoderClassifier | |
model_path = self._download_model() | |
self.model = EncoderClassifier.from_hparams( | |
source="speechbrain/spkrec-ecapa-voxceleb", | |
savedir=self.cache_dir, | |
run_opts={"device": self.device} | |
) | |
self.model_loaded = True | |
return True | |
except Exception as e: | |
logger.error(f"Error loading ECAPA-TDNN model: {e}") | |
return False | |
def embed_utterance(self, audio, sr=16000): | |
"""Extract speaker embedding from audio""" | |
if not self.model_loaded: | |
raise ValueError("Model not loaded. Call load_model() first.") | |
try: | |
if isinstance(audio, np.ndarray): | |
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) | |
else: | |
waveform = audio.unsqueeze(0) | |
if sr != 16000: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
with torch.no_grad(): | |
embedding = self.model.encode_batch(waveform) | |
return embedding.squeeze().cpu().numpy() | |
except Exception as e: | |
logger.error(f"Error extracting embedding: {e}") | |
return np.zeros(self.embedding_dim) | |
class AudioProcessor: | |
"""Processes audio data to extract speaker embeddings""" | |
def __init__(self, encoder): | |
self.encoder = encoder | |
def extract_embedding(self, audio_float): | |
try: | |
# Ensure audio is in the right format | |
if np.abs(audio_float).max() > 1.0: | |
audio_float = audio_float / np.abs(audio_float).max() | |
embedding = self.encoder.embed_utterance(audio_float) | |
return embedding | |
except Exception as e: | |
logger.error(f"Embedding extraction error: {e}") | |
return np.zeros(self.encoder.embedding_dim) | |
class SpeakerChangeDetector: | |
"""Speaker change detector that supports a configurable number of speakers""" | |
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
self.embedding_dim = embedding_dim | |
self.change_threshold = change_threshold | |
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
self.current_speaker = 0 | |
self.previous_embeddings = [] | |
self.last_change_time = time.time() | |
self.mean_embeddings = [None] * self.max_speakers | |
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
self.last_similarity = 0.0 | |
self.active_speakers = set([0]) | |
def set_max_speakers(self, max_speakers): | |
"""Update the maximum number of speakers""" | |
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
if new_max < self.max_speakers: | |
for speaker_id in list(self.active_speakers): | |
if speaker_id >= new_max: | |
self.active_speakers.discard(speaker_id) | |
if self.current_speaker >= new_max: | |
self.current_speaker = 0 | |
if new_max > self.max_speakers: | |
self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) | |
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
else: | |
self.mean_embeddings = self.mean_embeddings[:new_max] | |
self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
self.max_speakers = new_max | |
def set_change_threshold(self, threshold): | |
"""Update the threshold for detecting speaker changes""" | |
self.change_threshold = max(0.1, min(threshold, 0.99)) | |
def add_embedding(self, embedding, timestamp=None): | |
"""Add a new embedding and check if there's a speaker change""" | |
current_time = timestamp or time.time() | |
if not self.previous_embeddings: | |
self.previous_embeddings.append(embedding) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
if self.mean_embeddings[self.current_speaker] is None: | |
self.mean_embeddings[self.current_speaker] = embedding.copy() | |
return self.current_speaker, 1.0 | |
current_mean = self.mean_embeddings[self.current_speaker] | |
if current_mean is not None: | |
similarity = 1.0 - cosine(embedding, current_mean) | |
else: | |
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) | |
self.last_similarity = similarity | |
time_since_last_change = current_time - self.last_change_time | |
is_speaker_change = False | |
if time_since_last_change >= MIN_SEGMENT_DURATION: | |
if similarity < self.change_threshold: | |
best_speaker = self.current_speaker | |
best_similarity = similarity | |
for speaker_id in range(self.max_speakers): | |
if speaker_id == self.current_speaker: | |
continue | |
speaker_mean = self.mean_embeddings[speaker_id] | |
if speaker_mean is not None: | |
speaker_similarity = 1.0 - cosine(embedding, speaker_mean) | |
if speaker_similarity > best_similarity: | |
best_similarity = speaker_similarity | |
best_speaker = speaker_id | |
if best_speaker != self.current_speaker: | |
is_speaker_change = True | |
self.current_speaker = best_speaker | |
elif len(self.active_speakers) < self.max_speakers: | |
for new_id in range(self.max_speakers): | |
if new_id not in self.active_speakers: | |
is_speaker_change = True | |
self.current_speaker = new_id | |
self.active_speakers.add(new_id) | |
break | |
if is_speaker_change: | |
self.last_change_time = current_time | |
self.previous_embeddings.append(embedding) | |
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: | |
self.previous_embeddings.pop(0) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
self.active_speakers.add(self.current_speaker) | |
if len(self.speaker_embeddings[self.current_speaker]) > 30: | |
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] | |
if self.speaker_embeddings[self.current_speaker]: | |
self.mean_embeddings[self.current_speaker] = np.mean( | |
self.speaker_embeddings[self.current_speaker], axis=0 | |
) | |
return self.current_speaker, similarity | |
def get_color_for_speaker(self, speaker_id): | |
"""Return color for speaker ID""" | |
if 0 <= speaker_id < len(SPEAKER_COLORS): | |
return SPEAKER_COLORS[speaker_id] | |
return "#FFFFFF" | |
def get_status_info(self): | |
"""Return status information about the speaker change detector""" | |
speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] | |
return { | |
"current_speaker": self.current_speaker, | |
"speaker_counts": speaker_counts, | |
"active_speakers": len(self.active_speakers), | |
"max_speakers": self.max_speakers, | |
"last_similarity": self.last_similarity, | |
"threshold": self.change_threshold | |
} | |
class WhisperTranscriber: | |
"""Whisper transcriber using transformers with FastRTC optimization""" | |
def __init__(self, model_name="distil-large-v3"): | |
self.model = None | |
self.processor = None | |
self.model_name = model_name | |
self.model_loaded = False | |
def load_model(self): | |
"""Load Whisper model""" | |
try: | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
model_id = f"distil-whisper/distil-{self.model_name}" if "distil" in self.model_name else f"openai/whisper-{self.model_name}" | |
self.processor = WhisperProcessor.from_pretrained(model_id) | |
self.model = WhisperForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
if torch.cuda.is_available(): | |
self.model = self.model.cuda() | |
self.model_loaded = True | |
return True | |
except Exception as e: | |
logger.error(f"Error loading Whisper model: {e}") | |
return False | |
def transcribe(self, audio_array, sample_rate=16000): | |
"""Transcribe audio array""" | |
if not self.model_loaded: | |
return "" | |
try: | |
# Ensure audio is the right length and format | |
if len(audio_array) < 1600: # Less than 0.1 seconds | |
return "" | |
# Resample if needed | |
if sample_rate != 16000: | |
import torchaudio.functional as F | |
audio_tensor = torch.tensor(audio_array, dtype=torch.float32) | |
audio_array = F.resample(audio_tensor, sample_rate, 16000).numpy() | |
# Process with Whisper | |
inputs = self.processor( | |
audio_array, | |
sampling_rate=16000, | |
return_tensors="pt", | |
truncation=False, | |
padding=True | |
) | |
if torch.cuda.is_available(): | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
with torch.no_grad(): | |
predicted_ids = self.model.generate( | |
inputs["input_features"], | |
max_length=448, | |
num_beams=1, | |
do_sample=False, | |
use_cache=True | |
) | |
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription.strip() | |
except Exception as e: | |
logger.error(f"Transcription error: {e}") | |
return "" | |
class FastRTCSpeakerDiarization: | |
def __init__(self): | |
self.encoder = None | |
self.audio_processor = None | |
self.speaker_detector = None | |
self.transcriber = None | |
self.audio_queue = queue.Queue(maxsize=100) | |
self.processing_thread = None | |
self.full_sentences = [] | |
self.sentence_speakers = [] | |
self.is_running = False | |
self.change_threshold = DEFAULT_CHANGE_THRESHOLD | |
self.max_speakers = DEFAULT_MAX_SPEAKERS | |
self.audio_buffer = [] | |
self.buffer_duration = 3.0 # seconds | |
self.last_transcription_time = time.time() | |
self.chunk_size = int(SAMPLE_RATE * CHUNK_DURATION_MS / 1000) | |
def initialize_models(self): | |
"""Initialize the speaker encoder and transcription models""" | |
try: | |
device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device_str}") | |
# Initialize speaker encoder | |
self.encoder = SpeechBrainEncoder(device=device_str) | |
encoder_success = self.encoder.load_model() | |
# Initialize transcriber | |
self.transcriber = WhisperTranscriber(FINAL_TRANSCRIPTION_MODEL) | |
transcriber_success = self.transcriber.load_model() | |
if encoder_success and transcriber_success: | |
self.audio_processor = AudioProcessor(self.encoder) | |
self.speaker_detector = SpeakerChangeDetector( | |
embedding_dim=self.encoder.embedding_dim, | |
change_threshold=self.change_threshold, | |
max_speakers=self.max_speakers | |
) | |
logger.info("Models loaded successfully!") | |
return True | |
else: | |
logger.error("Failed to load models") | |
return False | |
except Exception as e: | |
logger.error(f"Model initialization error: {e}") | |
return False | |
def process_audio_chunk(self, audio_chunk: np.ndarray, sample_rate: int): | |
"""Process individual audio chunk from FastRTC""" | |
if not self.is_running or audio_chunk is None: | |
return | |
try: | |
# Ensure audio chunk is in correct format | |
if isinstance(audio_chunk, np.ndarray): | |
# Ensure mono audio | |
if len(audio_chunk.shape) > 1: | |
audio_chunk = audio_chunk.mean(axis=1) | |
# Normalize audio | |
if audio_chunk.dtype != np.float32: | |
audio_chunk = audio_chunk.astype(np.float32) | |
if np.abs(audio_chunk).max() > 1.0: | |
audio_chunk = audio_chunk / np.abs(audio_chunk).max() | |
# Add to buffer | |
self.audio_buffer.extend(audio_chunk) | |
# Keep buffer to specified duration | |
max_buffer_length = int(self.buffer_duration * sample_rate) | |
if len(self.audio_buffer) > max_buffer_length: | |
self.audio_buffer = self.audio_buffer[-max_buffer_length:] | |
# Process if enough audio accumulated and enough time passed | |
current_time = time.time() | |
if (current_time - self.last_transcription_time > 1.5 and | |
len(self.audio_buffer) > sample_rate * 0.8): # At least 0.8 seconds | |
if not self.audio_queue.full(): | |
self.audio_queue.put((np.array(self.audio_buffer[-int(sample_rate * 2):]), sample_rate)) | |
self.last_transcription_time = current_time | |
except Exception as e: | |
logger.error(f"Audio chunk processing error: {e}") | |
def process_audio_queue(self): | |
"""Process audio from the queue""" | |
while self.is_running: | |
try: | |
audio_data, sample_rate = self.audio_queue.get(timeout=1) | |
if len(audio_data) < 1600: # Skip very short audio | |
continue | |
# Transcribe audio | |
transcription = self.transcriber.transcribe(audio_data, sample_rate) | |
if transcription and len(transcription.strip()) > 0: | |
# Extract speaker embedding | |
speaker_embedding = self.audio_processor.extract_embedding(audio_data) | |
# Detect speaker | |
speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding) | |
# Store results | |
self.full_sentences.append(transcription.strip()) | |
self.sentence_speakers.append(speaker_id) | |
logger.info(f"Processed: Speaker {speaker_id + 1}: {transcription.strip()[:50]}...") | |
except queue.Empty: | |
continue | |
except Exception as e: | |
logger.error(f"Error processing audio queue: {e}") | |
def start_recording(self): | |
"""Start the recording and processing""" | |
if self.encoder is None or self.transcriber is None: | |
return "Please initialize models first!" | |
try: | |
self.is_running = True | |
self.audio_buffer = [] | |
self.last_transcription_time = time.time() | |
# Clear the queue | |
while not self.audio_queue.empty(): | |
try: | |
self.audio_queue.get_nowait() | |
except queue.Empty: | |
break | |
# Start processing thread | |
self.processing_thread = threading.Thread(target=self.process_audio_queue, daemon=True) | |
self.processing_thread.start() | |
logger.info("Recording started successfully!") | |
return "Recording started successfully!" | |
except Exception as e: | |
logger.error(f"Error starting recording: {e}") | |
return f"Error starting recording: {e}" | |
def stop_recording(self): | |
"""Stop the recording process""" | |
self.is_running = False | |
logger.info("Recording stopped!") | |
return "Recording stopped!" | |
def clear_conversation(self): | |
"""Clear all conversation data""" | |
self.full_sentences = [] | |
self.sentence_speakers = [] | |
self.audio_buffer = [] | |
# Clear the queue | |
while not self.audio_queue.empty(): | |
try: | |
self.audio_queue.get_nowait() | |
except queue.Empty: | |
break | |
if self.speaker_detector: | |
self.speaker_detector = SpeakerChangeDetector( | |
embedding_dim=self.encoder.embedding_dim, | |
change_threshold=self.change_threshold, | |
max_speakers=self.max_speakers | |
) | |
return "Conversation cleared!" | |
def update_settings(self, threshold, max_speakers): | |
"""Update speaker detection settings""" | |
self.change_threshold = threshold | |
self.max_speakers = max_speakers | |
if self.speaker_detector: | |
self.speaker_detector.set_change_threshold(threshold) | |
self.speaker_detector.set_max_speakers(max_speakers) | |
return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}" | |
def get_formatted_conversation(self): | |
"""Get the formatted conversation with speaker colors""" | |
try: | |
if not self.full_sentences: | |
return "Waiting for speech input... 🎤" | |
sentences_with_style = [] | |
for i, sentence in enumerate(self.full_sentences[-10:]): # Show last 10 sentences | |
if i >= len(self.sentence_speakers): | |
color = "#FFFFFF" | |
speaker_name = "Unknown" | |
else: | |
speaker_id = self.sentence_speakers[-(10-i) if len(self.sentence_speakers) >= 10 else i] | |
color = self.speaker_detector.get_color_for_speaker(speaker_id) | |
speaker_name = f"Speaker {speaker_id + 1}" | |
sentences_with_style.append( | |
f'<p><span style="color:{color}; font-weight: bold;">{speaker_name}:</span> {sentence}</p>') | |
return "".join(sentences_with_style) | |
except Exception as e: | |
return f"Error formatting conversation: {e}" | |
def get_status_info(self): | |
"""Get current status information""" | |
if not self.speaker_detector: | |
return "Speaker detector not initialized" | |
try: | |
status = self.speaker_detector.get_status_info() | |
queue_size = self.audio_queue.qsize() | |
status_lines = [ | |
f"**Current Speaker:** {status['current_speaker'] + 1}", | |
f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}", | |
f"**Last Similarity:** {status['last_similarity']:.3f}", | |
f"**Change Threshold:** {status['threshold']:.2f}", | |
f"**Total Sentences:** {len(self.full_sentences)}", | |
f"**Buffer Length:** {len(self.audio_buffer)} samples", | |
f"**Queue Size:** {queue_size}", | |
"", | |
"**Speaker Segment Counts:**" | |
] | |
for i in range(status['max_speakers']): | |
color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}" | |
status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}") | |
return "\n".join(status_lines) | |
except Exception as e: | |
return f"Error getting status: {e}" | |
# Global instance | |
diarization_system = FastRTCSpeakerDiarization() | |
def initialize_system(): | |
"""Initialize the diarization system""" | |
success = diarization_system.initialize_models() | |
if success: | |
return "✅ System initialized successfully! Models loaded." | |
else: | |
return "❌ Failed to initialize system. Please check the logs." | |
def start_recording(): | |
"""Start recording and transcription""" | |
return diarization_system.start_recording() | |
def stop_recording(): | |
"""Stop recording and transcription""" | |
return diarization_system.stop_recording() | |
def clear_conversation(): | |
"""Clear the conversation""" | |
return diarization_system.clear_conversation() | |
def update_settings(threshold, max_speakers): | |
"""Update system settings""" | |
return diarization_system.update_settings(threshold, max_speakers) | |
def get_conversation(): | |
"""Get the current conversation""" | |
return diarization_system.get_formatted_conversation() | |
def get_status(): | |
"""Get system status""" | |
return diarization_system.get_status_info() | |
def process_audio_stream(audio_stream): | |
"""Process streaming audio from FastRTC""" | |
if audio_stream is not None and diarization_system.is_running: | |
sample_rate, audio_data = audio_stream | |
diarization_system.process_audio_chunk(audio_data, sample_rate) | |
return get_conversation(), get_status() | |
# Create Gradio interface with FastRTC | |
def create_interface(): | |
with gr.Blocks(title="FastRTC Real-time Speaker Diarization", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# 🎤 FastRTC Real-time Speech Recognition with Speaker Diarization") | |
gr.Markdown("This app uses Hugging Face FastRTC for real-time audio streaming with automatic speaker identification and color-coding.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# FastRTC Audio input for real-time streaming | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
type="numpy", | |
streaming=True, | |
label="🎙️ FastRTC Microphone Input", | |
format="wav", | |
show_download_button=False, | |
container=True, | |
elem_id="fastrtc_audio" | |
) | |
# Main conversation display | |
conversation_output = gr.HTML( | |
value="<i>Click 'Initialize System' and then 'Start Recording' to begin...</i>", | |
label="Live Conversation", | |
elem_id="conversation_display" | |
) | |
# Control buttons | |
with gr.Row(): | |
init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg") | |
start_btn = gr.Button("🎙️ Start Recording", variant="primary", interactive=False, size="lg") | |
stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", interactive=False, size="lg") | |
clear_btn = gr.Button("🗑️ Clear", interactive=False, size="lg") | |
# Status display | |
status_output = gr.Textbox( | |
label="System Status", | |
value="System not initialized", | |
lines=10, | |
interactive=False, | |
show_copy_button=True | |
) | |
with gr.Column(scale=1): | |
# Settings panel | |
gr.Markdown("## ⚙️ Settings") | |
threshold_slider = gr.Slider( | |
minimum=0.1, | |
maximum=0.95, | |
step=0.05, | |
value=DEFAULT_CHANGE_THRESHOLD, | |
label="Speaker Change Sensitivity", | |
info="Lower = more sensitive to changes" | |
) | |
max_speakers_slider = gr.Slider( | |
minimum=2, | |
maximum=ABSOLUTE_MAX_SPEAKERS, | |
step=1, | |
value=DEFAULT_MAX_SPEAKERS, | |
label="Maximum Number of Speakers" | |
) | |
update_settings_btn = gr.Button("Update Settings", variant="secondary") | |
# Speaker color legend | |
gr.Markdown("## 🎨 Speaker Colors") | |
color_info = [] | |
for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)): | |
color_info.append(f'<span style="color:{color}; font-size: 16px;">●</span> Speaker {i+1} ({name})') | |
gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS])) | |
# Performance info | |
gr.Markdown("## 📊 Performance") | |
gr.Markdown(""" | |
- **FastRTC**: Low-latency audio streaming | |
- **Whisper**: distil-large-v3 for transcription | |
- **ECAPA-TDNN**: Speaker embeddings | |
- **Real-time**: ~100ms processing chunks | |
""") | |
# Event handlers | |
def on_initialize(): | |
result = initialize_system() | |
if "successfully" in result: | |
return ( | |
result, # status_output | |
gr.update(interactive=True), # start_btn | |
gr.update(interactive=True), # clear_btn | |
get_conversation(), # conversation_output | |
get_status() # status_output update | |
) | |
else: | |
return ( | |
result, # status_output | |
gr.update(interactive=False), # start_btn | |
gr.update(interactive=False), # clear_btn | |
get_conversation(), # conversation_output | |
get_status() # status_output update | |
) | |
def on_start(): | |
result = start_recording() | |
return ( | |
result, # status_output | |
gr.update(interactive=False), # start_btn | |
gr.update(interactive=True), # stop_btn | |
) | |
def on_stop(): | |
result = stop_recording() | |
return ( | |
result, # status_output | |
gr.update(interactive=True), # start_btn | |
gr.update(interactive=False), # stop_btn | |
) | |
# Auto-refresh function | |
def refresh_display(): | |
return get_conversation(), get_status() | |
# Connect event handlers | |
init_btn.click( | |
on_initialize, | |
outputs=[status_output, start_btn, clear_btn, conversation_output, status_output] | |
) | |
start_btn.click( | |
on_start, | |
outputs=[status_output, start_btn, stop_btn] | |
) | |
stop_btn.click( | |
on_stop, | |
outputs=[status_output, start_btn, stop_btn] | |
) | |
clear_btn.click( | |
clear_conversation, | |
outputs=[status_output] | |
) | |
update_settings_btn.click( | |
update_settings, | |
inputs=[threshold_slider, max_speakers_slider], | |
outputs=[status_output] | |
) | |
# FastRTC streaming audio processing | |
audio_input.stream( | |
process_audio_stream, | |
inputs=[audio_input], | |
outputs=[conversation_output, status_output], | |
stream_every=0.1, # Process every 100ms | |
time_limit=None | |
) | |
# Auto-refresh timer | |
refresh_timer = gr.Timer(2.0) | |
refresh_timer.tick( | |
refresh_display, | |
outputs=[conversation_output, status_output] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |