Spaces:
Sleeping
Sleeping
Commit
·
b37c0fc
1
Parent(s):
33445e6
Code fixing
Browse files- app.py +178 -213
- realtime_diarize.py +970 -0
app.py
CHANGED
@@ -8,18 +8,14 @@ import os
|
|
8 |
import urllib.request
|
9 |
import torchaudio
|
10 |
from scipy.spatial.distance import cosine
|
11 |
-
from RealtimeSTT import AudioToTextRecorder
|
12 |
import json
|
13 |
import io
|
14 |
import wave
|
|
|
15 |
|
16 |
# Simplified configuration parameters
|
17 |
SILENCE_THRESHS = [0, 0.4]
|
18 |
-
FINAL_TRANSCRIPTION_MODEL = "
|
19 |
-
FINAL_BEAM_SIZE = 5
|
20 |
-
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
|
21 |
-
REALTIME_BEAM_SIZE = 5
|
22 |
-
TRANSCRIPTION_LANGUAGE = "en"
|
23 |
SILERO_SENSITIVITY = 0.4
|
24 |
WEBRTC_SENSITIVITY = 3
|
25 |
MIN_LENGTH_OF_RECORDING = 0.7
|
@@ -271,52 +267,56 @@ class SpeakerChangeDetector:
|
|
271 |
}
|
272 |
|
273 |
|
274 |
-
class
|
275 |
-
"""
|
276 |
def __init__(self, diarization_system):
|
|
|
277 |
self.diarization_system = diarization_system
|
278 |
-
self.
|
279 |
-
self.
|
280 |
-
self.
|
281 |
-
self.
|
282 |
|
283 |
-
def
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
-
#
|
301 |
-
|
302 |
-
audio_array = audio_array[:, 0]
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
self.
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
except Exception as e:
|
319 |
-
print(f"Error processing WebRTC audio: {e}")
|
320 |
|
321 |
|
322 |
class RealtimeSpeakerDiarization:
|
@@ -324,8 +324,8 @@ class RealtimeSpeakerDiarization:
|
|
324 |
self.encoder = None
|
325 |
self.audio_processor = None
|
326 |
self.speaker_detector = None
|
327 |
-
self.
|
328 |
-
self.
|
329 |
self.sentence_queue = queue.Queue()
|
330 |
self.full_sentences = []
|
331 |
self.sentence_speakers = []
|
@@ -352,7 +352,6 @@ class RealtimeSpeakerDiarization:
|
|
352 |
change_threshold=self.change_threshold,
|
353 |
max_speakers=self.max_speakers
|
354 |
)
|
355 |
-
self.webrtc_processor = WebRTCAudioProcessor(self)
|
356 |
print("ECAPA-TDNN model loaded successfully!")
|
357 |
return True
|
358 |
else:
|
@@ -362,45 +361,69 @@ class RealtimeSpeakerDiarization:
|
|
362 |
print(f"Model initialization error: {e}")
|
363 |
return False
|
364 |
|
365 |
-
def
|
366 |
-
"""
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
374 |
)
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
-
def
|
386 |
-
"""Process
|
387 |
-
|
388 |
-
if text:
|
389 |
try:
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
393 |
except Exception as e:
|
394 |
-
print(f"Error processing
|
|
|
395 |
|
396 |
def process_sentence_queue(self):
|
397 |
"""Process sentences in the queue for speaker detection"""
|
398 |
while self.is_running:
|
399 |
try:
|
400 |
-
text,
|
401 |
|
402 |
# Convert audio data to int16
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
# Extract speaker embedding
|
406 |
speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
|
@@ -425,64 +448,10 @@ class RealtimeSpeakerDiarization:
|
|
425 |
except Exception as e:
|
426 |
print(f"Error processing sentence: {e}")
|
427 |
|
428 |
-
def
|
429 |
-
"""
|
430 |
-
if self.encoder is None:
|
431 |
-
return "Please initialize models first!"
|
432 |
-
|
433 |
-
try:
|
434 |
-
# Setup recorder configuration for WebRTC input
|
435 |
-
recorder_config = {
|
436 |
-
'spinner': False,
|
437 |
-
'use_microphone': False, # We'll feed audio manually
|
438 |
-
'model': FINAL_TRANSCRIPTION_MODEL,
|
439 |
-
'language': TRANSCRIPTION_LANGUAGE,
|
440 |
-
'silero_sensitivity': SILERO_SENSITIVITY,
|
441 |
-
'webrtc_sensitivity': WEBRTC_SENSITIVITY,
|
442 |
-
'post_speech_silence_duration': SILENCE_THRESHS[1],
|
443 |
-
'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
|
444 |
-
'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
|
445 |
-
'min_gap_between_recordings': 0,
|
446 |
-
'enable_realtime_transcription': True,
|
447 |
-
'realtime_processing_pause': 0,
|
448 |
-
'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
|
449 |
-
'on_realtime_transcription_update': self.live_text_detected,
|
450 |
-
'beam_size': FINAL_BEAM_SIZE,
|
451 |
-
'beam_size_realtime': REALTIME_BEAM_SIZE,
|
452 |
-
'buffer_size': BUFFER_SIZE,
|
453 |
-
'sample_rate': SAMPLE_RATE,
|
454 |
-
}
|
455 |
-
|
456 |
-
self.recorder = AudioToTextRecorder(**recorder_config)
|
457 |
-
|
458 |
-
# Start sentence processing thread
|
459 |
-
self.is_running = True
|
460 |
-
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
|
461 |
-
self.sentence_thread.start()
|
462 |
-
|
463 |
-
# Start transcription thread
|
464 |
-
self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
|
465 |
-
self.transcription_thread.start()
|
466 |
-
|
467 |
-
return "Recording started successfully! WebRTC audio input ready."
|
468 |
-
|
469 |
-
except Exception as e:
|
470 |
-
return f"Error starting recording: {e}"
|
471 |
-
|
472 |
-
def run_transcription(self):
|
473 |
-
"""Run the transcription loop"""
|
474 |
-
try:
|
475 |
-
while self.is_running:
|
476 |
-
self.recorder.text(self.process_final_text)
|
477 |
-
except Exception as e:
|
478 |
-
print(f"Transcription error: {e}")
|
479 |
-
|
480 |
-
def stop_recording(self):
|
481 |
-
"""Stop the recording process"""
|
482 |
self.is_running = False
|
483 |
-
|
484 |
-
self.recorder.stop()
|
485 |
-
return "Recording stopped!"
|
486 |
|
487 |
def clear_conversation(self):
|
488 |
"""Clear all conversation data"""
|
@@ -575,79 +544,33 @@ class RealtimeSpeakerDiarization:
|
|
575 |
diarization_system = RealtimeSpeakerDiarization()
|
576 |
|
577 |
|
578 |
-
|
579 |
-
"""Initialize the diarization system"""
|
580 |
-
success = diarization_system.initialize_models()
|
581 |
-
if success:
|
582 |
-
return "✅ System initialized successfully! Models loaded."
|
583 |
-
else:
|
584 |
-
return "❌ Failed to initialize system. Please check the logs."
|
585 |
-
|
586 |
-
|
587 |
-
def start_recording():
|
588 |
-
"""Start recording and transcription"""
|
589 |
-
return diarization_system.start_recording()
|
590 |
-
|
591 |
-
|
592 |
-
def stop_recording():
|
593 |
-
"""Stop recording and transcription"""
|
594 |
-
return diarization_system.stop_recording()
|
595 |
-
|
596 |
-
|
597 |
-
def clear_conversation():
|
598 |
-
"""Clear the conversation"""
|
599 |
-
return diarization_system.clear_conversation()
|
600 |
-
|
601 |
-
|
602 |
-
def update_settings(threshold, max_speakers):
|
603 |
-
"""Update system settings"""
|
604 |
-
return diarization_system.update_settings(threshold, max_speakers)
|
605 |
-
|
606 |
-
|
607 |
-
def get_conversation():
|
608 |
-
"""Get the current conversation"""
|
609 |
-
return diarization_system.get_formatted_conversation()
|
610 |
-
|
611 |
-
|
612 |
-
def get_status():
|
613 |
-
"""Get system status"""
|
614 |
-
return diarization_system.get_status_info()
|
615 |
-
|
616 |
-
|
617 |
-
def process_audio_stream(audio):
|
618 |
-
"""Process audio stream from WebRTC"""
|
619 |
-
if diarization_system.webrtc_processor and diarization_system.is_running:
|
620 |
-
diarization_system.webrtc_processor.process_audio(audio, SAMPLE_RATE)
|
621 |
-
return None
|
622 |
-
|
623 |
-
|
624 |
-
# Create Gradio interface
|
625 |
def create_interface():
|
626 |
-
|
|
|
|
|
627 |
gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
|
628 |
-
gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding using
|
629 |
|
630 |
with gr.Row():
|
631 |
with gr.Column(scale=2):
|
632 |
-
# WebRTC Audio Input
|
633 |
-
audio_input = gr.Audio(
|
634 |
-
sources=["microphone"],
|
635 |
-
streaming=True,
|
636 |
-
label="🎙️ Microphone Input",
|
637 |
-
type="numpy"
|
638 |
-
)
|
639 |
-
|
640 |
# Main conversation display
|
641 |
conversation_output = gr.HTML(
|
642 |
-
value="<i>Click 'Initialize System' to
|
643 |
label="Live Conversation"
|
644 |
)
|
645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
# Control buttons
|
647 |
with gr.Row():
|
648 |
init_btn = gr.Button("🔧 Initialize System", variant="secondary")
|
649 |
-
start_btn = gr.Button("🎙️ Start
|
650 |
-
stop_btn = gr.Button("⏹️ Stop
|
651 |
clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)
|
652 |
|
653 |
# Status display
|
@@ -685,13 +608,30 @@ def create_interface():
|
|
685 |
gr.Markdown("## 📝 Instructions")
|
686 |
gr.Markdown("""
|
687 |
1. Click **Initialize System** to load models
|
688 |
-
2. Click **Start
|
689 |
3. Allow microphone access when prompted
|
690 |
4. Speak into your microphone
|
691 |
5. Watch real-time transcription with speaker labels
|
692 |
6. Adjust settings as needed
|
693 |
""")
|
694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
# Speaker color legend
|
696 |
gr.Markdown("## 🎨 Speaker Colors")
|
697 |
color_info = []
|
@@ -702,7 +642,7 @@ def create_interface():
|
|
702 |
|
703 |
# Auto-refresh conversation and status
|
704 |
def refresh_display():
|
705 |
-
return
|
706 |
|
707 |
# Event handlers
|
708 |
def on_initialize():
|
@@ -712,7 +652,7 @@ def create_interface():
|
|
712 |
result,
|
713 |
gr.update(interactive=True), # start_btn
|
714 |
gr.update(interactive=True), # clear_btn
|
715 |
-
|
716 |
get_status()
|
717 |
)
|
718 |
else:
|
@@ -720,26 +660,58 @@ def create_interface():
|
|
720 |
result,
|
721 |
gr.update(interactive=False), # start_btn
|
722 |
gr.update(interactive=False), # clear_btn
|
723 |
-
|
724 |
get_status()
|
725 |
)
|
726 |
|
727 |
-
def
|
728 |
-
result =
|
729 |
return (
|
730 |
result,
|
731 |
gr.update(interactive=False), # start_btn
|
732 |
gr.update(interactive=True), # stop_btn
|
733 |
)
|
734 |
|
735 |
-
def
|
736 |
-
result =
|
737 |
return (
|
738 |
result,
|
739 |
gr.update(interactive=True), # start_btn
|
740 |
gr.update(interactive=False), # stop_btn
|
741 |
)
|
742 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
# Connect event handlers
|
744 |
init_btn.click(
|
745 |
on_initialize,
|
@@ -747,12 +719,12 @@ def create_interface():
|
|
747 |
)
|
748 |
|
749 |
start_btn.click(
|
750 |
-
|
751 |
outputs=[status_output, start_btn, stop_btn]
|
752 |
)
|
753 |
|
754 |
stop_btn.click(
|
755 |
-
|
756 |
outputs=[status_output, start_btn, stop_btn]
|
757 |
)
|
758 |
|
@@ -767,14 +739,7 @@ def create_interface():
|
|
767 |
outputs=[status_output]
|
768 |
)
|
769 |
|
770 |
-
#
|
771 |
-
audio_input.stream(
|
772 |
-
process_audio_stream,
|
773 |
-
inputs=[audio_input],
|
774 |
-
outputs=[]
|
775 |
-
)
|
776 |
-
|
777 |
-
# Auto-refresh every 2 seconds when recording
|
778 |
refresh_timer = gr.Timer(2.0)
|
779 |
refresh_timer.tick(
|
780 |
refresh_display,
|
|
|
8 |
import urllib.request
|
9 |
import torchaudio
|
10 |
from scipy.spatial.distance import cosine
|
|
|
11 |
import json
|
12 |
import io
|
13 |
import wave
|
14 |
+
from fastrtc import Stream, ReplyOnPause, AsyncStreamHandler, get_stt_model
|
15 |
|
16 |
# Simplified configuration parameters
|
17 |
SILENCE_THRESHS = [0, 0.4]
|
18 |
+
FINAL_TRANSCRIPTION_MODEL = "moonshine/base" # Using FastRTC's moonshine model
|
|
|
|
|
|
|
|
|
19 |
SILERO_SENSITIVITY = 0.4
|
20 |
WEBRTC_SENSITIVITY = 3
|
21 |
MIN_LENGTH_OF_RECORDING = 0.7
|
|
|
267 |
}
|
268 |
|
269 |
|
270 |
+
class DiarizationStreamHandler(AsyncStreamHandler):
|
271 |
+
"""FastRTC stream handler for real-time diarization"""
|
272 |
def __init__(self, diarization_system):
|
273 |
+
super().__init__(input_sample_rate=16000)
|
274 |
self.diarization_system = diarization_system
|
275 |
+
self.stt_model = get_stt_model(model=FINAL_TRANSCRIPTION_MODEL)
|
276 |
+
self.current_text = ""
|
277 |
+
self.current_audio_buffer = []
|
278 |
+
self.transcript_queue = queue.Queue()
|
279 |
|
280 |
+
def copy(self):
|
281 |
+
return DiarizationStreamHandler(self.diarization_system)
|
282 |
+
|
283 |
+
async def start_up(self):
|
284 |
+
"""Initialize the stream handler"""
|
285 |
+
pass
|
286 |
+
|
287 |
+
async def receive(self, frame):
|
288 |
+
"""Process incoming audio frame"""
|
289 |
+
# Extract audio data
|
290 |
+
sample_rate, audio_data = frame
|
291 |
+
|
292 |
+
# Convert to numpy array if needed
|
293 |
+
if isinstance(audio_data, torch.Tensor):
|
294 |
+
audio_data = audio_data.numpy()
|
295 |
+
|
296 |
+
# Add to buffer
|
297 |
+
self.current_audio_buffer.append(audio_data)
|
298 |
+
|
299 |
+
# If buffer is large enough, process it
|
300 |
+
if len(self.current_audio_buffer) > 3: # Process ~1.5 seconds of audio
|
301 |
+
# Concatenate audio data
|
302 |
+
combined_audio = np.concatenate(self.current_audio_buffer)
|
303 |
|
304 |
+
# Run speech-to-text
|
305 |
+
text = self.stt_model.stt((16000, combined_audio))
|
|
|
306 |
|
307 |
+
if text and text.strip():
|
308 |
+
# Save text and audio for processing
|
309 |
+
self.transcript_queue.put((text, combined_audio))
|
310 |
+
self.current_text = text
|
311 |
+
|
312 |
+
# Reset buffer but keep some overlap
|
313 |
+
if len(self.current_audio_buffer) > 5:
|
314 |
+
self.current_audio_buffer = self.current_audio_buffer[-2:]
|
315 |
+
|
316 |
+
async def emit(self):
|
317 |
+
"""Emit processed data"""
|
318 |
+
# Return current text as dummy; actual processing is done in background
|
319 |
+
return self.current_text
|
|
|
|
|
|
|
320 |
|
321 |
|
322 |
class RealtimeSpeakerDiarization:
|
|
|
324 |
self.encoder = None
|
325 |
self.audio_processor = None
|
326 |
self.speaker_detector = None
|
327 |
+
self.stream = None
|
328 |
+
self.stream_handler = None
|
329 |
self.sentence_queue = queue.Queue()
|
330 |
self.full_sentences = []
|
331 |
self.sentence_speakers = []
|
|
|
352 |
change_threshold=self.change_threshold,
|
353 |
max_speakers=self.max_speakers
|
354 |
)
|
|
|
355 |
print("ECAPA-TDNN model loaded successfully!")
|
356 |
return True
|
357 |
else:
|
|
|
361 |
print(f"Model initialization error: {e}")
|
362 |
return False
|
363 |
|
364 |
+
def start_stream(self, app):
|
365 |
+
"""Start the FastRTC stream"""
|
366 |
+
if self.encoder is None:
|
367 |
+
return "Please initialize models first!"
|
368 |
+
|
369 |
+
try:
|
370 |
+
# Create a FastRTC stream handler
|
371 |
+
self.stream_handler = DiarizationStreamHandler(self)
|
372 |
+
|
373 |
+
# Create FastRTC stream
|
374 |
+
self.stream = Stream(
|
375 |
+
handler=self.stream_handler,
|
376 |
+
modality="audio",
|
377 |
+
mode="send-receive"
|
378 |
)
|
379 |
+
|
380 |
+
# Mount the stream to the provided FastAPI app
|
381 |
+
self.stream.mount(app)
|
382 |
+
|
383 |
+
# Start sentence processing thread
|
384 |
+
self.is_running = True
|
385 |
+
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
|
386 |
+
self.sentence_thread.start()
|
387 |
+
|
388 |
+
# Start diarization processor thread
|
389 |
+
self.diarization_thread = threading.Thread(target=self.process_transcript_queue, daemon=True)
|
390 |
+
self.diarization_thread.start()
|
391 |
+
|
392 |
+
return "Stream started successfully! Ready for audio input."
|
393 |
+
|
394 |
+
except Exception as e:
|
395 |
+
return f"Error starting stream: {e}"
|
396 |
|
397 |
+
def process_transcript_queue(self):
|
398 |
+
"""Process transcripts from the stream handler"""
|
399 |
+
while self.is_running:
|
|
|
400 |
try:
|
401 |
+
if self.stream_handler and not self.stream_handler.transcript_queue.empty():
|
402 |
+
text, audio_data = self.stream_handler.transcript_queue.get(timeout=1)
|
403 |
+
|
404 |
+
# Add to sentence queue for diarization
|
405 |
+
self.pending_sentences.append(text)
|
406 |
+
self.sentence_queue.put((text, audio_data))
|
407 |
+
except queue.Empty:
|
408 |
+
time.sleep(0.1) # Short sleep to prevent CPU hogging
|
409 |
except Exception as e:
|
410 |
+
print(f"Error processing transcript queue: {e}")
|
411 |
+
time.sleep(0.5) # Slightly longer sleep on error
|
412 |
|
413 |
def process_sentence_queue(self):
|
414 |
"""Process sentences in the queue for speaker detection"""
|
415 |
while self.is_running:
|
416 |
try:
|
417 |
+
text, audio_data = self.sentence_queue.get(timeout=1)
|
418 |
|
419 |
# Convert audio data to int16
|
420 |
+
if isinstance(audio_data, np.ndarray):
|
421 |
+
if audio_data.dtype != np.int16:
|
422 |
+
audio_int16 = (audio_data * 32767).astype(np.int16)
|
423 |
+
else:
|
424 |
+
audio_int16 = audio_data
|
425 |
+
else:
|
426 |
+
audio_int16 = np.int16(audio_data * 32767)
|
427 |
|
428 |
# Extract speaker embedding
|
429 |
speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
|
|
|
448 |
except Exception as e:
|
449 |
print(f"Error processing sentence: {e}")
|
450 |
|
451 |
+
def stop_stream(self):
|
452 |
+
"""Stop the stream and processing"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
self.is_running = False
|
454 |
+
return "Stream stopped!"
|
|
|
|
|
455 |
|
456 |
def clear_conversation(self):
|
457 |
"""Clear all conversation data"""
|
|
|
544 |
diarization_system = RealtimeSpeakerDiarization()
|
545 |
|
546 |
|
547 |
+
# Create Gradio interface with FastAPI app integrated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
def create_interface():
|
549 |
+
app = gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome())
|
550 |
+
|
551 |
+
with app:
|
552 |
gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
|
553 |
+
gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding using FastRTC.")
|
554 |
|
555 |
with gr.Row():
|
556 |
with gr.Column(scale=2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
# Main conversation display
|
558 |
conversation_output = gr.HTML(
|
559 |
+
value="<i>Click 'Initialize System' and then 'Start Stream' to begin...</i>",
|
560 |
label="Live Conversation"
|
561 |
)
|
562 |
|
563 |
+
# FastRTC microphone widget for visualization only (the real audio comes through FastRTC stream)
|
564 |
+
audio_widget = gr.Audio(
|
565 |
+
label="🎙️ Microphone Input (Click Start Stream to enable)",
|
566 |
+
type="microphone"
|
567 |
+
)
|
568 |
+
|
569 |
# Control buttons
|
570 |
with gr.Row():
|
571 |
init_btn = gr.Button("🔧 Initialize System", variant="secondary")
|
572 |
+
start_btn = gr.Button("🎙️ Start Stream", variant="primary", interactive=False)
|
573 |
+
stop_btn = gr.Button("⏹️ Stop Stream", variant="stop", interactive=False)
|
574 |
clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)
|
575 |
|
576 |
# Status display
|
|
|
608 |
gr.Markdown("## 📝 Instructions")
|
609 |
gr.Markdown("""
|
610 |
1. Click **Initialize System** to load models
|
611 |
+
2. Click **Start Stream** to begin processing
|
612 |
3. Allow microphone access when prompted
|
613 |
4. Speak into your microphone
|
614 |
5. Watch real-time transcription with speaker labels
|
615 |
6. Adjust settings as needed
|
616 |
""")
|
617 |
|
618 |
+
# QR code for mobile access
|
619 |
+
gr.Markdown("## 📱 Mobile Access")
|
620 |
+
gr.Markdown("Scan this QR code to access from mobile device:")
|
621 |
+
qr_code = gr.HTML("""
|
622 |
+
<div id="qrcode" style="text-align: center;"></div>
|
623 |
+
<script src="https://cdn.jsdelivr.net/npm/qrcode-generator@1.4.4/qrcode.min.js"></script>
|
624 |
+
<script>
|
625 |
+
setTimeout(function() {
|
626 |
+
var currentUrl = window.location.href;
|
627 |
+
var qr = qrcode(0, 'M');
|
628 |
+
qr.addData(currentUrl);
|
629 |
+
qr.make();
|
630 |
+
document.getElementById('qrcode').innerHTML = qr.createImgTag(5);
|
631 |
+
}, 1000);
|
632 |
+
</script>
|
633 |
+
""")
|
634 |
+
|
635 |
# Speaker color legend
|
636 |
gr.Markdown("## 🎨 Speaker Colors")
|
637 |
color_info = []
|
|
|
642 |
|
643 |
# Auto-refresh conversation and status
|
644 |
def refresh_display():
|
645 |
+
return get_formatted_conversation(), get_status()
|
646 |
|
647 |
# Event handlers
|
648 |
def on_initialize():
|
|
|
652 |
result,
|
653 |
gr.update(interactive=True), # start_btn
|
654 |
gr.update(interactive=True), # clear_btn
|
655 |
+
get_formatted_conversation(),
|
656 |
get_status()
|
657 |
)
|
658 |
else:
|
|
|
660 |
result,
|
661 |
gr.update(interactive=False), # start_btn
|
662 |
gr.update(interactive=False), # clear_btn
|
663 |
+
get_formatted_conversation(),
|
664 |
get_status()
|
665 |
)
|
666 |
|
667 |
+
def on_start_stream():
|
668 |
+
result = start_stream(app)
|
669 |
return (
|
670 |
result,
|
671 |
gr.update(interactive=False), # start_btn
|
672 |
gr.update(interactive=True), # stop_btn
|
673 |
)
|
674 |
|
675 |
+
def on_stop_stream():
|
676 |
+
result = stop_stream()
|
677 |
return (
|
678 |
result,
|
679 |
gr.update(interactive=True), # start_btn
|
680 |
gr.update(interactive=False), # stop_btn
|
681 |
)
|
682 |
|
683 |
+
def initialize_system():
|
684 |
+
"""Initialize the diarization system"""
|
685 |
+
success = diarization_system.initialize_models()
|
686 |
+
if success:
|
687 |
+
return "✅ System initialized successfully! Models loaded."
|
688 |
+
else:
|
689 |
+
return "❌ Failed to initialize system. Please check the logs."
|
690 |
+
|
691 |
+
def start_stream(app):
|
692 |
+
"""Start the FastRTC stream"""
|
693 |
+
return diarization_system.start_stream(app)
|
694 |
+
|
695 |
+
def stop_stream():
|
696 |
+
"""Stop the FastRTC stream"""
|
697 |
+
return diarization_system.stop_stream()
|
698 |
+
|
699 |
+
def clear_conversation():
|
700 |
+
"""Clear the conversation"""
|
701 |
+
return diarization_system.clear_conversation()
|
702 |
+
|
703 |
+
def update_settings(threshold, max_speakers):
|
704 |
+
"""Update system settings"""
|
705 |
+
return diarization_system.update_settings(threshold, max_speakers)
|
706 |
+
|
707 |
+
def get_formatted_conversation():
|
708 |
+
"""Get the current conversation"""
|
709 |
+
return diarization_system.get_formatted_conversation()
|
710 |
+
|
711 |
+
def get_status():
|
712 |
+
"""Get system status"""
|
713 |
+
return diarization_system.get_status_info()
|
714 |
+
|
715 |
# Connect event handlers
|
716 |
init_btn.click(
|
717 |
on_initialize,
|
|
|
719 |
)
|
720 |
|
721 |
start_btn.click(
|
722 |
+
on_start_stream,
|
723 |
outputs=[status_output, start_btn, stop_btn]
|
724 |
)
|
725 |
|
726 |
stop_btn.click(
|
727 |
+
on_stop_stream,
|
728 |
outputs=[status_output, start_btn, stop_btn]
|
729 |
)
|
730 |
|
|
|
739 |
outputs=[status_output]
|
740 |
)
|
741 |
|
742 |
+
# Auto-refresh every 2 seconds when streaming
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
refresh_timer = gr.Timer(2.0)
|
744 |
refresh_timer.tick(
|
745 |
refresh_display,
|
realtime_diarize.py
ADDED
@@ -0,0 +1,970 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt6.QtWidgets import (QApplication, QTextEdit, QMainWindow, QLabel, QVBoxLayout, QWidget,
|
2 |
+
QHBoxLayout, QPushButton, QSizePolicy, QGroupBox, QSlider, QSpinBox)
|
3 |
+
from PyQt6.QtCore import Qt, pyqtSignal, QThread, QEvent, QTimer
|
4 |
+
from scipy.spatial.distance import cosine
|
5 |
+
from RealtimeSTT import AudioToTextRecorder
|
6 |
+
import numpy as np
|
7 |
+
import soundcard as sc
|
8 |
+
import queue
|
9 |
+
import torch
|
10 |
+
import time
|
11 |
+
import sys
|
12 |
+
import os
|
13 |
+
import urllib.request
|
14 |
+
import torchaudio
|
15 |
+
|
16 |
+
# Simplified configuration parameters
|
17 |
+
SILENCE_THRESHS = [0, 0.4]
|
18 |
+
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
|
19 |
+
FINAL_BEAM_SIZE = 5
|
20 |
+
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
|
21 |
+
REALTIME_BEAM_SIZE = 5
|
22 |
+
TRANSCRIPTION_LANGUAGE = "en" # Accuracy in languages other than English is very low.
|
23 |
+
SILERO_SENSITIVITY = 0.4
|
24 |
+
WEBRTC_SENSITIVITY = 3
|
25 |
+
MIN_LENGTH_OF_RECORDING = 0.7
|
26 |
+
PRE_RECORDING_BUFFER_DURATION = 0.35
|
27 |
+
|
28 |
+
# Speaker change detection parameters
|
29 |
+
DEFAULT_CHANGE_THRESHOLD = 0.7 # Threshold for detecting speaker change
|
30 |
+
EMBEDDING_HISTORY_SIZE = 5 # Number of embeddings to keep for comparison
|
31 |
+
MIN_SEGMENT_DURATION = 1.0 # Minimum duration before considering a speaker change
|
32 |
+
DEFAULT_MAX_SPEAKERS = 4 # Default maximum number of speakers
|
33 |
+
ABSOLUTE_MAX_SPEAKERS = 10 # Absolute maximum number of speakers allowed
|
34 |
+
|
35 |
+
# Global variables
|
36 |
+
FAST_SENTENCE_END = True
|
37 |
+
USE_MICROPHONE = False
|
38 |
+
SAMPLE_RATE = 16000
|
39 |
+
BUFFER_SIZE = 512
|
40 |
+
CHANNELS = 1
|
41 |
+
|
42 |
+
# Speaker colors - now we have colors for up to 10 speakers
|
43 |
+
SPEAKER_COLORS = [
|
44 |
+
"#FFFF00", # Yellow
|
45 |
+
"#FF0000", # Red
|
46 |
+
"#00FF00", # Green
|
47 |
+
"#00FFFF", # Cyan
|
48 |
+
"#FF00FF", # Magenta
|
49 |
+
"#0000FF", # Blue
|
50 |
+
"#FF8000", # Orange
|
51 |
+
"#00FF80", # Spring Green
|
52 |
+
"#8000FF", # Purple
|
53 |
+
"#FFFFFF", # White
|
54 |
+
]
|
55 |
+
|
56 |
+
# Color names for display
|
57 |
+
SPEAKER_COLOR_NAMES = [
|
58 |
+
"Yellow",
|
59 |
+
"Red",
|
60 |
+
"Green",
|
61 |
+
"Cyan",
|
62 |
+
"Magenta",
|
63 |
+
"Blue",
|
64 |
+
"Orange",
|
65 |
+
"Spring Green",
|
66 |
+
"Purple",
|
67 |
+
"White"
|
68 |
+
]
|
69 |
+
|
70 |
+
|
71 |
+
class SpeechBrainEncoder:
|
72 |
+
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
|
73 |
+
def __init__(self, device="cpu"):
|
74 |
+
self.device = device
|
75 |
+
self.model = None
|
76 |
+
self.embedding_dim = 192 # ECAPA-TDNN default dimension
|
77 |
+
self.model_loaded = False
|
78 |
+
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
|
79 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
80 |
+
|
81 |
+
def _download_model(self):
|
82 |
+
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
|
83 |
+
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
|
84 |
+
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
|
85 |
+
|
86 |
+
if not os.path.exists(model_path):
|
87 |
+
print(f"Downloading ECAPA-TDNN model to {model_path}...")
|
88 |
+
urllib.request.urlretrieve(model_url, model_path)
|
89 |
+
|
90 |
+
return model_path
|
91 |
+
|
92 |
+
def load_model(self):
|
93 |
+
"""Load the ECAPA-TDNN model"""
|
94 |
+
try:
|
95 |
+
# Import SpeechBrain
|
96 |
+
from speechbrain.pretrained import EncoderClassifier
|
97 |
+
|
98 |
+
# Get model path
|
99 |
+
model_path = self._download_model()
|
100 |
+
|
101 |
+
# Load the pre-trained model
|
102 |
+
self.model = EncoderClassifier.from_hparams(
|
103 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
104 |
+
savedir=self.cache_dir,
|
105 |
+
run_opts={"device": self.device}
|
106 |
+
)
|
107 |
+
|
108 |
+
self.model_loaded = True
|
109 |
+
return True
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error loading ECAPA-TDNN model: {e}")
|
112 |
+
return False
|
113 |
+
|
114 |
+
def embed_utterance(self, audio, sr=16000):
|
115 |
+
"""Extract speaker embedding from audio"""
|
116 |
+
if not self.model_loaded:
|
117 |
+
raise ValueError("Model not loaded. Call load_model() first.")
|
118 |
+
|
119 |
+
try:
|
120 |
+
# Convert numpy array to torch tensor
|
121 |
+
if isinstance(audio, np.ndarray):
|
122 |
+
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
|
123 |
+
else:
|
124 |
+
waveform = audio.unsqueeze(0)
|
125 |
+
|
126 |
+
# Ensure sample rate matches model expected rate
|
127 |
+
if sr != 16000:
|
128 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
|
129 |
+
|
130 |
+
# Get embedding
|
131 |
+
with torch.no_grad():
|
132 |
+
embedding = self.model.encode_batch(waveform)
|
133 |
+
|
134 |
+
return embedding.squeeze().cpu().numpy()
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Error extracting embedding: {e}")
|
137 |
+
return np.zeros(self.embedding_dim)
|
138 |
+
|
139 |
+
|
140 |
+
class AudioProcessor:
|
141 |
+
"""Processes audio data to extract speaker embeddings"""
|
142 |
+
def __init__(self, encoder):
|
143 |
+
self.encoder = encoder
|
144 |
+
|
145 |
+
def extract_embedding(self, audio_int16):
|
146 |
+
try:
|
147 |
+
# Convert int16 audio data to float32
|
148 |
+
float_audio = audio_int16.astype(np.float32) / 32768.0
|
149 |
+
|
150 |
+
# Normalize if needed
|
151 |
+
if np.abs(float_audio).max() > 1.0:
|
152 |
+
float_audio = float_audio / np.abs(float_audio).max()
|
153 |
+
|
154 |
+
# Extract embedding using the loaded encoder
|
155 |
+
embedding = self.encoder.embed_utterance(float_audio)
|
156 |
+
|
157 |
+
return embedding
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Embedding extraction error: {e}")
|
160 |
+
return np.zeros(self.encoder.embedding_dim)
|
161 |
+
|
162 |
+
|
163 |
+
class EncoderLoaderThread(QThread):
|
164 |
+
"""Thread for loading the speaker encoder model"""
|
165 |
+
model_loaded = pyqtSignal(object)
|
166 |
+
progress_update = pyqtSignal(str)
|
167 |
+
|
168 |
+
def run(self):
|
169 |
+
try:
|
170 |
+
self.progress_update.emit("Initializing speaker encoder model...")
|
171 |
+
|
172 |
+
# Check device
|
173 |
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
174 |
+
self.progress_update.emit(f"Using device: {device_str}")
|
175 |
+
|
176 |
+
# Create SpeechBrain encoder
|
177 |
+
self.progress_update.emit("Loading ECAPA-TDNN model...")
|
178 |
+
encoder = SpeechBrainEncoder(device=device_str)
|
179 |
+
|
180 |
+
# Load the model
|
181 |
+
success = encoder.load_model()
|
182 |
+
|
183 |
+
if success:
|
184 |
+
self.progress_update.emit("ECAPA-TDNN model loading complete!")
|
185 |
+
self.model_loaded.emit(encoder)
|
186 |
+
else:
|
187 |
+
self.progress_update.emit("Failed to load ECAPA-TDNN model. Using fallback...")
|
188 |
+
self.model_loaded.emit(None)
|
189 |
+
except Exception as e:
|
190 |
+
self.progress_update.emit(f"Model loading error: {e}")
|
191 |
+
self.model_loaded.emit(None)
|
192 |
+
|
193 |
+
|
194 |
+
class SpeakerChangeDetector:
|
195 |
+
"""Modified speaker change detector that supports a configurable number of speakers"""
|
196 |
+
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
|
197 |
+
self.embedding_dim = embedding_dim
|
198 |
+
self.change_threshold = change_threshold
|
199 |
+
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) # Ensure we don't exceed absolute max
|
200 |
+
self.current_speaker = 0 # Initial speaker (0 to max_speakers-1)
|
201 |
+
self.previous_embeddings = []
|
202 |
+
self.last_change_time = time.time()
|
203 |
+
self.mean_embeddings = [None] * self.max_speakers # Mean embeddings for each speaker
|
204 |
+
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] # All embeddings for each speaker
|
205 |
+
self.last_similarity = 0.0
|
206 |
+
self.active_speakers = set([0]) # Track which speakers have been detected
|
207 |
+
|
208 |
+
def set_max_speakers(self, max_speakers):
|
209 |
+
"""Update the maximum number of speakers"""
|
210 |
+
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
|
211 |
+
|
212 |
+
# If reducing the number of speakers
|
213 |
+
if new_max < self.max_speakers:
|
214 |
+
# Remove any speakers beyond the new max
|
215 |
+
for speaker_id in list(self.active_speakers):
|
216 |
+
if speaker_id >= new_max:
|
217 |
+
self.active_speakers.discard(speaker_id)
|
218 |
+
|
219 |
+
# Ensure current speaker is valid
|
220 |
+
if self.current_speaker >= new_max:
|
221 |
+
self.current_speaker = 0
|
222 |
+
|
223 |
+
# Expand arrays if increasing max speakers
|
224 |
+
if new_max > self.max_speakers:
|
225 |
+
# Extend mean_embeddings array
|
226 |
+
self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
|
227 |
+
|
228 |
+
# Extend speaker_embeddings array
|
229 |
+
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
|
230 |
+
|
231 |
+
# Truncate arrays if decreasing max speakers
|
232 |
+
else:
|
233 |
+
self.mean_embeddings = self.mean_embeddings[:new_max]
|
234 |
+
self.speaker_embeddings = self.speaker_embeddings[:new_max]
|
235 |
+
|
236 |
+
self.max_speakers = new_max
|
237 |
+
|
238 |
+
def set_change_threshold(self, threshold):
|
239 |
+
"""Update the threshold for detecting speaker changes"""
|
240 |
+
self.change_threshold = max(0.1, min(threshold, 0.99))
|
241 |
+
|
242 |
+
def add_embedding(self, embedding, timestamp=None):
|
243 |
+
"""Add a new embedding and check if there's a speaker change"""
|
244 |
+
current_time = timestamp or time.time()
|
245 |
+
|
246 |
+
# Initialize first speaker if no embeddings yet
|
247 |
+
if not self.previous_embeddings:
|
248 |
+
self.previous_embeddings.append(embedding)
|
249 |
+
self.speaker_embeddings[self.current_speaker].append(embedding)
|
250 |
+
if self.mean_embeddings[self.current_speaker] is None:
|
251 |
+
self.mean_embeddings[self.current_speaker] = embedding.copy()
|
252 |
+
return self.current_speaker, 1.0
|
253 |
+
|
254 |
+
# Calculate similarity with current speaker's mean embedding
|
255 |
+
current_mean = self.mean_embeddings[self.current_speaker]
|
256 |
+
if current_mean is not None:
|
257 |
+
similarity = 1.0 - cosine(embedding, current_mean)
|
258 |
+
else:
|
259 |
+
# If no mean yet, compare with most recent embedding
|
260 |
+
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
|
261 |
+
|
262 |
+
self.last_similarity = similarity
|
263 |
+
|
264 |
+
# Decide if this is a speaker change
|
265 |
+
time_since_last_change = current_time - self.last_change_time
|
266 |
+
is_speaker_change = False
|
267 |
+
|
268 |
+
# Only consider change if minimum time has passed since last change
|
269 |
+
if time_since_last_change >= MIN_SEGMENT_DURATION:
|
270 |
+
# Check similarity against threshold
|
271 |
+
if similarity < self.change_threshold:
|
272 |
+
# Compare with all other speakers' means if available
|
273 |
+
best_speaker = self.current_speaker
|
274 |
+
best_similarity = similarity
|
275 |
+
|
276 |
+
# Check each active speaker
|
277 |
+
for speaker_id in range(self.max_speakers):
|
278 |
+
if speaker_id == self.current_speaker:
|
279 |
+
continue
|
280 |
+
|
281 |
+
speaker_mean = self.mean_embeddings[speaker_id]
|
282 |
+
|
283 |
+
if speaker_mean is not None:
|
284 |
+
# Calculate similarity with this speaker
|
285 |
+
speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
|
286 |
+
|
287 |
+
# If more similar to this speaker, update best match
|
288 |
+
if speaker_similarity > best_similarity:
|
289 |
+
best_similarity = speaker_similarity
|
290 |
+
best_speaker = speaker_id
|
291 |
+
|
292 |
+
# If best match is different from current speaker, change speaker
|
293 |
+
if best_speaker != self.current_speaker:
|
294 |
+
is_speaker_change = True
|
295 |
+
self.current_speaker = best_speaker
|
296 |
+
# If no good match with existing speakers and we haven't used all speakers yet
|
297 |
+
elif len(self.active_speakers) < self.max_speakers:
|
298 |
+
# Find the next unused speaker ID
|
299 |
+
for new_id in range(self.max_speakers):
|
300 |
+
if new_id not in self.active_speakers:
|
301 |
+
is_speaker_change = True
|
302 |
+
self.current_speaker = new_id
|
303 |
+
self.active_speakers.add(new_id)
|
304 |
+
break
|
305 |
+
|
306 |
+
# Handle speaker change
|
307 |
+
if is_speaker_change:
|
308 |
+
self.last_change_time = current_time
|
309 |
+
|
310 |
+
# Update embeddings
|
311 |
+
self.previous_embeddings.append(embedding)
|
312 |
+
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
|
313 |
+
self.previous_embeddings.pop(0)
|
314 |
+
|
315 |
+
# Update current speaker's embeddings and mean
|
316 |
+
self.speaker_embeddings[self.current_speaker].append(embedding)
|
317 |
+
self.active_speakers.add(self.current_speaker)
|
318 |
+
|
319 |
+
if len(self.speaker_embeddings[self.current_speaker]) > 30: # Limit history size
|
320 |
+
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
|
321 |
+
|
322 |
+
# Update mean embedding for current speaker
|
323 |
+
if self.speaker_embeddings[self.current_speaker]:
|
324 |
+
self.mean_embeddings[self.current_speaker] = np.mean(
|
325 |
+
self.speaker_embeddings[self.current_speaker], axis=0
|
326 |
+
)
|
327 |
+
|
328 |
+
return self.current_speaker, similarity
|
329 |
+
|
330 |
+
def get_color_for_speaker(self, speaker_id):
|
331 |
+
"""Return color for speaker ID (0 to max_speakers-1)"""
|
332 |
+
if 0 <= speaker_id < len(SPEAKER_COLORS):
|
333 |
+
return SPEAKER_COLORS[speaker_id]
|
334 |
+
return "#FFFFFF" # Default to white if out of range
|
335 |
+
|
336 |
+
def get_status_info(self):
|
337 |
+
"""Return status information about the speaker change detector"""
|
338 |
+
speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
|
339 |
+
|
340 |
+
return {
|
341 |
+
"current_speaker": self.current_speaker,
|
342 |
+
"speaker_counts": speaker_counts,
|
343 |
+
"active_speakers": len(self.active_speakers),
|
344 |
+
"max_speakers": self.max_speakers,
|
345 |
+
"last_similarity": self.last_similarity,
|
346 |
+
"threshold": self.change_threshold
|
347 |
+
}
|
348 |
+
|
349 |
+
|
350 |
+
class TextUpdateThread(QThread):
|
351 |
+
text_update_signal = pyqtSignal(str)
|
352 |
+
|
353 |
+
def __init__(self, text):
|
354 |
+
super().__init__()
|
355 |
+
self.text = text
|
356 |
+
|
357 |
+
def run(self):
|
358 |
+
self.text_update_signal.emit(self.text)
|
359 |
+
|
360 |
+
|
361 |
+
class SentenceWorker(QThread):
|
362 |
+
sentence_update_signal = pyqtSignal(list, list)
|
363 |
+
status_signal = pyqtSignal(str)
|
364 |
+
|
365 |
+
def __init__(self, queue, encoder, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
|
366 |
+
super().__init__()
|
367 |
+
self.queue = queue
|
368 |
+
self.encoder = encoder
|
369 |
+
self._is_running = True
|
370 |
+
self.full_sentences = []
|
371 |
+
self.sentence_speakers = []
|
372 |
+
self.change_threshold = change_threshold
|
373 |
+
self.max_speakers = max_speakers
|
374 |
+
|
375 |
+
# Initialize audio processor for embedding extraction
|
376 |
+
self.audio_processor = AudioProcessor(self.encoder)
|
377 |
+
|
378 |
+
# Initialize speaker change detector
|
379 |
+
self.speaker_detector = SpeakerChangeDetector(
|
380 |
+
embedding_dim=self.encoder.embedding_dim,
|
381 |
+
change_threshold=self.change_threshold,
|
382 |
+
max_speakers=self.max_speakers
|
383 |
+
)
|
384 |
+
|
385 |
+
# Setup monitoring timer
|
386 |
+
self.monitoring_timer = QTimer()
|
387 |
+
self.monitoring_timer.timeout.connect(self.report_status)
|
388 |
+
self.monitoring_timer.start(2000) # Report every 2 seconds
|
389 |
+
|
390 |
+
def set_change_threshold(self, threshold):
|
391 |
+
"""Update change detection threshold"""
|
392 |
+
self.change_threshold = threshold
|
393 |
+
self.speaker_detector.set_change_threshold(threshold)
|
394 |
+
|
395 |
+
def set_max_speakers(self, max_speakers):
|
396 |
+
"""Update maximum number of speakers"""
|
397 |
+
self.max_speakers = max_speakers
|
398 |
+
self.speaker_detector.set_max_speakers(max_speakers)
|
399 |
+
|
400 |
+
def run(self):
|
401 |
+
"""Main worker thread loop"""
|
402 |
+
while self._is_running:
|
403 |
+
try:
|
404 |
+
text, bytes = self.queue.get(timeout=1)
|
405 |
+
self.process_item(text, bytes)
|
406 |
+
except queue.Empty:
|
407 |
+
continue
|
408 |
+
|
409 |
+
def report_status(self):
|
410 |
+
"""Report status information"""
|
411 |
+
# Get status information from speaker detector
|
412 |
+
status = self.speaker_detector.get_status_info()
|
413 |
+
|
414 |
+
# Prepare status message with information for all speakers
|
415 |
+
status_text = f"Current speaker: {status['current_speaker'] + 1}\n"
|
416 |
+
status_text += f"Active speakers: {status['active_speakers']} of {status['max_speakers']}\n"
|
417 |
+
|
418 |
+
# Show segment counts for each speaker
|
419 |
+
for i in range(status['max_speakers']):
|
420 |
+
if i < len(SPEAKER_COLOR_NAMES):
|
421 |
+
color_name = SPEAKER_COLOR_NAMES[i]
|
422 |
+
else:
|
423 |
+
color_name = f"Speaker {i+1}"
|
424 |
+
status_text += f"Speaker {i+1} ({color_name}) segments: {status['speaker_counts'][i]}\n"
|
425 |
+
|
426 |
+
status_text += f"Last similarity score: {status['last_similarity']:.3f}\n"
|
427 |
+
status_text += f"Change threshold: {status['threshold']:.2f}\n"
|
428 |
+
status_text += f"Total sentences: {len(self.full_sentences)}"
|
429 |
+
|
430 |
+
# Send to UI
|
431 |
+
self.status_signal.emit(status_text)
|
432 |
+
|
433 |
+
def process_item(self, text, bytes):
|
434 |
+
"""Process a new text-audio pair"""
|
435 |
+
# Convert audio data to int16
|
436 |
+
audio_int16 = np.int16(bytes * 32767)
|
437 |
+
|
438 |
+
# Extract speaker embedding
|
439 |
+
speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
|
440 |
+
|
441 |
+
# Store sentence and embedding
|
442 |
+
self.full_sentences.append((text, speaker_embedding))
|
443 |
+
|
444 |
+
# Fill in any missing speaker assignments
|
445 |
+
if len(self.sentence_speakers) < len(self.full_sentences) - 1:
|
446 |
+
while len(self.sentence_speakers) < len(self.full_sentences) - 1:
|
447 |
+
self.sentence_speakers.append(0) # Default to first speaker
|
448 |
+
|
449 |
+
# Detect speaker changes
|
450 |
+
speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
|
451 |
+
self.sentence_speakers.append(speaker_id)
|
452 |
+
|
453 |
+
# Send updated data to UI
|
454 |
+
self.sentence_update_signal.emit(self.full_sentences, self.sentence_speakers)
|
455 |
+
|
456 |
+
def stop(self):
|
457 |
+
"""Stop the worker thread"""
|
458 |
+
self._is_running = False
|
459 |
+
if self.monitoring_timer.isActive():
|
460 |
+
self.monitoring_timer.stop()
|
461 |
+
|
462 |
+
|
463 |
+
class RecordingThread(QThread):
|
464 |
+
def __init__(self, recorder):
|
465 |
+
super().__init__()
|
466 |
+
self.recorder = recorder
|
467 |
+
self._is_running = True
|
468 |
+
|
469 |
+
# Determine input source
|
470 |
+
if USE_MICROPHONE:
|
471 |
+
self.device_id = str(sc.default_microphone().name)
|
472 |
+
self.include_loopback = False
|
473 |
+
else:
|
474 |
+
self.device_id = str(sc.default_speaker().name)
|
475 |
+
self.include_loopback = True
|
476 |
+
|
477 |
+
def updateDevice(self, device_id, include_loopback):
|
478 |
+
self.device_id = device_id
|
479 |
+
self.include_loopback = include_loopback
|
480 |
+
|
481 |
+
def run(self):
|
482 |
+
while self._is_running:
|
483 |
+
try:
|
484 |
+
with sc.get_microphone(id=self.device_id, include_loopback=self.include_loopback).recorder(
|
485 |
+
samplerate=SAMPLE_RATE, blocksize=BUFFER_SIZE
|
486 |
+
) as mic:
|
487 |
+
# Process audio chunks while device hasn't changed
|
488 |
+
current_device = self.device_id
|
489 |
+
current_loopback = self.include_loopback
|
490 |
+
|
491 |
+
while self._is_running and current_device == self.device_id and current_loopback == self.include_loopback:
|
492 |
+
# Record audio chunk
|
493 |
+
audio_data = mic.record(numframes=BUFFER_SIZE)
|
494 |
+
|
495 |
+
# Convert stereo to mono if needed
|
496 |
+
if audio_data.shape[1] > 1 and CHANNELS == 1:
|
497 |
+
audio_data = audio_data[:, 0]
|
498 |
+
|
499 |
+
# Convert to int16
|
500 |
+
audio_int16 = (audio_data.flatten() * 32767).astype(np.int16)
|
501 |
+
|
502 |
+
# Feed to recorder
|
503 |
+
audio_bytes = audio_int16.tobytes()
|
504 |
+
self.recorder.feed_audio(audio_bytes)
|
505 |
+
|
506 |
+
except Exception as e:
|
507 |
+
print(f"Recording error: {e}")
|
508 |
+
# Wait before retry on error
|
509 |
+
time.sleep(1)
|
510 |
+
|
511 |
+
def stop(self):
|
512 |
+
self._is_running = False
|
513 |
+
|
514 |
+
|
515 |
+
class TextRetrievalThread(QThread):
|
516 |
+
textRetrievedFinal = pyqtSignal(str, np.ndarray)
|
517 |
+
textRetrievedLive = pyqtSignal(str)
|
518 |
+
recorderStarted = pyqtSignal()
|
519 |
+
|
520 |
+
def __init__(self):
|
521 |
+
super().__init__()
|
522 |
+
|
523 |
+
def live_text_detected(self, text):
|
524 |
+
self.textRetrievedLive.emit(text)
|
525 |
+
|
526 |
+
def run(self):
|
527 |
+
recorder_config = {
|
528 |
+
'spinner': False,
|
529 |
+
'use_microphone': False,
|
530 |
+
'model': FINAL_TRANSCRIPTION_MODEL,
|
531 |
+
'language': TRANSCRIPTION_LANGUAGE,
|
532 |
+
'silero_sensitivity': SILERO_SENSITIVITY,
|
533 |
+
'webrtc_sensitivity': WEBRTC_SENSITIVITY,
|
534 |
+
'post_speech_silence_duration': SILENCE_THRESHS[1],
|
535 |
+
'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
|
536 |
+
'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
|
537 |
+
'min_gap_between_recordings': 0,
|
538 |
+
'enable_realtime_transcription': True,
|
539 |
+
'realtime_processing_pause': 0,
|
540 |
+
'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
|
541 |
+
'on_realtime_transcription_update': self.live_text_detected,
|
542 |
+
'beam_size': FINAL_BEAM_SIZE,
|
543 |
+
'beam_size_realtime': REALTIME_BEAM_SIZE,
|
544 |
+
'buffer_size': BUFFER_SIZE,
|
545 |
+
'sample_rate': SAMPLE_RATE,
|
546 |
+
}
|
547 |
+
|
548 |
+
self.recorder = AudioToTextRecorder(**recorder_config)
|
549 |
+
self.recorderStarted.emit()
|
550 |
+
|
551 |
+
def process_text(text):
|
552 |
+
bytes = self.recorder.last_transcription_bytes
|
553 |
+
self.textRetrievedFinal.emit(text, bytes)
|
554 |
+
|
555 |
+
while True:
|
556 |
+
self.recorder.text(process_text)
|
557 |
+
|
558 |
+
|
559 |
+
class MainWindow(QMainWindow):
|
560 |
+
def __init__(self):
|
561 |
+
super().__init__()
|
562 |
+
|
563 |
+
self.setWindowTitle("Real-time Speaker Change Detection")
|
564 |
+
|
565 |
+
self.encoder = None
|
566 |
+
self.initialized = False
|
567 |
+
self.displayed_text = ""
|
568 |
+
self.last_realtime_text = ""
|
569 |
+
self.full_sentences = []
|
570 |
+
self.sentence_speakers = []
|
571 |
+
self.pending_sentences = []
|
572 |
+
self.queue = queue.Queue()
|
573 |
+
self.recording_thread = None
|
574 |
+
self.change_threshold = DEFAULT_CHANGE_THRESHOLD
|
575 |
+
self.max_speakers = DEFAULT_MAX_SPEAKERS
|
576 |
+
|
577 |
+
# Create main horizontal layout
|
578 |
+
self.mainLayout = QHBoxLayout()
|
579 |
+
|
580 |
+
# Add text edit area to main layout
|
581 |
+
self.text_edit = QTextEdit(self)
|
582 |
+
self.mainLayout.addWidget(self.text_edit, 1)
|
583 |
+
|
584 |
+
# Create right layout for controls
|
585 |
+
self.rightLayout = QVBoxLayout()
|
586 |
+
self.rightLayout.setAlignment(Qt.AlignmentFlag.AlignTop)
|
587 |
+
|
588 |
+
# Create all controls
|
589 |
+
self.create_controls()
|
590 |
+
|
591 |
+
# Create container for right layout
|
592 |
+
self.rightContainer = QWidget()
|
593 |
+
self.rightContainer.setLayout(self.rightLayout)
|
594 |
+
self.mainLayout.addWidget(self.rightContainer, 0)
|
595 |
+
|
596 |
+
# Set main layout as central widget
|
597 |
+
self.centralWidget = QWidget()
|
598 |
+
self.centralWidget.setLayout(self.mainLayout)
|
599 |
+
self.setCentralWidget(self.centralWidget)
|
600 |
+
|
601 |
+
self.setStyleSheet("""
|
602 |
+
QGroupBox {
|
603 |
+
border: 1px solid #555;
|
604 |
+
border-radius: 3px;
|
605 |
+
margin-top: 10px;
|
606 |
+
padding-top: 10px;
|
607 |
+
color: #ddd;
|
608 |
+
}
|
609 |
+
QGroupBox::title {
|
610 |
+
subcontrol-origin: margin;
|
611 |
+
subcontrol-position: top center;
|
612 |
+
padding: 0 5px;
|
613 |
+
}
|
614 |
+
QLabel {
|
615 |
+
color: #ddd;
|
616 |
+
}
|
617 |
+
QPushButton {
|
618 |
+
background: #444;
|
619 |
+
color: #ddd;
|
620 |
+
border: 1px solid #555;
|
621 |
+
padding: 5px;
|
622 |
+
margin-bottom: 10px;
|
623 |
+
}
|
624 |
+
QPushButton:hover {
|
625 |
+
background: #555;
|
626 |
+
}
|
627 |
+
QTextEdit {
|
628 |
+
background-color: #1e1e1e;
|
629 |
+
color: #ffffff;
|
630 |
+
font-family: 'Arial';
|
631 |
+
font-size: 16pt;
|
632 |
+
}
|
633 |
+
QSlider {
|
634 |
+
height: 30px;
|
635 |
+
}
|
636 |
+
QSlider::groove:horizontal {
|
637 |
+
height: 8px;
|
638 |
+
background: #333;
|
639 |
+
margin: 2px 0;
|
640 |
+
}
|
641 |
+
QSlider::handle:horizontal {
|
642 |
+
background: #666;
|
643 |
+
border: 1px solid #777;
|
644 |
+
width: 18px;
|
645 |
+
margin: -8px 0;
|
646 |
+
border-radius: 9px;
|
647 |
+
}
|
648 |
+
""")
|
649 |
+
|
650 |
+
def create_controls(self):
|
651 |
+
# Speaker change threshold control
|
652 |
+
self.threshold_group = QGroupBox("Speaker Change Sensitivity")
|
653 |
+
threshold_layout = QVBoxLayout()
|
654 |
+
|
655 |
+
self.threshold_label = QLabel(f"Change threshold: {self.change_threshold:.2f}")
|
656 |
+
threshold_layout.addWidget(self.threshold_label)
|
657 |
+
|
658 |
+
self.threshold_slider = QSlider(Qt.Orientation.Horizontal)
|
659 |
+
self.threshold_slider.setMinimum(10)
|
660 |
+
self.threshold_slider.setMaximum(95)
|
661 |
+
self.threshold_slider.setValue(int(self.change_threshold * 100))
|
662 |
+
self.threshold_slider.valueChanged.connect(self.update_threshold)
|
663 |
+
threshold_layout.addWidget(self.threshold_slider)
|
664 |
+
|
665 |
+
self.threshold_explanation = QLabel(
|
666 |
+
"If the speakers have similar voices, it would be better to set it above 0.5, and if they have different voices, it would be lower."
|
667 |
+
)
|
668 |
+
self.threshold_explanation.setWordWrap(True)
|
669 |
+
threshold_layout.addWidget(self.threshold_explanation)
|
670 |
+
|
671 |
+
self.threshold_group.setLayout(threshold_layout)
|
672 |
+
self.rightLayout.addWidget(self.threshold_group)
|
673 |
+
|
674 |
+
# Max speakers control
|
675 |
+
self.max_speakers_group = QGroupBox("Maximum Number of Speakers")
|
676 |
+
max_speakers_layout = QVBoxLayout()
|
677 |
+
|
678 |
+
self.max_speakers_label = QLabel(f"Max speakers: {self.max_speakers}")
|
679 |
+
max_speakers_layout.addWidget(self.max_speakers_label)
|
680 |
+
|
681 |
+
self.max_speakers_spinbox = QSpinBox()
|
682 |
+
self.max_speakers_spinbox.setMinimum(2)
|
683 |
+
self.max_speakers_spinbox.setMaximum(ABSOLUTE_MAX_SPEAKERS)
|
684 |
+
self.max_speakers_spinbox.setValue(self.max_speakers)
|
685 |
+
self.max_speakers_spinbox.valueChanged.connect(self.update_max_speakers)
|
686 |
+
max_speakers_layout.addWidget(self.max_speakers_spinbox)
|
687 |
+
|
688 |
+
self.max_speakers_explanation = QLabel(
|
689 |
+
f"You can set between 2 and {ABSOLUTE_MAX_SPEAKERS} speakers.\n"
|
690 |
+
"Changes will apply immediately."
|
691 |
+
)
|
692 |
+
self.max_speakers_explanation.setWordWrap(True)
|
693 |
+
max_speakers_layout.addWidget(self.max_speakers_explanation)
|
694 |
+
|
695 |
+
self.max_speakers_group.setLayout(max_speakers_layout)
|
696 |
+
self.rightLayout.addWidget(self.max_speakers_group)
|
697 |
+
|
698 |
+
# Speaker color legend - dynamic based on max speakers
|
699 |
+
self.legend_group = QGroupBox("Speaker Colors")
|
700 |
+
self.legend_layout = QVBoxLayout()
|
701 |
+
|
702 |
+
# Create speaker labels dynamically
|
703 |
+
self.speaker_labels = []
|
704 |
+
for i in range(ABSOLUTE_MAX_SPEAKERS):
|
705 |
+
color = SPEAKER_COLORS[i]
|
706 |
+
color_name = SPEAKER_COLOR_NAMES[i]
|
707 |
+
label = QLabel(f"Speaker {i+1} ({color_name}): <span style='color:{color};'>■■■■■</span>")
|
708 |
+
self.speaker_labels.append(label)
|
709 |
+
if i < self.max_speakers:
|
710 |
+
self.legend_layout.addWidget(label)
|
711 |
+
|
712 |
+
self.legend_group.setLayout(self.legend_layout)
|
713 |
+
self.rightLayout.addWidget(self.legend_group)
|
714 |
+
|
715 |
+
# Status display area
|
716 |
+
self.status_group = QGroupBox("Status")
|
717 |
+
status_layout = QVBoxLayout()
|
718 |
+
|
719 |
+
self.status_label = QLabel("Status information will be displayed here.")
|
720 |
+
self.status_label.setWordWrap(True)
|
721 |
+
status_layout.addWidget(self.status_label)
|
722 |
+
|
723 |
+
self.status_group.setLayout(status_layout)
|
724 |
+
self.rightLayout.addWidget(self.status_group)
|
725 |
+
|
726 |
+
# Clear button
|
727 |
+
self.clear_button = QPushButton("Clear Conversation")
|
728 |
+
self.clear_button.clicked.connect(self.clear_state)
|
729 |
+
self.clear_button.setEnabled(False)
|
730 |
+
self.rightLayout.addWidget(self.clear_button)
|
731 |
+
|
732 |
+
def update_threshold(self, value):
|
733 |
+
"""Update speaker change detection threshold"""
|
734 |
+
threshold = value / 100.0
|
735 |
+
self.change_threshold = threshold
|
736 |
+
self.threshold_label.setText(f"Change threshold: {threshold:.2f}")
|
737 |
+
|
738 |
+
# Update in worker if it exists
|
739 |
+
if hasattr(self, 'worker_thread'):
|
740 |
+
self.worker_thread.set_change_threshold(threshold)
|
741 |
+
|
742 |
+
def update_max_speakers(self, value):
|
743 |
+
"""Update maximum number of speakers"""
|
744 |
+
self.max_speakers = value
|
745 |
+
self.max_speakers_label.setText(f"Max speakers: {value}")
|
746 |
+
|
747 |
+
# Update visible speaker labels
|
748 |
+
self.update_speaker_labels()
|
749 |
+
|
750 |
+
# Update in worker if it exists
|
751 |
+
if hasattr(self, 'worker_thread'):
|
752 |
+
self.worker_thread.set_max_speakers(value)
|
753 |
+
|
754 |
+
def update_speaker_labels(self):
|
755 |
+
"""Update which speaker labels are visible based on max_speakers"""
|
756 |
+
# Clear all labels first
|
757 |
+
for i in range(len(self.speaker_labels)):
|
758 |
+
label = self.speaker_labels[i]
|
759 |
+
if label.parent():
|
760 |
+
self.legend_layout.removeWidget(label)
|
761 |
+
label.setParent(None)
|
762 |
+
|
763 |
+
# Add only the labels for the current max_speakers
|
764 |
+
for i in range(min(self.max_speakers, len(self.speaker_labels))):
|
765 |
+
self.legend_layout.addWidget(self.speaker_labels[i])
|
766 |
+
|
767 |
+
def clear_state(self):
|
768 |
+
# Clear text edit area
|
769 |
+
self.text_edit.clear()
|
770 |
+
|
771 |
+
# Reset state variables
|
772 |
+
self.displayed_text = ""
|
773 |
+
self.last_realtime_text = ""
|
774 |
+
self.full_sentences = []
|
775 |
+
self.sentence_speakers = []
|
776 |
+
self.pending_sentences = []
|
777 |
+
|
778 |
+
if hasattr(self, 'worker_thread'):
|
779 |
+
self.worker_thread.full_sentences = []
|
780 |
+
self.worker_thread.sentence_speakers = []
|
781 |
+
# Reset speaker detector with current threshold and max_speakers
|
782 |
+
self.worker_thread.speaker_detector = SpeakerChangeDetector(
|
783 |
+
embedding_dim=self.encoder.embedding_dim,
|
784 |
+
change_threshold=self.change_threshold,
|
785 |
+
max_speakers=self.max_speakers
|
786 |
+
)
|
787 |
+
|
788 |
+
# Display message
|
789 |
+
self.text_edit.setHtml("<i>All content cleared. Waiting for new input...</i>")
|
790 |
+
|
791 |
+
def update_status(self, status_text):
|
792 |
+
self.status_label.setText(status_text)
|
793 |
+
|
794 |
+
def showEvent(self, event):
|
795 |
+
super().showEvent(event)
|
796 |
+
if event.type() == QEvent.Type.Show:
|
797 |
+
if not self.initialized:
|
798 |
+
self.initialized = True
|
799 |
+
self.resize(1200, 800)
|
800 |
+
self.update_text("<i>Initializing application...</i>")
|
801 |
+
|
802 |
+
QTimer.singleShot(500, self.init)
|
803 |
+
|
804 |
+
def process_live_text(self, text):
|
805 |
+
text = text.strip()
|
806 |
+
|
807 |
+
if text:
|
808 |
+
sentence_delimiters = '.?!。'
|
809 |
+
prob_sentence_end = (
|
810 |
+
len(self.last_realtime_text) > 0
|
811 |
+
and text[-1] in sentence_delimiters
|
812 |
+
and self.last_realtime_text[-1] in sentence_delimiters
|
813 |
+
)
|
814 |
+
|
815 |
+
self.last_realtime_text = text
|
816 |
+
|
817 |
+
if prob_sentence_end:
|
818 |
+
if FAST_SENTENCE_END:
|
819 |
+
self.text_retrieval_thread.recorder.stop()
|
820 |
+
else:
|
821 |
+
self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
|
822 |
+
else:
|
823 |
+
self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
|
824 |
+
|
825 |
+
self.text_detected(text)
|
826 |
+
|
827 |
+
def text_detected(self, text):
|
828 |
+
try:
|
829 |
+
sentences_with_style = []
|
830 |
+
for i, sentence in enumerate(self.full_sentences):
|
831 |
+
sentence_text, _ = sentence
|
832 |
+
if i >= len(self.sentence_speakers):
|
833 |
+
color = "#FFFFFF" # Default white
|
834 |
+
else:
|
835 |
+
speaker_id = self.sentence_speakers[i]
|
836 |
+
color = self.worker_thread.speaker_detector.get_color_for_speaker(speaker_id)
|
837 |
+
|
838 |
+
sentences_with_style.append(
|
839 |
+
f'<span style="color:{color};">{sentence_text}</span>')
|
840 |
+
|
841 |
+
for pending_sentence in self.pending_sentences:
|
842 |
+
sentences_with_style.append(
|
843 |
+
f'<span style="color:#60FFFF;">{pending_sentence}</span>')
|
844 |
+
|
845 |
+
new_text = " ".join(sentences_with_style).strip() + " " + text if len(sentences_with_style) > 0 else text
|
846 |
+
|
847 |
+
if new_text != self.displayed_text:
|
848 |
+
self.displayed_text = new_text
|
849 |
+
self.update_text(new_text)
|
850 |
+
except Exception as e:
|
851 |
+
print(f"Error: {e}")
|
852 |
+
|
853 |
+
def process_final(self, text, bytes):
|
854 |
+
text = text.strip()
|
855 |
+
if text:
|
856 |
+
try:
|
857 |
+
self.pending_sentences.append(text)
|
858 |
+
self.queue.put((text, bytes))
|
859 |
+
except Exception as e:
|
860 |
+
print(f"Error: {e}")
|
861 |
+
|
862 |
+
def capture_output_and_feed_to_recorder(self):
|
863 |
+
# Use default device settings
|
864 |
+
device_id = str(sc.default_speaker().name)
|
865 |
+
include_loopback = True
|
866 |
+
|
867 |
+
self.recording_thread = RecordingThread(self.text_retrieval_thread.recorder)
|
868 |
+
# Update with current device settings
|
869 |
+
self.recording_thread.updateDevice(device_id, include_loopback)
|
870 |
+
self.recording_thread.start()
|
871 |
+
|
872 |
+
def recorder_ready(self):
|
873 |
+
self.update_text("<i>Recording ready</i>")
|
874 |
+
self.capture_output_and_feed_to_recorder()
|
875 |
+
|
876 |
+
def init(self):
|
877 |
+
self.update_text("<i>Loading ECAPA-TDNN model... Please wait.</i>")
|
878 |
+
|
879 |
+
# Start model loading in background thread
|
880 |
+
self.start_encoder()
|
881 |
+
|
882 |
+
def update_loading_status(self, message):
|
883 |
+
self.update_text(f"<i>{message}</i>")
|
884 |
+
|
885 |
+
def start_encoder(self):
|
886 |
+
# Create and start encoder loader thread
|
887 |
+
self.encoder_loader_thread = EncoderLoaderThread()
|
888 |
+
self.encoder_loader_thread.model_loaded.connect(self.on_model_loaded)
|
889 |
+
self.encoder_loader_thread.progress_update.connect(self.update_loading_status)
|
890 |
+
self.encoder_loader_thread.start()
|
891 |
+
|
892 |
+
def on_model_loaded(self, encoder):
|
893 |
+
# Store loaded encoder model
|
894 |
+
self.encoder = encoder
|
895 |
+
|
896 |
+
if self.encoder is None:
|
897 |
+
self.update_text("<i>Failed to load ECAPA-TDNN model. Please check your configuration.</i>")
|
898 |
+
return
|
899 |
+
|
900 |
+
# Enable all controls after model is loaded
|
901 |
+
self.clear_button.setEnabled(True)
|
902 |
+
self.threshold_slider.setEnabled(True)
|
903 |
+
|
904 |
+
# Continue initialization
|
905 |
+
self.update_text("<i>ECAPA-TDNN model loaded. Starting recorder...</i>")
|
906 |
+
|
907 |
+
self.text_retrieval_thread = TextRetrievalThread()
|
908 |
+
self.text_retrieval_thread.recorderStarted.connect(
|
909 |
+
self.recorder_ready)
|
910 |
+
self.text_retrieval_thread.textRetrievedLive.connect(
|
911 |
+
self.process_live_text)
|
912 |
+
self.text_retrieval_thread.textRetrievedFinal.connect(
|
913 |
+
self.process_final)
|
914 |
+
self.text_retrieval_thread.start()
|
915 |
+
|
916 |
+
self.worker_thread = SentenceWorker(
|
917 |
+
self.queue,
|
918 |
+
self.encoder,
|
919 |
+
change_threshold=self.change_threshold,
|
920 |
+
max_speakers=self.max_speakers
|
921 |
+
)
|
922 |
+
self.worker_thread.sentence_update_signal.connect(
|
923 |
+
self.sentence_updated)
|
924 |
+
self.worker_thread.status_signal.connect(
|
925 |
+
self.update_status)
|
926 |
+
self.worker_thread.start()
|
927 |
+
|
928 |
+
def sentence_updated(self, full_sentences, sentence_speakers):
|
929 |
+
self.pending_text = ""
|
930 |
+
self.full_sentences = full_sentences
|
931 |
+
self.sentence_speakers = sentence_speakers
|
932 |
+
for sentence in self.full_sentences:
|
933 |
+
sentence_text, _ = sentence
|
934 |
+
if sentence_text in self.pending_sentences:
|
935 |
+
self.pending_sentences.remove(sentence_text)
|
936 |
+
self.text_detected("")
|
937 |
+
|
938 |
+
def set_text(self, text):
|
939 |
+
self.update_thread = TextUpdateThread(text)
|
940 |
+
self.update_thread.text_update_signal.connect(self.update_text)
|
941 |
+
self.update_thread.start()
|
942 |
+
|
943 |
+
def update_text(self, text):
|
944 |
+
self.text_edit.setHtml(text)
|
945 |
+
self.text_edit.verticalScrollBar().setValue(
|
946 |
+
self.text_edit.verticalScrollBar().maximum())
|
947 |
+
|
948 |
+
|
949 |
+
def main():
|
950 |
+
app = QApplication(sys.argv)
|
951 |
+
|
952 |
+
dark_stylesheet = """
|
953 |
+
QMainWindow {
|
954 |
+
background-color: #323232;
|
955 |
+
}
|
956 |
+
QTextEdit {
|
957 |
+
background-color: #1e1e1e;
|
958 |
+
color: #ffffff;
|
959 |
+
}
|
960 |
+
"""
|
961 |
+
app.setStyleSheet(dark_stylesheet)
|
962 |
+
|
963 |
+
main_window = MainWindow()
|
964 |
+
main_window.show()
|
965 |
+
|
966 |
+
sys.exit(app.exec())
|
967 |
+
|
968 |
+
|
969 |
+
if __name__ == "__main__":
|
970 |
+
main()
|