Saiyaswanth007 commited on
Commit
91b17d7
·
1 Parent(s): 89ba8a1

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +93 -85
app.py CHANGED
@@ -330,10 +330,47 @@ class RealtimeSpeakerDiarization:
330
  logger.error(f"Model initialization error: {e}")
331
  return False
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  def live_text_detected(self, text):
334
  """Callback for real-time transcription updates"""
335
  with self.transcription_lock:
336
  self.last_transcription = text.strip()
 
 
 
337
 
338
  def process_final_text(self, text):
339
  """Process final transcribed text with speaker embedding"""
@@ -419,7 +456,7 @@ class RealtimeSpeakerDiarization:
419
  # Setup recorder configuration
420
  recorder_config = {
421
  'spinner': False,
422
- 'use_microphone': False, # Must be False since we're using FastRTC
423
  'model': FINAL_TRANSCRIPTION_MODEL,
424
  'language': TRANSCRIPTION_LANGUAGE,
425
  'silero_sensitivity': SILERO_SENSITIVITY,
@@ -429,7 +466,7 @@ class RealtimeSpeakerDiarization:
429
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
430
  'min_gap_between_recordings': 0,
431
  'enable_realtime_transcription': True,
432
- 'realtime_processing_pause': 0.05, # Faster updates for live transcription
433
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
434
  'on_realtime_transcription_update': self.live_text_detected,
435
  'beam_size': FINAL_BEAM_SIZE,
@@ -447,8 +484,7 @@ class RealtimeSpeakerDiarization:
447
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
448
  self.transcription_thread.start()
449
 
450
- logger.info("Recording started with FastRTC integration")
451
- return "Recording started successfully! Speak now..."
452
 
453
  except Exception as e:
454
  logger.error(f"Error starting recording: {e}")
@@ -587,26 +623,35 @@ class DiarizationHandler(AsyncStreamHandler):
587
  return
588
 
589
  # Extract audio data
590
- if hasattr(frame, 'data'):
591
- audio_data = frame.data
592
- else:
593
- audio_data = frame
594
-
595
- # Convert to numpy array
596
- if isinstance(audio_data, bytes):
597
- audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
598
- elif isinstance(audio_data, tuple) and len(audio_data) >= 2:
599
- sample_rate, data = audio_data
600
- audio_array = np.array(data, dtype=np.float32)
601
- elif isinstance(audio_data, (list, tuple)):
602
- audio_array = np.array(audio_data, dtype=np.float32)
603
  else:
604
- audio_array = np.array(audio_data, dtype=np.float32)
 
 
 
 
 
 
 
 
 
605
 
606
  # Ensure 1D
607
  if len(audio_array.shape) > 1:
608
  audio_array = audio_array.flatten()
609
 
 
 
 
 
 
 
 
 
610
  # Buffer audio chunks
611
  self.audio_buffer.extend(audio_array)
612
 
@@ -615,16 +660,8 @@ class DiarizationHandler(AsyncStreamHandler):
615
  chunk = np.array(self.audio_buffer[:self.buffer_size])
616
  self.audio_buffer = self.audio_buffer[self.buffer_size:]
617
 
618
- # Process both for speaker detection and feed to the recorder for transcription
619
  await self.process_audio_async(chunk)
620
-
621
- # If recorder exists, feed audio for transcription
622
- if self.diarization_system.recorder:
623
- # Convert to bytes for the recorder's audio buffer
624
- audio_bytes = (chunk * 32768.0).astype(np.int16).tobytes()
625
- if hasattr(self.diarization_system.recorder, '_handle_audio'):
626
- # Send audio to the recorder's audio buffer
627
- self.diarization_system.recorder._handle_audio(audio_bytes)
628
 
629
  except Exception as e:
630
  logger.error(f"Error in FastRTC receive: {e}")
@@ -643,17 +680,18 @@ class DiarizationHandler(AsyncStreamHandler):
643
  logger.error(f"Error in async audio processing: {e}")
644
 
645
  async def start_up(self):
646
- """Called when stream starts"""
647
- logger.info("FastRTC stream handler started")
648
 
649
  async def shutdown(self):
650
- """Called when stream ends"""
651
- logger.info("FastRTC stream handler shutdown")
652
 
653
 
654
  # Global instances
655
  diarization_system = RealtimeSpeakerDiarization()
656
- audio_handler = None
 
 
 
657
 
658
  def initialize_system():
659
  """Initialize the diarization system"""
@@ -661,14 +699,20 @@ def initialize_system():
661
  try:
662
  success = diarization_system.initialize_models()
663
  if success:
664
- # Create a fresh handler that uses our diarization system
665
  handler = DiarizationHandler(diarization_system)
666
-
667
  # Update the Stream's handler
668
- stream.handler = handler
 
 
 
 
 
 
 
 
669
 
670
- logger.info("FastRTC handler initialized successfully")
671
- return "✅ System initialized successfully! Click 'Start' to begin recording."
672
  else:
673
  return "❌ Failed to initialize system. Check logs for details."
674
  except Exception as e:
@@ -685,8 +729,7 @@ def start_recording():
685
 
686
  def on_start():
687
  result = start_recording()
688
- # When starting recording, update UI and return WebRTC component with autostart=True
689
- return result, gr.update(interactive=False), gr.update(interactive=True), gr.update(autostart=True)
690
 
691
  def stop_recording():
692
  """Stop recording and transcription"""
@@ -726,15 +769,6 @@ def get_status():
726
  except Exception as e:
727
  return f"Error getting status: {str(e)}"
728
 
729
- def refresh_conversation():
730
- """Get the current conversation and update live transcription status"""
731
- has_live = diarization_system.last_transcription != ""
732
- status = "🟢 **Live Transcription Status:** Active" if has_live else "🟠 **Live Transcription Status:** Ready (No speech detected)"
733
- if not diarization_system.is_running:
734
- status = "🔴 **Live Transcription Status:** Not running"
735
-
736
- return get_conversation(), status
737
-
738
  # Create Gradio interface
739
  def create_interface():
740
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
@@ -743,17 +777,12 @@ def create_interface():
743
 
744
  with gr.Row():
745
  with gr.Column(scale=2):
746
- # Replace standard Gradio audio with FastRTC WebRTC component
747
  audio_component = WebRTC(
748
- stream=stream,
749
- label="Audio Input (FastRTC)",
750
- show_audio_waveform=True,
751
- autostart=False,
752
- )
753
-
754
- # Add live transcription status indicator
755
- live_transcription_status = gr.Markdown(
756
- "🔴 **Live Transcription Status:** Waiting to initialize...",
757
  )
758
 
759
  # Conversation display
@@ -829,8 +858,7 @@ def create_interface():
829
 
830
  def on_start():
831
  result = start_recording()
832
- # When starting recording, update UI and return WebRTC component with autostart=True
833
- return result, gr.update(interactive=False), gr.update(interactive=True), gr.update(autostart=True)
834
 
835
  def on_stop():
836
  result = stop_recording()
@@ -858,7 +886,7 @@ def create_interface():
858
 
859
  start_btn.click(
860
  fn=on_start,
861
- outputs=[status_output, start_btn, stop_btn, audio_component]
862
  )
863
 
864
  stop_btn.click(
@@ -879,7 +907,7 @@ def create_interface():
879
 
880
  # Auto-refresh conversation display every 1 second
881
  conversation_timer = gr.Timer(1)
882
- conversation_timer.tick(refresh_conversation, outputs=[conversation_output, live_transcription_status])
883
 
884
  # Auto-refresh status every 2 seconds
885
  status_timer = gr.Timer(2)
@@ -891,29 +919,9 @@ def create_interface():
891
  # FastAPI setup for FastRTC integration
892
  app = FastAPI()
893
 
894
- # Create a placeholder handler - will be properly initialized later
895
- class DefaultHandler(AsyncStreamHandler):
896
- def __init__(self):
897
- super().__init__()
898
-
899
- async def receive(self, frame):
900
- pass
901
-
902
- async def emit(self):
903
- return None
904
-
905
- def copy(self):
906
- return DefaultHandler()
907
-
908
- async def shutdown(self):
909
- pass
910
-
911
- async def start_up(self):
912
- pass
913
-
914
- # Initialize with placeholder handler
915
- stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
916
- stream.mount(app)
917
 
918
  @app.get("/")
919
  async def root():
 
330
  logger.error(f"Model initialization error: {e}")
331
  return False
332
 
333
+ def feed_audio(self, audio_data):
334
+ """Feed audio data directly to the recorder for live transcription"""
335
+ if not self.is_running or not self.recorder:
336
+ return
337
+
338
+ try:
339
+ # Normalize if needed
340
+ if isinstance(audio_data, np.ndarray):
341
+ if audio_data.dtype != np.float32:
342
+ audio_data = audio_data.astype(np.float32)
343
+
344
+ # Convert to int16 for the recorder
345
+ audio_int16 = (audio_data * 32767).astype(np.int16)
346
+ audio_bytes = audio_int16.tobytes()
347
+
348
+ # Feed to recorder
349
+ self.recorder.feed_audio(audio_bytes)
350
+
351
+ # Also process for speaker detection
352
+ self.process_audio_chunk(audio_data)
353
+
354
+ elif isinstance(audio_data, bytes):
355
+ # Feed raw bytes directly
356
+ self.recorder.feed_audio(audio_data)
357
+
358
+ # Convert to float for speaker detection
359
+ audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
360
+ audio_float = audio_int16.astype(np.float32) / 32768.0
361
+ self.process_audio_chunk(audio_float)
362
+
363
+ logger.debug("Audio fed to recorder")
364
+ except Exception as e:
365
+ logger.error(f"Error feeding audio: {e}")
366
+
367
  def live_text_detected(self, text):
368
  """Callback for real-time transcription updates"""
369
  with self.transcription_lock:
370
  self.last_transcription = text.strip()
371
+
372
+ # Update the display immediately on new transcription
373
+ self.update_conversation_display()
374
 
375
  def process_final_text(self, text):
376
  """Process final transcribed text with speaker embedding"""
 
456
  # Setup recorder configuration
457
  recorder_config = {
458
  'spinner': False,
459
+ 'use_microphone': False, # Change to False for Hugging Face Spaces
460
  'model': FINAL_TRANSCRIPTION_MODEL,
461
  'language': TRANSCRIPTION_LANGUAGE,
462
  'silero_sensitivity': SILERO_SENSITIVITY,
 
466
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
467
  'min_gap_between_recordings': 0,
468
  'enable_realtime_transcription': True,
469
+ 'realtime_processing_pause': 0.1,
470
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
471
  'on_realtime_transcription_update': self.live_text_detected,
472
  'beam_size': FINAL_BEAM_SIZE,
 
484
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
485
  self.transcription_thread.start()
486
 
487
+ return "Recording started successfully!"
 
488
 
489
  except Exception as e:
490
  logger.error(f"Error starting recording: {e}")
 
623
  return
624
 
625
  # Extract audio data
626
+ audio_data = getattr(frame, 'data', frame)
627
+
628
+ # Check if this is a tuple (sample_rate, audio_array)
629
+ if isinstance(audio_data, tuple) and len(audio_data) >= 2:
630
+ sample_rate, audio_array = audio_data
 
 
 
 
 
 
 
 
631
  else:
632
+ # If not a tuple, assume it's raw audio bytes/array
633
+ sample_rate = SAMPLE_RATE # Use default sample rate
634
+
635
+ # Convert to numpy array
636
+ if isinstance(audio_data, bytes):
637
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
638
+ elif isinstance(audio_data, (list, tuple)):
639
+ audio_array = np.array(audio_data, dtype=np.float32)
640
+ else:
641
+ audio_array = np.array(audio_data, dtype=np.float32)
642
 
643
  # Ensure 1D
644
  if len(audio_array.shape) > 1:
645
  audio_array = audio_array.flatten()
646
 
647
+ # Send audio to recorder for live transcription
648
+ if self.diarization_system.recorder:
649
+ try:
650
+ self.diarization_system.recorder.feed_audio(audio_array)
651
+ logger.info("Fed audio to recorder")
652
+ except Exception as e:
653
+ logger.error(f"Error feeding audio to recorder: {e}")
654
+
655
  # Buffer audio chunks
656
  self.audio_buffer.extend(audio_array)
657
 
 
660
  chunk = np.array(self.audio_buffer[:self.buffer_size])
661
  self.audio_buffer = self.audio_buffer[self.buffer_size:]
662
 
663
+ # Process asynchronously
664
  await self.process_audio_async(chunk)
 
 
 
 
 
 
 
 
665
 
666
  except Exception as e:
667
  logger.error(f"Error in FastRTC receive: {e}")
 
680
  logger.error(f"Error in async audio processing: {e}")
681
 
682
  async def start_up(self):
683
+ logger.info("DiarizationHandler started")
 
684
 
685
  async def shutdown(self):
686
+ logger.info("DiarizationHandler shutdown")
 
687
 
688
 
689
  # Global instances
690
  diarization_system = RealtimeSpeakerDiarization()
691
+
692
+ # We'll initialize the stream in initialize_system()
693
+ # For now, just create a placeholder
694
+ stream = None
695
 
696
  def initialize_system():
697
  """Initialize the diarization system"""
 
699
  try:
700
  success = diarization_system.initialize_models()
701
  if success:
702
+ # Create a DiarizationHandler linked to our system
703
  handler = DiarizationHandler(diarization_system)
 
704
  # Update the Stream's handler
705
+ stream = Stream(
706
+ handler=handler,
707
+ modality="audio",
708
+ mode="send-receive",
709
+ stream_name="audio_stream" # Match the stream_name in WebRTC component
710
+ )
711
+
712
+ # Mount the stream to the FastAPI app
713
+ stream.mount(app)
714
 
715
+ return " System initialized successfully!"
 
716
  else:
717
  return "❌ Failed to initialize system. Check logs for details."
718
  except Exception as e:
 
729
 
730
  def on_start():
731
  result = start_recording()
732
+ return result, gr.update(interactive=False), gr.update(interactive=True)
 
733
 
734
  def stop_recording():
735
  """Stop recording and transcription"""
 
769
  except Exception as e:
770
  return f"Error getting status: {str(e)}"
771
 
 
 
 
 
 
 
 
 
 
772
  # Create Gradio interface
773
  def create_interface():
774
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
 
777
 
778
  with gr.Row():
779
  with gr.Column(scale=2):
780
+ # Replace standard Audio with WebRTC component
781
  audio_component = WebRTC(
782
+ label="Audio Input",
783
+ stream_name="audio_stream",
784
+ modality="audio",
785
+ mode="send-receive"
 
 
 
 
 
786
  )
787
 
788
  # Conversation display
 
858
 
859
  def on_start():
860
  result = start_recording()
861
+ return result, gr.update(interactive=False), gr.update(interactive=True)
 
862
 
863
  def on_stop():
864
  result = stop_recording()
 
886
 
887
  start_btn.click(
888
  fn=on_start,
889
+ outputs=[status_output, start_btn, stop_btn]
890
  )
891
 
892
  stop_btn.click(
 
907
 
908
  # Auto-refresh conversation display every 1 second
909
  conversation_timer = gr.Timer(1)
910
+ conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
911
 
912
  # Auto-refresh status every 2 seconds
913
  status_timer = gr.Timer(2)
 
919
  # FastAPI setup for FastRTC integration
920
  app = FastAPI()
921
 
922
+ # We'll initialize the stream in initialize_system()
923
+ # For now, just create a placeholder
924
+ stream = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
 
926
  @app.get("/")
927
  async def root():