import gradio as gr import numpy as np import torch import torchaudio import time import os import urllib.request import queue import threading from scipy.spatial.distance import cosine from RealtimeSTT import AudioToTextRecorder # Configuration parameters (kept same as original) SILENCE_THRESHS = [0, 0.4] FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" FINAL_BEAM_SIZE = 5 REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" REALTIME_BEAM_SIZE = 5 TRANSCRIPTION_LANGUAGE = "en" SILERO_SENSITIVITY = 0.4 WEBRTC_SENSITIVITY = 3 MIN_LENGTH_OF_RECORDING = 0.7 PRE_RECORDING_BUFFER_DURATION = 0.35 # Speaker change detection parameters DEFAULT_CHANGE_THRESHOLD = 0.7 EMBEDDING_HISTORY_SIZE = 5 MIN_SEGMENT_DURATION = 1.0 DEFAULT_MAX_SPEAKERS = 4 ABSOLUTE_MAX_SPEAKERS = 10 # Audio parameters FAST_SENTENCE_END = True SAMPLE_RATE = 16000 BUFFER_SIZE = 512 CHANNELS = 1 # Speaker colors for HTML display SPEAKER_COLORS = [ "#FFFF00", "#FF0000", "#00FF00", "#00FFFF", "#FF00FF", "#0000FF", "#FF8000", "#00FF80", "#8000FF", "#FFFFFF" ] SPEAKER_COLOR_NAMES = [ "Yellow", "Red", "Green", "Cyan", "Magenta", "Blue", "Orange", "Spring Green", "Purple", "White" ] class SpeechBrainEncoder: """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" def __init__(self, device="cpu"): self.device = device self.model = None self.embedding_dim = 192 self.model_loaded = False self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") os.makedirs(self.cache_dir, exist_ok=True) def _download_model(self): """Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") if not os.path.exists(model_path): print(f"Downloading ECAPA-TDNN model to {model_path}...") urllib.request.urlretrieve(model_url, model_path) return model_path def load_model(self): """Load the ECAPA-TDNN model""" try: from speechbrain.pretrained import EncoderClassifier model_path = self._download_model() self.model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir=self.cache_dir, run_opts={"device": self.device} ) self.model_loaded = True return True except Exception as e: print(f"Error loading ECAPA-TDNN model: {e}") return False def embed_utterance(self, audio, sr=16000): """Extract speaker embedding from audio""" if not self.model_loaded: raise ValueError("Model not loaded. Call load_model() first.") try: if isinstance(audio, np.ndarray): waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) else: waveform = audio.unsqueeze(0) if sr != 16000: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) with torch.no_grad(): embedding = self.model.encode_batch(waveform) return embedding.squeeze().cpu().numpy() except Exception as e: print(f"Error extracting embedding: {e}") return np.zeros(self.embedding_dim) class AudioProcessor: """Processes audio data to extract speaker embeddings""" def __init__(self, encoder): self.encoder = encoder def extract_embedding(self, audio_int16): try: float_audio = audio_int16.astype(np.float32) / 32768.0 if np.abs(float_audio).max() > 1.0: float_audio = float_audio / np.abs(float_audio).max() embedding = self.encoder.embed_utterance(float_audio) return embedding except Exception as e: print(f"Embedding extraction error: {e}") return np.zeros(self.encoder.embedding_dim) class SpeakerChangeDetector: """Speaker change detector with configurable number of speakers""" def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): self.embedding_dim = embedding_dim self.change_threshold = change_threshold self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) self.current_speaker = 0 self.previous_embeddings = [] self.last_change_time = time.time() self.mean_embeddings = [None] * self.max_speakers self.speaker_embeddings = [[] for _ in range(self.max_speakers)] self.last_similarity = 0.0 self.active_speakers = set([0]) def set_max_speakers(self, max_speakers): """Update the maximum number of speakers""" new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) if new_max < self.max_speakers: for speaker_id in list(self.active_speakers): if speaker_id >= new_max: self.active_speakers.discard(speaker_id) if self.current_speaker >= new_max: self.current_speaker = 0 if new_max > self.max_speakers: self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) else: self.mean_embeddings = self.mean_embeddings[:new_max] self.speaker_embeddings = self.speaker_embeddings[:new_max] self.max_speakers = new_max def set_change_threshold(self, threshold): """Update the threshold for detecting speaker changes""" self.change_threshold = max(0.1, min(threshold, 0.99)) def add_embedding(self, embedding, timestamp=None): """Add a new embedding and check if there's a speaker change""" current_time = timestamp or time.time() if not self.previous_embeddings: self.previous_embeddings.append(embedding) self.speaker_embeddings[self.current_speaker].append(embedding) if self.mean_embeddings[self.current_speaker] is None: self.mean_embeddings[self.current_speaker] = embedding.copy() return self.current_speaker, 1.0 current_mean = self.mean_embeddings[self.current_speaker] if current_mean is not None: similarity = 1.0 - cosine(embedding, current_mean) else: similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) self.last_similarity = similarity time_since_last_change = current_time - self.last_change_time is_speaker_change = False if time_since_last_change >= MIN_SEGMENT_DURATION: if similarity < self.change_threshold: best_speaker = self.current_speaker best_similarity = similarity for speaker_id in range(self.max_speakers): if speaker_id == self.current_speaker: continue speaker_mean = self.mean_embeddings[speaker_id] if speaker_mean is not None: speaker_similarity = 1.0 - cosine(embedding, speaker_mean) if speaker_similarity > best_similarity: best_similarity = speaker_similarity best_speaker = speaker_id if best_speaker != self.current_speaker: is_speaker_change = True self.current_speaker = best_speaker elif len(self.active_speakers) < self.max_speakers: for new_id in range(self.max_speakers): if new_id not in self.active_speakers: is_speaker_change = True self.current_speaker = new_id self.active_speakers.add(new_id) break if is_speaker_change: self.last_change_time = current_time self.previous_embeddings.append(embedding) if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: self.previous_embeddings.pop(0) self.speaker_embeddings[self.current_speaker].append(embedding) self.active_speakers.add(self.current_speaker) if len(self.speaker_embeddings[self.current_speaker]) > 30: self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] if self.speaker_embeddings[self.current_speaker]: self.mean_embeddings[self.current_speaker] = np.mean( self.speaker_embeddings[self.current_speaker], axis=0 ) return self.current_speaker, similarity def get_color_for_speaker(self, speaker_id): """Return color for speaker ID""" if 0 <= speaker_id < len(SPEAKER_COLORS): return SPEAKER_COLORS[speaker_id] return "#FFFFFF" class RealtimeASRDiarization: """Main class for real-time ASR with speaker diarization""" def __init__(self): self.encoder = None self.audio_processor = None self.speaker_detector = None self.recorder = None self.is_recording = False self.full_sentences = [] self.sentence_speakers = [] self.pending_sentences = [] self.last_realtime_text = "" self.sentence_queue = queue.Queue() self.change_threshold = DEFAULT_CHANGE_THRESHOLD self.max_speakers = DEFAULT_MAX_SPEAKERS # Initialize model self.initialize_model() def initialize_model(self): """Initialize the speaker encoder model""" try: device_str = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device_str}") self.encoder = SpeechBrainEncoder(device=device_str) success = self.encoder.load_model() if success: print("ECAPA-TDNN model loaded successfully!") self.audio_processor = AudioProcessor(self.encoder) self.speaker_detector = SpeakerChangeDetector( embedding_dim=self.encoder.embedding_dim, change_threshold=self.change_threshold, max_speakers=self.max_speakers ) # Start sentence processing thread self.sentence_thread = threading.Thread(target=self.process_sentences, daemon=True) self.sentence_thread.start() else: print("Failed to load ECAPA-TDNN model") except Exception as e: print(f"Model initialization error: {e}") def process_sentences(self): """Process sentences in background thread""" while True: try: text, audio_bytes = self.sentence_queue.get(timeout=1) self.process_sentence(text, audio_bytes) except queue.Empty: continue def process_sentence(self, text, audio_bytes): """Process a sentence with speaker diarization""" if self.audio_processor is None or self.speaker_detector is None: return try: # Convert audio data to int16 audio_int16 = np.int16(audio_bytes * 32767) # Extract speaker embedding speaker_embedding = self.audio_processor.extract_embedding(audio_int16) # Store sentence and embedding self.full_sentences.append((text, speaker_embedding)) # Fill in any missing speaker assignments while len(self.sentence_speakers) < len(self.full_sentences) - 1: self.sentence_speakers.append(0) # Detect speaker changes speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding) self.sentence_speakers.append(speaker_id) # Remove from pending if text in self.pending_sentences: self.pending_sentences.remove(text) except Exception as e: print(f"Error processing sentence: {e}") def setup_recorder(self): """Setup the audio recorder""" try: recorder_config = { 'spinner': False, 'use_microphone': False, 'model': FINAL_TRANSCRIPTION_MODEL, 'language': TRANSCRIPTION_LANGUAGE, 'silero_sensitivity': SILERO_SENSITIVITY, 'webrtc_sensitivity': WEBRTC_SENSITIVITY, 'post_speech_silence_duration': SILENCE_THRESHS[1], 'min_length_of_recording': MIN_LENGTH_OF_RECORDING, 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, 'min_gap_between_recordings': 0, 'enable_realtime_transcription': True, 'realtime_processing_pause': 0, 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, 'on_realtime_transcription_update': self.live_text_detected, 'beam_size': FINAL_BEAM_SIZE, 'beam_size_realtime': REALTIME_BEAM_SIZE, 'buffer_size': BUFFER_SIZE, 'sample_rate': SAMPLE_RATE, } self.recorder = AudioToTextRecorder(**recorder_config) return True except Exception as e: print(f"Error setting up recorder: {e}") return False def live_text_detected(self, text): """Handle live text detection""" text = text.strip() if not text: return sentence_delimiters = '.?!。' prob_sentence_end = ( len(self.last_realtime_text) > 0 and text[-1] in sentence_delimiters and self.last_realtime_text[-1] in sentence_delimiters ) self.last_realtime_text = text if prob_sentence_end: if FAST_SENTENCE_END: self.recorder.stop() else: self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0] else: self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1] def process_audio_chunk(self, audio_chunk): """Process incoming audio chunk from FastRTC""" if self.recorder is None: if not self.setup_recorder(): return "Failed to setup recorder" try: # Convert audio to the format expected by the recorder if isinstance(audio_chunk, tuple): sample_rate, audio_data = audio_chunk else: audio_data = audio_chunk sample_rate = SAMPLE_RATE # Ensure audio is in the right format if audio_data.dtype != np.int16: if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: audio_data = (audio_data * 32767).astype(np.int16) else: audio_data = audio_data.astype(np.int16) # Convert to bytes and feed to recorder audio_bytes = audio_data.tobytes() self.recorder.feed_audio(audio_bytes) # Process final text if available def process_final_text(text): text = text.strip() if text: self.pending_sentences.append(text) audio_bytes = self.recorder.last_transcription_bytes self.sentence_queue.put((text, audio_bytes)) # Get transcription self.recorder.text(process_final_text) return self.get_formatted_transcript() except Exception as e: print(f"Error processing audio: {e}") return f"Error: {e}" def get_formatted_transcript(self): """Get formatted transcript with speaker labels""" try: transcript_parts = [] # Add completed sentences with speaker labels for i, (sentence_text, _) in enumerate(self.full_sentences): if i < len(self.sentence_speakers): speaker_id = self.sentence_speakers[i] speaker_label = f"Speaker {speaker_id + 1}" transcript_parts.append(f"{speaker_label}: {sentence_text}") # Add pending sentences for pending in self.pending_sentences: transcript_parts.append(f"[Processing]: {pending}") # Add current live text if self.last_realtime_text: transcript_parts.append(f"[Live]: {self.last_realtime_text}") return "\n".join(transcript_parts) except Exception as e: print(f"Error formatting transcript: {e}") return "Error formatting transcript" def update_settings(self, change_threshold, max_speakers): """Update diarization settings""" self.change_threshold = change_threshold self.max_speakers = max_speakers if self.speaker_detector: self.speaker_detector.set_change_threshold(change_threshold) self.speaker_detector.set_max_speakers(max_speakers) def clear_transcript(self): """Clear all transcript data""" self.full_sentences = [] self.sentence_speakers = [] self.pending_sentences = [] self.last_realtime_text = "" if self.speaker_detector: self.speaker_detector = SpeakerChangeDetector( embedding_dim=self.encoder.embedding_dim, change_threshold=self.change_threshold, max_speakers=self.max_speakers ) # Global instance asr_diarization = RealtimeASRDiarization() def process_audio_stream(audio_chunk, change_threshold, max_speakers): """Process audio stream and return transcript""" # Update settings if changed asr_diarization.update_settings(change_threshold, max_speakers) # Process audio transcript = asr_diarization.process_audio_chunk(audio_chunk) return transcript def clear_transcript(): """Clear the transcript""" asr_diarization.clear_transcript() return "Transcript cleared. Ready for new input..." def create_interface(): """Create Gradio interface with FastRTC""" with gr.Blocks(title="Real-time Speaker Diarization") as iface: gr.Markdown("# Real-time ASR with Speaker Diarization") gr.Markdown("Speak into your microphone to see real-time transcription with speaker labels!") with gr.Row(): with gr.Column(scale=3): # Audio input with FastRTC audio_input = gr.Audio( sources=["microphone"], streaming=True, label="Microphone Input" ) # Transcript output transcript_output = gr.Textbox( label="Live Transcript with Speaker Labels", lines=15, max_lines=20, value="Ready to start transcription...", interactive=False ) with gr.Column(scale=1): gr.Markdown("### Settings") # Speaker change threshold change_threshold = gr.Slider( minimum=0.1, maximum=0.95, value=DEFAULT_CHANGE_THRESHOLD, step=0.05, label="Speaker Change Threshold", info="Lower values = more sensitive to speaker changes" ) # Max speakers max_speakers = gr.Slider( minimum=2, maximum=ABSOLUTE_MAX_SPEAKERS, value=DEFAULT_MAX_SPEAKERS, step=1, label="Maximum Speakers", info="Maximum number of speakers to detect" ) # Clear button clear_btn = gr.Button("Clear Transcript", variant="secondary") gr.Markdown("### Speaker Colors") color_info = "\\n".join([ f"Speaker {i+1}: {SPEAKER_COLOR_NAMES[i]}" for i in range(min(DEFAULT_MAX_SPEAKERS, len(SPEAKER_COLOR_NAMES))) ]) gr.Markdown(color_info) # Set up streaming audio_input.stream( fn=process_audio_stream, inputs=[audio_input, change_threshold, max_speakers], outputs=[transcript_output], show_progress=False ) # Clear button functionality clear_btn.click( fn=clear_transcript, outputs=[transcript_output] ) gr.Markdown(""" ### Instructions: 1. Allow microphone access when prompted 2. Start speaking - transcription will appear in real-time 3. Different speakers will be automatically detected and labeled 4. Adjust the threshold if speaker changes aren't detected properly 5. Use the clear button to reset the transcript ### Notes: - The system works best with clear audio and distinct speakers - It may take a moment to load the speaker recognition model on first use - Lower threshold values make the system more sensitive to speaker changes """) return iface if __name__ == "__main__": # Create and launch the interface iface = create_interface() iface.launch( server_name="0.0.0.0", server_port=7860, share=True )