Saiyaswanth007 commited on
Commit
b37c0fc
·
1 Parent(s): 33445e6

Code fixing

Browse files
Files changed (2) hide show
  1. app.py +178 -213
  2. realtime_diarize.py +970 -0
app.py CHANGED
@@ -8,18 +8,14 @@ import os
8
  import urllib.request
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
11
- from RealtimeSTT import AudioToTextRecorder
12
  import json
13
  import io
14
  import wave
 
15
 
16
  # Simplified configuration parameters
17
  SILENCE_THRESHS = [0, 0.4]
18
- FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
19
- FINAL_BEAM_SIZE = 5
20
- REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
21
- REALTIME_BEAM_SIZE = 5
22
- TRANSCRIPTION_LANGUAGE = "en"
23
  SILERO_SENSITIVITY = 0.4
24
  WEBRTC_SENSITIVITY = 3
25
  MIN_LENGTH_OF_RECORDING = 0.7
@@ -271,52 +267,56 @@ class SpeakerChangeDetector:
271
  }
272
 
273
 
274
- class WebRTCAudioProcessor:
275
- """Processes WebRTC audio streams for speaker diarization"""
276
  def __init__(self, diarization_system):
 
277
  self.diarization_system = diarization_system
278
- self.audio_buffer = []
279
- self.buffer_lock = threading.Lock()
280
- self.processing_thread = None
281
- self.is_processing = False
282
 
283
- def process_audio(self, audio_data, sample_rate):
284
- """Process incoming audio data from WebRTC"""
285
- try:
286
- # Convert audio data to numpy array if needed
287
- if isinstance(audio_data, bytes):
288
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
289
- elif isinstance(audio_data, tuple):
290
- # Handle tuple format (sample_rate, audio_array)
291
- sample_rate, audio_array = audio_data
292
- if isinstance(audio_array, np.ndarray):
293
- if audio_array.dtype != np.int16:
294
- audio_array = (audio_array * 32767).astype(np.int16)
295
- else:
296
- audio_array = np.array(audio_array, dtype=np.int16)
297
- else:
298
- audio_array = np.array(audio_data, dtype=np.int16)
 
 
 
 
 
 
 
299
 
300
- # Ensure mono audio
301
- if len(audio_array.shape) > 1:
302
- audio_array = audio_array[:, 0]
303
 
304
- # Add to buffer
305
- with self.buffer_lock:
306
- self.audio_buffer.extend(audio_array)
307
-
308
- # Process buffer when it's large enough (1 second of audio)
309
- if len(self.audio_buffer) >= sample_rate:
310
- buffer_to_process = np.array(self.audio_buffer[:sample_rate])
311
- self.audio_buffer = self.audio_buffer[sample_rate//2:] # Keep 50% overlap
312
-
313
- # Feed to recorder in separate thread
314
- if self.diarization_system.recorder:
315
- audio_bytes = buffer_to_process.tobytes()
316
- self.diarization_system.recorder.feed_audio(audio_bytes)
317
-
318
- except Exception as e:
319
- print(f"Error processing WebRTC audio: {e}")
320
 
321
 
322
  class RealtimeSpeakerDiarization:
@@ -324,8 +324,8 @@ class RealtimeSpeakerDiarization:
324
  self.encoder = None
325
  self.audio_processor = None
326
  self.speaker_detector = None
327
- self.recorder = None
328
- self.webrtc_processor = None
329
  self.sentence_queue = queue.Queue()
330
  self.full_sentences = []
331
  self.sentence_speakers = []
@@ -352,7 +352,6 @@ class RealtimeSpeakerDiarization:
352
  change_threshold=self.change_threshold,
353
  max_speakers=self.max_speakers
354
  )
355
- self.webrtc_processor = WebRTCAudioProcessor(self)
356
  print("ECAPA-TDNN model loaded successfully!")
357
  return True
358
  else:
@@ -362,45 +361,69 @@ class RealtimeSpeakerDiarization:
362
  print(f"Model initialization error: {e}")
363
  return False
364
 
365
- def live_text_detected(self, text):
366
- """Callback for real-time transcription updates"""
367
- text = text.strip()
368
- if text:
369
- sentence_delimiters = '.?!。'
370
- prob_sentence_end = (
371
- len(self.last_realtime_text) > 0
372
- and text[-1] in sentence_delimiters
373
- and self.last_realtime_text[-1] in sentence_delimiters
 
 
 
 
 
374
  )
375
-
376
- self.last_realtime_text = text
377
-
378
- if prob_sentence_end and FAST_SENTENCE_END:
379
- self.recorder.stop()
380
- elif prob_sentence_end:
381
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
382
- else:
383
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
 
 
 
 
 
 
 
 
384
 
385
- def process_final_text(self, text):
386
- """Process final transcribed text with speaker embedding"""
387
- text = text.strip()
388
- if text:
389
  try:
390
- bytes_data = self.recorder.last_transcription_bytes
391
- self.sentence_queue.put((text, bytes_data))
392
- self.pending_sentences.append(text)
 
 
 
 
 
393
  except Exception as e:
394
- print(f"Error processing final text: {e}")
 
395
 
396
  def process_sentence_queue(self):
397
  """Process sentences in the queue for speaker detection"""
398
  while self.is_running:
399
  try:
400
- text, bytes_data = self.sentence_queue.get(timeout=1)
401
 
402
  # Convert audio data to int16
403
- audio_int16 = np.int16(bytes_data * 32767)
 
 
 
 
 
 
404
 
405
  # Extract speaker embedding
406
  speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
@@ -425,64 +448,10 @@ class RealtimeSpeakerDiarization:
425
  except Exception as e:
426
  print(f"Error processing sentence: {e}")
427
 
428
- def start_recording(self):
429
- """Start the recording and transcription process"""
430
- if self.encoder is None:
431
- return "Please initialize models first!"
432
-
433
- try:
434
- # Setup recorder configuration for WebRTC input
435
- recorder_config = {
436
- 'spinner': False,
437
- 'use_microphone': False, # We'll feed audio manually
438
- 'model': FINAL_TRANSCRIPTION_MODEL,
439
- 'language': TRANSCRIPTION_LANGUAGE,
440
- 'silero_sensitivity': SILERO_SENSITIVITY,
441
- 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
442
- 'post_speech_silence_duration': SILENCE_THRESHS[1],
443
- 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
444
- 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
445
- 'min_gap_between_recordings': 0,
446
- 'enable_realtime_transcription': True,
447
- 'realtime_processing_pause': 0,
448
- 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
449
- 'on_realtime_transcription_update': self.live_text_detected,
450
- 'beam_size': FINAL_BEAM_SIZE,
451
- 'beam_size_realtime': REALTIME_BEAM_SIZE,
452
- 'buffer_size': BUFFER_SIZE,
453
- 'sample_rate': SAMPLE_RATE,
454
- }
455
-
456
- self.recorder = AudioToTextRecorder(**recorder_config)
457
-
458
- # Start sentence processing thread
459
- self.is_running = True
460
- self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
461
- self.sentence_thread.start()
462
-
463
- # Start transcription thread
464
- self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
465
- self.transcription_thread.start()
466
-
467
- return "Recording started successfully! WebRTC audio input ready."
468
-
469
- except Exception as e:
470
- return f"Error starting recording: {e}"
471
-
472
- def run_transcription(self):
473
- """Run the transcription loop"""
474
- try:
475
- while self.is_running:
476
- self.recorder.text(self.process_final_text)
477
- except Exception as e:
478
- print(f"Transcription error: {e}")
479
-
480
- def stop_recording(self):
481
- """Stop the recording process"""
482
  self.is_running = False
483
- if self.recorder:
484
- self.recorder.stop()
485
- return "Recording stopped!"
486
 
487
  def clear_conversation(self):
488
  """Clear all conversation data"""
@@ -575,79 +544,33 @@ class RealtimeSpeakerDiarization:
575
  diarization_system = RealtimeSpeakerDiarization()
576
 
577
 
578
- def initialize_system():
579
- """Initialize the diarization system"""
580
- success = diarization_system.initialize_models()
581
- if success:
582
- return "✅ System initialized successfully! Models loaded."
583
- else:
584
- return "❌ Failed to initialize system. Please check the logs."
585
-
586
-
587
- def start_recording():
588
- """Start recording and transcription"""
589
- return diarization_system.start_recording()
590
-
591
-
592
- def stop_recording():
593
- """Stop recording and transcription"""
594
- return diarization_system.stop_recording()
595
-
596
-
597
- def clear_conversation():
598
- """Clear the conversation"""
599
- return diarization_system.clear_conversation()
600
-
601
-
602
- def update_settings(threshold, max_speakers):
603
- """Update system settings"""
604
- return diarization_system.update_settings(threshold, max_speakers)
605
-
606
-
607
- def get_conversation():
608
- """Get the current conversation"""
609
- return diarization_system.get_formatted_conversation()
610
-
611
-
612
- def get_status():
613
- """Get system status"""
614
- return diarization_system.get_status_info()
615
-
616
-
617
- def process_audio_stream(audio):
618
- """Process audio stream from WebRTC"""
619
- if diarization_system.webrtc_processor and diarization_system.is_running:
620
- diarization_system.webrtc_processor.process_audio(audio, SAMPLE_RATE)
621
- return None
622
-
623
-
624
- # Create Gradio interface
625
  def create_interface():
626
- with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome()) as app:
 
 
627
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
628
- gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding using WebRTC.")
629
 
630
  with gr.Row():
631
  with gr.Column(scale=2):
632
- # WebRTC Audio Input
633
- audio_input = gr.Audio(
634
- sources=["microphone"],
635
- streaming=True,
636
- label="🎙️ Microphone Input",
637
- type="numpy"
638
- )
639
-
640
  # Main conversation display
641
  conversation_output = gr.HTML(
642
- value="<i>Click 'Initialize System' to start...</i>",
643
  label="Live Conversation"
644
  )
645
 
 
 
 
 
 
 
646
  # Control buttons
647
  with gr.Row():
648
  init_btn = gr.Button("🔧 Initialize System", variant="secondary")
649
- start_btn = gr.Button("🎙️ Start Recording", variant="primary", interactive=False)
650
- stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", interactive=False)
651
  clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)
652
 
653
  # Status display
@@ -685,13 +608,30 @@ def create_interface():
685
  gr.Markdown("## 📝 Instructions")
686
  gr.Markdown("""
687
  1. Click **Initialize System** to load models
688
- 2. Click **Start Recording** to begin processing
689
  3. Allow microphone access when prompted
690
  4. Speak into your microphone
691
  5. Watch real-time transcription with speaker labels
692
  6. Adjust settings as needed
693
  """)
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  # Speaker color legend
696
  gr.Markdown("## 🎨 Speaker Colors")
697
  color_info = []
@@ -702,7 +642,7 @@ def create_interface():
702
 
703
  # Auto-refresh conversation and status
704
  def refresh_display():
705
- return get_conversation(), get_status()
706
 
707
  # Event handlers
708
  def on_initialize():
@@ -712,7 +652,7 @@ def create_interface():
712
  result,
713
  gr.update(interactive=True), # start_btn
714
  gr.update(interactive=True), # clear_btn
715
- get_conversation(),
716
  get_status()
717
  )
718
  else:
@@ -720,26 +660,58 @@ def create_interface():
720
  result,
721
  gr.update(interactive=False), # start_btn
722
  gr.update(interactive=False), # clear_btn
723
- get_conversation(),
724
  get_status()
725
  )
726
 
727
- def on_start():
728
- result = start_recording()
729
  return (
730
  result,
731
  gr.update(interactive=False), # start_btn
732
  gr.update(interactive=True), # stop_btn
733
  )
734
 
735
- def on_stop():
736
- result = stop_recording()
737
  return (
738
  result,
739
  gr.update(interactive=True), # start_btn
740
  gr.update(interactive=False), # stop_btn
741
  )
742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  # Connect event handlers
744
  init_btn.click(
745
  on_initialize,
@@ -747,12 +719,12 @@ def create_interface():
747
  )
748
 
749
  start_btn.click(
750
- on_start,
751
  outputs=[status_output, start_btn, stop_btn]
752
  )
753
 
754
  stop_btn.click(
755
- on_stop,
756
  outputs=[status_output, start_btn, stop_btn]
757
  )
758
 
@@ -767,14 +739,7 @@ def create_interface():
767
  outputs=[status_output]
768
  )
769
 
770
- # Connect WebRTC audio stream to processing
771
- audio_input.stream(
772
- process_audio_stream,
773
- inputs=[audio_input],
774
- outputs=[]
775
- )
776
-
777
- # Auto-refresh every 2 seconds when recording
778
  refresh_timer = gr.Timer(2.0)
779
  refresh_timer.tick(
780
  refresh_display,
 
8
  import urllib.request
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
 
11
  import json
12
  import io
13
  import wave
14
+ from fastrtc import Stream, ReplyOnPause, AsyncStreamHandler, get_stt_model
15
 
16
  # Simplified configuration parameters
17
  SILENCE_THRESHS = [0, 0.4]
18
+ FINAL_TRANSCRIPTION_MODEL = "moonshine/base" # Using FastRTC's moonshine model
 
 
 
 
19
  SILERO_SENSITIVITY = 0.4
20
  WEBRTC_SENSITIVITY = 3
21
  MIN_LENGTH_OF_RECORDING = 0.7
 
267
  }
268
 
269
 
270
+ class DiarizationStreamHandler(AsyncStreamHandler):
271
+ """FastRTC stream handler for real-time diarization"""
272
  def __init__(self, diarization_system):
273
+ super().__init__(input_sample_rate=16000)
274
  self.diarization_system = diarization_system
275
+ self.stt_model = get_stt_model(model=FINAL_TRANSCRIPTION_MODEL)
276
+ self.current_text = ""
277
+ self.current_audio_buffer = []
278
+ self.transcript_queue = queue.Queue()
279
 
280
+ def copy(self):
281
+ return DiarizationStreamHandler(self.diarization_system)
282
+
283
+ async def start_up(self):
284
+ """Initialize the stream handler"""
285
+ pass
286
+
287
+ async def receive(self, frame):
288
+ """Process incoming audio frame"""
289
+ # Extract audio data
290
+ sample_rate, audio_data = frame
291
+
292
+ # Convert to numpy array if needed
293
+ if isinstance(audio_data, torch.Tensor):
294
+ audio_data = audio_data.numpy()
295
+
296
+ # Add to buffer
297
+ self.current_audio_buffer.append(audio_data)
298
+
299
+ # If buffer is large enough, process it
300
+ if len(self.current_audio_buffer) > 3: # Process ~1.5 seconds of audio
301
+ # Concatenate audio data
302
+ combined_audio = np.concatenate(self.current_audio_buffer)
303
 
304
+ # Run speech-to-text
305
+ text = self.stt_model.stt((16000, combined_audio))
 
306
 
307
+ if text and text.strip():
308
+ # Save text and audio for processing
309
+ self.transcript_queue.put((text, combined_audio))
310
+ self.current_text = text
311
+
312
+ # Reset buffer but keep some overlap
313
+ if len(self.current_audio_buffer) > 5:
314
+ self.current_audio_buffer = self.current_audio_buffer[-2:]
315
+
316
+ async def emit(self):
317
+ """Emit processed data"""
318
+ # Return current text as dummy; actual processing is done in background
319
+ return self.current_text
 
 
 
320
 
321
 
322
  class RealtimeSpeakerDiarization:
 
324
  self.encoder = None
325
  self.audio_processor = None
326
  self.speaker_detector = None
327
+ self.stream = None
328
+ self.stream_handler = None
329
  self.sentence_queue = queue.Queue()
330
  self.full_sentences = []
331
  self.sentence_speakers = []
 
352
  change_threshold=self.change_threshold,
353
  max_speakers=self.max_speakers
354
  )
 
355
  print("ECAPA-TDNN model loaded successfully!")
356
  return True
357
  else:
 
361
  print(f"Model initialization error: {e}")
362
  return False
363
 
364
+ def start_stream(self, app):
365
+ """Start the FastRTC stream"""
366
+ if self.encoder is None:
367
+ return "Please initialize models first!"
368
+
369
+ try:
370
+ # Create a FastRTC stream handler
371
+ self.stream_handler = DiarizationStreamHandler(self)
372
+
373
+ # Create FastRTC stream
374
+ self.stream = Stream(
375
+ handler=self.stream_handler,
376
+ modality="audio",
377
+ mode="send-receive"
378
  )
379
+
380
+ # Mount the stream to the provided FastAPI app
381
+ self.stream.mount(app)
382
+
383
+ # Start sentence processing thread
384
+ self.is_running = True
385
+ self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
386
+ self.sentence_thread.start()
387
+
388
+ # Start diarization processor thread
389
+ self.diarization_thread = threading.Thread(target=self.process_transcript_queue, daemon=True)
390
+ self.diarization_thread.start()
391
+
392
+ return "Stream started successfully! Ready for audio input."
393
+
394
+ except Exception as e:
395
+ return f"Error starting stream: {e}"
396
 
397
+ def process_transcript_queue(self):
398
+ """Process transcripts from the stream handler"""
399
+ while self.is_running:
 
400
  try:
401
+ if self.stream_handler and not self.stream_handler.transcript_queue.empty():
402
+ text, audio_data = self.stream_handler.transcript_queue.get(timeout=1)
403
+
404
+ # Add to sentence queue for diarization
405
+ self.pending_sentences.append(text)
406
+ self.sentence_queue.put((text, audio_data))
407
+ except queue.Empty:
408
+ time.sleep(0.1) # Short sleep to prevent CPU hogging
409
  except Exception as e:
410
+ print(f"Error processing transcript queue: {e}")
411
+ time.sleep(0.5) # Slightly longer sleep on error
412
 
413
  def process_sentence_queue(self):
414
  """Process sentences in the queue for speaker detection"""
415
  while self.is_running:
416
  try:
417
+ text, audio_data = self.sentence_queue.get(timeout=1)
418
 
419
  # Convert audio data to int16
420
+ if isinstance(audio_data, np.ndarray):
421
+ if audio_data.dtype != np.int16:
422
+ audio_int16 = (audio_data * 32767).astype(np.int16)
423
+ else:
424
+ audio_int16 = audio_data
425
+ else:
426
+ audio_int16 = np.int16(audio_data * 32767)
427
 
428
  # Extract speaker embedding
429
  speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
 
448
  except Exception as e:
449
  print(f"Error processing sentence: {e}")
450
 
451
+ def stop_stream(self):
452
+ """Stop the stream and processing"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  self.is_running = False
454
+ return "Stream stopped!"
 
 
455
 
456
  def clear_conversation(self):
457
  """Clear all conversation data"""
 
544
  diarization_system = RealtimeSpeakerDiarization()
545
 
546
 
547
+ # Create Gradio interface with FastAPI app integrated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  def create_interface():
549
+ app = gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome())
550
+
551
+ with app:
552
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
553
+ gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding using FastRTC.")
554
 
555
  with gr.Row():
556
  with gr.Column(scale=2):
 
 
 
 
 
 
 
 
557
  # Main conversation display
558
  conversation_output = gr.HTML(
559
+ value="<i>Click 'Initialize System' and then 'Start Stream' to begin...</i>",
560
  label="Live Conversation"
561
  )
562
 
563
+ # FastRTC microphone widget for visualization only (the real audio comes through FastRTC stream)
564
+ audio_widget = gr.Audio(
565
+ label="🎙️ Microphone Input (Click Start Stream to enable)",
566
+ type="microphone"
567
+ )
568
+
569
  # Control buttons
570
  with gr.Row():
571
  init_btn = gr.Button("🔧 Initialize System", variant="secondary")
572
+ start_btn = gr.Button("🎙️ Start Stream", variant="primary", interactive=False)
573
+ stop_btn = gr.Button("⏹️ Stop Stream", variant="stop", interactive=False)
574
  clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)
575
 
576
  # Status display
 
608
  gr.Markdown("## 📝 Instructions")
609
  gr.Markdown("""
610
  1. Click **Initialize System** to load models
611
+ 2. Click **Start Stream** to begin processing
612
  3. Allow microphone access when prompted
613
  4. Speak into your microphone
614
  5. Watch real-time transcription with speaker labels
615
  6. Adjust settings as needed
616
  """)
617
 
618
+ # QR code for mobile access
619
+ gr.Markdown("## 📱 Mobile Access")
620
+ gr.Markdown("Scan this QR code to access from mobile device:")
621
+ qr_code = gr.HTML("""
622
+ <div id="qrcode" style="text-align: center;"></div>
623
+ <script src="https://cdn.jsdelivr.net/npm/qrcode-generator@1.4.4/qrcode.min.js"></script>
624
+ <script>
625
+ setTimeout(function() {
626
+ var currentUrl = window.location.href;
627
+ var qr = qrcode(0, 'M');
628
+ qr.addData(currentUrl);
629
+ qr.make();
630
+ document.getElementById('qrcode').innerHTML = qr.createImgTag(5);
631
+ }, 1000);
632
+ </script>
633
+ """)
634
+
635
  # Speaker color legend
636
  gr.Markdown("## 🎨 Speaker Colors")
637
  color_info = []
 
642
 
643
  # Auto-refresh conversation and status
644
  def refresh_display():
645
+ return get_formatted_conversation(), get_status()
646
 
647
  # Event handlers
648
  def on_initialize():
 
652
  result,
653
  gr.update(interactive=True), # start_btn
654
  gr.update(interactive=True), # clear_btn
655
+ get_formatted_conversation(),
656
  get_status()
657
  )
658
  else:
 
660
  result,
661
  gr.update(interactive=False), # start_btn
662
  gr.update(interactive=False), # clear_btn
663
+ get_formatted_conversation(),
664
  get_status()
665
  )
666
 
667
+ def on_start_stream():
668
+ result = start_stream(app)
669
  return (
670
  result,
671
  gr.update(interactive=False), # start_btn
672
  gr.update(interactive=True), # stop_btn
673
  )
674
 
675
+ def on_stop_stream():
676
+ result = stop_stream()
677
  return (
678
  result,
679
  gr.update(interactive=True), # start_btn
680
  gr.update(interactive=False), # stop_btn
681
  )
682
 
683
+ def initialize_system():
684
+ """Initialize the diarization system"""
685
+ success = diarization_system.initialize_models()
686
+ if success:
687
+ return "✅ System initialized successfully! Models loaded."
688
+ else:
689
+ return "❌ Failed to initialize system. Please check the logs."
690
+
691
+ def start_stream(app):
692
+ """Start the FastRTC stream"""
693
+ return diarization_system.start_stream(app)
694
+
695
+ def stop_stream():
696
+ """Stop the FastRTC stream"""
697
+ return diarization_system.stop_stream()
698
+
699
+ def clear_conversation():
700
+ """Clear the conversation"""
701
+ return diarization_system.clear_conversation()
702
+
703
+ def update_settings(threshold, max_speakers):
704
+ """Update system settings"""
705
+ return diarization_system.update_settings(threshold, max_speakers)
706
+
707
+ def get_formatted_conversation():
708
+ """Get the current conversation"""
709
+ return diarization_system.get_formatted_conversation()
710
+
711
+ def get_status():
712
+ """Get system status"""
713
+ return diarization_system.get_status_info()
714
+
715
  # Connect event handlers
716
  init_btn.click(
717
  on_initialize,
 
719
  )
720
 
721
  start_btn.click(
722
+ on_start_stream,
723
  outputs=[status_output, start_btn, stop_btn]
724
  )
725
 
726
  stop_btn.click(
727
+ on_stop_stream,
728
  outputs=[status_output, start_btn, stop_btn]
729
  )
730
 
 
739
  outputs=[status_output]
740
  )
741
 
742
+ # Auto-refresh every 2 seconds when streaming
 
 
 
 
 
 
 
743
  refresh_timer = gr.Timer(2.0)
744
  refresh_timer.tick(
745
  refresh_display,
realtime_diarize.py ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt6.QtWidgets import (QApplication, QTextEdit, QMainWindow, QLabel, QVBoxLayout, QWidget,
2
+ QHBoxLayout, QPushButton, QSizePolicy, QGroupBox, QSlider, QSpinBox)
3
+ from PyQt6.QtCore import Qt, pyqtSignal, QThread, QEvent, QTimer
4
+ from scipy.spatial.distance import cosine
5
+ from RealtimeSTT import AudioToTextRecorder
6
+ import numpy as np
7
+ import soundcard as sc
8
+ import queue
9
+ import torch
10
+ import time
11
+ import sys
12
+ import os
13
+ import urllib.request
14
+ import torchaudio
15
+
16
+ # Simplified configuration parameters
17
+ SILENCE_THRESHS = [0, 0.4]
18
+ FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
19
+ FINAL_BEAM_SIZE = 5
20
+ REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
21
+ REALTIME_BEAM_SIZE = 5
22
+ TRANSCRIPTION_LANGUAGE = "en" # Accuracy in languages ​​other than English is very low.
23
+ SILERO_SENSITIVITY = 0.4
24
+ WEBRTC_SENSITIVITY = 3
25
+ MIN_LENGTH_OF_RECORDING = 0.7
26
+ PRE_RECORDING_BUFFER_DURATION = 0.35
27
+
28
+ # Speaker change detection parameters
29
+ DEFAULT_CHANGE_THRESHOLD = 0.7 # Threshold for detecting speaker change
30
+ EMBEDDING_HISTORY_SIZE = 5 # Number of embeddings to keep for comparison
31
+ MIN_SEGMENT_DURATION = 1.0 # Minimum duration before considering a speaker change
32
+ DEFAULT_MAX_SPEAKERS = 4 # Default maximum number of speakers
33
+ ABSOLUTE_MAX_SPEAKERS = 10 # Absolute maximum number of speakers allowed
34
+
35
+ # Global variables
36
+ FAST_SENTENCE_END = True
37
+ USE_MICROPHONE = False
38
+ SAMPLE_RATE = 16000
39
+ BUFFER_SIZE = 512
40
+ CHANNELS = 1
41
+
42
+ # Speaker colors - now we have colors for up to 10 speakers
43
+ SPEAKER_COLORS = [
44
+ "#FFFF00", # Yellow
45
+ "#FF0000", # Red
46
+ "#00FF00", # Green
47
+ "#00FFFF", # Cyan
48
+ "#FF00FF", # Magenta
49
+ "#0000FF", # Blue
50
+ "#FF8000", # Orange
51
+ "#00FF80", # Spring Green
52
+ "#8000FF", # Purple
53
+ "#FFFFFF", # White
54
+ ]
55
+
56
+ # Color names for display
57
+ SPEAKER_COLOR_NAMES = [
58
+ "Yellow",
59
+ "Red",
60
+ "Green",
61
+ "Cyan",
62
+ "Magenta",
63
+ "Blue",
64
+ "Orange",
65
+ "Spring Green",
66
+ "Purple",
67
+ "White"
68
+ ]
69
+
70
+
71
+ class SpeechBrainEncoder:
72
+ """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
73
+ def __init__(self, device="cpu"):
74
+ self.device = device
75
+ self.model = None
76
+ self.embedding_dim = 192 # ECAPA-TDNN default dimension
77
+ self.model_loaded = False
78
+ self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
79
+ os.makedirs(self.cache_dir, exist_ok=True)
80
+
81
+ def _download_model(self):
82
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
83
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
84
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
85
+
86
+ if not os.path.exists(model_path):
87
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
88
+ urllib.request.urlretrieve(model_url, model_path)
89
+
90
+ return model_path
91
+
92
+ def load_model(self):
93
+ """Load the ECAPA-TDNN model"""
94
+ try:
95
+ # Import SpeechBrain
96
+ from speechbrain.pretrained import EncoderClassifier
97
+
98
+ # Get model path
99
+ model_path = self._download_model()
100
+
101
+ # Load the pre-trained model
102
+ self.model = EncoderClassifier.from_hparams(
103
+ source="speechbrain/spkrec-ecapa-voxceleb",
104
+ savedir=self.cache_dir,
105
+ run_opts={"device": self.device}
106
+ )
107
+
108
+ self.model_loaded = True
109
+ return True
110
+ except Exception as e:
111
+ print(f"Error loading ECAPA-TDNN model: {e}")
112
+ return False
113
+
114
+ def embed_utterance(self, audio, sr=16000):
115
+ """Extract speaker embedding from audio"""
116
+ if not self.model_loaded:
117
+ raise ValueError("Model not loaded. Call load_model() first.")
118
+
119
+ try:
120
+ # Convert numpy array to torch tensor
121
+ if isinstance(audio, np.ndarray):
122
+ waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
123
+ else:
124
+ waveform = audio.unsqueeze(0)
125
+
126
+ # Ensure sample rate matches model expected rate
127
+ if sr != 16000:
128
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
129
+
130
+ # Get embedding
131
+ with torch.no_grad():
132
+ embedding = self.model.encode_batch(waveform)
133
+
134
+ return embedding.squeeze().cpu().numpy()
135
+ except Exception as e:
136
+ print(f"Error extracting embedding: {e}")
137
+ return np.zeros(self.embedding_dim)
138
+
139
+
140
+ class AudioProcessor:
141
+ """Processes audio data to extract speaker embeddings"""
142
+ def __init__(self, encoder):
143
+ self.encoder = encoder
144
+
145
+ def extract_embedding(self, audio_int16):
146
+ try:
147
+ # Convert int16 audio data to float32
148
+ float_audio = audio_int16.astype(np.float32) / 32768.0
149
+
150
+ # Normalize if needed
151
+ if np.abs(float_audio).max() > 1.0:
152
+ float_audio = float_audio / np.abs(float_audio).max()
153
+
154
+ # Extract embedding using the loaded encoder
155
+ embedding = self.encoder.embed_utterance(float_audio)
156
+
157
+ return embedding
158
+ except Exception as e:
159
+ print(f"Embedding extraction error: {e}")
160
+ return np.zeros(self.encoder.embedding_dim)
161
+
162
+
163
+ class EncoderLoaderThread(QThread):
164
+ """Thread for loading the speaker encoder model"""
165
+ model_loaded = pyqtSignal(object)
166
+ progress_update = pyqtSignal(str)
167
+
168
+ def run(self):
169
+ try:
170
+ self.progress_update.emit("Initializing speaker encoder model...")
171
+
172
+ # Check device
173
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
174
+ self.progress_update.emit(f"Using device: {device_str}")
175
+
176
+ # Create SpeechBrain encoder
177
+ self.progress_update.emit("Loading ECAPA-TDNN model...")
178
+ encoder = SpeechBrainEncoder(device=device_str)
179
+
180
+ # Load the model
181
+ success = encoder.load_model()
182
+
183
+ if success:
184
+ self.progress_update.emit("ECAPA-TDNN model loading complete!")
185
+ self.model_loaded.emit(encoder)
186
+ else:
187
+ self.progress_update.emit("Failed to load ECAPA-TDNN model. Using fallback...")
188
+ self.model_loaded.emit(None)
189
+ except Exception as e:
190
+ self.progress_update.emit(f"Model loading error: {e}")
191
+ self.model_loaded.emit(None)
192
+
193
+
194
+ class SpeakerChangeDetector:
195
+ """Modified speaker change detector that supports a configurable number of speakers"""
196
+ def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
197
+ self.embedding_dim = embedding_dim
198
+ self.change_threshold = change_threshold
199
+ self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) # Ensure we don't exceed absolute max
200
+ self.current_speaker = 0 # Initial speaker (0 to max_speakers-1)
201
+ self.previous_embeddings = []
202
+ self.last_change_time = time.time()
203
+ self.mean_embeddings = [None] * self.max_speakers # Mean embeddings for each speaker
204
+ self.speaker_embeddings = [[] for _ in range(self.max_speakers)] # All embeddings for each speaker
205
+ self.last_similarity = 0.0
206
+ self.active_speakers = set([0]) # Track which speakers have been detected
207
+
208
+ def set_max_speakers(self, max_speakers):
209
+ """Update the maximum number of speakers"""
210
+ new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
211
+
212
+ # If reducing the number of speakers
213
+ if new_max < self.max_speakers:
214
+ # Remove any speakers beyond the new max
215
+ for speaker_id in list(self.active_speakers):
216
+ if speaker_id >= new_max:
217
+ self.active_speakers.discard(speaker_id)
218
+
219
+ # Ensure current speaker is valid
220
+ if self.current_speaker >= new_max:
221
+ self.current_speaker = 0
222
+
223
+ # Expand arrays if increasing max speakers
224
+ if new_max > self.max_speakers:
225
+ # Extend mean_embeddings array
226
+ self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
227
+
228
+ # Extend speaker_embeddings array
229
+ self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
230
+
231
+ # Truncate arrays if decreasing max speakers
232
+ else:
233
+ self.mean_embeddings = self.mean_embeddings[:new_max]
234
+ self.speaker_embeddings = self.speaker_embeddings[:new_max]
235
+
236
+ self.max_speakers = new_max
237
+
238
+ def set_change_threshold(self, threshold):
239
+ """Update the threshold for detecting speaker changes"""
240
+ self.change_threshold = max(0.1, min(threshold, 0.99))
241
+
242
+ def add_embedding(self, embedding, timestamp=None):
243
+ """Add a new embedding and check if there's a speaker change"""
244
+ current_time = timestamp or time.time()
245
+
246
+ # Initialize first speaker if no embeddings yet
247
+ if not self.previous_embeddings:
248
+ self.previous_embeddings.append(embedding)
249
+ self.speaker_embeddings[self.current_speaker].append(embedding)
250
+ if self.mean_embeddings[self.current_speaker] is None:
251
+ self.mean_embeddings[self.current_speaker] = embedding.copy()
252
+ return self.current_speaker, 1.0
253
+
254
+ # Calculate similarity with current speaker's mean embedding
255
+ current_mean = self.mean_embeddings[self.current_speaker]
256
+ if current_mean is not None:
257
+ similarity = 1.0 - cosine(embedding, current_mean)
258
+ else:
259
+ # If no mean yet, compare with most recent embedding
260
+ similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
261
+
262
+ self.last_similarity = similarity
263
+
264
+ # Decide if this is a speaker change
265
+ time_since_last_change = current_time - self.last_change_time
266
+ is_speaker_change = False
267
+
268
+ # Only consider change if minimum time has passed since last change
269
+ if time_since_last_change >= MIN_SEGMENT_DURATION:
270
+ # Check similarity against threshold
271
+ if similarity < self.change_threshold:
272
+ # Compare with all other speakers' means if available
273
+ best_speaker = self.current_speaker
274
+ best_similarity = similarity
275
+
276
+ # Check each active speaker
277
+ for speaker_id in range(self.max_speakers):
278
+ if speaker_id == self.current_speaker:
279
+ continue
280
+
281
+ speaker_mean = self.mean_embeddings[speaker_id]
282
+
283
+ if speaker_mean is not None:
284
+ # Calculate similarity with this speaker
285
+ speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
286
+
287
+ # If more similar to this speaker, update best match
288
+ if speaker_similarity > best_similarity:
289
+ best_similarity = speaker_similarity
290
+ best_speaker = speaker_id
291
+
292
+ # If best match is different from current speaker, change speaker
293
+ if best_speaker != self.current_speaker:
294
+ is_speaker_change = True
295
+ self.current_speaker = best_speaker
296
+ # If no good match with existing speakers and we haven't used all speakers yet
297
+ elif len(self.active_speakers) < self.max_speakers:
298
+ # Find the next unused speaker ID
299
+ for new_id in range(self.max_speakers):
300
+ if new_id not in self.active_speakers:
301
+ is_speaker_change = True
302
+ self.current_speaker = new_id
303
+ self.active_speakers.add(new_id)
304
+ break
305
+
306
+ # Handle speaker change
307
+ if is_speaker_change:
308
+ self.last_change_time = current_time
309
+
310
+ # Update embeddings
311
+ self.previous_embeddings.append(embedding)
312
+ if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
313
+ self.previous_embeddings.pop(0)
314
+
315
+ # Update current speaker's embeddings and mean
316
+ self.speaker_embeddings[self.current_speaker].append(embedding)
317
+ self.active_speakers.add(self.current_speaker)
318
+
319
+ if len(self.speaker_embeddings[self.current_speaker]) > 30: # Limit history size
320
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
321
+
322
+ # Update mean embedding for current speaker
323
+ if self.speaker_embeddings[self.current_speaker]:
324
+ self.mean_embeddings[self.current_speaker] = np.mean(
325
+ self.speaker_embeddings[self.current_speaker], axis=0
326
+ )
327
+
328
+ return self.current_speaker, similarity
329
+
330
+ def get_color_for_speaker(self, speaker_id):
331
+ """Return color for speaker ID (0 to max_speakers-1)"""
332
+ if 0 <= speaker_id < len(SPEAKER_COLORS):
333
+ return SPEAKER_COLORS[speaker_id]
334
+ return "#FFFFFF" # Default to white if out of range
335
+
336
+ def get_status_info(self):
337
+ """Return status information about the speaker change detector"""
338
+ speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
339
+
340
+ return {
341
+ "current_speaker": self.current_speaker,
342
+ "speaker_counts": speaker_counts,
343
+ "active_speakers": len(self.active_speakers),
344
+ "max_speakers": self.max_speakers,
345
+ "last_similarity": self.last_similarity,
346
+ "threshold": self.change_threshold
347
+ }
348
+
349
+
350
+ class TextUpdateThread(QThread):
351
+ text_update_signal = pyqtSignal(str)
352
+
353
+ def __init__(self, text):
354
+ super().__init__()
355
+ self.text = text
356
+
357
+ def run(self):
358
+ self.text_update_signal.emit(self.text)
359
+
360
+
361
+ class SentenceWorker(QThread):
362
+ sentence_update_signal = pyqtSignal(list, list)
363
+ status_signal = pyqtSignal(str)
364
+
365
+ def __init__(self, queue, encoder, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
366
+ super().__init__()
367
+ self.queue = queue
368
+ self.encoder = encoder
369
+ self._is_running = True
370
+ self.full_sentences = []
371
+ self.sentence_speakers = []
372
+ self.change_threshold = change_threshold
373
+ self.max_speakers = max_speakers
374
+
375
+ # Initialize audio processor for embedding extraction
376
+ self.audio_processor = AudioProcessor(self.encoder)
377
+
378
+ # Initialize speaker change detector
379
+ self.speaker_detector = SpeakerChangeDetector(
380
+ embedding_dim=self.encoder.embedding_dim,
381
+ change_threshold=self.change_threshold,
382
+ max_speakers=self.max_speakers
383
+ )
384
+
385
+ # Setup monitoring timer
386
+ self.monitoring_timer = QTimer()
387
+ self.monitoring_timer.timeout.connect(self.report_status)
388
+ self.monitoring_timer.start(2000) # Report every 2 seconds
389
+
390
+ def set_change_threshold(self, threshold):
391
+ """Update change detection threshold"""
392
+ self.change_threshold = threshold
393
+ self.speaker_detector.set_change_threshold(threshold)
394
+
395
+ def set_max_speakers(self, max_speakers):
396
+ """Update maximum number of speakers"""
397
+ self.max_speakers = max_speakers
398
+ self.speaker_detector.set_max_speakers(max_speakers)
399
+
400
+ def run(self):
401
+ """Main worker thread loop"""
402
+ while self._is_running:
403
+ try:
404
+ text, bytes = self.queue.get(timeout=1)
405
+ self.process_item(text, bytes)
406
+ except queue.Empty:
407
+ continue
408
+
409
+ def report_status(self):
410
+ """Report status information"""
411
+ # Get status information from speaker detector
412
+ status = self.speaker_detector.get_status_info()
413
+
414
+ # Prepare status message with information for all speakers
415
+ status_text = f"Current speaker: {status['current_speaker'] + 1}\n"
416
+ status_text += f"Active speakers: {status['active_speakers']} of {status['max_speakers']}\n"
417
+
418
+ # Show segment counts for each speaker
419
+ for i in range(status['max_speakers']):
420
+ if i < len(SPEAKER_COLOR_NAMES):
421
+ color_name = SPEAKER_COLOR_NAMES[i]
422
+ else:
423
+ color_name = f"Speaker {i+1}"
424
+ status_text += f"Speaker {i+1} ({color_name}) segments: {status['speaker_counts'][i]}\n"
425
+
426
+ status_text += f"Last similarity score: {status['last_similarity']:.3f}\n"
427
+ status_text += f"Change threshold: {status['threshold']:.2f}\n"
428
+ status_text += f"Total sentences: {len(self.full_sentences)}"
429
+
430
+ # Send to UI
431
+ self.status_signal.emit(status_text)
432
+
433
+ def process_item(self, text, bytes):
434
+ """Process a new text-audio pair"""
435
+ # Convert audio data to int16
436
+ audio_int16 = np.int16(bytes * 32767)
437
+
438
+ # Extract speaker embedding
439
+ speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
440
+
441
+ # Store sentence and embedding
442
+ self.full_sentences.append((text, speaker_embedding))
443
+
444
+ # Fill in any missing speaker assignments
445
+ if len(self.sentence_speakers) < len(self.full_sentences) - 1:
446
+ while len(self.sentence_speakers) < len(self.full_sentences) - 1:
447
+ self.sentence_speakers.append(0) # Default to first speaker
448
+
449
+ # Detect speaker changes
450
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
451
+ self.sentence_speakers.append(speaker_id)
452
+
453
+ # Send updated data to UI
454
+ self.sentence_update_signal.emit(self.full_sentences, self.sentence_speakers)
455
+
456
+ def stop(self):
457
+ """Stop the worker thread"""
458
+ self._is_running = False
459
+ if self.monitoring_timer.isActive():
460
+ self.monitoring_timer.stop()
461
+
462
+
463
+ class RecordingThread(QThread):
464
+ def __init__(self, recorder):
465
+ super().__init__()
466
+ self.recorder = recorder
467
+ self._is_running = True
468
+
469
+ # Determine input source
470
+ if USE_MICROPHONE:
471
+ self.device_id = str(sc.default_microphone().name)
472
+ self.include_loopback = False
473
+ else:
474
+ self.device_id = str(sc.default_speaker().name)
475
+ self.include_loopback = True
476
+
477
+ def updateDevice(self, device_id, include_loopback):
478
+ self.device_id = device_id
479
+ self.include_loopback = include_loopback
480
+
481
+ def run(self):
482
+ while self._is_running:
483
+ try:
484
+ with sc.get_microphone(id=self.device_id, include_loopback=self.include_loopback).recorder(
485
+ samplerate=SAMPLE_RATE, blocksize=BUFFER_SIZE
486
+ ) as mic:
487
+ # Process audio chunks while device hasn't changed
488
+ current_device = self.device_id
489
+ current_loopback = self.include_loopback
490
+
491
+ while self._is_running and current_device == self.device_id and current_loopback == self.include_loopback:
492
+ # Record audio chunk
493
+ audio_data = mic.record(numframes=BUFFER_SIZE)
494
+
495
+ # Convert stereo to mono if needed
496
+ if audio_data.shape[1] > 1 and CHANNELS == 1:
497
+ audio_data = audio_data[:, 0]
498
+
499
+ # Convert to int16
500
+ audio_int16 = (audio_data.flatten() * 32767).astype(np.int16)
501
+
502
+ # Feed to recorder
503
+ audio_bytes = audio_int16.tobytes()
504
+ self.recorder.feed_audio(audio_bytes)
505
+
506
+ except Exception as e:
507
+ print(f"Recording error: {e}")
508
+ # Wait before retry on error
509
+ time.sleep(1)
510
+
511
+ def stop(self):
512
+ self._is_running = False
513
+
514
+
515
+ class TextRetrievalThread(QThread):
516
+ textRetrievedFinal = pyqtSignal(str, np.ndarray)
517
+ textRetrievedLive = pyqtSignal(str)
518
+ recorderStarted = pyqtSignal()
519
+
520
+ def __init__(self):
521
+ super().__init__()
522
+
523
+ def live_text_detected(self, text):
524
+ self.textRetrievedLive.emit(text)
525
+
526
+ def run(self):
527
+ recorder_config = {
528
+ 'spinner': False,
529
+ 'use_microphone': False,
530
+ 'model': FINAL_TRANSCRIPTION_MODEL,
531
+ 'language': TRANSCRIPTION_LANGUAGE,
532
+ 'silero_sensitivity': SILERO_SENSITIVITY,
533
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
534
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
535
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
536
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
537
+ 'min_gap_between_recordings': 0,
538
+ 'enable_realtime_transcription': True,
539
+ 'realtime_processing_pause': 0,
540
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
541
+ 'on_realtime_transcription_update': self.live_text_detected,
542
+ 'beam_size': FINAL_BEAM_SIZE,
543
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
544
+ 'buffer_size': BUFFER_SIZE,
545
+ 'sample_rate': SAMPLE_RATE,
546
+ }
547
+
548
+ self.recorder = AudioToTextRecorder(**recorder_config)
549
+ self.recorderStarted.emit()
550
+
551
+ def process_text(text):
552
+ bytes = self.recorder.last_transcription_bytes
553
+ self.textRetrievedFinal.emit(text, bytes)
554
+
555
+ while True:
556
+ self.recorder.text(process_text)
557
+
558
+
559
+ class MainWindow(QMainWindow):
560
+ def __init__(self):
561
+ super().__init__()
562
+
563
+ self.setWindowTitle("Real-time Speaker Change Detection")
564
+
565
+ self.encoder = None
566
+ self.initialized = False
567
+ self.displayed_text = ""
568
+ self.last_realtime_text = ""
569
+ self.full_sentences = []
570
+ self.sentence_speakers = []
571
+ self.pending_sentences = []
572
+ self.queue = queue.Queue()
573
+ self.recording_thread = None
574
+ self.change_threshold = DEFAULT_CHANGE_THRESHOLD
575
+ self.max_speakers = DEFAULT_MAX_SPEAKERS
576
+
577
+ # Create main horizontal layout
578
+ self.mainLayout = QHBoxLayout()
579
+
580
+ # Add text edit area to main layout
581
+ self.text_edit = QTextEdit(self)
582
+ self.mainLayout.addWidget(self.text_edit, 1)
583
+
584
+ # Create right layout for controls
585
+ self.rightLayout = QVBoxLayout()
586
+ self.rightLayout.setAlignment(Qt.AlignmentFlag.AlignTop)
587
+
588
+ # Create all controls
589
+ self.create_controls()
590
+
591
+ # Create container for right layout
592
+ self.rightContainer = QWidget()
593
+ self.rightContainer.setLayout(self.rightLayout)
594
+ self.mainLayout.addWidget(self.rightContainer, 0)
595
+
596
+ # Set main layout as central widget
597
+ self.centralWidget = QWidget()
598
+ self.centralWidget.setLayout(self.mainLayout)
599
+ self.setCentralWidget(self.centralWidget)
600
+
601
+ self.setStyleSheet("""
602
+ QGroupBox {
603
+ border: 1px solid #555;
604
+ border-radius: 3px;
605
+ margin-top: 10px;
606
+ padding-top: 10px;
607
+ color: #ddd;
608
+ }
609
+ QGroupBox::title {
610
+ subcontrol-origin: margin;
611
+ subcontrol-position: top center;
612
+ padding: 0 5px;
613
+ }
614
+ QLabel {
615
+ color: #ddd;
616
+ }
617
+ QPushButton {
618
+ background: #444;
619
+ color: #ddd;
620
+ border: 1px solid #555;
621
+ padding: 5px;
622
+ margin-bottom: 10px;
623
+ }
624
+ QPushButton:hover {
625
+ background: #555;
626
+ }
627
+ QTextEdit {
628
+ background-color: #1e1e1e;
629
+ color: #ffffff;
630
+ font-family: 'Arial';
631
+ font-size: 16pt;
632
+ }
633
+ QSlider {
634
+ height: 30px;
635
+ }
636
+ QSlider::groove:horizontal {
637
+ height: 8px;
638
+ background: #333;
639
+ margin: 2px 0;
640
+ }
641
+ QSlider::handle:horizontal {
642
+ background: #666;
643
+ border: 1px solid #777;
644
+ width: 18px;
645
+ margin: -8px 0;
646
+ border-radius: 9px;
647
+ }
648
+ """)
649
+
650
+ def create_controls(self):
651
+ # Speaker change threshold control
652
+ self.threshold_group = QGroupBox("Speaker Change Sensitivity")
653
+ threshold_layout = QVBoxLayout()
654
+
655
+ self.threshold_label = QLabel(f"Change threshold: {self.change_threshold:.2f}")
656
+ threshold_layout.addWidget(self.threshold_label)
657
+
658
+ self.threshold_slider = QSlider(Qt.Orientation.Horizontal)
659
+ self.threshold_slider.setMinimum(10)
660
+ self.threshold_slider.setMaximum(95)
661
+ self.threshold_slider.setValue(int(self.change_threshold * 100))
662
+ self.threshold_slider.valueChanged.connect(self.update_threshold)
663
+ threshold_layout.addWidget(self.threshold_slider)
664
+
665
+ self.threshold_explanation = QLabel(
666
+ "If the speakers have similar voices, it would be better to set it above 0.5, and if they have different voices, it would be lower."
667
+ )
668
+ self.threshold_explanation.setWordWrap(True)
669
+ threshold_layout.addWidget(self.threshold_explanation)
670
+
671
+ self.threshold_group.setLayout(threshold_layout)
672
+ self.rightLayout.addWidget(self.threshold_group)
673
+
674
+ # Max speakers control
675
+ self.max_speakers_group = QGroupBox("Maximum Number of Speakers")
676
+ max_speakers_layout = QVBoxLayout()
677
+
678
+ self.max_speakers_label = QLabel(f"Max speakers: {self.max_speakers}")
679
+ max_speakers_layout.addWidget(self.max_speakers_label)
680
+
681
+ self.max_speakers_spinbox = QSpinBox()
682
+ self.max_speakers_spinbox.setMinimum(2)
683
+ self.max_speakers_spinbox.setMaximum(ABSOLUTE_MAX_SPEAKERS)
684
+ self.max_speakers_spinbox.setValue(self.max_speakers)
685
+ self.max_speakers_spinbox.valueChanged.connect(self.update_max_speakers)
686
+ max_speakers_layout.addWidget(self.max_speakers_spinbox)
687
+
688
+ self.max_speakers_explanation = QLabel(
689
+ f"You can set between 2 and {ABSOLUTE_MAX_SPEAKERS} speakers.\n"
690
+ "Changes will apply immediately."
691
+ )
692
+ self.max_speakers_explanation.setWordWrap(True)
693
+ max_speakers_layout.addWidget(self.max_speakers_explanation)
694
+
695
+ self.max_speakers_group.setLayout(max_speakers_layout)
696
+ self.rightLayout.addWidget(self.max_speakers_group)
697
+
698
+ # Speaker color legend - dynamic based on max speakers
699
+ self.legend_group = QGroupBox("Speaker Colors")
700
+ self.legend_layout = QVBoxLayout()
701
+
702
+ # Create speaker labels dynamically
703
+ self.speaker_labels = []
704
+ for i in range(ABSOLUTE_MAX_SPEAKERS):
705
+ color = SPEAKER_COLORS[i]
706
+ color_name = SPEAKER_COLOR_NAMES[i]
707
+ label = QLabel(f"Speaker {i+1} ({color_name}): <span style='color:{color};'>■■■■■</span>")
708
+ self.speaker_labels.append(label)
709
+ if i < self.max_speakers:
710
+ self.legend_layout.addWidget(label)
711
+
712
+ self.legend_group.setLayout(self.legend_layout)
713
+ self.rightLayout.addWidget(self.legend_group)
714
+
715
+ # Status display area
716
+ self.status_group = QGroupBox("Status")
717
+ status_layout = QVBoxLayout()
718
+
719
+ self.status_label = QLabel("Status information will be displayed here.")
720
+ self.status_label.setWordWrap(True)
721
+ status_layout.addWidget(self.status_label)
722
+
723
+ self.status_group.setLayout(status_layout)
724
+ self.rightLayout.addWidget(self.status_group)
725
+
726
+ # Clear button
727
+ self.clear_button = QPushButton("Clear Conversation")
728
+ self.clear_button.clicked.connect(self.clear_state)
729
+ self.clear_button.setEnabled(False)
730
+ self.rightLayout.addWidget(self.clear_button)
731
+
732
+ def update_threshold(self, value):
733
+ """Update speaker change detection threshold"""
734
+ threshold = value / 100.0
735
+ self.change_threshold = threshold
736
+ self.threshold_label.setText(f"Change threshold: {threshold:.2f}")
737
+
738
+ # Update in worker if it exists
739
+ if hasattr(self, 'worker_thread'):
740
+ self.worker_thread.set_change_threshold(threshold)
741
+
742
+ def update_max_speakers(self, value):
743
+ """Update maximum number of speakers"""
744
+ self.max_speakers = value
745
+ self.max_speakers_label.setText(f"Max speakers: {value}")
746
+
747
+ # Update visible speaker labels
748
+ self.update_speaker_labels()
749
+
750
+ # Update in worker if it exists
751
+ if hasattr(self, 'worker_thread'):
752
+ self.worker_thread.set_max_speakers(value)
753
+
754
+ def update_speaker_labels(self):
755
+ """Update which speaker labels are visible based on max_speakers"""
756
+ # Clear all labels first
757
+ for i in range(len(self.speaker_labels)):
758
+ label = self.speaker_labels[i]
759
+ if label.parent():
760
+ self.legend_layout.removeWidget(label)
761
+ label.setParent(None)
762
+
763
+ # Add only the labels for the current max_speakers
764
+ for i in range(min(self.max_speakers, len(self.speaker_labels))):
765
+ self.legend_layout.addWidget(self.speaker_labels[i])
766
+
767
+ def clear_state(self):
768
+ # Clear text edit area
769
+ self.text_edit.clear()
770
+
771
+ # Reset state variables
772
+ self.displayed_text = ""
773
+ self.last_realtime_text = ""
774
+ self.full_sentences = []
775
+ self.sentence_speakers = []
776
+ self.pending_sentences = []
777
+
778
+ if hasattr(self, 'worker_thread'):
779
+ self.worker_thread.full_sentences = []
780
+ self.worker_thread.sentence_speakers = []
781
+ # Reset speaker detector with current threshold and max_speakers
782
+ self.worker_thread.speaker_detector = SpeakerChangeDetector(
783
+ embedding_dim=self.encoder.embedding_dim,
784
+ change_threshold=self.change_threshold,
785
+ max_speakers=self.max_speakers
786
+ )
787
+
788
+ # Display message
789
+ self.text_edit.setHtml("<i>All content cleared. Waiting for new input...</i>")
790
+
791
+ def update_status(self, status_text):
792
+ self.status_label.setText(status_text)
793
+
794
+ def showEvent(self, event):
795
+ super().showEvent(event)
796
+ if event.type() == QEvent.Type.Show:
797
+ if not self.initialized:
798
+ self.initialized = True
799
+ self.resize(1200, 800)
800
+ self.update_text("<i>Initializing application...</i>")
801
+
802
+ QTimer.singleShot(500, self.init)
803
+
804
+ def process_live_text(self, text):
805
+ text = text.strip()
806
+
807
+ if text:
808
+ sentence_delimiters = '.?!。'
809
+ prob_sentence_end = (
810
+ len(self.last_realtime_text) > 0
811
+ and text[-1] in sentence_delimiters
812
+ and self.last_realtime_text[-1] in sentence_delimiters
813
+ )
814
+
815
+ self.last_realtime_text = text
816
+
817
+ if prob_sentence_end:
818
+ if FAST_SENTENCE_END:
819
+ self.text_retrieval_thread.recorder.stop()
820
+ else:
821
+ self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
822
+ else:
823
+ self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
824
+
825
+ self.text_detected(text)
826
+
827
+ def text_detected(self, text):
828
+ try:
829
+ sentences_with_style = []
830
+ for i, sentence in enumerate(self.full_sentences):
831
+ sentence_text, _ = sentence
832
+ if i >= len(self.sentence_speakers):
833
+ color = "#FFFFFF" # Default white
834
+ else:
835
+ speaker_id = self.sentence_speakers[i]
836
+ color = self.worker_thread.speaker_detector.get_color_for_speaker(speaker_id)
837
+
838
+ sentences_with_style.append(
839
+ f'<span style="color:{color};">{sentence_text}</span>')
840
+
841
+ for pending_sentence in self.pending_sentences:
842
+ sentences_with_style.append(
843
+ f'<span style="color:#60FFFF;">{pending_sentence}</span>')
844
+
845
+ new_text = " ".join(sentences_with_style).strip() + " " + text if len(sentences_with_style) > 0 else text
846
+
847
+ if new_text != self.displayed_text:
848
+ self.displayed_text = new_text
849
+ self.update_text(new_text)
850
+ except Exception as e:
851
+ print(f"Error: {e}")
852
+
853
+ def process_final(self, text, bytes):
854
+ text = text.strip()
855
+ if text:
856
+ try:
857
+ self.pending_sentences.append(text)
858
+ self.queue.put((text, bytes))
859
+ except Exception as e:
860
+ print(f"Error: {e}")
861
+
862
+ def capture_output_and_feed_to_recorder(self):
863
+ # Use default device settings
864
+ device_id = str(sc.default_speaker().name)
865
+ include_loopback = True
866
+
867
+ self.recording_thread = RecordingThread(self.text_retrieval_thread.recorder)
868
+ # Update with current device settings
869
+ self.recording_thread.updateDevice(device_id, include_loopback)
870
+ self.recording_thread.start()
871
+
872
+ def recorder_ready(self):
873
+ self.update_text("<i>Recording ready</i>")
874
+ self.capture_output_and_feed_to_recorder()
875
+
876
+ def init(self):
877
+ self.update_text("<i>Loading ECAPA-TDNN model... Please wait.</i>")
878
+
879
+ # Start model loading in background thread
880
+ self.start_encoder()
881
+
882
+ def update_loading_status(self, message):
883
+ self.update_text(f"<i>{message}</i>")
884
+
885
+ def start_encoder(self):
886
+ # Create and start encoder loader thread
887
+ self.encoder_loader_thread = EncoderLoaderThread()
888
+ self.encoder_loader_thread.model_loaded.connect(self.on_model_loaded)
889
+ self.encoder_loader_thread.progress_update.connect(self.update_loading_status)
890
+ self.encoder_loader_thread.start()
891
+
892
+ def on_model_loaded(self, encoder):
893
+ # Store loaded encoder model
894
+ self.encoder = encoder
895
+
896
+ if self.encoder is None:
897
+ self.update_text("<i>Failed to load ECAPA-TDNN model. Please check your configuration.</i>")
898
+ return
899
+
900
+ # Enable all controls after model is loaded
901
+ self.clear_button.setEnabled(True)
902
+ self.threshold_slider.setEnabled(True)
903
+
904
+ # Continue initialization
905
+ self.update_text("<i>ECAPA-TDNN model loaded. Starting recorder...</i>")
906
+
907
+ self.text_retrieval_thread = TextRetrievalThread()
908
+ self.text_retrieval_thread.recorderStarted.connect(
909
+ self.recorder_ready)
910
+ self.text_retrieval_thread.textRetrievedLive.connect(
911
+ self.process_live_text)
912
+ self.text_retrieval_thread.textRetrievedFinal.connect(
913
+ self.process_final)
914
+ self.text_retrieval_thread.start()
915
+
916
+ self.worker_thread = SentenceWorker(
917
+ self.queue,
918
+ self.encoder,
919
+ change_threshold=self.change_threshold,
920
+ max_speakers=self.max_speakers
921
+ )
922
+ self.worker_thread.sentence_update_signal.connect(
923
+ self.sentence_updated)
924
+ self.worker_thread.status_signal.connect(
925
+ self.update_status)
926
+ self.worker_thread.start()
927
+
928
+ def sentence_updated(self, full_sentences, sentence_speakers):
929
+ self.pending_text = ""
930
+ self.full_sentences = full_sentences
931
+ self.sentence_speakers = sentence_speakers
932
+ for sentence in self.full_sentences:
933
+ sentence_text, _ = sentence
934
+ if sentence_text in self.pending_sentences:
935
+ self.pending_sentences.remove(sentence_text)
936
+ self.text_detected("")
937
+
938
+ def set_text(self, text):
939
+ self.update_thread = TextUpdateThread(text)
940
+ self.update_thread.text_update_signal.connect(self.update_text)
941
+ self.update_thread.start()
942
+
943
+ def update_text(self, text):
944
+ self.text_edit.setHtml(text)
945
+ self.text_edit.verticalScrollBar().setValue(
946
+ self.text_edit.verticalScrollBar().maximum())
947
+
948
+
949
+ def main():
950
+ app = QApplication(sys.argv)
951
+
952
+ dark_stylesheet = """
953
+ QMainWindow {
954
+ background-color: #323232;
955
+ }
956
+ QTextEdit {
957
+ background-color: #1e1e1e;
958
+ color: #ffffff;
959
+ }
960
+ """
961
+ app.setStyleSheet(dark_stylesheet)
962
+
963
+ main_window = MainWindow()
964
+ main_window.show()
965
+
966
+ sys.exit(app.exec())
967
+
968
+
969
+ if __name__ == "__main__":
970
+ main()