import gradio as gr import numpy as np import torch import torchaudio import threading import queue import time import os import urllib.request from scipy.spatial.distance import cosine from collections import deque import tempfile import librosa # Configuration parameters FINAL_TRANSCRIPTION_MODEL = "openai/whisper-small" TRANSCRIPTION_LANGUAGE = "en" DEFAULT_CHANGE_THRESHOLD = 0.7 EMBEDDING_HISTORY_SIZE = 5 MIN_SEGMENT_DURATION = 1.0 DEFAULT_MAX_SPEAKERS = 4 ABSOLUTE_MAX_SPEAKERS = 6 SAMPLE_RATE = 16000 # Speaker colors for up to 6 speakers SPEAKER_COLORS = [ "#FFD700", # Gold "#FF6B6B", # Red "#4ECDC4", # Teal "#45B7D1", # Blue "#96CEB4", # Green "#FFEAA7", # Yellow ] SPEAKER_COLOR_NAMES = [ "Gold", "Red", "Teal", "Blue", "Green", "Yellow" ] class SpeechBrainEncoder: """Simplified encoder for speaker embeddings using torch audio features""" def __init__(self, device="cpu"): self.device = device self.embedding_dim = 128 self.model_loaded = True def load_model(self): """Model loading simulation""" return True def embed_utterance(self, audio, sr=16000): """Extract simple spectral features as speaker embedding""" try: if isinstance(audio, np.ndarray): waveform = torch.tensor(audio, dtype=torch.float32) else: waveform = audio if len(waveform.shape) == 1: waveform = waveform.unsqueeze(0) # Resample if needed if sr != 16000: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) # Extract MFCC features as a simple embedding mfcc_transform = torchaudio.transforms.MFCC( sample_rate=16000, n_mfcc=13, melkwargs={'n_mels': 40} ) mfcc = mfcc_transform(waveform) # Take mean across time dimension and flatten embedding = mfcc.mean(dim=2).flatten() # Pad or truncate to fixed size if len(embedding) > self.embedding_dim: embedding = embedding[:self.embedding_dim] elif len(embedding) < self.embedding_dim: padding = torch.zeros(self.embedding_dim - len(embedding)) embedding = torch.cat([embedding, padding]) return embedding.numpy() except Exception as e: print(f"Error extracting embedding: {e}") return np.random.randn(self.embedding_dim) class SpeakerChangeDetector: """Speaker change detector for real-time diarization""" def __init__(self, embedding_dim=128, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): self.embedding_dim = embedding_dim self.change_threshold = change_threshold self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) self.current_speaker = 0 self.previous_embeddings = [] self.last_change_time = time.time() self.mean_embeddings = [None] * self.max_speakers self.speaker_embeddings = [[] for _ in range(self.max_speakers)] self.last_similarity = 0.0 self.active_speakers = set([0]) def set_max_speakers(self, max_speakers): """Update the maximum number of speakers""" new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) if new_max < self.max_speakers: for speaker_id in list(self.active_speakers): if speaker_id >= new_max: self.active_speakers.discard(speaker_id) if self.current_speaker >= new_max: self.current_speaker = 0 if new_max > self.max_speakers: self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) else: self.mean_embeddings = self.mean_embeddings[:new_max] self.speaker_embeddings = self.speaker_embeddings[:new_max] self.max_speakers = new_max def set_change_threshold(self, threshold): """Update the threshold for detecting speaker changes""" self.change_threshold = max(0.1, min(threshold, 0.99)) def add_embedding(self, embedding, timestamp=None): """Add a new embedding and check if there's a speaker change""" current_time = timestamp or time.time() if not self.previous_embeddings: self.previous_embeddings.append(embedding) self.speaker_embeddings[self.current_speaker].append(embedding) if self.mean_embeddings[self.current_speaker] is None: self.mean_embeddings[self.current_speaker] = embedding.copy() return self.current_speaker, 1.0 current_mean = self.mean_embeddings[self.current_speaker] if current_mean is not None: similarity = 1.0 - cosine(embedding, current_mean) else: similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) self.last_similarity = similarity time_since_last_change = current_time - self.last_change_time is_speaker_change = False if time_since_last_change >= MIN_SEGMENT_DURATION: if similarity < self.change_threshold: best_speaker = self.current_speaker best_similarity = similarity for speaker_id in range(self.max_speakers): if speaker_id == self.current_speaker: continue speaker_mean = self.mean_embeddings[speaker_id] if speaker_mean is not None: speaker_similarity = 1.0 - cosine(embedding, speaker_mean) if speaker_similarity > best_similarity: best_similarity = speaker_similarity best_speaker = speaker_id if best_speaker != self.current_speaker: is_speaker_change = True self.current_speaker = best_speaker elif len(self.active_speakers) < self.max_speakers: for new_id in range(self.max_speakers): if new_id not in self.active_speakers: is_speaker_change = True self.current_speaker = new_id self.active_speakers.add(new_id) break if is_speaker_change: self.last_change_time = current_time self.previous_embeddings.append(embedding) if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: self.previous_embeddings.pop(0) self.speaker_embeddings[self.current_speaker].append(embedding) self.active_speakers.add(self.current_speaker) if len(self.speaker_embeddings[self.current_speaker]) > 30: self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] if self.speaker_embeddings[self.current_speaker]: self.mean_embeddings[self.current_speaker] = np.mean( self.speaker_embeddings[self.current_speaker], axis=0 ) return self.current_speaker, similarity def get_color_for_speaker(self, speaker_id): """Return color for speaker ID""" if 0 <= speaker_id < len(SPEAKER_COLORS): return SPEAKER_COLORS[speaker_id] return "#FFFFFF" class RealTimeASRDiarization: """Main class for real-time ASR with speaker diarization""" def __init__(self): self.encoder = SpeechBrainEncoder() self.encoder.load_model() self.speaker_detector = SpeakerChangeDetector() self.transcription_queue = queue.Queue() self.conversation_history = [] self.is_processing = False # Load Whisper model try: import whisper self.whisper_model = whisper.load_model("base") except ImportError: print("Whisper not available, using mock transcription") self.whisper_model = None def transcribe_audio(self, audio_data, sr=16000): """Transcribe audio using Whisper""" try: if self.whisper_model is None: return "Mock transcription: Hello, this is a test." # Ensure audio is the right format if isinstance(audio_data, tuple): sr, audio_data = audio_data if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) # Normalize audio audio_data = audio_data.astype(np.float32) if np.abs(audio_data).max() > 1.0: audio_data = audio_data / np.abs(audio_data).max() # Resample to 16kHz if needed if sr != 16000: audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) # Transcribe result = self.whisper_model.transcribe(audio_data, language="en") return result["text"].strip() except Exception as e: print(f"Transcription error: {e}") return "" def extract_speaker_embedding(self, audio_data, sr=16000): """Extract speaker embedding from audio""" return self.encoder.embed_utterance(audio_data, sr) def process_audio_segment(self, audio_data, sr=16000): """Process an audio segment for transcription and speaker identification""" if len(audio_data) < sr * 0.5: # Skip very short segments return None, None, None # Transcribe the audio transcription = self.transcribe_audio(audio_data, sr) if not transcription: return None, None, None # Extract speaker embedding embedding = self.extract_speaker_embedding(audio_data, sr) # Detect speaker speaker_id, similarity = self.speaker_detector.add_embedding(embedding) return transcription, speaker_id, similarity def update_conversation(self, transcription, speaker_id): """Update conversation history with new transcription""" speaker_name = f"Speaker {speaker_id + 1}" color = self.speaker_detector.get_color_for_speaker(speaker_id) entry = { "speaker": speaker_name, "text": transcription, "color": color, "timestamp": time.time() } self.conversation_history.append(entry) return entry def format_conversation_html(self): """Format conversation history as HTML""" if not self.conversation_history: return "

No conversation yet. Start speaking to see real-time transcription with speaker diarization.

" html_parts = [] for entry in self.conversation_history: html_parts.append( f'

' f'{entry["speaker"]}: {entry["text"]}

' ) return "".join(html_parts) def get_status_info(self): """Get current status information""" status = { "active_speakers": len(self.speaker_detector.active_speakers), "max_speakers": self.speaker_detector.max_speakers, "current_speaker": self.speaker_detector.current_speaker + 1, "total_segments": len(self.conversation_history), "threshold": self.speaker_detector.change_threshold } return status def clear_conversation(self): """Clear conversation history and reset speaker detector""" self.conversation_history = [] self.speaker_detector = SpeakerChangeDetector( change_threshold=self.speaker_detector.change_threshold, max_speakers=self.speaker_detector.max_speakers ) def set_parameters(self, threshold, max_speakers): """Update parameters""" self.speaker_detector.set_change_threshold(threshold) self.speaker_detector.set_max_speakers(max_speakers) # Global instance asr_system = RealTimeASRDiarization() def process_audio_realtime(audio_data, threshold, max_speakers): """Process audio in real-time""" global asr_system if audio_data is None: return asr_system.format_conversation_html(), get_status_display() # Update parameters asr_system.set_parameters(threshold, max_speakers) try: # Process the audio segment sr, audio_array = audio_data # Convert to float32 and normalize if audio_array.dtype != np.float32: audio_array = audio_array.astype(np.float32) if audio_array.dtype == np.int16: audio_array = audio_array / 32768.0 elif audio_array.dtype == np.int32: audio_array = audio_array / 2147483648.0 # Process the audio segment transcription, speaker_id, similarity = asr_system.process_audio_segment(audio_array, sr) if transcription and speaker_id is not None: # Update conversation asr_system.update_conversation(transcription, speaker_id) except Exception as e: print(f"Error processing audio: {e}") return asr_system.format_conversation_html(), get_status_display() def get_status_display(): """Get formatted status display""" status = asr_system.get_status_info() status_html = f"""
Status:
Current Speaker: {status['current_speaker']}
Active Speakers: {status['active_speakers']} / {status['max_speakers']}
Total Segments: {status['total_segments']}
Threshold: {status['threshold']:.2f}
""" return status_html def clear_conversation(): """Clear the conversation""" global asr_system asr_system.clear_conversation() return asr_system.format_conversation_html(), get_status_display() def create_interface(): """Create Gradio interface""" with gr.Blocks( title="Real-time ASR with Speaker Diarization", theme=gr.themes.Soft(), css=""" .conversation-box { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; background-color: #f9f9f9; } .status-box { border: 1px solid #ccc; padding: 10px; background-color: #f0f0f0; } """ ) as demo: gr.Markdown( """ # 🎤 Real-time ASR with Live Speaker Diarization This application provides real-time speech recognition with speaker diarization. It can distinguish between different speakers and display their conversations in different colors. **Instructions:** 1. Adjust the speaker change threshold and maximum speakers 2. Click the microphone button to start recording 3. Speak naturally - the system will detect speaker changes and transcribe speech 4. Each speaker will be assigned a different color """ ) with gr.Row(): with gr.Column(scale=3): # Main conversation display conversation_display = gr.HTML( value="

Click the microphone to start recording...

", elem_classes=["conversation-box"] ) # Audio input audio_input = gr.Audio( source="microphone", type="numpy", streaming=True, label="🎤 Microphone Input" ) with gr.Column(scale=1): # Controls gr.Markdown("### Controls") threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, value=DEFAULT_CHANGE_THRESHOLD, step=0.05, label="Speaker Change Threshold", info="Higher values = less sensitive to speaker changes" ) max_speakers_slider = gr.Slider( minimum=2, maximum=ABSOLUTE_MAX_SPEAKERS, value=DEFAULT_MAX_SPEAKERS, step=1, label="Maximum Speakers", info="Maximum number of different speakers to detect" ) clear_btn = gr.Button("🗑️ Clear Conversation", variant="secondary") # Status display gr.Markdown("### Status") status_display = gr.HTML( value=get_status_display(), elem_classes=["status-box"] ) # Speaker color legend gr.Markdown("### Speaker Colors") legend_html = "" for i in range(ABSOLUTE_MAX_SPEAKERS): color = SPEAKER_COLORS[i] name = SPEAKER_COLOR_NAMES[i] legend_html += f'

● Speaker {i+1} ({name})

' gr.HTML(legend_html) # Event handlers audio_input.change( fn=process_audio_realtime, inputs=[audio_input, threshold_slider, max_speakers_slider], outputs=[conversation_display, status_display], show_progress=False ) clear_btn.click( fn=clear_conversation, outputs=[conversation_display, status_display] ) # Update status periodically demo.load( fn=lambda: (asr_system.format_conversation_html(), get_status_display()), outputs=[conversation_display, status_display], every=2 ) return demo if __name__ == "__main__": # Create and launch the interface demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )