Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import queue | |
import threading | |
import signal | |
import atexit | |
from contextlib import contextmanager | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
import numpy as np | |
import torch | |
import torchaudio | |
from scipy.spatial.distance import cosine | |
try: | |
import soundcard as sc | |
except ImportError: | |
print("soundcard not found. Install with: pip install soundcard") | |
sys.exit(1) | |
try: | |
from RealtimeSTT import AudioToTextRecorder | |
except ImportError: | |
print("RealtimeSTT not found. Install with: pip install RealtimeSTT") | |
sys.exit(1) | |
# Configuration | |
class Config: | |
# Audio settings | |
SAMPLE_RATE = 16000 | |
BUFFER_SIZE = 1024 | |
CHANNELS = 1 | |
# Transcription settings | |
FINAL_MODEL = "distil-large-v3" | |
REALTIME_MODEL = "distil-small.en" | |
LANGUAGE = "en" | |
BEAM_SIZE = 5 | |
REALTIME_BEAM_SIZE = 3 | |
# Voice activity detection | |
SILENCE_THRESHOLD = 0.4 | |
MIN_RECORDING_LENGTH = 0.5 | |
PRE_RECORDING_BUFFER = 0.2 | |
SILERO_SENSITIVITY = 0.4 | |
WEBRTC_SENSITIVITY = 3 | |
# Speaker detection | |
CHANGE_THRESHOLD = 0.65 | |
MAX_SPEAKERS = 4 | |
MIN_SEGMENT_DURATION = 1.0 | |
EMBEDDING_HISTORY_SIZE = 3 | |
SPEAKER_MEMORY_SIZE = 20 | |
# Console colors for speakers | |
COLORS = [ | |
'\033[93m', # Yellow | |
'\033[91m', # Red | |
'\033[92m', # Green | |
'\033[96m', # Cyan | |
'\033[95m', # Magenta | |
'\033[94m', # Blue | |
'\033[97m', # White | |
'\033[33m', # Orange | |
] | |
RESET = '\033[0m' | |
LIVE_COLOR = '\033[90m' | |
class SpeakerEncoder: | |
"""Simplified speaker encoder using torchaudio transforms""" | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.embedding_dim = 128 | |
self.model_loaded = False | |
self._setup_model() | |
def _setup_model(self): | |
"""Setup a simple MFCC-based feature extractor""" | |
try: | |
self.mfcc_transform = torchaudio.transforms.MFCC( | |
sample_rate=Config.SAMPLE_RATE, | |
n_mfcc=13, | |
melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23} | |
).to(self.device) | |
self.model_loaded = True | |
print("Simple MFCC-based encoder initialized") | |
except Exception as e: | |
print(f"Error setting up encoder: {e}") | |
self.model_loaded = False | |
def extract_embedding(self, audio): | |
"""Extract speaker embedding from audio""" | |
if not self.model_loaded: | |
return np.zeros(self.embedding_dim) | |
try: | |
# Ensure audio is float32 and normalized | |
if isinstance(audio, np.ndarray): | |
audio = torch.from_numpy(audio).float() | |
# Normalize audio | |
if audio.abs().max() > 0: | |
audio = audio / audio.abs().max() | |
# Add batch dimension if needed | |
if audio.dim() == 1: | |
audio = audio.unsqueeze(0) | |
# Extract MFCC features | |
with torch.no_grad(): | |
mfcc = self.mfcc_transform(audio) | |
# Simple statistics-based embedding | |
embedding = torch.cat([ | |
mfcc.mean(dim=2).flatten(), | |
mfcc.std(dim=2).flatten(), | |
mfcc.max(dim=2)[0].flatten(), | |
mfcc.min(dim=2)[0].flatten() | |
]) | |
# Pad or truncate to fixed size | |
if embedding.size(0) > self.embedding_dim: | |
embedding = embedding[:self.embedding_dim] | |
elif embedding.size(0) < self.embedding_dim: | |
padding = torch.zeros(self.embedding_dim - embedding.size(0)) | |
embedding = torch.cat([embedding, padding]) | |
return embedding.cpu().numpy() | |
except Exception as e: | |
print(f"Error extracting embedding: {e}") | |
return np.zeros(self.embedding_dim) | |
class SpeakerDetector: | |
"""Speaker change detection using embeddings""" | |
def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS): | |
self.threshold = threshold | |
self.max_speakers = max_speakers | |
self.current_speaker = 0 | |
self.speaker_embeddings = [[] for _ in range(max_speakers)] | |
self.speaker_centroids = [None] * max_speakers | |
self.last_change_time = time.time() | |
self.active_speakers = {0} | |
def detect_speaker(self, embedding): | |
"""Detect current speaker from embedding""" | |
current_time = time.time() | |
# Initialize first speaker | |
if not self.speaker_embeddings[0]: | |
self.speaker_embeddings[0].append(embedding) | |
self.speaker_centroids[0] = embedding.copy() | |
return 0, 1.0 | |
# Calculate similarity with current speaker | |
current_centroid = self.speaker_centroids[self.current_speaker] | |
if current_centroid is not None: | |
similarity = 1.0 - cosine(embedding, current_centroid) | |
else: | |
similarity = 0.0 | |
# Check if enough time has passed for a speaker change | |
if current_time - self.last_change_time < Config.MIN_SEGMENT_DURATION: | |
self._update_speaker_model(self.current_speaker, embedding) | |
return self.current_speaker, similarity | |
# Check for speaker change | |
if similarity < self.threshold: | |
# Find best matching existing speaker | |
best_speaker = self.current_speaker | |
best_similarity = similarity | |
for speaker_id in self.active_speakers: | |
if speaker_id == self.current_speaker: | |
continue | |
centroid = self.speaker_centroids[speaker_id] | |
if centroid is not None: | |
sim = 1.0 - cosine(embedding, centroid) | |
if sim > best_similarity and sim > self.threshold: | |
best_similarity = sim | |
best_speaker = speaker_id | |
# Create new speaker if no good match and slots available | |
if (best_speaker == self.current_speaker and | |
len(self.active_speakers) < self.max_speakers): | |
for new_id in range(self.max_speakers): | |
if new_id not in self.active_speakers: | |
best_speaker = new_id | |
best_similarity = 0.0 | |
self.active_speakers.add(new_id) | |
break | |
# Update current speaker if changed | |
if best_speaker != self.current_speaker: | |
self.current_speaker = best_speaker | |
self.last_change_time = current_time | |
similarity = best_similarity | |
# Update speaker model | |
self._update_speaker_model(self.current_speaker, embedding) | |
return self.current_speaker, similarity | |
def _update_speaker_model(self, speaker_id, embedding): | |
"""Update speaker model with new embedding""" | |
self.speaker_embeddings[speaker_id].append(embedding) | |
# Keep only recent embeddings | |
if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE: | |
self.speaker_embeddings[speaker_id] = \ | |
self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:] | |
# Update centroid | |
if self.speaker_embeddings[speaker_id]: | |
self.speaker_centroids[speaker_id] = np.mean( | |
self.speaker_embeddings[speaker_id], axis=0 | |
) | |
class AudioRecorder: | |
"""Handles audio recording from system audio""" | |
def __init__(self, audio_queue): | |
self.audio_queue = audio_queue | |
self.running = False | |
self.thread = None | |
def start(self): | |
"""Start recording""" | |
self.running = True | |
self.thread = threading.Thread(target=self._record_loop, daemon=True) | |
self.thread.start() | |
print("Audio recording started") | |
def stop(self): | |
"""Stop recording""" | |
self.running = False | |
if self.thread and self.thread.is_alive(): | |
self.thread.join(timeout=2) | |
def _record_loop(self): | |
"""Main recording loop""" | |
try: | |
# Try to use system audio (loopback) | |
try: | |
device = sc.default_speaker() | |
with device.recorder( | |
samplerate=Config.SAMPLE_RATE, | |
blocksize=Config.BUFFER_SIZE, | |
channels=Config.CHANNELS | |
) as recorder: | |
print(f"Recording from: {device.name}") | |
while self.running: | |
data = recorder.record(numframes=Config.BUFFER_SIZE) | |
if data is not None and len(data) > 0: | |
# Convert to mono if needed | |
if data.ndim > 1: | |
data = data[:, 0] | |
self.audio_queue.put(data.flatten()) | |
except Exception as e: | |
print(f"Loopback recording failed: {e}") | |
print("Falling back to microphone...") | |
# Fallback to microphone | |
mic = sc.default_microphone() | |
with mic.recorder( | |
samplerate=Config.SAMPLE_RATE, | |
blocksize=Config.BUFFER_SIZE, | |
channels=Config.CHANNELS | |
) as recorder: | |
print(f"Recording from microphone: {mic.name}") | |
while self.running: | |
data = recorder.record(numframes=Config.BUFFER_SIZE) | |
if data is not None and len(data) > 0: | |
if data.ndim > 1: | |
data = data[:, 0] | |
self.audio_queue.put(data.flatten()) | |
except Exception as e: | |
print(f"Recording error: {e}") | |
self.running = False | |
class TranscriptionProcessor: | |
"""Handles transcription and speaker detection""" | |
def __init__(self): | |
self.encoder = SpeakerEncoder() | |
self.detector = SpeakerDetector() | |
self.recorder = None | |
self.audio_queue = queue.Queue(maxsize=100) | |
self.audio_recorder = AudioRecorder(self.audio_queue) | |
self.processing_thread = None | |
self.running = False | |
def setup(self): | |
"""Setup transcription recorder""" | |
try: | |
self.recorder = AudioToTextRecorder( | |
spinner=False, | |
use_microphone=False, | |
model=Config.FINAL_MODEL, | |
language=Config.LANGUAGE, | |
silero_sensitivity=Config.SILERO_SENSITIVITY, | |
webrtc_sensitivity=Config.WEBRTC_SENSITIVITY, | |
post_speech_silence_duration=Config.SILENCE_THRESHOLD, | |
min_length_of_recording=Config.MIN_RECORDING_LENGTH, | |
pre_recording_buffer_duration=Config.PRE_RECORDING_BUFFER, | |
enable_realtime_transcription=True, | |
realtime_model_type=Config.REALTIME_MODEL, | |
beam_size=Config.BEAM_SIZE, | |
beam_size_realtime=Config.REALTIME_BEAM_SIZE, | |
on_realtime_transcription_update=self._on_live_text, | |
) | |
print("Transcription recorder setup complete") | |
return True | |
except Exception as e: | |
print(f"Transcription setup failed: {e}") | |
return False | |
def start(self): | |
"""Start processing""" | |
if not self.setup(): | |
return False | |
self.running = True | |
# Start audio recording | |
self.audio_recorder.start() | |
# Start audio processing thread | |
self.processing_thread = threading.Thread(target=self._process_audio, daemon=True) | |
self.processing_thread.start() | |
# Start transcription | |
self._start_transcription() | |
return True | |
def stop(self): | |
"""Stop processing""" | |
print("\nStopping transcription...") | |
self.running = False | |
if self.audio_recorder: | |
self.audio_recorder.stop() | |
if self.processing_thread and self.processing_thread.is_alive(): | |
self.processing_thread.join(timeout=2) | |
if self.recorder: | |
try: | |
self.recorder.shutdown() | |
except: | |
pass | |
def _process_audio(self): | |
"""Process audio chunks for speaker detection""" | |
audio_buffer = [] | |
while self.running: | |
try: | |
# Get audio chunk | |
chunk = self.audio_queue.get(timeout=0.1) | |
audio_buffer.extend(chunk) | |
# Process when we have enough audio (about 1 second) | |
if len(audio_buffer) >= Config.SAMPLE_RATE: | |
audio_array = np.array(audio_buffer[:Config.SAMPLE_RATE]) | |
audio_buffer = audio_buffer[Config.SAMPLE_RATE//2:] # 50% overlap | |
# Convert to int16 for recorder | |
audio_int16 = (audio_array * 32767).astype(np.int16) | |
# Feed to transcription recorder | |
if self.recorder: | |
self.recorder.feed_audio(audio_int16.tobytes()) | |
except queue.Empty: | |
continue | |
except Exception as e: | |
if self.running: | |
print(f"Audio processing error: {e}") | |
def _start_transcription(self): | |
"""Start transcription loop""" | |
def transcription_loop(): | |
while self.running: | |
try: | |
text = self.recorder.text() | |
if text and text.strip(): | |
self._process_final_text(text) | |
except Exception as e: | |
if self.running: | |
print(f"Transcription error: {e}") | |
break | |
transcription_thread = threading.Thread(target=transcription_loop, daemon=True) | |
transcription_thread.start() | |
def _on_live_text(self, text): | |
"""Handle live transcription updates""" | |
if text and text.strip(): | |
print(f"\r{LIVE_COLOR}[Live] {text}{RESET}", end="", flush=True) | |
def _process_final_text(self, text): | |
"""Process final transcription with speaker detection""" | |
# Clear live text line | |
print("\r" + " " * 80 + "\r", end="") | |
try: | |
# Get recent audio for speaker detection | |
recent_audio = [] | |
temp_queue = [] | |
# Collect recent audio chunks | |
for _ in range(min(10, self.audio_queue.qsize())): | |
try: | |
chunk = self.audio_queue.get_nowait() | |
recent_audio.extend(chunk) | |
temp_queue.append(chunk) | |
except queue.Empty: | |
break | |
# Put chunks back | |
for chunk in reversed(temp_queue): | |
try: | |
self.audio_queue.put_nowait(chunk) | |
except queue.Full: | |
break | |
# Extract speaker embedding if we have audio | |
if recent_audio: | |
audio_tensor = torch.FloatTensor(recent_audio[-Config.SAMPLE_RATE:]) | |
embedding = self.encoder.extract_embedding(audio_tensor) | |
speaker_id, similarity = self.detector.detect_speaker(embedding) | |
else: | |
speaker_id, similarity = 0, 1.0 | |
# Display with speaker color | |
color = COLORS[speaker_id % len(COLORS)] | |
print(f"{color}Speaker {speaker_id + 1}: {text}{RESET}") | |
except Exception as e: | |
print(f"Error processing text: {e}") | |
print(f"Text: {text}") | |
class RealTimeSpeakerDetection: | |
"""Main application class""" | |
def __init__(self): | |
self.processor = None | |
self.running = False | |
# Setup signal handlers for clean shutdown | |
signal.signal(signal.SIGINT, self._signal_handler) | |
signal.signal(signal.SIGTERM, self._signal_handler) | |
atexit.register(self.cleanup) | |
def _signal_handler(self, signum, frame): | |
"""Handle shutdown signals""" | |
print(f"\nReceived signal {signum}, shutting down...") | |
self.stop() | |
def start(self): | |
"""Start the application""" | |
print("=== Real-time Speaker Detection and Transcription ===") | |
print("Initializing...") | |
self.processor = TranscriptionProcessor() | |
if not self.processor.start(): | |
print("Failed to start. Check your audio setup and dependencies.") | |
return False | |
self.running = True | |
print("=" * 60) | |
print("System ready! Listening for audio...") | |
print("Different speakers will be shown in different colors.") | |
print("Press Ctrl+C to stop.") | |
print("=" * 60) | |
# Keep main thread alive | |
try: | |
while self.running: | |
time.sleep(1) | |
except KeyboardInterrupt: | |
pass | |
return True | |
def stop(self): | |
"""Stop the application""" | |
if not self.running: | |
return | |
self.running = False | |
if self.processor: | |
self.processor.stop() | |
print("System stopped.") | |
def cleanup(self): | |
"""Cleanup resources""" | |
self.stop() | |
def main(): | |
"""Main entry point""" | |
app = RealTimeSpeakerDetection() | |
try: | |
app.start() | |
except Exception as e: | |
print(f"Application error: {e}") | |
finally: | |
app.cleanup() | |
if __name__ == "__main__": | |
main() | |