Saiyaswanth007 commited on
Commit
89ba8a1
·
1 Parent(s): cba237d

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +63 -34
app.py CHANGED
@@ -10,13 +10,12 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, AsyncStreamHandler
14
  import json
15
  import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
19
- from fastrtc import WebRTC
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
@@ -420,7 +419,7 @@ class RealtimeSpeakerDiarization:
420
  # Setup recorder configuration
421
  recorder_config = {
422
  'spinner': False,
423
- 'use_microphone': False, # Change to False for Hugging Face Spaces
424
  'model': FINAL_TRANSCRIPTION_MODEL,
425
  'language': TRANSCRIPTION_LANGUAGE,
426
  'silero_sensitivity': SILERO_SENSITIVITY,
@@ -430,7 +429,7 @@ class RealtimeSpeakerDiarization:
430
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
431
  'min_gap_between_recordings': 0,
432
  'enable_realtime_transcription': True,
433
- 'realtime_processing_pause': 0.1,
434
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
435
  'on_realtime_transcription_update': self.live_text_detected,
436
  'beam_size': FINAL_BEAM_SIZE,
@@ -448,7 +447,8 @@ class RealtimeSpeakerDiarization:
448
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
449
  self.transcription_thread.start()
450
 
451
- return "Recording started successfully!"
 
452
 
453
  except Exception as e:
454
  logger.error(f"Error starting recording: {e}")
@@ -587,11 +587,17 @@ class DiarizationHandler(AsyncStreamHandler):
587
  return
588
 
589
  # Extract audio data
590
- audio_data = getattr(frame, 'data', frame)
 
 
 
591
 
592
  # Convert to numpy array
593
  if isinstance(audio_data, bytes):
594
  audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
 
 
 
595
  elif isinstance(audio_data, (list, tuple)):
596
  audio_array = np.array(audio_data, dtype=np.float32)
597
  else:
@@ -609,8 +615,16 @@ class DiarizationHandler(AsyncStreamHandler):
609
  chunk = np.array(self.audio_buffer[:self.buffer_size])
610
  self.audio_buffer = self.audio_buffer[self.buffer_size:]
611
 
612
- # Process asynchronously
613
  await self.process_audio_async(chunk)
 
 
 
 
 
 
 
 
614
 
615
  except Exception as e:
616
  logger.error(f"Error in FastRTC receive: {e}")
@@ -627,6 +641,14 @@ class DiarizationHandler(AsyncStreamHandler):
627
  )
628
  except Exception as e:
629
  logger.error(f"Error in async audio processing: {e}")
 
 
 
 
 
 
 
 
630
 
631
 
632
  # Global instances
@@ -639,9 +661,14 @@ def initialize_system():
639
  try:
640
  success = diarization_system.initialize_models()
641
  if success:
642
- # Update the Stream's handler to use our DiarizationHandler
643
- stream.handler = DiarizationHandler(diarization_system)
644
- return "✅ System initialized successfully!"
 
 
 
 
 
645
  else:
646
  return "❌ Failed to initialize system. Check logs for details."
647
  except Exception as e:
@@ -658,7 +685,8 @@ def start_recording():
658
 
659
  def on_start():
660
  result = start_recording()
661
- return result, gr.update(interactive=False), gr.update(interactive=True)
 
662
 
663
  def stop_recording():
664
  """Stop recording and transcription"""
@@ -698,6 +726,15 @@ def get_status():
698
  except Exception as e:
699
  return f"Error getting status: {str(e)}"
700
 
 
 
 
 
 
 
 
 
 
701
  # Create Gradio interface
702
  def create_interface():
703
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
@@ -706,11 +743,17 @@ def create_interface():
706
 
707
  with gr.Row():
708
  with gr.Column(scale=2):
709
- # Replace WebRTC with standard Gradio audio component
710
- audio_component = gr.Audio(
711
- label="Audio Input",
712
- sources=["microphone"],
713
- streaming=True
 
 
 
 
 
 
714
  )
715
 
716
  # Conversation display
@@ -786,7 +829,8 @@ def create_interface():
786
 
787
  def on_start():
788
  result = start_recording()
789
- return result, gr.update(interactive=False), gr.update(interactive=True)
 
790
 
791
  def on_stop():
792
  result = stop_recording()
@@ -814,7 +858,7 @@ def create_interface():
814
 
815
  start_btn.click(
816
  fn=on_start,
817
- outputs=[status_output, start_btn, stop_btn]
818
  )
819
 
820
  stop_btn.click(
@@ -835,27 +879,12 @@ def create_interface():
835
 
836
  # Auto-refresh conversation display every 1 second
837
  conversation_timer = gr.Timer(1)
838
- conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
839
 
840
  # Auto-refresh status every 2 seconds
841
  status_timer = gr.Timer(2)
842
  status_timer.tick(refresh_status, outputs=[status_output])
843
 
844
- # Process audio from Gradio component
845
- def process_audio_input(audio_data):
846
- if audio_data is not None and diarization_system.is_running:
847
- # Extract audio data
848
- if isinstance(audio_data, tuple) and len(audio_data) >= 2:
849
- sample_rate, audio_array = audio_data[0], audio_data[1]
850
- diarization_system.process_audio_chunk(audio_array, sample_rate)
851
- return get_conversation()
852
-
853
- # Connect audio component to processing function
854
- audio_component.stream(
855
- fn=process_audio_input,
856
- outputs=[conversation_output]
857
- )
858
-
859
  return interface
860
 
861
 
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, AsyncStreamHandler, WebRTC
14
  import json
15
  import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
 
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO)
 
419
  # Setup recorder configuration
420
  recorder_config = {
421
  'spinner': False,
422
+ 'use_microphone': False, # Must be False since we're using FastRTC
423
  'model': FINAL_TRANSCRIPTION_MODEL,
424
  'language': TRANSCRIPTION_LANGUAGE,
425
  'silero_sensitivity': SILERO_SENSITIVITY,
 
429
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
430
  'min_gap_between_recordings': 0,
431
  'enable_realtime_transcription': True,
432
+ 'realtime_processing_pause': 0.05, # Faster updates for live transcription
433
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
434
  'on_realtime_transcription_update': self.live_text_detected,
435
  'beam_size': FINAL_BEAM_SIZE,
 
447
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
448
  self.transcription_thread.start()
449
 
450
+ logger.info("Recording started with FastRTC integration")
451
+ return "Recording started successfully! Speak now..."
452
 
453
  except Exception as e:
454
  logger.error(f"Error starting recording: {e}")
 
587
  return
588
 
589
  # Extract audio data
590
+ if hasattr(frame, 'data'):
591
+ audio_data = frame.data
592
+ else:
593
+ audio_data = frame
594
 
595
  # Convert to numpy array
596
  if isinstance(audio_data, bytes):
597
  audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
598
+ elif isinstance(audio_data, tuple) and len(audio_data) >= 2:
599
+ sample_rate, data = audio_data
600
+ audio_array = np.array(data, dtype=np.float32)
601
  elif isinstance(audio_data, (list, tuple)):
602
  audio_array = np.array(audio_data, dtype=np.float32)
603
  else:
 
615
  chunk = np.array(self.audio_buffer[:self.buffer_size])
616
  self.audio_buffer = self.audio_buffer[self.buffer_size:]
617
 
618
+ # Process both for speaker detection and feed to the recorder for transcription
619
  await self.process_audio_async(chunk)
620
+
621
+ # If recorder exists, feed audio for transcription
622
+ if self.diarization_system.recorder:
623
+ # Convert to bytes for the recorder's audio buffer
624
+ audio_bytes = (chunk * 32768.0).astype(np.int16).tobytes()
625
+ if hasattr(self.diarization_system.recorder, '_handle_audio'):
626
+ # Send audio to the recorder's audio buffer
627
+ self.diarization_system.recorder._handle_audio(audio_bytes)
628
 
629
  except Exception as e:
630
  logger.error(f"Error in FastRTC receive: {e}")
 
641
  )
642
  except Exception as e:
643
  logger.error(f"Error in async audio processing: {e}")
644
+
645
+ async def start_up(self):
646
+ """Called when stream starts"""
647
+ logger.info("FastRTC stream handler started")
648
+
649
+ async def shutdown(self):
650
+ """Called when stream ends"""
651
+ logger.info("FastRTC stream handler shutdown")
652
 
653
 
654
  # Global instances
 
661
  try:
662
  success = diarization_system.initialize_models()
663
  if success:
664
+ # Create a fresh handler that uses our diarization system
665
+ handler = DiarizationHandler(diarization_system)
666
+
667
+ # Update the Stream's handler
668
+ stream.handler = handler
669
+
670
+ logger.info("FastRTC handler initialized successfully")
671
+ return "✅ System initialized successfully! Click 'Start' to begin recording."
672
  else:
673
  return "❌ Failed to initialize system. Check logs for details."
674
  except Exception as e:
 
685
 
686
  def on_start():
687
  result = start_recording()
688
+ # When starting recording, update UI and return WebRTC component with autostart=True
689
+ return result, gr.update(interactive=False), gr.update(interactive=True), gr.update(autostart=True)
690
 
691
  def stop_recording():
692
  """Stop recording and transcription"""
 
726
  except Exception as e:
727
  return f"Error getting status: {str(e)}"
728
 
729
+ def refresh_conversation():
730
+ """Get the current conversation and update live transcription status"""
731
+ has_live = diarization_system.last_transcription != ""
732
+ status = "🟢 **Live Transcription Status:** Active" if has_live else "🟠 **Live Transcription Status:** Ready (No speech detected)"
733
+ if not diarization_system.is_running:
734
+ status = "🔴 **Live Transcription Status:** Not running"
735
+
736
+ return get_conversation(), status
737
+
738
  # Create Gradio interface
739
  def create_interface():
740
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
 
743
 
744
  with gr.Row():
745
  with gr.Column(scale=2):
746
+ # Replace standard Gradio audio with FastRTC WebRTC component
747
+ audio_component = WebRTC(
748
+ stream=stream,
749
+ label="Audio Input (FastRTC)",
750
+ show_audio_waveform=True,
751
+ autostart=False,
752
+ )
753
+
754
+ # Add live transcription status indicator
755
+ live_transcription_status = gr.Markdown(
756
+ "🔴 **Live Transcription Status:** Waiting to initialize...",
757
  )
758
 
759
  # Conversation display
 
829
 
830
  def on_start():
831
  result = start_recording()
832
+ # When starting recording, update UI and return WebRTC component with autostart=True
833
+ return result, gr.update(interactive=False), gr.update(interactive=True), gr.update(autostart=True)
834
 
835
  def on_stop():
836
  result = stop_recording()
 
858
 
859
  start_btn.click(
860
  fn=on_start,
861
+ outputs=[status_output, start_btn, stop_btn, audio_component]
862
  )
863
 
864
  stop_btn.click(
 
879
 
880
  # Auto-refresh conversation display every 1 second
881
  conversation_timer = gr.Timer(1)
882
+ conversation_timer.tick(refresh_conversation, outputs=[conversation_output, live_transcription_status])
883
 
884
  # Auto-refresh status every 2 seconds
885
  status_timer = gr.Timer(2)
886
  status_timer.tick(refresh_status, outputs=[status_output])
887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
  return interface
889
 
890