import gradio as gr import numpy as np import torch import torchaudio import time import os import urllib.request from scipy.spatial.distance import cosine import threading import queue from collections import deque import asyncio from typing import Generator, Tuple, List, Optional # Configuration parameters (keeping original models) FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" FINAL_BEAM_SIZE = 5 REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" REALTIME_BEAM_SIZE = 5 TRANSCRIPTION_LANGUAGE = "en" SILERO_SENSITIVITY = 0.4 WEBRTC_SENSITIVITY = 3 MIN_LENGTH_OF_RECORDING = 0.7 PRE_RECORDING_BUFFER_DURATION = 0.35 # Speaker change detection parameters DEFAULT_CHANGE_THRESHOLD = 0.7 EMBEDDING_HISTORY_SIZE = 5 MIN_SEGMENT_DURATION = 1.0 DEFAULT_MAX_SPEAKERS = 4 ABSOLUTE_MAX_SPEAKERS = 10 SAMPLE_RATE = 16000 # Speaker labels SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)] class SpeechBrainEncoder: """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" def __init__(self, device="cpu"): self.device = device self.model = None self.embedding_dim = 192 self.model_loaded = False self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") os.makedirs(self.cache_dir, exist_ok=True) def load_model(self): """Load the ECAPA-TDNN model""" try: from speechbrain.pretrained import EncoderClassifier self.model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir=self.cache_dir, run_opts={"device": self.device} ) self.model_loaded = True return True except Exception as e: print(f"Error loading ECAPA-TDNN model: {e}") return False def embed_utterance(self, audio, sr=16000): """Extract speaker embedding from audio""" if not self.model_loaded: raise ValueError("Model not loaded. Call load_model() first.") try: if isinstance(audio, np.ndarray): waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) else: waveform = audio.unsqueeze(0) if sr != 16000: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) with torch.no_grad(): embedding = self.model.encode_batch(waveform) return embedding.squeeze().cpu().numpy() except Exception as e: print(f"Error extracting embedding: {e}") return np.zeros(self.embedding_dim) class SpeakerChangeDetector: """Speaker change detector that supports configurable number of speakers""" def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): self.embedding_dim = embedding_dim self.change_threshold = change_threshold self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) self.current_speaker = 0 self.previous_embeddings = [] self.last_change_time = time.time() self.mean_embeddings = [None] * self.max_speakers self.speaker_embeddings = [[] for _ in range(self.max_speakers)] self.last_similarity = 0.0 self.active_speakers = set([0]) def set_max_speakers(self, max_speakers): """Update the maximum number of speakers""" new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) if new_max < self.max_speakers: for speaker_id in list(self.active_speakers): if speaker_id >= new_max: self.active_speakers.discard(speaker_id) if self.current_speaker >= new_max: self.current_speaker = 0 if new_max > self.max_speakers: self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) else: self.mean_embeddings = self.mean_embeddings[:new_max] self.speaker_embeddings = self.speaker_embeddings[:new_max] self.max_speakers = new_max def set_change_threshold(self, threshold): """Update the threshold for detecting speaker changes""" self.change_threshold = max(0.1, min(threshold, 0.99)) def add_embedding(self, embedding, timestamp=None): """Add a new embedding and check if there's a speaker change""" current_time = timestamp or time.time() if not self.previous_embeddings: self.previous_embeddings.append(embedding) self.speaker_embeddings[self.current_speaker].append(embedding) if self.mean_embeddings[self.current_speaker] is None: self.mean_embeddings[self.current_speaker] = embedding.copy() return self.current_speaker, 1.0 current_mean = self.mean_embeddings[self.current_speaker] if current_mean is not None: similarity = 1.0 - cosine(embedding, current_mean) else: similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) self.last_similarity = similarity time_since_last_change = current_time - self.last_change_time is_speaker_change = False if time_since_last_change >= MIN_SEGMENT_DURATION: if similarity < self.change_threshold: best_speaker = self.current_speaker best_similarity = similarity for speaker_id in range(self.max_speakers): if speaker_id == self.current_speaker: continue speaker_mean = self.mean_embeddings[speaker_id] if speaker_mean is not None: speaker_similarity = 1.0 - cosine(embedding, speaker_mean) if speaker_similarity > best_similarity: best_similarity = speaker_similarity best_speaker = speaker_id if best_speaker != self.current_speaker: is_speaker_change = True self.current_speaker = best_speaker elif len(self.active_speakers) < self.max_speakers: for new_id in range(self.max_speakers): if new_id not in self.active_speakers: is_speaker_change = True self.current_speaker = new_id self.active_speakers.add(new_id) break if is_speaker_change: self.last_change_time = current_time self.previous_embeddings.append(embedding) if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: self.previous_embeddings.pop(0) self.speaker_embeddings[self.current_speaker].append(embedding) self.active_speakers.add(self.current_speaker) if len(self.speaker_embeddings[self.current_speaker]) > 30: self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] if self.speaker_embeddings[self.current_speaker]: self.mean_embeddings[self.current_speaker] = np.mean( self.speaker_embeddings[self.current_speaker], axis=0 ) return self.current_speaker, similarity class AudioProcessor: """Processes audio data to extract speaker embeddings""" def __init__(self, encoder): self.encoder = encoder def extract_embedding(self, audio_data): try: # Ensure audio is float32 and normalized if audio_data.dtype != np.float32: audio_data = audio_data.astype(np.float32) # Normalize if needed if np.abs(audio_data).max() > 1.0: audio_data = audio_data / np.abs(audio_data).max() # Extract embedding using the loaded encoder embedding = self.encoder.embed_utterance(audio_data) return embedding except Exception as e: print(f"Embedding extraction error: {e}") return np.zeros(self.encoder.embedding_dim) class RealTimeSpeakerDiarization: """Main class for real-time speaker diarization""" def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): self.encoder = None self.audio_processor = None self.speaker_detector = None self.change_threshold = change_threshold self.max_speakers = max_speakers self.transcript_history = [] self.is_initialized = False # Threading components self.audio_queue = queue.Queue() self.processing_thread = None self.running = False async def initialize(self): """Initialize the speaker diarization system""" if self.is_initialized: return True try: device_str = "cuda" if torch.cuda.is_available() else "cpu" print(f"Initializing ECAPA-TDNN model on {device_str}...") self.encoder = SpeechBrainEncoder(device=device_str) success = self.encoder.load_model() if not success: return False self.audio_processor = AudioProcessor(self.encoder) self.speaker_detector = SpeakerChangeDetector( embedding_dim=self.encoder.embedding_dim, change_threshold=self.change_threshold, max_speakers=self.max_speakers ) self.is_initialized = True print("Speaker diarization system initialized successfully!") return True except Exception as e: print(f"Initialization error: {e}") return False def update_settings(self, change_threshold, max_speakers): """Update diarization settings""" self.change_threshold = change_threshold self.max_speakers = max_speakers if self.speaker_detector: self.speaker_detector.set_change_threshold(change_threshold) self.speaker_detector.set_max_speakers(max_speakers) def process_audio_segment(self, audio_data: np.ndarray, text: str) -> Tuple[int, str]: """Process an audio segment and return speaker ID and formatted text""" if not self.is_initialized: return 0, text try: # Extract speaker embedding embedding = self.audio_processor.extract_embedding(audio_data) # Detect speaker speaker_id, similarity = self.speaker_detector.add_embedding(embedding) # Format text with speaker label speaker_label = SPEAKER_LABELS[speaker_id] formatted_text = f"{speaker_label}: {text}" return speaker_id, formatted_text except Exception as e: print(f"Error processing audio segment: {e}") return 0, f"Speaker 1: {text}" def get_transcript_history(self): """Get the formatted transcript history""" return "\n".join(self.transcript_history) def add_to_transcript(self, formatted_text: str): """Add formatted text to transcript history""" self.transcript_history.append(formatted_text) # Keep only last 50 entries to prevent memory issues if len(self.transcript_history) > 50: self.transcript_history = self.transcript_history[-50:] def clear_transcript(self): """Clear transcript history and reset speaker detector""" self.transcript_history = [] if self.speaker_detector: self.speaker_detector = SpeakerChangeDetector( embedding_dim=self.encoder.embedding_dim, change_threshold=self.change_threshold, max_speakers=self.max_speakers ) # Global instance diarization_system = RealTimeSpeakerDiarization() async def initialize_system(): """Initialize the diarization system""" success = await diarization_system.initialize() if success: return "✅ Speaker diarization system initialized successfully!" else: return "❌ Failed to initialize speaker diarization system. Please check your setup." def process_audio_with_transcript(audio_data, sample_rate, transcription_text, change_threshold, max_speakers): """Process audio with transcription for speaker diarization""" if not diarization_system.is_initialized: return "Please initialize the system first.", "" if audio_data is None or transcription_text.strip() == "": return diarization_system.get_transcript_history(), "" try: # Update settings diarization_system.update_settings(change_threshold, max_speakers) # Convert audio to the right format if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) # Convert to mono # Resample if needed if sample_rate != SAMPLE_RATE: audio_data = torchaudio.functional.resample( torch.tensor(audio_data), sample_rate, SAMPLE_RATE ).numpy() # Process the audio segment speaker_id, formatted_text = diarization_system.process_audio_segment(audio_data, transcription_text) # Add to transcript diarization_system.add_to_transcript(formatted_text) # Return updated transcript and current speaker info transcript = diarization_system.get_transcript_history() current_speaker_info = f"Current Speaker: {SPEAKER_LABELS[speaker_id]}" return transcript, current_speaker_info except Exception as e: error_msg = f"Error processing audio: {str(e)}" return diarization_system.get_transcript_history(), error_msg def clear_conversation(): """Clear the conversation transcript""" diarization_system.clear_transcript() return "", "Conversation cleared." def create_gradio_interface(): """Create and return the Gradio interface""" with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎙️ Real-time Speaker Diarization with ASR") gr.Markdown("Upload audio with transcription to perform real-time speaker diarization.") # Initialization section with gr.Row(): init_btn = gr.Button("🚀 Initialize System", variant="primary") init_status = gr.Textbox(label="Initialization Status", interactive=False) # Settings section with gr.Row(): with gr.Column(): change_threshold = gr.Slider( minimum=0.1, maximum=0.9, value=DEFAULT_CHANGE_THRESHOLD, step=0.05, label="Speaker Change Threshold", info="Lower values = more sensitive to speaker changes" ) with gr.Column(): max_speakers = gr.Slider( minimum=2, maximum=ABSOLUTE_MAX_SPEAKERS, value=DEFAULT_MAX_SPEAKERS, step=1, label="Maximum Number of Speakers", info="Maximum number of speakers to detect" ) # Audio input and transcription with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="Audio Input", type="numpy", format="wav" ) transcription_input = gr.Textbox( label="Transcription Text", placeholder="Enter the transcription of the audio...", lines=3 ) process_btn = gr.Button("🎯 Process Audio", variant="secondary") with gr.Column(): current_speaker = gr.Textbox( label="Current Speaker", interactive=False ) clear_btn = gr.Button("🗑️ Clear Conversation", variant="stop") # Output section transcript_output = gr.Textbox( label="Live Transcript with Speaker Labels", lines=15, max_lines=20, interactive=False, placeholder="Processed transcript will appear here..." ) # Event handlers init_btn.click( fn=initialize_system, outputs=[init_status] ) process_btn.click( fn=process_audio_with_transcript, inputs=[ audio_input, gr.Number(value=SAMPLE_RATE, visible=False), # Hidden sample rate transcription_input, change_threshold, max_speakers ], outputs=[transcript_output, current_speaker] ) clear_btn.click( fn=clear_conversation, outputs=[transcript_output, current_speaker] ) # Auto-process when audio and transcription are provided audio_input.change( fn=process_audio_with_transcript, inputs=[ audio_input, gr.Number(value=SAMPLE_RATE, visible=False), transcription_input, change_threshold, max_speakers ], outputs=[transcript_output, current_speaker] ) # Instructions gr.Markdown(""" ## Instructions: 1. **Initialize**: Click "Initialize System" to load the speaker diarization models 2. **Upload Audio**: Upload an audio file (WAV format recommended) 3. **Add Transcription**: Enter the transcription text for the audio 4. **Adjust Settings**: - **Speaker Change Threshold**: Lower values detect speaker changes more easily - **Max Speakers**: Set the maximum number of speakers you expect 5. **Process**: Click "Process Audio" or the system will auto-process 6. **View Results**: See the transcript with speaker labels (Speaker 1, Speaker 2, etc.) ## Tips: - For similar-sounding speakers, increase the threshold (0.6-0.8) - For different-sounding speakers, lower threshold works better (0.3-0.5) - The system maintains speaker consistency across the conversation - Use "Clear Conversation" to reset the speaker memory """) return demo if __name__ == "__main__": # Create and launch the Gradio interface demo = create_gradio_interface() demo.launch( share=True, server_name="0.0.0.0", server_port=7860, show_error=True )