Saiyaswanth007 commited on
Commit
af84a93
·
1 Parent(s): 98185ad

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +52 -88
app.py CHANGED
@@ -16,7 +16,7 @@ import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
19
- from gradio_webrtc import WebRTC
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
@@ -331,47 +331,10 @@ class RealtimeSpeakerDiarization:
331
  logger.error(f"Model initialization error: {e}")
332
  return False
333
 
334
- def feed_audio(self, audio_data):
335
- """Feed audio data directly to the recorder for live transcription"""
336
- if not self.is_running or not self.recorder:
337
- return
338
-
339
- try:
340
- # Normalize if needed
341
- if isinstance(audio_data, np.ndarray):
342
- if audio_data.dtype != np.float32:
343
- audio_data = audio_data.astype(np.float32)
344
-
345
- # Convert to int16 for the recorder
346
- audio_int16 = (audio_data * 32767).astype(np.int16)
347
- audio_bytes = audio_int16.tobytes()
348
-
349
- # Feed to recorder
350
- self.recorder.feed_audio(audio_bytes)
351
-
352
- # Also process for speaker detection
353
- self.process_audio_chunk(audio_data)
354
-
355
- elif isinstance(audio_data, bytes):
356
- # Feed raw bytes directly
357
- self.recorder.feed_audio(audio_data)
358
-
359
- # Convert to float for speaker detection
360
- audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
361
- audio_float = audio_int16.astype(np.float32) / 32768.0
362
- self.process_audio_chunk(audio_float)
363
-
364
- logger.debug("Audio fed to recorder")
365
- except Exception as e:
366
- logger.error(f"Error feeding audio: {e}")
367
-
368
  def live_text_detected(self, text):
369
  """Callback for real-time transcription updates"""
370
  with self.transcription_lock:
371
  self.last_transcription = text.strip()
372
-
373
- # Update the display immediately on new transcription
374
- self.update_conversation_display()
375
 
376
  def process_final_text(self, text):
377
  """Process final transcribed text with speaker embedding"""
@@ -626,33 +589,18 @@ class DiarizationHandler(AsyncStreamHandler):
626
  # Extract audio data
627
  audio_data = getattr(frame, 'data', frame)
628
 
629
- # Check if this is a tuple (sample_rate, audio_array)
630
- if isinstance(audio_data, tuple) and len(audio_data) >= 2:
631
- sample_rate, audio_array = audio_data
 
 
632
  else:
633
- # If not a tuple, assume it's raw audio bytes/array
634
- sample_rate = SAMPLE_RATE # Use default sample rate
635
-
636
- # Convert to numpy array
637
- if isinstance(audio_data, bytes):
638
- audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
639
- elif isinstance(audio_data, (list, tuple)):
640
- audio_array = np.array(audio_data, dtype=np.float32)
641
- else:
642
- audio_array = np.array(audio_data, dtype=np.float32)
643
 
644
  # Ensure 1D
645
  if len(audio_array.shape) > 1:
646
  audio_array = audio_array.flatten()
647
 
648
- # Send audio to recorder for live transcription
649
- if self.diarization_system.recorder:
650
- try:
651
- self.diarization_system.recorder.feed_audio(audio_array)
652
- logger.info("Fed audio to recorder")
653
- except Exception as e:
654
- logger.error(f"Error feeding audio to recorder: {e}")
655
-
656
  # Buffer audio chunks
657
  self.audio_buffer.extend(audio_array)
658
 
@@ -679,20 +627,11 @@ class DiarizationHandler(AsyncStreamHandler):
679
  )
680
  except Exception as e:
681
  logger.error(f"Error in async audio processing: {e}")
682
-
683
- async def start_up(self):
684
- logger.info("DiarizationHandler started")
685
-
686
- async def shutdown(self):
687
- logger.info("DiarizationHandler shutdown")
688
 
689
 
690
  # Global instances
691
  diarization_system = RealtimeSpeakerDiarization()
692
-
693
- # We'll initialize the stream in initialize_system()
694
- # For now, just create a placeholder
695
- stream = None
696
 
697
  def initialize_system():
698
  """Initialize the diarization system"""
@@ -700,18 +639,8 @@ def initialize_system():
700
  try:
701
  success = diarization_system.initialize_models()
702
  if success:
703
- # Create a DiarizationHandler linked to our system
704
- handler = DiarizationHandler(diarization_system)
705
- # Update the Stream's handler
706
- stream = Stream(
707
- handler=handler,
708
- modality="audio",
709
- mode="send-receive"
710
- )
711
-
712
- # Mount the stream to the FastAPI app
713
- stream.mount(app)
714
-
715
  return "✅ System initialized successfully!"
716
  else:
717
  return "❌ Failed to initialize system. Check logs for details."
@@ -777,11 +706,11 @@ def create_interface():
777
 
778
  with gr.Row():
779
  with gr.Column(scale=2):
780
- # Replace standard Audio with WebRTC component
781
- audio_component = WebRTC(
782
  label="Audio Input",
783
- modality="audio",
784
- mode="send-receive"
785
  )
786
 
787
  # Conversation display
@@ -912,15 +841,50 @@ def create_interface():
912
  status_timer = gr.Timer(2)
913
  status_timer.tick(refresh_status, outputs=[status_output])
914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
915
  return interface
916
 
917
 
918
  # FastAPI setup for FastRTC integration
919
  app = FastAPI()
920
 
921
- # We'll initialize the stream in initialize_system()
922
- # For now, just create a placeholder
923
- stream = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924
 
925
  @app.get("/")
926
  async def root():
 
16
  import uvicorn
17
  from queue import Queue
18
  import logging
19
+ from fastrtc import WebRTC
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
 
331
  logger.error(f"Model initialization error: {e}")
332
  return False
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  def live_text_detected(self, text):
335
  """Callback for real-time transcription updates"""
336
  with self.transcription_lock:
337
  self.last_transcription = text.strip()
 
 
 
338
 
339
  def process_final_text(self, text):
340
  """Process final transcribed text with speaker embedding"""
 
589
  # Extract audio data
590
  audio_data = getattr(frame, 'data', frame)
591
 
592
+ # Convert to numpy array
593
+ if isinstance(audio_data, bytes):
594
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
595
+ elif isinstance(audio_data, (list, tuple)):
596
+ audio_array = np.array(audio_data, dtype=np.float32)
597
  else:
598
+ audio_array = np.array(audio_data, dtype=np.float32)
 
 
 
 
 
 
 
 
 
599
 
600
  # Ensure 1D
601
  if len(audio_array.shape) > 1:
602
  audio_array = audio_array.flatten()
603
 
 
 
 
 
 
 
 
 
604
  # Buffer audio chunks
605
  self.audio_buffer.extend(audio_array)
606
 
 
627
  )
628
  except Exception as e:
629
  logger.error(f"Error in async audio processing: {e}")
 
 
 
 
 
 
630
 
631
 
632
  # Global instances
633
  diarization_system = RealtimeSpeakerDiarization()
634
+ audio_handler = None
 
 
 
635
 
636
  def initialize_system():
637
  """Initialize the diarization system"""
 
639
  try:
640
  success = diarization_system.initialize_models()
641
  if success:
642
+ # Update the Stream's handler to use our DiarizationHandler
643
+ stream.handler = DiarizationHandler(diarization_system)
 
 
 
 
 
 
 
 
 
 
644
  return "✅ System initialized successfully!"
645
  else:
646
  return "❌ Failed to initialize system. Check logs for details."
 
706
 
707
  with gr.Row():
708
  with gr.Column(scale=2):
709
+ # Replace WebRTC with standard Gradio audio component
710
+ audio_component = gr.Audio(
711
  label="Audio Input",
712
+ sources=["microphone"],
713
+ streaming=True
714
  )
715
 
716
  # Conversation display
 
841
  status_timer = gr.Timer(2)
842
  status_timer.tick(refresh_status, outputs=[status_output])
843
 
844
+ # Process audio from Gradio component
845
+ def process_audio_input(audio_data):
846
+ if audio_data is not None and diarization_system.is_running:
847
+ # Extract audio data
848
+ if isinstance(audio_data, tuple) and len(audio_data) >= 2:
849
+ sample_rate, audio_array = audio_data[0], audio_data[1]
850
+ diarization_system.process_audio_chunk(audio_array, sample_rate)
851
+ return get_conversation()
852
+
853
+ # Connect audio component to processing function
854
+ audio_component.stream(
855
+ fn=process_audio_input,
856
+ outputs=[conversation_output]
857
+ )
858
+
859
  return interface
860
 
861
 
862
  # FastAPI setup for FastRTC integration
863
  app = FastAPI()
864
 
865
+ # Create a placeholder handler - will be properly initialized later
866
+ class DefaultHandler(AsyncStreamHandler):
867
+ def __init__(self):
868
+ super().__init__()
869
+
870
+ async def receive(self, frame):
871
+ pass
872
+
873
+ async def emit(self):
874
+ return None
875
+
876
+ def copy(self):
877
+ return DefaultHandler()
878
+
879
+ async def shutdown(self):
880
+ pass
881
+
882
+ async def start_up(self):
883
+ pass
884
+
885
+ # Initialize with placeholder handler
886
+ stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
887
+ stream.mount(app)
888
 
889
  @app.get("/")
890
  async def root():