Saiyaswanth007's picture
updated import
d65b6e8
raw
history blame
31.4 kB
import gradio as gr
import numpy as np
import queue
import torch
import time
import threading
import os
import urllib.request
import torchaudio
from scipy.spatial.distance import cosine
import json
import asyncio
from typing import Iterator
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Simplified configuration parameters
SILENCE_THRESHS = [0, 0.4]
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
FINAL_BEAM_SIZE = 5
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
REALTIME_BEAM_SIZE = 5
TRANSCRIPTION_LANGUAGE = "en"
SILERO_SENSITIVITY = 0.4
WEBRTC_SENSITIVITY = 3
MIN_LENGTH_OF_RECORDING = 0.7
PRE_RECORDING_BUFFER_DURATION = 0.35
# Speaker change detection parameters
DEFAULT_CHANGE_THRESHOLD = 0.7
EMBEDDING_HISTORY_SIZE = 5
MIN_SEGMENT_DURATION = 1.0
DEFAULT_MAX_SPEAKERS = 4
ABSOLUTE_MAX_SPEAKERS = 10
# Global variables
FAST_SENTENCE_END = True
SAMPLE_RATE = 16000
BUFFER_SIZE = 1024
CHANNELS = 1
CHUNK_DURATION_MS = 100 # 100ms chunks for FastRTC
# Speaker colors
SPEAKER_COLORS = [
"#FFFF00", # Yellow
"#FF0000", # Red
"#00FF00", # Green
"#00FFFF", # Cyan
"#FF00FF", # Magenta
"#0000FF", # Blue
"#FF8000", # Orange
"#00FF80", # Spring Green
"#8000FF", # Purple
"#FFFFFF", # White
]
SPEAKER_COLOR_NAMES = [
"Yellow", "Red", "Green", "Cyan", "Magenta",
"Blue", "Orange", "Spring Green", "Purple", "White"
]
class SpeechBrainEncoder:
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
def __init__(self, device="cpu"):
self.device = device
self.model = None
self.embedding_dim = 192
self.model_loaded = False
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
os.makedirs(self.cache_dir, exist_ok=True)
def _download_model(self):
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
if not os.path.exists(model_path):
logger.info(f"Downloading ECAPA-TDNN model to {model_path}...")
urllib.request.urlretrieve(model_url, model_path)
return model_path
def load_model(self):
"""Load the ECAPA-TDNN model"""
try:
from speechbrain.pretrained import EncoderClassifier
model_path = self._download_model()
self.model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=self.cache_dir,
run_opts={"device": self.device}
)
self.model_loaded = True
return True
except Exception as e:
logger.error(f"Error loading ECAPA-TDNN model: {e}")
return False
def embed_utterance(self, audio, sr=16000):
"""Extract speaker embedding from audio"""
if not self.model_loaded:
raise ValueError("Model not loaded. Call load_model() first.")
try:
if isinstance(audio, np.ndarray):
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
else:
waveform = audio.unsqueeze(0)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
with torch.no_grad():
embedding = self.model.encode_batch(waveform)
return embedding.squeeze().cpu().numpy()
except Exception as e:
logger.error(f"Error extracting embedding: {e}")
return np.zeros(self.embedding_dim)
class AudioProcessor:
"""Processes audio data to extract speaker embeddings"""
def __init__(self, encoder):
self.encoder = encoder
def extract_embedding(self, audio_float):
try:
# Ensure audio is in the right format
if np.abs(audio_float).max() > 1.0:
audio_float = audio_float / np.abs(audio_float).max()
embedding = self.encoder.embed_utterance(audio_float)
return embedding
except Exception as e:
logger.error(f"Embedding extraction error: {e}")
return np.zeros(self.encoder.embedding_dim)
class SpeakerChangeDetector:
"""Speaker change detector that supports a configurable number of speakers"""
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
self.embedding_dim = embedding_dim
self.change_threshold = change_threshold
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
self.current_speaker = 0
self.previous_embeddings = []
self.last_change_time = time.time()
self.mean_embeddings = [None] * self.max_speakers
self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
self.last_similarity = 0.0
self.active_speakers = set([0])
def set_max_speakers(self, max_speakers):
"""Update the maximum number of speakers"""
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
if new_max < self.max_speakers:
for speaker_id in list(self.active_speakers):
if speaker_id >= new_max:
self.active_speakers.discard(speaker_id)
if self.current_speaker >= new_max:
self.current_speaker = 0
if new_max > self.max_speakers:
self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
else:
self.mean_embeddings = self.mean_embeddings[:new_max]
self.speaker_embeddings = self.speaker_embeddings[:new_max]
self.max_speakers = new_max
def set_change_threshold(self, threshold):
"""Update the threshold for detecting speaker changes"""
self.change_threshold = max(0.1, min(threshold, 0.99))
def add_embedding(self, embedding, timestamp=None):
"""Add a new embedding and check if there's a speaker change"""
current_time = timestamp or time.time()
if not self.previous_embeddings:
self.previous_embeddings.append(embedding)
self.speaker_embeddings[self.current_speaker].append(embedding)
if self.mean_embeddings[self.current_speaker] is None:
self.mean_embeddings[self.current_speaker] = embedding.copy()
return self.current_speaker, 1.0
current_mean = self.mean_embeddings[self.current_speaker]
if current_mean is not None:
similarity = 1.0 - cosine(embedding, current_mean)
else:
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
self.last_similarity = similarity
time_since_last_change = current_time - self.last_change_time
is_speaker_change = False
if time_since_last_change >= MIN_SEGMENT_DURATION:
if similarity < self.change_threshold:
best_speaker = self.current_speaker
best_similarity = similarity
for speaker_id in range(self.max_speakers):
if speaker_id == self.current_speaker:
continue
speaker_mean = self.mean_embeddings[speaker_id]
if speaker_mean is not None:
speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
if speaker_similarity > best_similarity:
best_similarity = speaker_similarity
best_speaker = speaker_id
if best_speaker != self.current_speaker:
is_speaker_change = True
self.current_speaker = best_speaker
elif len(self.active_speakers) < self.max_speakers:
for new_id in range(self.max_speakers):
if new_id not in self.active_speakers:
is_speaker_change = True
self.current_speaker = new_id
self.active_speakers.add(new_id)
break
if is_speaker_change:
self.last_change_time = current_time
self.previous_embeddings.append(embedding)
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
self.previous_embeddings.pop(0)
self.speaker_embeddings[self.current_speaker].append(embedding)
self.active_speakers.add(self.current_speaker)
if len(self.speaker_embeddings[self.current_speaker]) > 30:
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
if self.speaker_embeddings[self.current_speaker]:
self.mean_embeddings[self.current_speaker] = np.mean(
self.speaker_embeddings[self.current_speaker], axis=0
)
return self.current_speaker, similarity
def get_color_for_speaker(self, speaker_id):
"""Return color for speaker ID"""
if 0 <= speaker_id < len(SPEAKER_COLORS):
return SPEAKER_COLORS[speaker_id]
return "#FFFFFF"
def get_status_info(self):
"""Return status information about the speaker change detector"""
speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
return {
"current_speaker": self.current_speaker,
"speaker_counts": speaker_counts,
"active_speakers": len(self.active_speakers),
"max_speakers": self.max_speakers,
"last_similarity": self.last_similarity,
"threshold": self.change_threshold
}
class WhisperTranscriber:
"""Whisper transcriber using transformers with FastRTC optimization"""
def __init__(self, model_name="distil-large-v3"):
self.model = None
self.processor = None
self.model_name = model_name
self.model_loaded = False
def load_model(self):
"""Load Whisper model"""
try:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_id = f"distil-whisper/distil-{self.model_name}" if "distil" in self.model_name else f"openai/whisper-{self.model_name}"
self.processor = WhisperProcessor.from_pretrained(model_id)
self.model = WhisperForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True
)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.model_loaded = True
return True
except Exception as e:
logger.error(f"Error loading Whisper model: {e}")
return False
def transcribe(self, audio_array, sample_rate=16000):
"""Transcribe audio array"""
if not self.model_loaded:
return ""
try:
# Ensure audio is the right length and format
if len(audio_array) < 1600: # Less than 0.1 seconds
return ""
# Resample if needed
if sample_rate != 16000:
import torchaudio.functional as F
audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
audio_array = F.resample(audio_tensor, sample_rate, 16000).numpy()
# Process with Whisper
inputs = self.processor(
audio_array,
sampling_rate=16000,
return_tensors="pt",
truncation=False,
padding=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
predicted_ids = self.model.generate(
inputs["input_features"],
max_length=448,
num_beams=1,
do_sample=False,
use_cache=True
)
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription.strip()
except Exception as e:
logger.error(f"Transcription error: {e}")
return ""
class FastRTCSpeakerDiarization:
def __init__(self):
self.encoder = None
self.audio_processor = None
self.speaker_detector = None
self.transcriber = None
self.audio_queue = queue.Queue(maxsize=100)
self.processing_thread = None
self.full_sentences = []
self.sentence_speakers = []
self.is_running = False
self.change_threshold = DEFAULT_CHANGE_THRESHOLD
self.max_speakers = DEFAULT_MAX_SPEAKERS
self.audio_buffer = []
self.buffer_duration = 3.0 # seconds
self.last_transcription_time = time.time()
self.chunk_size = int(SAMPLE_RATE * CHUNK_DURATION_MS / 1000)
def initialize_models(self):
"""Initialize the speaker encoder and transcription models"""
try:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device_str}")
# Initialize speaker encoder
self.encoder = SpeechBrainEncoder(device=device_str)
encoder_success = self.encoder.load_model()
# Initialize transcriber
self.transcriber = WhisperTranscriber(FINAL_TRANSCRIPTION_MODEL)
transcriber_success = self.transcriber.load_model()
if encoder_success and transcriber_success:
self.audio_processor = AudioProcessor(self.encoder)
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
logger.info("Models loaded successfully!")
return True
else:
logger.error("Failed to load models")
return False
except Exception as e:
logger.error(f"Model initialization error: {e}")
return False
def process_audio_chunk(self, audio_chunk: np.ndarray, sample_rate: int):
"""Process individual audio chunk from FastRTC"""
if not self.is_running or audio_chunk is None:
return
try:
# Ensure audio chunk is in correct format
if isinstance(audio_chunk, np.ndarray):
# Ensure mono audio
if len(audio_chunk.shape) > 1:
audio_chunk = audio_chunk.mean(axis=1)
# Normalize audio
if audio_chunk.dtype != np.float32:
audio_chunk = audio_chunk.astype(np.float32)
if np.abs(audio_chunk).max() > 1.0:
audio_chunk = audio_chunk / np.abs(audio_chunk).max()
# Add to buffer
self.audio_buffer.extend(audio_chunk)
# Keep buffer to specified duration
max_buffer_length = int(self.buffer_duration * sample_rate)
if len(self.audio_buffer) > max_buffer_length:
self.audio_buffer = self.audio_buffer[-max_buffer_length:]
# Process if enough audio accumulated and enough time passed
current_time = time.time()
if (current_time - self.last_transcription_time > 1.5 and
len(self.audio_buffer) > sample_rate * 0.8): # At least 0.8 seconds
if not self.audio_queue.full():
self.audio_queue.put((np.array(self.audio_buffer[-int(sample_rate * 2):]), sample_rate))
self.last_transcription_time = current_time
except Exception as e:
logger.error(f"Audio chunk processing error: {e}")
def process_audio_queue(self):
"""Process audio from the queue"""
while self.is_running:
try:
audio_data, sample_rate = self.audio_queue.get(timeout=1)
if len(audio_data) < 1600: # Skip very short audio
continue
# Transcribe audio
transcription = self.transcriber.transcribe(audio_data, sample_rate)
if transcription and len(transcription.strip()) > 0:
# Extract speaker embedding
speaker_embedding = self.audio_processor.extract_embedding(audio_data)
# Detect speaker
speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
# Store results
self.full_sentences.append(transcription.strip())
self.sentence_speakers.append(speaker_id)
logger.info(f"Processed: Speaker {speaker_id + 1}: {transcription.strip()[:50]}...")
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error processing audio queue: {e}")
def start_recording(self):
"""Start the recording and processing"""
if self.encoder is None or self.transcriber is None:
return "Please initialize models first!"
try:
self.is_running = True
self.audio_buffer = []
self.last_transcription_time = time.time()
# Clear the queue
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
break
# Start processing thread
self.processing_thread = threading.Thread(target=self.process_audio_queue, daemon=True)
self.processing_thread.start()
logger.info("Recording started successfully!")
return "Recording started successfully!"
except Exception as e:
logger.error(f"Error starting recording: {e}")
return f"Error starting recording: {e}"
def stop_recording(self):
"""Stop the recording process"""
self.is_running = False
logger.info("Recording stopped!")
return "Recording stopped!"
def clear_conversation(self):
"""Clear all conversation data"""
self.full_sentences = []
self.sentence_speakers = []
self.audio_buffer = []
# Clear the queue
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
except queue.Empty:
break
if self.speaker_detector:
self.speaker_detector = SpeakerChangeDetector(
embedding_dim=self.encoder.embedding_dim,
change_threshold=self.change_threshold,
max_speakers=self.max_speakers
)
return "Conversation cleared!"
def update_settings(self, threshold, max_speakers):
"""Update speaker detection settings"""
self.change_threshold = threshold
self.max_speakers = max_speakers
if self.speaker_detector:
self.speaker_detector.set_change_threshold(threshold)
self.speaker_detector.set_max_speakers(max_speakers)
return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
def get_formatted_conversation(self):
"""Get the formatted conversation with speaker colors"""
try:
if not self.full_sentences:
return "Waiting for speech input... 🎤"
sentences_with_style = []
for i, sentence in enumerate(self.full_sentences[-10:]): # Show last 10 sentences
if i >= len(self.sentence_speakers):
color = "#FFFFFF"
speaker_name = "Unknown"
else:
speaker_id = self.sentence_speakers[-(10-i) if len(self.sentence_speakers) >= 10 else i]
color = self.speaker_detector.get_color_for_speaker(speaker_id)
speaker_name = f"Speaker {speaker_id + 1}"
sentences_with_style.append(
f'<p><span style="color:{color}; font-weight: bold;">{speaker_name}:</span> {sentence}</p>')
return "".join(sentences_with_style)
except Exception as e:
return f"Error formatting conversation: {e}"
def get_status_info(self):
"""Get current status information"""
if not self.speaker_detector:
return "Speaker detector not initialized"
try:
status = self.speaker_detector.get_status_info()
queue_size = self.audio_queue.qsize()
status_lines = [
f"**Current Speaker:** {status['current_speaker'] + 1}",
f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
f"**Last Similarity:** {status['last_similarity']:.3f}",
f"**Change Threshold:** {status['threshold']:.2f}",
f"**Total Sentences:** {len(self.full_sentences)}",
f"**Buffer Length:** {len(self.audio_buffer)} samples",
f"**Queue Size:** {queue_size}",
"",
"**Speaker Segment Counts:**"
]
for i in range(status['max_speakers']):
color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
return "\n".join(status_lines)
except Exception as e:
return f"Error getting status: {e}"
# Global instance
diarization_system = FastRTCSpeakerDiarization()
def initialize_system():
"""Initialize the diarization system"""
success = diarization_system.initialize_models()
if success:
return "✅ System initialized successfully! Models loaded."
else:
return "❌ Failed to initialize system. Please check the logs."
def start_recording():
"""Start recording and transcription"""
return diarization_system.start_recording()
def stop_recording():
"""Stop recording and transcription"""
return diarization_system.stop_recording()
def clear_conversation():
"""Clear the conversation"""
return diarization_system.clear_conversation()
def update_settings(threshold, max_speakers):
"""Update system settings"""
return diarization_system.update_settings(threshold, max_speakers)
def get_conversation():
"""Get the current conversation"""
return diarization_system.get_formatted_conversation()
def get_status():
"""Get system status"""
return diarization_system.get_status_info()
def process_audio_stream(audio_stream):
"""Process streaming audio from FastRTC"""
if audio_stream is not None and diarization_system.is_running:
sample_rate, audio_data = audio_stream
diarization_system.process_audio_chunk(audio_data, sample_rate)
return get_conversation(), get_status()
# Create Gradio interface with FastRTC
def create_interface():
with gr.Blocks(title="FastRTC Real-time Speaker Diarization", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎤 FastRTC Real-time Speech Recognition with Speaker Diarization")
gr.Markdown("This app uses Hugging Face FastRTC for real-time audio streaming with automatic speaker identification and color-coding.")
with gr.Row():
with gr.Column(scale=2):
# FastRTC Audio input for real-time streaming
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
streaming=True,
label="🎙️ FastRTC Microphone Input",
format="wav",
show_download_button=False,
container=True,
elem_id="fastrtc_audio"
)
# Main conversation display
conversation_output = gr.HTML(
value="<i>Click 'Initialize System' and then 'Start Recording' to begin...</i>",
label="Live Conversation",
elem_id="conversation_display"
)
# Control buttons
with gr.Row():
init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
start_btn = gr.Button("🎙️ Start Recording", variant="primary", interactive=False, size="lg")
stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", interactive=False, size="lg")
clear_btn = gr.Button("🗑️ Clear", interactive=False, size="lg")
# Status display
status_output = gr.Textbox(
label="System Status",
value="System not initialized",
lines=10,
interactive=False,
show_copy_button=True
)
with gr.Column(scale=1):
# Settings panel
gr.Markdown("## ⚙️ Settings")
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.95,
step=0.05,
value=DEFAULT_CHANGE_THRESHOLD,
label="Speaker Change Sensitivity",
info="Lower = more sensitive to changes"
)
max_speakers_slider = gr.Slider(
minimum=2,
maximum=ABSOLUTE_MAX_SPEAKERS,
step=1,
value=DEFAULT_MAX_SPEAKERS,
label="Maximum Number of Speakers"
)
update_settings_btn = gr.Button("Update Settings", variant="secondary")
# Speaker color legend
gr.Markdown("## 🎨 Speaker Colors")
color_info = []
for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)):
color_info.append(f'<span style="color:{color}; font-size: 16px;">●</span> Speaker {i+1} ({name})')
gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
# Performance info
gr.Markdown("## 📊 Performance")
gr.Markdown("""
- **FastRTC**: Low-latency audio streaming
- **Whisper**: distil-large-v3 for transcription
- **ECAPA-TDNN**: Speaker embeddings
- **Real-time**: ~100ms processing chunks
""")
# Event handlers
def on_initialize():
result = initialize_system()
if "successfully" in result:
return (
result, # status_output
gr.update(interactive=True), # start_btn
gr.update(interactive=True), # clear_btn
get_conversation(), # conversation_output
get_status() # status_output update
)
else:
return (
result, # status_output
gr.update(interactive=False), # start_btn
gr.update(interactive=False), # clear_btn
get_conversation(), # conversation_output
get_status() # status_output update
)
def on_start():
result = start_recording()
return (
result, # status_output
gr.update(interactive=False), # start_btn
gr.update(interactive=True), # stop_btn
)
def on_stop():
result = stop_recording()
return (
result, # status_output
gr.update(interactive=True), # start_btn
gr.update(interactive=False), # stop_btn
)
# Auto-refresh function
def refresh_display():
return get_conversation(), get_status()
# Connect event handlers
init_btn.click(
on_initialize,
outputs=[status_output, start_btn, clear_btn, conversation_output, status_output]
)
start_btn.click(
on_start,
outputs=[status_output, start_btn, stop_btn]
)
stop_btn.click(
on_stop,
outputs=[status_output, start_btn, stop_btn]
)
clear_btn.click(
clear_conversation,
outputs=[status_output]
)
update_settings_btn.click(
update_settings,
inputs=[threshold_slider, max_speakers_slider],
outputs=[status_output]
)
# FastRTC streaming audio processing
audio_input.stream(
process_audio_stream,
inputs=[audio_input],
outputs=[conversation_output, status_output],
stream_every=0.1, # Process every 100ms
time_limit=None
)
# Auto-refresh timer
refresh_timer = gr.Timer(2.0)
refresh_timer.tick(
refresh_display,
outputs=[conversation_output, status_output]
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)