|
import gradio as gr |
|
import numpy as np |
|
import queue |
|
import torch |
|
import time |
|
import threading |
|
import os |
|
import urllib.request |
|
import torchaudio |
|
from scipy.spatial.distance import cosine |
|
from RealtimeSTT import AudioToTextRecorder |
|
from fastapi import FastAPI, APIRouter |
|
from fastrtc import Stream, ReplyOnPause, AudioStreamHandler |
|
import json |
|
import asyncio |
|
import uvicorn |
|
from queue import Queue |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
SILENCE_THRESHS = [0, 0.4] |
|
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" |
|
FINAL_BEAM_SIZE = 5 |
|
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" |
|
REALTIME_BEAM_SIZE = 5 |
|
TRANSCRIPTION_LANGUAGE = "en" |
|
SILERO_SENSITIVITY = 0.4 |
|
WEBRTC_SENSITIVITY = 3 |
|
MIN_LENGTH_OF_RECORDING = 0.7 |
|
PRE_RECORDING_BUFFER_DURATION = 0.35 |
|
|
|
|
|
DEFAULT_CHANGE_THRESHOLD = 0.65 |
|
EMBEDDING_HISTORY_SIZE = 5 |
|
MIN_SEGMENT_DURATION = 1.5 |
|
DEFAULT_MAX_SPEAKERS = 4 |
|
ABSOLUTE_MAX_SPEAKERS = 8 |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
BUFFER_SIZE = 1024 |
|
CHANNELS = 1 |
|
|
|
|
|
SPEAKER_COLORS = [ |
|
"#FF6B6B", |
|
"#4ECDC4", |
|
"#45B7D1", |
|
"#96CEB4", |
|
"#FFEAA7", |
|
"#DDA0DD", |
|
"#98D8C8", |
|
"#F7DC6F", |
|
] |
|
|
|
SPEAKER_COLOR_NAMES = [ |
|
"Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold" |
|
] |
|
|
|
|
|
class SpeechBrainEncoder: |
|
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" |
|
def __init__(self, device="cpu"): |
|
self.device = device |
|
self.model = None |
|
self.embedding_dim = 192 |
|
self.model_loaded = False |
|
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") |
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
def load_model(self): |
|
"""Load the ECAPA-TDNN model""" |
|
try: |
|
from speechbrain.pretrained import EncoderClassifier |
|
|
|
self.model = EncoderClassifier.from_hparams( |
|
source="speechbrain/spkrec-ecapa-voxceleb", |
|
savedir=self.cache_dir, |
|
run_opts={"device": self.device} |
|
) |
|
|
|
self.model_loaded = True |
|
logger.info("ECAPA-TDNN model loaded successfully!") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error loading ECAPA-TDNN model: {e}") |
|
return False |
|
|
|
def embed_utterance(self, audio, sr=16000): |
|
"""Extract speaker embedding from audio""" |
|
if not self.model_loaded: |
|
raise ValueError("Model not loaded. Call load_model() first.") |
|
|
|
try: |
|
if isinstance(audio, np.ndarray): |
|
|
|
audio = audio.astype(np.float32) |
|
if np.max(np.abs(audio)) > 1.0: |
|
audio = audio / np.max(np.abs(audio)) |
|
waveform = torch.tensor(audio).unsqueeze(0) |
|
else: |
|
waveform = audio.unsqueeze(0) |
|
|
|
|
|
if sr != 16000: |
|
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) |
|
|
|
with torch.no_grad(): |
|
embedding = self.model.encode_batch(waveform) |
|
|
|
return embedding.squeeze().cpu().numpy() |
|
except Exception as e: |
|
logger.error(f"Error extracting embedding: {e}") |
|
return np.zeros(self.embedding_dim) |
|
|
|
|
|
class AudioProcessor: |
|
"""Processes audio data to extract speaker embeddings""" |
|
def __init__(self, encoder): |
|
self.encoder = encoder |
|
self.audio_buffer = [] |
|
self.min_audio_length = int(SAMPLE_RATE * 1.0) |
|
|
|
def add_audio_chunk(self, audio_chunk): |
|
"""Add audio chunk to buffer""" |
|
self.audio_buffer.extend(audio_chunk) |
|
|
|
|
|
max_buffer_size = int(SAMPLE_RATE * 10) |
|
if len(self.audio_buffer) > max_buffer_size: |
|
self.audio_buffer = self.audio_buffer[-max_buffer_size:] |
|
|
|
def extract_embedding_from_buffer(self): |
|
"""Extract embedding from current audio buffer""" |
|
if len(self.audio_buffer) < self.min_audio_length: |
|
return None |
|
|
|
try: |
|
|
|
audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32) |
|
|
|
|
|
if np.max(np.abs(audio_segment)) > 0: |
|
audio_segment = audio_segment / np.max(np.abs(audio_segment)) |
|
else: |
|
return None |
|
|
|
embedding = self.encoder.embed_utterance(audio_segment) |
|
return embedding |
|
except Exception as e: |
|
logger.error(f"Embedding extraction error: {e}") |
|
return None |
|
|
|
|
|
class SpeakerChangeDetector: |
|
"""Improved speaker change detector""" |
|
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): |
|
self.embedding_dim = embedding_dim |
|
self.change_threshold = change_threshold |
|
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
|
self.current_speaker = 0 |
|
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] |
|
self.speaker_centroids = [None] * self.max_speakers |
|
self.last_change_time = time.time() |
|
self.last_similarity = 1.0 |
|
self.active_speakers = set([0]) |
|
self.segment_counter = 0 |
|
|
|
def set_max_speakers(self, max_speakers): |
|
"""Update the maximum number of speakers""" |
|
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
|
|
|
if new_max < self.max_speakers: |
|
|
|
for speaker_id in list(self.active_speakers): |
|
if speaker_id >= new_max: |
|
self.active_speakers.discard(speaker_id) |
|
|
|
if self.current_speaker >= new_max: |
|
self.current_speaker = 0 |
|
|
|
|
|
if new_max > self.max_speakers: |
|
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) |
|
self.speaker_centroids.extend([None] * (new_max - self.max_speakers)) |
|
else: |
|
self.speaker_embeddings = self.speaker_embeddings[:new_max] |
|
self.speaker_centroids = self.speaker_centroids[:new_max] |
|
|
|
self.max_speakers = new_max |
|
|
|
def set_change_threshold(self, threshold): |
|
"""Update the threshold for detecting speaker changes""" |
|
self.change_threshold = max(0.1, min(threshold, 0.95)) |
|
|
|
def add_embedding(self, embedding, timestamp=None): |
|
"""Add a new embedding and detect speaker changes""" |
|
current_time = timestamp or time.time() |
|
self.segment_counter += 1 |
|
|
|
|
|
if not self.speaker_embeddings[0]: |
|
self.speaker_embeddings[0].append(embedding) |
|
self.speaker_centroids[0] = embedding.copy() |
|
self.active_speakers.add(0) |
|
return 0, 1.0 |
|
|
|
|
|
current_centroid = self.speaker_centroids[self.current_speaker] |
|
if current_centroid is not None: |
|
similarity = 1.0 - cosine(embedding, current_centroid) |
|
else: |
|
similarity = 0.5 |
|
|
|
self.last_similarity = similarity |
|
|
|
|
|
time_since_last_change = current_time - self.last_change_time |
|
speaker_changed = False |
|
|
|
if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold: |
|
|
|
best_speaker = self.current_speaker |
|
best_similarity = similarity |
|
|
|
for speaker_id in self.active_speakers: |
|
if speaker_id == self.current_speaker: |
|
continue |
|
|
|
centroid = self.speaker_centroids[speaker_id] |
|
if centroid is not None: |
|
speaker_similarity = 1.0 - cosine(embedding, centroid) |
|
if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold: |
|
best_similarity = speaker_similarity |
|
best_speaker = speaker_id |
|
|
|
|
|
if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers: |
|
for new_id in range(self.max_speakers): |
|
if new_id not in self.active_speakers: |
|
best_speaker = new_id |
|
self.active_speakers.add(new_id) |
|
break |
|
|
|
if best_speaker != self.current_speaker: |
|
self.current_speaker = best_speaker |
|
self.last_change_time = current_time |
|
speaker_changed = True |
|
|
|
|
|
self.speaker_embeddings[self.current_speaker].append(embedding) |
|
|
|
|
|
max_embeddings = 20 |
|
if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings: |
|
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:] |
|
|
|
|
|
if self.speaker_embeddings[self.current_speaker]: |
|
self.speaker_centroids[self.current_speaker] = np.mean( |
|
self.speaker_embeddings[self.current_speaker], axis=0 |
|
) |
|
|
|
return self.current_speaker, similarity |
|
|
|
def get_color_for_speaker(self, speaker_id): |
|
"""Return color for speaker ID""" |
|
if 0 <= speaker_id < len(SPEAKER_COLORS): |
|
return SPEAKER_COLORS[speaker_id] |
|
return "#FFFFFF" |
|
|
|
def get_status_info(self): |
|
"""Return status information""" |
|
speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] |
|
|
|
return { |
|
"current_speaker": self.current_speaker, |
|
"speaker_counts": speaker_counts, |
|
"active_speakers": len(self.active_speakers), |
|
"max_speakers": self.max_speakers, |
|
"last_similarity": self.last_similarity, |
|
"threshold": self.change_threshold, |
|
"segment_counter": self.segment_counter |
|
} |
|
|
|
|
|
class RealtimeSpeakerDiarization: |
|
def __init__(self): |
|
self.encoder = None |
|
self.audio_processor = None |
|
self.speaker_detector = None |
|
self.recorder = None |
|
self.sentence_queue = queue.Queue() |
|
self.full_sentences = [] |
|
self.sentence_speakers = [] |
|
self.pending_sentences = [] |
|
self.current_conversation = "" |
|
self.is_running = False |
|
self.change_threshold = DEFAULT_CHANGE_THRESHOLD |
|
self.max_speakers = DEFAULT_MAX_SPEAKERS |
|
self.last_transcription = "" |
|
self.transcription_lock = threading.Lock() |
|
|
|
def initialize_models(self): |
|
"""Initialize the speaker encoder model""" |
|
try: |
|
device_str = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {device_str}") |
|
|
|
self.encoder = SpeechBrainEncoder(device=device_str) |
|
success = self.encoder.load_model() |
|
|
|
if success: |
|
self.audio_processor = AudioProcessor(self.encoder) |
|
self.speaker_detector = SpeakerChangeDetector( |
|
embedding_dim=self.encoder.embedding_dim, |
|
change_threshold=self.change_threshold, |
|
max_speakers=self.max_speakers |
|
) |
|
logger.info("Models initialized successfully!") |
|
return True |
|
else: |
|
logger.error("Failed to load models") |
|
return False |
|
except Exception as e: |
|
logger.error(f"Model initialization error: {e}") |
|
return False |
|
|
|
def feed_audio(self, audio_data): |
|
"""Feed audio data directly to the recorder for live transcription""" |
|
if not self.is_running or not self.recorder: |
|
return |
|
|
|
try: |
|
|
|
if isinstance(audio_data, np.ndarray): |
|
if audio_data.dtype != np.float32: |
|
audio_data = audio_data.astype(np.float32) |
|
|
|
|
|
audio_int16 = (audio_data * 32767).astype(np.int16) |
|
audio_bytes = audio_int16.tobytes() |
|
|
|
|
|
self.recorder.feed_audio(audio_bytes) |
|
|
|
|
|
self.process_audio_chunk(audio_data) |
|
|
|
elif isinstance(audio_data, bytes): |
|
|
|
self.recorder.feed_audio(audio_data) |
|
|
|
|
|
audio_int16 = np.frombuffer(audio_data, dtype=np.int16) |
|
audio_float = audio_int16.astype(np.float32) / 32768.0 |
|
self.process_audio_chunk(audio_float) |
|
|
|
logger.debug("Audio fed to recorder") |
|
except Exception as e: |
|
logger.error(f"Error feeding audio: {e}") |
|
|
|
def live_text_detected(self, text): |
|
"""Callback for real-time transcription updates""" |
|
with self.transcription_lock: |
|
self.last_transcription = text.strip() |
|
|
|
|
|
self.update_conversation_display() |
|
|
|
def process_final_text(self, text): |
|
"""Process final transcribed text with speaker embedding""" |
|
text = text.strip() |
|
if text: |
|
try: |
|
|
|
audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None) |
|
if audio_bytes: |
|
self.sentence_queue.put((text, audio_bytes)) |
|
else: |
|
|
|
self.sentence_queue.put((text, None)) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing final text: {e}") |
|
|
|
def process_sentence_queue(self): |
|
"""Process sentences in the queue for speaker detection""" |
|
while self.is_running: |
|
try: |
|
text, audio_bytes = self.sentence_queue.get(timeout=1) |
|
|
|
current_speaker = self.speaker_detector.current_speaker |
|
|
|
if audio_bytes: |
|
|
|
audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16) |
|
audio_float = audio_int16.astype(np.float32) / 32768.0 |
|
|
|
|
|
embedding = self.audio_processor.encoder.embed_utterance(audio_float) |
|
if embedding is not None: |
|
current_speaker, similarity = self.speaker_detector.add_embedding(embedding) |
|
|
|
|
|
with self.transcription_lock: |
|
self.full_sentences.append((text, current_speaker)) |
|
self.update_conversation_display() |
|
|
|
except queue.Empty: |
|
continue |
|
except Exception as e: |
|
logger.error(f"Error processing sentence: {e}") |
|
|
|
def update_conversation_display(self): |
|
"""Update the conversation display""" |
|
try: |
|
sentences_with_style = [] |
|
|
|
for sentence_text, speaker_id in self.full_sentences: |
|
color = self.speaker_detector.get_color_for_speaker(speaker_id) |
|
speaker_name = f"Speaker {speaker_id + 1}" |
|
sentences_with_style.append( |
|
f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> ' |
|
f'<span style="color:#333333;">{sentence_text}</span>' |
|
) |
|
|
|
|
|
if self.last_transcription: |
|
current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker) |
|
current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}" |
|
sentences_with_style.append( |
|
f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> ' |
|
f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>' |
|
) |
|
|
|
if sentences_with_style: |
|
self.current_conversation = "<br><br>".join(sentences_with_style) |
|
else: |
|
self.current_conversation = "<i>Waiting for speech input...</i>" |
|
|
|
except Exception as e: |
|
logger.error(f"Error updating conversation display: {e}") |
|
self.current_conversation = f"<i>Error: {str(e)}</i>" |
|
|
|
def start_recording(self): |
|
"""Start the recording and transcription process""" |
|
if self.encoder is None: |
|
return "Please initialize models first!" |
|
|
|
try: |
|
|
|
recorder_config = { |
|
'spinner': False, |
|
'use_microphone': False, |
|
'model': FINAL_TRANSCRIPTION_MODEL, |
|
'language': TRANSCRIPTION_LANGUAGE, |
|
'silero_sensitivity': SILERO_SENSITIVITY, |
|
'webrtc_sensitivity': WEBRTC_SENSITIVITY, |
|
'post_speech_silence_duration': SILENCE_THRESHS[1], |
|
'min_length_of_recording': MIN_LENGTH_OF_RECORDING, |
|
'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, |
|
'min_gap_between_recordings': 0, |
|
'enable_realtime_transcription': True, |
|
'realtime_processing_pause': 0.1, |
|
'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, |
|
'on_realtime_transcription_update': self.live_text_detected, |
|
'beam_size': FINAL_BEAM_SIZE, |
|
'beam_size_realtime': REALTIME_BEAM_SIZE, |
|
'sample_rate': SAMPLE_RATE, |
|
} |
|
|
|
self.recorder = AudioToTextRecorder(**recorder_config) |
|
|
|
|
|
self.is_running = True |
|
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True) |
|
self.sentence_thread.start() |
|
|
|
self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True) |
|
self.transcription_thread.start() |
|
|
|
return "Recording started successfully!" |
|
|
|
except Exception as e: |
|
logger.error(f"Error starting recording: {e}") |
|
return f"Error starting recording: {e}" |
|
|
|
def run_transcription(self): |
|
"""Run the transcription loop""" |
|
try: |
|
while self.is_running: |
|
self.recorder.text(self.process_final_text) |
|
except Exception as e: |
|
logger.error(f"Transcription error: {e}") |
|
|
|
def stop_recording(self): |
|
"""Stop the recording process""" |
|
self.is_running = False |
|
if self.recorder: |
|
self.recorder.stop() |
|
return "Recording stopped!" |
|
|
|
def clear_conversation(self): |
|
"""Clear all conversation data""" |
|
with self.transcription_lock: |
|
self.full_sentences = [] |
|
self.last_transcription = "" |
|
self.current_conversation = "Conversation cleared!" |
|
|
|
if self.speaker_detector: |
|
self.speaker_detector = SpeakerChangeDetector( |
|
embedding_dim=self.encoder.embedding_dim, |
|
change_threshold=self.change_threshold, |
|
max_speakers=self.max_speakers |
|
) |
|
|
|
return "Conversation cleared!" |
|
|
|
def update_settings(self, threshold, max_speakers): |
|
"""Update speaker detection settings""" |
|
self.change_threshold = threshold |
|
self.max_speakers = max_speakers |
|
|
|
if self.speaker_detector: |
|
self.speaker_detector.set_change_threshold(threshold) |
|
self.speaker_detector.set_max_speakers(max_speakers) |
|
|
|
return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}" |
|
|
|
def get_formatted_conversation(self): |
|
"""Get the formatted conversation""" |
|
return self.current_conversation |
|
|
|
def get_status_info(self): |
|
"""Get current status information""" |
|
if not self.speaker_detector: |
|
return "Speaker detector not initialized" |
|
|
|
try: |
|
status = self.speaker_detector.get_status_info() |
|
|
|
status_lines = [ |
|
f"**Current Speaker:** {status['current_speaker'] + 1}", |
|
f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}", |
|
f"**Last Similarity:** {status['last_similarity']:.3f}", |
|
f"**Change Threshold:** {status['threshold']:.2f}", |
|
f"**Total Sentences:** {len(self.full_sentences)}", |
|
f"**Segments Processed:** {status['segment_counter']}", |
|
"", |
|
"**Speaker Activity:**" |
|
] |
|
|
|
for i in range(status['max_speakers']): |
|
color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}" |
|
count = status['speaker_counts'][i] |
|
active = "🟢" if count > 0 else "⚫" |
|
status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments") |
|
|
|
return "\n".join(status_lines) |
|
|
|
except Exception as e: |
|
return f"Error getting status: {e}" |
|
|
|
def process_audio_chunk(self, audio_data, sample_rate=16000): |
|
"""Process audio chunk from FastRTC input""" |
|
if not self.is_running or self.audio_processor is None: |
|
return |
|
|
|
try: |
|
|
|
if isinstance(audio_data, np.ndarray): |
|
if audio_data.dtype != np.float32: |
|
audio_data = audio_data.astype(np.float32) |
|
else: |
|
audio_data = np.array(audio_data, dtype=np.float32) |
|
|
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten() |
|
|
|
|
|
if np.max(np.abs(audio_data)) > 1.0: |
|
audio_data = audio_data / np.max(np.abs(audio_data)) |
|
|
|
|
|
self.audio_processor.add_audio_chunk(audio_data) |
|
|
|
|
|
if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: |
|
embedding = self.audio_processor.extract_embedding_from_buffer() |
|
if embedding is not None: |
|
self.speaker_detector.add_embedding(embedding) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing audio chunk: {e}") |
|
|
|
|
|
|
|
class DiarizationAudioHandler(AudioStreamHandler): |
|
def __init__(self, diarization_system): |
|
super().__init__() |
|
self.diarization_system = diarization_system |
|
|
|
def receive(self, frame): |
|
"""Process incoming audio frame""" |
|
if not self.diarization_system.is_running: |
|
return |
|
|
|
try: |
|
|
|
sample_rate, audio_array = frame |
|
|
|
|
|
self.diarization_system.feed_audio(audio_array) |
|
except Exception as e: |
|
logger.error(f"Error processing FastRTC audio: {e}") |
|
|
|
def copy(self): |
|
"""Return a fresh handler instance""" |
|
return DiarizationAudioHandler(self.diarization_system) |
|
|
|
def shutdown(self): |
|
"""Clean up resources""" |
|
pass |
|
|
|
def start_up(self): |
|
"""Initialize resources""" |
|
logger.info("DiarizationAudioHandler started") |
|
|
|
|
|
|
|
diarization_system = RealtimeSpeakerDiarization() |
|
|
|
def initialize_system(): |
|
"""Initialize the diarization system""" |
|
try: |
|
success = diarization_system.initialize_models() |
|
if success: |
|
return "✅ System initialized successfully!" |
|
else: |
|
return "❌ Failed to initialize system. Check logs for details." |
|
except Exception as e: |
|
logger.error(f"Initialization error: {e}") |
|
return f"❌ Initialization error: {str(e)}" |
|
|
|
def start_recording(): |
|
"""Start recording and transcription""" |
|
try: |
|
result = diarization_system.start_recording() |
|
return result |
|
except Exception as e: |
|
return f"❌ Failed to start recording: {str(e)}" |
|
|
|
def stop_recording(): |
|
"""Stop recording and transcription""" |
|
try: |
|
result = diarization_system.stop_recording() |
|
return f"⏹️ {result}" |
|
except Exception as e: |
|
return f"❌ Failed to stop recording: {str(e)}" |
|
|
|
def clear_conversation(): |
|
"""Clear the conversation""" |
|
try: |
|
result = diarization_system.clear_conversation() |
|
return f"🗑️ {result}" |
|
except Exception as e: |
|
return f"❌ Failed to clear conversation: {str(e)}" |
|
|
|
def update_settings(threshold, max_speakers): |
|
"""Update system settings""" |
|
try: |
|
result = diarization_system.update_settings(threshold, max_speakers) |
|
return f"⚙️ {result}" |
|
except Exception as e: |
|
return f"❌ Failed to update settings: {str(e)}" |
|
|
|
def get_conversation(): |
|
"""Get the current conversation""" |
|
try: |
|
return diarization_system.get_formatted_conversation() |
|
except Exception as e: |
|
return f"<i>Error getting conversation: {str(e)}</i>" |
|
|
|
def get_status(): |
|
"""Get system status""" |
|
try: |
|
return diarization_system.get_status_info() |
|
except Exception as e: |
|
return f"Error getting status: {str(e)}" |
|
|
|
|
|
def diarization_handler(audio_data): |
|
"""Handler function for FastRTC stream""" |
|
try: |
|
|
|
diarization_system.process_audio_chunk(audio_data[1], audio_data[0]) |
|
|
|
|
|
|
|
|
|
yield audio_data |
|
|
|
except Exception as e: |
|
logger.error(f"Error in diarization handler: {e}") |
|
|
|
|
|
stream = Stream( |
|
handler=ReplyOnPause(diarization_handler), |
|
modality="audio", |
|
mode="send-receive", |
|
ui_args={ |
|
"title": "Real-time Speaker Diarization", |
|
"description": "Live transcription with automatic speaker identification" |
|
} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System") |
|
parser.add_argument("--mode", choices=["ui", "api", "both"], default="ui", |
|
help="Run mode: FastRTC UI, API only, or both") |
|
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") |
|
parser.add_argument("--port", type=int, default=7860, help="Port to bind to") |
|
parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
initialize_system() |
|
start_recording() |
|
|
|
if args.mode == "ui": |
|
|
|
stream.ui.launch( |
|
server_name=args.host, |
|
server_port=args.port, |
|
share=True, |
|
show_error=True |
|
) |
|
|
|
elif args.mode == "api": |
|
|
|
app = FastAPI() |
|
stream.mount(app) |
|
uvicorn.run( |
|
app, |
|
host=args.host, |
|
port=args.port, |
|
log_level="info" |
|
) |
|
|
|
elif args.mode == "both": |
|
|
|
import threading |
|
|
|
def run_fastapi(): |
|
app = FastAPI() |
|
stream.mount(app) |
|
uvicorn.run( |
|
app, |
|
host=args.host, |
|
port=args.api_port, |
|
log_level="info" |
|
) |
|
|
|
|
|
api_thread = threading.Thread(target=run_fastapi, daemon=True) |
|
api_thread.start() |
|
|
|
|
|
stream.ui.launch( |
|
server_name=args.host, |
|
server_port=args.port, |
|
share=True, |
|
show_error=True |
|
) |