import gradio as gr import numpy as np import torch import torchaudio from scipy.spatial.distance import cosine import tempfile import os import warnings warnings.filterwarnings("ignore", category=UserWarning) try: from transformers import pipeline except ImportError: print("transformers not found. Install with: pip install transformers") # Configuration class Config: # Audio settings SAMPLE_RATE = 16000 # Speaker detection CHANGE_THRESHOLD = 0.65 MAX_SPEAKERS = 4 MIN_SEGMENT_DURATION = 1.0 EMBEDDING_HISTORY_SIZE = 3 SPEAKER_MEMORY_SIZE = 20 # Console colors for speakers (HTML version) SPEAKER_COLORS = [ "#FFD700", # Gold "#FF6B6B", # Red "#4ECDC4", # Teal "#45B7D1", # Blue "#96CEB4", # Mint "#FFEAA7", # Light Yellow "#DDA0DD", # Plum "#98D8C8", # Mint Green ] class SpeakerEncoder: """Simplified speaker encoder using torchaudio transforms""" def __init__(self, device="cpu"): self.device = device self.embedding_dim = 128 self.model_loaded = False self._setup_model() def _setup_model(self): """Setup a simple MFCC-based feature extractor""" try: self.mfcc_transform = torchaudio.transforms.MFCC( sample_rate=Config.SAMPLE_RATE, n_mfcc=13, melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23} ).to(self.device) self.model_loaded = True print("Simple MFCC-based encoder initialized") except Exception as e: print(f"Error setting up encoder: {e}") self.model_loaded = False def extract_embedding(self, audio): """Extract speaker embedding from audio""" if not self.model_loaded: return np.zeros(self.embedding_dim) try: # Ensure audio is float32 and normalized if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio).float() # Normalize audio if audio.abs().max() > 0: audio = audio / audio.abs().max() # Add batch dimension if needed if audio.dim() == 1: audio = audio.unsqueeze(0) # Extract MFCC features with torch.no_grad(): mfcc = self.mfcc_transform(audio) # Simple statistics-based embedding embedding = torch.cat([ mfcc.mean(dim=2).flatten(), mfcc.std(dim=2).flatten(), mfcc.max(dim=2)[0].flatten(), mfcc.min(dim=2)[0].flatten() ]) # Pad or truncate to fixed size if embedding.size(0) > self.embedding_dim: embedding = embedding[:self.embedding_dim] elif embedding.size(0) < self.embedding_dim: padding = torch.zeros(self.embedding_dim - embedding.size(0)) embedding = torch.cat([embedding, padding]) return embedding.cpu().numpy() except Exception as e: print(f"Error extracting embedding: {e}") return np.zeros(self.embedding_dim) class SpeakerDetector: """Speaker change detection using embeddings""" def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS): self.threshold = threshold self.max_speakers = max_speakers self.current_speaker = 0 self.speaker_embeddings = [[] for _ in range(max_speakers)] self.speaker_centroids = [None] * max_speakers self.active_speakers = {0} def reset(self): """Reset speaker detection state""" self.current_speaker = 0 self.speaker_embeddings = [[] for _ in range(self.max_speakers)] self.speaker_centroids = [None] * self.max_speakers self.active_speakers = {0} def detect_speaker(self, embedding): """Detect current speaker from embedding""" # Initialize first speaker if not self.speaker_embeddings[0]: self.speaker_embeddings[0].append(embedding) self.speaker_centroids[0] = embedding.copy() return 0, 1.0 # Calculate similarity with current speaker current_centroid = self.speaker_centroids[self.current_speaker] if current_centroid is not None: similarity = 1.0 - cosine(embedding, current_centroid) else: similarity = 0.0 # Check for speaker change if similarity < self.threshold: # Find best matching existing speaker best_speaker = self.current_speaker best_similarity = similarity for speaker_id in self.active_speakers: if speaker_id == self.current_speaker: continue centroid = self.speaker_centroids[speaker_id] if centroid is not None: sim = 1.0 - cosine(embedding, centroid) if sim > best_similarity and sim > self.threshold: best_similarity = sim best_speaker = speaker_id # Create new speaker if no good match and slots available if (best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers): for new_id in range(self.max_speakers): if new_id not in self.active_speakers: best_speaker = new_id best_similarity = 0.0 self.active_speakers.add(new_id) break # Update current speaker if changed if best_speaker != self.current_speaker: self.current_speaker = best_speaker similarity = best_similarity # Update speaker model self._update_speaker_model(self.current_speaker, embedding) return self.current_speaker, similarity def _update_speaker_model(self, speaker_id, embedding): """Update speaker model with new embedding""" self.speaker_embeddings[speaker_id].append(embedding) # Keep only recent embeddings if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE: self.speaker_embeddings[speaker_id] = \ self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:] # Update centroid if self.speaker_embeddings[speaker_id]: self.speaker_centroids[speaker_id] = np.mean( self.speaker_embeddings[speaker_id], axis=0 ) class AudioProcessor: """Handles audio processing and transcription""" def __init__(self): self.encoder = SpeakerEncoder() self.detector = SpeakerDetector() # Initialize Whisper model for transcription try: self.transcriber = pipeline( "automatic-speech-recognition", model="openai/whisper-base", chunk_length_s=30, device=0 if torch.cuda.is_available() else -1 ) print("Whisper model loaded successfully") except Exception as e: print(f"Error loading Whisper model: {e}") self.transcriber = None def process_audio_file(self, audio_file): """Process uploaded audio file""" if audio_file is None: return "Please upload an audio file.", "" try: # Reset speaker detection for new file self.detector.reset() # Load audio file waveform, sample_rate = torchaudio.load(audio_file) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample to 16kHz if needed if sample_rate != Config.SAMPLE_RATE: resampler = torchaudio.transforms.Resample(sample_rate, Config.SAMPLE_RATE) waveform = resampler(waveform) # Convert to numpy audio_data = waveform.squeeze().numpy() # Transcribe entire audio if self.transcriber: transcription_result = self.transcriber(audio_file) full_transcription = transcription_result['text'] else: full_transcription = "Transcription service unavailable" # Process audio in chunks for speaker detection chunk_duration = 3.0 # 3 second chunks chunk_samples = int(chunk_duration * Config.SAMPLE_RATE) results = [] for i in range(0, len(audio_data), chunk_samples // 2): # 50% overlap chunk = audio_data[i:i + chunk_samples] if len(chunk) < Config.SAMPLE_RATE: # Skip chunks less than 1 second continue # Extract speaker embedding embedding = self.encoder.extract_embedding(chunk) speaker_id, similarity = self.detector.detect_speaker(embedding) # Get timestamp start_time = i / Config.SAMPLE_RATE end_time = (i + len(chunk)) / Config.SAMPLE_RATE # Transcribe chunk if self.transcriber and len(chunk) > Config.SAMPLE_RATE: # Save chunk temporarily for transcription with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: torchaudio.save(tmp_file.name, torch.tensor(chunk).unsqueeze(0), Config.SAMPLE_RATE) chunk_result = self.transcriber(tmp_file.name) chunk_text = chunk_result['text'].strip() os.unlink(tmp_file.name) # Clean up temp file else: chunk_text = "" if chunk_text: # Only add if there's actual text results.append({ 'speaker_id': speaker_id, 'start_time': start_time, 'end_time': end_time, 'text': chunk_text, 'similarity': similarity }) # Format results formatted_output = self._format_results(results) return formatted_output, full_transcription except Exception as e: return f"Error processing audio: {str(e)}", "" def _format_results(self, results): """Format results with speaker colors""" if not results: return "No speech detected in the audio file." formatted_lines = [] formatted_lines.append("🎤 **Speaker Diarization Results**\n") for result in results: speaker_id = result['speaker_id'] start_time = result['start_time'] end_time = result['end_time'] text = result['text'] similarity = result['similarity'] color = SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)] # Format timestamp start_min, start_sec = divmod(int(start_time), 60) end_min, end_sec = divmod(int(end_time), 60) timestamp = f"[{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}]" # Create colored HTML output formatted_lines.append( f'
' f'Speaker {speaker_id + 1} ' f'{timestamp}
' f'{text}' f'
' ) return "".join(formatted_lines) # Global processor instance processor = AudioProcessor() def process_audio(audio_file, sensitivity): """Process audio file with speaker detection""" if audio_file is None: return "Please upload an audio file.", "" # Update sensitivity processor.detector.threshold = sensitivity # Process the audio diarized_output, full_transcription = processor.process_audio_file(audio_file) return diarized_output, full_transcription # Create Gradio interface def create_interface(): """Create Gradio interface""" with gr.Blocks( theme=gr.themes.Soft(), title="Speaker Diarization & Transcription", css=""" .gradio-container { max-width: 1200px !important; } .speaker-output { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } """ ) as demo: gr.Markdown( """ # 🎙️ Speaker Diarization & Transcription Upload an audio file to automatically detect different speakers and transcribe their speech. The system will identify speaker changes and display each speaker's text in different colors. """ ) with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio( label="Upload Audio File", type="filepath", sources=["upload", "microphone"] ) sensitivity_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.65, step=0.05, label="Speaker Change Sensitivity", info="Lower values = more sensitive to speaker changes" ) process_btn = gr.Button("🎯 Process Audio", variant="primary", size="lg") gr.Markdown( """ ### Instructions: 1. Upload an audio file (WAV, MP3, etc.) 2. Adjust sensitivity if needed 3. Click "Process Audio" 4. View results with speaker colors ### Tips: - Works best with clear speech - Supports multiple file formats - Different speakers shown in different colors - Processing may take a moment for longer files """ ) with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("🎨 Speaker Diarization"): diarized_output = gr.HTML( label="Speaker Diarization Results", elem_classes=["speaker-output"] ) with gr.TabItem("📝 Full Transcription"): full_transcription = gr.Textbox( label="Complete Transcription", lines=15, max_lines=20, show_copy_button=True ) # Event handlers process_btn.click( fn=process_audio, inputs=[audio_input, sensitivity_slider], outputs=[diarized_output, full_transcription], show_progress=True ) # Auto-process when audio is uploaded audio_input.change( fn=process_audio, inputs=[audio_input, sensitivity_slider], outputs=[diarized_output, full_transcription], show_progress=True ) gr.Markdown( """ --- ### About This application uses: - **MFCC features** for speaker embedding extraction - **Cosine similarity** for speaker change detection - **OpenAI Whisper** for speech-to-text transcription - **Gradio** for the web interface **Note**: This is a simplified speaker diarization system. For production use, consider more advanced speaker embedding models like speechbrain or pyannote.audio. """ ) return demo # Create and launch the interface if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )