Saiyaswanth007 commited on
Commit
876be23
·
1 Parent(s): 534a53d

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +65 -28
app.py CHANGED
@@ -10,12 +10,13 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, AsyncStreamHandler, WebRTC
14
  import json
15
  import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
 
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO)
@@ -330,10 +331,47 @@ class RealtimeSpeakerDiarization:
330
  logger.error(f"Model initialization error: {e}")
331
  return False
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  def live_text_detected(self, text):
334
  """Callback for real-time transcription updates"""
335
  with self.transcription_lock:
336
  self.last_transcription = text.strip()
 
 
 
337
 
338
  def process_final_text(self, text):
339
  """Process final transcribed text with speaker embedding"""
@@ -652,7 +690,9 @@ class DiarizationHandler(AsyncStreamHandler):
652
  # Global instances
653
  diarization_system = RealtimeSpeakerDiarization()
654
 
655
- # We'll initialize the stream properly in initialize_system()
 
 
656
 
657
  def initialize_system():
658
  """Initialize the diarization system"""
@@ -666,9 +706,12 @@ def initialize_system():
666
  stream = Stream(
667
  handler=handler,
668
  modality="audio",
669
- mode="send-receive",
670
- stream_name="audio_stream" # Match the stream_name in WebRTC component
671
  )
 
 
 
 
672
  return "✅ System initialized successfully!"
673
  else:
674
  return "❌ Failed to initialize system. Check logs for details."
@@ -737,7 +780,6 @@ def create_interface():
737
  # Replace standard Audio with WebRTC component
738
  audio_component = WebRTC(
739
  label="Audio Input",
740
- stream_name="audio_stream",
741
  modality="audio",
742
  mode="send-receive"
743
  )
@@ -869,6 +911,21 @@ def create_interface():
869
  # Auto-refresh status every 2 seconds
870
  status_timer = gr.Timer(2)
871
  status_timer.tick(refresh_status, outputs=[status_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
872
 
873
  return interface
874
 
@@ -876,29 +933,9 @@ def create_interface():
876
  # FastAPI setup for FastRTC integration
877
  app = FastAPI()
878
 
879
- # Create a placeholder handler - will be properly initialized later
880
- class DefaultHandler(AsyncStreamHandler):
881
- def __init__(self):
882
- super().__init__()
883
-
884
- async def receive(self, frame):
885
- pass
886
-
887
- async def emit(self):
888
- return None
889
-
890
- def copy(self):
891
- return DefaultHandler()
892
-
893
- async def shutdown(self):
894
- pass
895
-
896
- async def start_up(self):
897
- pass
898
-
899
- # Initialize with placeholder handler
900
- stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
901
- stream.mount(app)
902
 
903
  @app.get("/")
904
  async def root():
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, AsyncStreamHandler
14
  import json
15
  import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
19
+ from gradio_webrtc import WebRTC
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
 
331
  logger.error(f"Model initialization error: {e}")
332
  return False
333
 
334
+ def feed_audio(self, audio_data):
335
+ """Feed audio data directly to the recorder for live transcription"""
336
+ if not self.is_running or not self.recorder:
337
+ return
338
+
339
+ try:
340
+ # Normalize if needed
341
+ if isinstance(audio_data, np.ndarray):
342
+ if audio_data.dtype != np.float32:
343
+ audio_data = audio_data.astype(np.float32)
344
+
345
+ # Convert to int16 for the recorder
346
+ audio_int16 = (audio_data * 32767).astype(np.int16)
347
+ audio_bytes = audio_int16.tobytes()
348
+
349
+ # Feed to recorder
350
+ self.recorder.feed_audio(audio_bytes)
351
+
352
+ # Also process for speaker detection
353
+ self.process_audio_chunk(audio_data)
354
+
355
+ elif isinstance(audio_data, bytes):
356
+ # Feed raw bytes directly
357
+ self.recorder.feed_audio(audio_data)
358
+
359
+ # Convert to float for speaker detection
360
+ audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
361
+ audio_float = audio_int16.astype(np.float32) / 32768.0
362
+ self.process_audio_chunk(audio_float)
363
+
364
+ logger.debug("Audio fed to recorder")
365
+ except Exception as e:
366
+ logger.error(f"Error feeding audio: {e}")
367
+
368
  def live_text_detected(self, text):
369
  """Callback for real-time transcription updates"""
370
  with self.transcription_lock:
371
  self.last_transcription = text.strip()
372
+
373
+ # Update the display immediately on new transcription
374
+ self.update_conversation_display()
375
 
376
  def process_final_text(self, text):
377
  """Process final transcribed text with speaker embedding"""
 
690
  # Global instances
691
  diarization_system = RealtimeSpeakerDiarization()
692
 
693
+ # We'll initialize the stream in initialize_system()
694
+ # For now, just create a placeholder
695
+ stream = None
696
 
697
  def initialize_system():
698
  """Initialize the diarization system"""
 
706
  stream = Stream(
707
  handler=handler,
708
  modality="audio",
709
+ mode="send-receive"
 
710
  )
711
+
712
+ # Mount the stream to the FastAPI app
713
+ stream.mount(app)
714
+
715
  return "✅ System initialized successfully!"
716
  else:
717
  return "❌ Failed to initialize system. Check logs for details."
 
780
  # Replace standard Audio with WebRTC component
781
  audio_component = WebRTC(
782
  label="Audio Input",
 
783
  modality="audio",
784
  mode="send-receive"
785
  )
 
911
  # Auto-refresh status every 2 seconds
912
  status_timer = gr.Timer(2)
913
  status_timer.tick(refresh_status, outputs=[status_output])
914
+
915
+ # Connect the WebRTC component to our processing function
916
+ def process_webrtc_audio(audio_data):
917
+ if audio_data is not None and diarization_system.is_running:
918
+ try:
919
+ # Feed audio to our diarization system
920
+ diarization_system.feed_audio(audio_data)
921
+ except Exception as e:
922
+ logger.error(f"Error processing WebRTC audio: {e}")
923
+ return get_conversation()
924
+
925
+ audio_component.stream(
926
+ fn=process_webrtc_audio,
927
+ outputs=[conversation_output]
928
+ )
929
 
930
  return interface
931
 
 
933
  # FastAPI setup for FastRTC integration
934
  app = FastAPI()
935
 
936
+ # We'll initialize the stream in initialize_system()
937
+ # For now, just create a placeholder
938
+ stream = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
 
940
  @app.get("/")
941
  async def root():