Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
from scipy.spatial.distance import cosine | |
import tempfile | |
import os | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
try: | |
from transformers import pipeline | |
except ImportError: | |
print("transformers not found. Install with: pip install transformers") | |
# Configuration | |
class Config: | |
# Audio settings | |
SAMPLE_RATE = 16000 | |
# Speaker detection | |
CHANGE_THRESHOLD = 0.65 | |
MAX_SPEAKERS = 4 | |
MIN_SEGMENT_DURATION = 1.0 | |
EMBEDDING_HISTORY_SIZE = 3 | |
SPEAKER_MEMORY_SIZE = 20 | |
# Console colors for speakers (HTML version) | |
SPEAKER_COLORS = [ | |
"#FFD700", # Gold | |
"#FF6B6B", # Red | |
"#4ECDC4", # Teal | |
"#45B7D1", # Blue | |
"#96CEB4", # Mint | |
"#FFEAA7", # Light Yellow | |
"#DDA0DD", # Plum | |
"#98D8C8", # Mint Green | |
] | |
class SpeakerEncoder: | |
"""Simplified speaker encoder using torchaudio transforms""" | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.embedding_dim = 128 | |
self.model_loaded = False | |
self._setup_model() | |
def _setup_model(self): | |
"""Setup a simple MFCC-based feature extractor""" | |
try: | |
self.mfcc_transform = torchaudio.transforms.MFCC( | |
sample_rate=Config.SAMPLE_RATE, | |
n_mfcc=13, | |
melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23} | |
).to(self.device) | |
self.model_loaded = True | |
print("Simple MFCC-based encoder initialized") | |
except Exception as e: | |
print(f"Error setting up encoder: {e}") | |
self.model_loaded = False | |
def extract_embedding(self, audio): | |
"""Extract speaker embedding from audio""" | |
if not self.model_loaded: | |
return np.zeros(self.embedding_dim) | |
try: | |
# Ensure audio is float32 and normalized | |
if isinstance(audio, np.ndarray): | |
audio = torch.from_numpy(audio).float() | |
# Normalize audio | |
if audio.abs().max() > 0: | |
audio = audio / audio.abs().max() | |
# Add batch dimension if needed | |
if audio.dim() == 1: | |
audio = audio.unsqueeze(0) | |
# Extract MFCC features | |
with torch.no_grad(): | |
mfcc = self.mfcc_transform(audio) | |
# Simple statistics-based embedding | |
embedding = torch.cat([ | |
mfcc.mean(dim=2).flatten(), | |
mfcc.std(dim=2).flatten(), | |
mfcc.max(dim=2)[0].flatten(), | |
mfcc.min(dim=2)[0].flatten() | |
]) | |
# Pad or truncate to fixed size | |
if embedding.size(0) > self.embedding_dim: | |
embedding = embedding[:self.embedding_dim] | |
elif embedding.size(0) < self.embedding_dim: | |
padding = torch.zeros(self.embedding_dim - embedding.size(0)) | |
embedding = torch.cat([embedding, padding]) | |
return embedding.cpu().numpy() | |
except Exception as e: | |
print(f"Error extracting embedding: {e}") | |
return np.zeros(self.embedding_dim) | |
class SpeakerDetector: | |
"""Speaker change detection using embeddings""" | |
def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS): | |
self.threshold = threshold | |
self.max_speakers = max_speakers | |
self.current_speaker = 0 | |
self.speaker_embeddings = [[] for _ in range(max_speakers)] | |
self.speaker_centroids = [None] * max_speakers | |
self.active_speakers = {0} | |
def reset(self): | |
"""Reset speaker detection state""" | |
self.current_speaker = 0 | |
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
self.speaker_centroids = [None] * self.max_speakers | |
self.active_speakers = {0} | |
def detect_speaker(self, embedding): | |
"""Detect current speaker from embedding""" | |
# Initialize first speaker | |
if not self.speaker_embeddings[0]: | |
self.speaker_embeddings[0].append(embedding) | |
self.speaker_centroids[0] = embedding.copy() | |
return 0, 1.0 | |
# Calculate similarity with current speaker | |
current_centroid = self.speaker_centroids[self.current_speaker] | |
if current_centroid is not None: | |
similarity = 1.0 - cosine(embedding, current_centroid) | |
else: | |
similarity = 0.0 | |
# Check for speaker change | |
if similarity < self.threshold: | |
# Find best matching existing speaker | |
best_speaker = self.current_speaker | |
best_similarity = similarity | |
for speaker_id in self.active_speakers: | |
if speaker_id == self.current_speaker: | |
continue | |
centroid = self.speaker_centroids[speaker_id] | |
if centroid is not None: | |
sim = 1.0 - cosine(embedding, centroid) | |
if sim > best_similarity and sim > self.threshold: | |
best_similarity = sim | |
best_speaker = speaker_id | |
# Create new speaker if no good match and slots available | |
if (best_speaker == self.current_speaker and | |
len(self.active_speakers) < self.max_speakers): | |
for new_id in range(self.max_speakers): | |
if new_id not in self.active_speakers: | |
best_speaker = new_id | |
best_similarity = 0.0 | |
self.active_speakers.add(new_id) | |
break | |
# Update current speaker if changed | |
if best_speaker != self.current_speaker: | |
self.current_speaker = best_speaker | |
similarity = best_similarity | |
# Update speaker model | |
self._update_speaker_model(self.current_speaker, embedding) | |
return self.current_speaker, similarity | |
def _update_speaker_model(self, speaker_id, embedding): | |
"""Update speaker model with new embedding""" | |
self.speaker_embeddings[speaker_id].append(embedding) | |
# Keep only recent embeddings | |
if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE: | |
self.speaker_embeddings[speaker_id] = \ | |
self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:] | |
# Update centroid | |
if self.speaker_embeddings[speaker_id]: | |
self.speaker_centroids[speaker_id] = np.mean( | |
self.speaker_embeddings[speaker_id], axis=0 | |
) | |
class AudioProcessor: | |
"""Handles audio processing and transcription""" | |
def __init__(self): | |
self.encoder = SpeakerEncoder() | |
self.detector = SpeakerDetector() | |
# Initialize Whisper model for transcription | |
try: | |
self.transcriber = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-base", | |
chunk_length_s=30, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
print("Whisper model loaded successfully") | |
except Exception as e: | |
print(f"Error loading Whisper model: {e}") | |
self.transcriber = None | |
def process_audio_file(self, audio_file): | |
"""Process uploaded audio file""" | |
if audio_file is None: | |
return "Please upload an audio file.", "" | |
try: | |
# Reset speaker detection for new file | |
self.detector.reset() | |
# Load audio file | |
waveform, sample_rate = torchaudio.load(audio_file) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
# Resample to 16kHz if needed | |
if sample_rate != Config.SAMPLE_RATE: | |
resampler = torchaudio.transforms.Resample(sample_rate, Config.SAMPLE_RATE) | |
waveform = resampler(waveform) | |
# Convert to numpy | |
audio_data = waveform.squeeze().numpy() | |
# Transcribe entire audio | |
if self.transcriber: | |
transcription_result = self.transcriber(audio_file) | |
full_transcription = transcription_result['text'] | |
else: | |
full_transcription = "Transcription service unavailable" | |
# Process audio in chunks for speaker detection | |
chunk_duration = 3.0 # 3 second chunks | |
chunk_samples = int(chunk_duration * Config.SAMPLE_RATE) | |
results = [] | |
for i in range(0, len(audio_data), chunk_samples // 2): # 50% overlap | |
chunk = audio_data[i:i + chunk_samples] | |
if len(chunk) < Config.SAMPLE_RATE: # Skip chunks less than 1 second | |
continue | |
# Extract speaker embedding | |
embedding = self.encoder.extract_embedding(chunk) | |
speaker_id, similarity = self.detector.detect_speaker(embedding) | |
# Get timestamp | |
start_time = i / Config.SAMPLE_RATE | |
end_time = (i + len(chunk)) / Config.SAMPLE_RATE | |
# Transcribe chunk | |
if self.transcriber and len(chunk) > Config.SAMPLE_RATE: | |
# Save chunk temporarily for transcription | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: | |
torchaudio.save(tmp_file.name, torch.tensor(chunk).unsqueeze(0), Config.SAMPLE_RATE) | |
chunk_result = self.transcriber(tmp_file.name) | |
chunk_text = chunk_result['text'].strip() | |
os.unlink(tmp_file.name) # Clean up temp file | |
else: | |
chunk_text = "" | |
if chunk_text: # Only add if there's actual text | |
results.append({ | |
'speaker_id': speaker_id, | |
'start_time': start_time, | |
'end_time': end_time, | |
'text': chunk_text, | |
'similarity': similarity | |
}) | |
# Format results | |
formatted_output = self._format_results(results) | |
return formatted_output, full_transcription | |
except Exception as e: | |
return f"Error processing audio: {str(e)}", "" | |
def _format_results(self, results): | |
"""Format results with speaker colors""" | |
if not results: | |
return "No speech detected in the audio file." | |
formatted_lines = [] | |
formatted_lines.append("๐ค **Speaker Diarization Results**\n") | |
for result in results: | |
speaker_id = result['speaker_id'] | |
start_time = result['start_time'] | |
end_time = result['end_time'] | |
text = result['text'] | |
similarity = result['similarity'] | |
color = SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)] | |
# Format timestamp | |
start_min, start_sec = divmod(int(start_time), 60) | |
end_min, end_sec = divmod(int(end_time), 60) | |
timestamp = f"[{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}]" | |
# Create colored HTML output | |
formatted_lines.append( | |
f'<div style="margin-bottom: 10px; padding: 8px; border-left: 4px solid {color}; background-color: {color}20;">' | |
f'<strong style="color: {color};">Speaker {speaker_id + 1}</strong> ' | |
f'<span style="color: #666; font-size: 0.9em;">{timestamp}</span><br>' | |
f'<span style="color: #333;">{text}</span>' | |
f'</div>' | |
) | |
return "".join(formatted_lines) | |
# Global processor instance | |
processor = AudioProcessor() | |
def process_audio(audio_file, sensitivity): | |
"""Process audio file with speaker detection""" | |
if audio_file is None: | |
return "Please upload an audio file.", "" | |
# Update sensitivity | |
processor.detector.threshold = sensitivity | |
# Process the audio | |
diarized_output, full_transcription = processor.process_audio_file(audio_file) | |
return diarized_output, full_transcription | |
# Create Gradio interface | |
def create_interface(): | |
"""Create Gradio interface""" | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="Speaker Diarization & Transcription", | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.speaker-output { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
""" | |
) as demo: | |
gr.Markdown( | |
""" | |
# ๐๏ธ Speaker Diarization & Transcription | |
Upload an audio file to automatically detect different speakers and transcribe their speech. | |
The system will identify speaker changes and display each speaker's text in different colors. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
audio_input = gr.Audio( | |
label="Upload Audio File", | |
type="filepath", | |
sources=["upload", "microphone"] | |
) | |
sensitivity_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.65, | |
step=0.05, | |
label="Speaker Change Sensitivity", | |
info="Lower values = more sensitive to speaker changes" | |
) | |
process_btn = gr.Button("๐ฏ Process Audio", variant="primary", size="lg") | |
gr.Markdown( | |
""" | |
### Instructions: | |
1. Upload an audio file (WAV, MP3, etc.) | |
2. Adjust sensitivity if needed | |
3. Click "Process Audio" | |
4. View results with speaker colors | |
### Tips: | |
- Works best with clear speech | |
- Supports multiple file formats | |
- Different speakers shown in different colors | |
- Processing may take a moment for longer files | |
""" | |
) | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("๐จ Speaker Diarization"): | |
diarized_output = gr.HTML( | |
label="Speaker Diarization Results", | |
elem_classes=["speaker-output"] | |
) | |
with gr.TabItem("๐ Full Transcription"): | |
full_transcription = gr.Textbox( | |
label="Complete Transcription", | |
lines=15, | |
max_lines=20, | |
show_copy_button=True | |
) | |
# Event handlers | |
process_btn.click( | |
fn=process_audio, | |
inputs=[audio_input, sensitivity_slider], | |
outputs=[diarized_output, full_transcription], | |
show_progress=True | |
) | |
# Auto-process when audio is uploaded | |
audio_input.change( | |
fn=process_audio, | |
inputs=[audio_input, sensitivity_slider], | |
outputs=[diarized_output, full_transcription], | |
show_progress=True | |
) | |
gr.Markdown( | |
""" | |
--- | |
### About | |
This application uses: | |
- **MFCC features** for speaker embedding extraction | |
- **Cosine similarity** for speaker change detection | |
- **OpenAI Whisper** for speech-to-text transcription | |
- **Gradio** for the web interface | |
**Note**: This is a simplified speaker diarization system. For production use, | |
consider more advanced speaker embedding models like speechbrain or pyannote.audio. | |
""" | |
) | |
return demo | |
# Create and launch the interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) | |