Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
import threading | |
import queue | |
import time | |
import os | |
import urllib.request | |
from scipy.spatial.distance import cosine | |
from collections import deque | |
import tempfile | |
import librosa | |
# Configuration parameters | |
FINAL_TRANSCRIPTION_MODEL = "openai/whisper-small" | |
TRANSCRIPTION_LANGUAGE = "en" | |
DEFAULT_CHANGE_THRESHOLD = 0.7 | |
EMBEDDING_HISTORY_SIZE = 5 | |
MIN_SEGMENT_DURATION = 1.0 | |
DEFAULT_MAX_SPEAKERS = 4 | |
ABSOLUTE_MAX_SPEAKERS = 6 | |
SAMPLE_RATE = 16000 | |
# Speaker colors for up to 6 speakers | |
SPEAKER_COLORS = [ | |
"#FFD700", # Gold | |
"#FF6B6B", # Red | |
"#4ECDC4", # Teal | |
"#45B7D1", # Blue | |
"#96CEB4", # Green | |
"#FFEAA7", # Yellow | |
] | |
SPEAKER_COLOR_NAMES = [ | |
"Gold", "Red", "Teal", "Blue", "Green", "Yellow" | |
] | |
class SpeechBrainEncoder: | |
"""Simplified encoder for speaker embeddings using torch audio features""" | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.embedding_dim = 128 | |
self.model_loaded = True | |
def load_model(self): | |
"""Model loading simulation""" | |
return True | |
def embed_utterance(self, audio, sr=16000): | |
"""Extract simple spectral features as speaker embedding""" | |
try: | |
if isinstance(audio, np.ndarray): | |
waveform = torch.tensor(audio, dtype=torch.float32) | |
else: | |
waveform = audio | |
if len(waveform.shape) == 1: | |
waveform = waveform.unsqueeze(0) | |
# Resample if needed | |
if sr != 16000: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
# Extract MFCC features as a simple embedding | |
mfcc_transform = torchaudio.transforms.MFCC( | |
sample_rate=16000, | |
n_mfcc=13, | |
melkwargs={'n_mels': 40} | |
) | |
mfcc = mfcc_transform(waveform) | |
# Take mean across time dimension and flatten | |
embedding = mfcc.mean(dim=2).flatten() | |
# Pad or truncate to fixed size | |
if len(embedding) > self.embedding_dim: | |
embedding = embedding[:self.embedding_dim] | |
elif len(embedding) < self.embedding_dim: | |
padding = torch.zeros(self.embedding_dim - len(embedding)) | |
embedding = torch.cat([embedding, padding]) | |
return embedding.numpy() | |
except Exception as e: | |
print(f"Error extracting embedding: {e}") | |
return np.random.randn(self.embedding_dim) | |
class SpeakerChangeDetector: | |
"""Speaker change detector for real-time diarization""" | |
def __init__(self, embedding_dim=128, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
self.embedding_dim = embedding_dim | |
self.change_threshold = change_threshold | |
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
self.current_speaker = 0 | |
self.previous_embeddings = [] | |
self.last_change_time = time.time() | |
self.mean_embeddings = [None] * self.max_speakers | |
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
self.last_similarity = 0.0 | |
self.active_speakers = set([0]) | |
def set_max_speakers(self, max_speakers): | |
"""Update the maximum number of speakers""" | |
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
if new_max < self.max_speakers: | |
for speaker_id in list(self.active_speakers): | |
if speaker_id >= new_max: | |
self.active_speakers.discard(speaker_id) | |
if self.current_speaker >= new_max: | |
self.current_speaker = 0 | |
if new_max > self.max_speakers: | |
self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) | |
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
else: | |
self.mean_embeddings = self.mean_embeddings[:new_max] | |
self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
self.max_speakers = new_max | |
def set_change_threshold(self, threshold): | |
"""Update the threshold for detecting speaker changes""" | |
self.change_threshold = max(0.1, min(threshold, 0.99)) | |
def add_embedding(self, embedding, timestamp=None): | |
"""Add a new embedding and check if there's a speaker change""" | |
current_time = timestamp or time.time() | |
if not self.previous_embeddings: | |
self.previous_embeddings.append(embedding) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
if self.mean_embeddings[self.current_speaker] is None: | |
self.mean_embeddings[self.current_speaker] = embedding.copy() | |
return self.current_speaker, 1.0 | |
current_mean = self.mean_embeddings[self.current_speaker] | |
if current_mean is not None: | |
similarity = 1.0 - cosine(embedding, current_mean) | |
else: | |
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) | |
self.last_similarity = similarity | |
time_since_last_change = current_time - self.last_change_time | |
is_speaker_change = False | |
if time_since_last_change >= MIN_SEGMENT_DURATION: | |
if similarity < self.change_threshold: | |
best_speaker = self.current_speaker | |
best_similarity = similarity | |
for speaker_id in range(self.max_speakers): | |
if speaker_id == self.current_speaker: | |
continue | |
speaker_mean = self.mean_embeddings[speaker_id] | |
if speaker_mean is not None: | |
speaker_similarity = 1.0 - cosine(embedding, speaker_mean) | |
if speaker_similarity > best_similarity: | |
best_similarity = speaker_similarity | |
best_speaker = speaker_id | |
if best_speaker != self.current_speaker: | |
is_speaker_change = True | |
self.current_speaker = best_speaker | |
elif len(self.active_speakers) < self.max_speakers: | |
for new_id in range(self.max_speakers): | |
if new_id not in self.active_speakers: | |
is_speaker_change = True | |
self.current_speaker = new_id | |
self.active_speakers.add(new_id) | |
break | |
if is_speaker_change: | |
self.last_change_time = current_time | |
self.previous_embeddings.append(embedding) | |
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: | |
self.previous_embeddings.pop(0) | |
self.speaker_embeddings[self.current_speaker].append(embedding) | |
self.active_speakers.add(self.current_speaker) | |
if len(self.speaker_embeddings[self.current_speaker]) > 30: | |
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] | |
if self.speaker_embeddings[self.current_speaker]: | |
self.mean_embeddings[self.current_speaker] = np.mean( | |
self.speaker_embeddings[self.current_speaker], axis=0 | |
) | |
return self.current_speaker, similarity | |
def get_color_for_speaker(self, speaker_id): | |
"""Return color for speaker ID""" | |
if 0 <= speaker_id < len(SPEAKER_COLORS): | |
return SPEAKER_COLORS[speaker_id] | |
return "#FFFFFF" | |
class RealTimeASRDiarization: | |
"""Main class for real-time ASR with speaker diarization""" | |
def __init__(self): | |
self.encoder = SpeechBrainEncoder() | |
self.encoder.load_model() | |
self.speaker_detector = SpeakerChangeDetector() | |
self.transcription_queue = queue.Queue() | |
self.conversation_history = [] | |
self.is_processing = False | |
# Load Whisper model | |
try: | |
import whisper | |
self.whisper_model = whisper.load_model("base") | |
except ImportError: | |
print("Whisper not available, using mock transcription") | |
self.whisper_model = None | |
def transcribe_audio(self, audio_data, sr=16000): | |
"""Transcribe audio using Whisper""" | |
try: | |
if self.whisper_model is None: | |
return "Mock transcription: Hello, this is a test." | |
# Ensure audio is the right format | |
if isinstance(audio_data, tuple): | |
sr, audio_data = audio_data | |
if len(audio_data.shape) > 1: | |
audio_data = audio_data.mean(axis=1) | |
# Normalize audio | |
audio_data = audio_data.astype(np.float32) | |
if np.abs(audio_data).max() > 1.0: | |
audio_data = audio_data / np.abs(audio_data).max() | |
# Resample to 16kHz if needed | |
if sr != 16000: | |
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) | |
# Transcribe | |
result = self.whisper_model.transcribe(audio_data, language="en") | |
return result["text"].strip() | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
return "" | |
def extract_speaker_embedding(self, audio_data, sr=16000): | |
"""Extract speaker embedding from audio""" | |
return self.encoder.embed_utterance(audio_data, sr) | |
def process_audio_segment(self, audio_data, sr=16000): | |
"""Process an audio segment for transcription and speaker identification""" | |
if len(audio_data) < sr * 0.5: # Skip very short segments | |
return None, None, None | |
# Transcribe the audio | |
transcription = self.transcribe_audio(audio_data, sr) | |
if not transcription: | |
return None, None, None | |
# Extract speaker embedding | |
embedding = self.extract_speaker_embedding(audio_data, sr) | |
# Detect speaker | |
speaker_id, similarity = self.speaker_detector.add_embedding(embedding) | |
return transcription, speaker_id, similarity | |
def update_conversation(self, transcription, speaker_id): | |
"""Update conversation history with new transcription""" | |
speaker_name = f"Speaker {speaker_id + 1}" | |
color = self.speaker_detector.get_color_for_speaker(speaker_id) | |
entry = { | |
"speaker": speaker_name, | |
"text": transcription, | |
"color": color, | |
"timestamp": time.time() | |
} | |
self.conversation_history.append(entry) | |
return entry | |
def format_conversation_html(self): | |
"""Format conversation history as HTML""" | |
if not self.conversation_history: | |
return "<p><i>No conversation yet. Start speaking to see real-time transcription with speaker diarization.</i></p>" | |
html_parts = [] | |
for entry in self.conversation_history: | |
html_parts.append( | |
f'<p><span style="color: {entry["color"]}; font-weight: bold;">' | |
f'{entry["speaker"]}:</span> {entry["text"]}</p>' | |
) | |
return "".join(html_parts) | |
def get_status_info(self): | |
"""Get current status information""" | |
status = { | |
"active_speakers": len(self.speaker_detector.active_speakers), | |
"max_speakers": self.speaker_detector.max_speakers, | |
"current_speaker": self.speaker_detector.current_speaker + 1, | |
"total_segments": len(self.conversation_history), | |
"threshold": self.speaker_detector.change_threshold | |
} | |
return status | |
def clear_conversation(self): | |
"""Clear conversation history and reset speaker detector""" | |
self.conversation_history = [] | |
self.speaker_detector = SpeakerChangeDetector( | |
change_threshold=self.speaker_detector.change_threshold, | |
max_speakers=self.speaker_detector.max_speakers | |
) | |
def set_parameters(self, threshold, max_speakers): | |
"""Update parameters""" | |
self.speaker_detector.set_change_threshold(threshold) | |
self.speaker_detector.set_max_speakers(max_speakers) | |
# Global instance | |
asr_system = RealTimeASRDiarization() | |
def process_audio_realtime(audio_data, threshold, max_speakers): | |
"""Process audio in real-time""" | |
global asr_system | |
if audio_data is None: | |
return asr_system.format_conversation_html(), get_status_display() | |
# Update parameters | |
asr_system.set_parameters(threshold, max_speakers) | |
try: | |
# Process the audio segment | |
sr, audio_array = audio_data | |
# Convert to float32 and normalize | |
if audio_array.dtype != np.float32: | |
audio_array = audio_array.astype(np.float32) | |
if audio_array.dtype == np.int16: | |
audio_array = audio_array / 32768.0 | |
elif audio_array.dtype == np.int32: | |
audio_array = audio_array / 2147483648.0 | |
# Process the audio segment | |
transcription, speaker_id, similarity = asr_system.process_audio_segment(audio_array, sr) | |
if transcription and speaker_id is not None: | |
# Update conversation | |
asr_system.update_conversation(transcription, speaker_id) | |
except Exception as e: | |
print(f"Error processing audio: {e}") | |
return asr_system.format_conversation_html(), get_status_display() | |
def get_status_display(): | |
"""Get formatted status display""" | |
status = asr_system.get_status_info() | |
status_html = f""" | |
<div style="font-family: monospace; font-size: 12px;"> | |
<strong>Status:</strong><br> | |
Current Speaker: {status['current_speaker']}<br> | |
Active Speakers: {status['active_speakers']} / {status['max_speakers']}<br> | |
Total Segments: {status['total_segments']}<br> | |
Threshold: {status['threshold']:.2f}<br> | |
</div> | |
""" | |
return status_html | |
def clear_conversation(): | |
"""Clear the conversation""" | |
global asr_system | |
asr_system.clear_conversation() | |
return asr_system.format_conversation_html(), get_status_display() | |
def create_interface(): | |
"""Create Gradio interface""" | |
with gr.Blocks( | |
title="Real-time ASR with Speaker Diarization", | |
theme=gr.themes.Soft(), | |
css=""" | |
.conversation-box { | |
height: 400px; | |
overflow-y: auto; | |
border: 1px solid #ddd; | |
padding: 10px; | |
background-color: #f9f9f9; | |
} | |
.status-box { | |
border: 1px solid #ccc; | |
padding: 10px; | |
background-color: #f0f0f0; | |
} | |
""" | |
) as demo: | |
gr.Markdown( | |
""" | |
# 🎤 Real-time ASR with Live Speaker Diarization | |
This application provides real-time speech recognition with speaker diarization. | |
It can distinguish between different speakers and display their conversations in different colors. | |
**Instructions:** | |
1. Adjust the speaker change threshold and maximum speakers | |
2. Click the microphone button to start recording | |
3. Speak naturally - the system will detect speaker changes and transcribe speech | |
4. Each speaker will be assigned a different color | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Main conversation display | |
conversation_display = gr.HTML( | |
value="<p><i>Click the microphone to start recording...</i></p>", | |
elem_classes=["conversation-box"] | |
) | |
# Audio input | |
audio_input = gr.Audio( | |
source="microphone", | |
type="numpy", | |
streaming=True, | |
label="🎤 Microphone Input" | |
) | |
with gr.Column(scale=1): | |
# Controls | |
gr.Markdown("### Controls") | |
threshold_slider = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=DEFAULT_CHANGE_THRESHOLD, | |
step=0.05, | |
label="Speaker Change Threshold", | |
info="Higher values = less sensitive to speaker changes" | |
) | |
max_speakers_slider = gr.Slider( | |
minimum=2, | |
maximum=ABSOLUTE_MAX_SPEAKERS, | |
value=DEFAULT_MAX_SPEAKERS, | |
step=1, | |
label="Maximum Speakers", | |
info="Maximum number of different speakers to detect" | |
) | |
clear_btn = gr.Button("🗑️ Clear Conversation", variant="secondary") | |
# Status display | |
gr.Markdown("### Status") | |
status_display = gr.HTML( | |
value=get_status_display(), | |
elem_classes=["status-box"] | |
) | |
# Speaker color legend | |
gr.Markdown("### Speaker Colors") | |
legend_html = "" | |
for i in range(ABSOLUTE_MAX_SPEAKERS): | |
color = SPEAKER_COLORS[i] | |
name = SPEAKER_COLOR_NAMES[i] | |
legend_html += f'<p><span style="color: {color}; font-weight: bold;">● Speaker {i+1} ({name})</span></p>' | |
gr.HTML(legend_html) | |
# Event handlers | |
audio_input.change( | |
fn=process_audio_realtime, | |
inputs=[audio_input, threshold_slider, max_speakers_slider], | |
outputs=[conversation_display, status_display], | |
show_progress=False | |
) | |
clear_btn.click( | |
fn=clear_conversation, | |
outputs=[conversation_display, status_display] | |
) | |
# Update status periodically | |
demo.load( | |
fn=lambda: (asr_system.format_conversation_html(), get_status_display()), | |
outputs=[conversation_display, status_display], | |
every=2 | |
) | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |