Saiyaswanth007 commited on
Commit
534a53d
·
1 Parent(s): 2b9c901

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +325 -110
app.py CHANGED
@@ -10,7 +10,7 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, ReplyOnPause, StreamHandler, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
14
  import json
15
  import asyncio
16
  import uvicorn
@@ -329,48 +329,11 @@ class RealtimeSpeakerDiarization:
329
  except Exception as e:
330
  logger.error(f"Model initialization error: {e}")
331
  return False
332
-
333
- def feed_audio(self, audio_data):
334
- """Feed audio data directly to the recorder for live transcription"""
335
- if not self.is_running or not self.recorder:
336
- return
337
-
338
- try:
339
- # Normalize if needed
340
- if isinstance(audio_data, np.ndarray):
341
- if audio_data.dtype != np.float32:
342
- audio_data = audio_data.astype(np.float32)
343
-
344
- # Convert to int16 for the recorder
345
- audio_int16 = (audio_data * 32767).astype(np.int16)
346
- audio_bytes = audio_int16.tobytes()
347
-
348
- # Feed to recorder
349
- self.recorder.feed_audio(audio_bytes)
350
-
351
- # Also process for speaker detection
352
- self.process_audio_chunk(audio_data)
353
-
354
- elif isinstance(audio_data, bytes):
355
- # Feed raw bytes directly
356
- self.recorder.feed_audio(audio_data)
357
-
358
- # Convert to float for speaker detection
359
- audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
360
- audio_float = audio_int16.astype(np.float32) / 32768.0
361
- self.process_audio_chunk(audio_float)
362
-
363
- logger.debug("Audio fed to recorder")
364
- except Exception as e:
365
- logger.error(f"Error feeding audio: {e}")
366
 
367
  def live_text_detected(self, text):
368
  """Callback for real-time transcription updates"""
369
  with self.transcription_lock:
370
  self.last_transcription = text.strip()
371
-
372
- # Update the display immediately on new transcription
373
- self.update_conversation_display()
374
 
375
  def process_final_text(self, text):
376
  """Process final transcribed text with speaker embedding"""
@@ -600,47 +563,112 @@ class RealtimeSpeakerDiarization:
600
  logger.error(f"Error processing audio chunk: {e}")
601
 
602
 
603
- # Create diarization handler for FastRTC
604
- class DiarizationAudioHandler(StreamHandler):
605
  def __init__(self, diarization_system):
606
  super().__init__()
607
  self.diarization_system = diarization_system
 
 
608
 
609
- def receive(self, frame):
610
- """Process incoming audio frame"""
611
- if not self.diarization_system.is_running:
612
- return
613
-
 
 
 
 
 
614
  try:
 
 
 
615
  # Extract audio data
616
- sample_rate, audio_array = frame
617
 
618
- # Send audio to diarization system for processing
619
- self.diarization_system.feed_audio(audio_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  except Exception as e:
621
- logger.error(f"Error processing FastRTC audio: {e}")
622
-
623
- def copy(self):
624
- """Return a fresh handler instance"""
625
- return DiarizationAudioHandler(self.diarization_system)
626
 
627
- def shutdown(self):
628
- """Clean up resources"""
629
- pass
630
-
631
- def start_up(self):
632
- """Initialize resources"""
633
- logger.info("DiarizationAudioHandler started")
 
 
 
 
 
 
 
 
 
 
 
634
 
635
 
636
- # Global diarization system instance
637
  diarization_system = RealtimeSpeakerDiarization()
638
 
 
 
639
  def initialize_system():
640
  """Initialize the diarization system"""
 
641
  try:
642
  success = diarization_system.initialize_models()
643
  if success:
 
 
 
 
 
 
 
 
 
644
  return "✅ System initialized successfully!"
645
  else:
646
  return "❌ Failed to initialize system. Check logs for details."
@@ -656,6 +684,10 @@ def start_recording():
656
  except Exception as e:
657
  return f"❌ Failed to start recording: {str(e)}"
658
 
 
 
 
 
659
  def stop_recording():
660
  """Stop recording and transcription"""
661
  try:
@@ -694,56 +726,238 @@ def get_status():
694
  except Exception as e:
695
  return f"Error getting status: {str(e)}"
696
 
697
- # Create handler wrapper function for FastRTC
698
- def diarization_handler(audio_data):
699
- """Handler function for FastRTC stream"""
700
- try:
701
- # Process the audio data
702
- diarization_system.process_audio_chunk(audio_data[1], audio_data[0])
703
 
704
- # Just yield the original audio back (echo)
705
- # This can be changed to just return None since we don't need echo
706
- # This can be changed to just return None since we don't need echo
707
- yield audio_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
 
709
- except Exception as e:
710
- logger.error(f"Error in diarization handler: {e}")
711
-
712
- # Create FastRTC stream with ReplyOnPause pattern
713
- stream = Stream(
714
- handler=ReplyOnPause(diarization_handler),
715
- modality="audio",
716
- mode="send-receive",
717
- rtc_configuration=get_cloudflare_turn_credentials_async,
718
- server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000),
719
- ui_args={
720
- "title": "Real-time Speaker Diarization",
721
- "description": "Live transcription with automatic speaker identification"
722
- }
723
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
 
725
  # Main execution
726
  if __name__ == "__main__":
727
  import argparse
728
- import os
729
 
730
  parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
731
- parser.add_argument("--mode", choices=["ui", "api", "both"], default="ui",
732
- help="Run mode: FastRTC UI, API only, or both")
733
  parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
734
- parser.add_argument("--port", type=int, default=int(os.environ.get("GRADIO_SERVER_PORT", 7860)),
735
- help="Port to bind to")
736
  parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
737
 
738
  args = parser.parse_args()
739
 
740
- # Initialize the system before running anything
741
- initialize_system()
742
- start_recording()
743
-
744
- if args.mode == "ui":
745
- # Launch the FastRTC built-in UI
746
- stream.ui.launch(
747
  server_name=args.host,
748
  server_port=args.port,
749
  share=True,
@@ -752,8 +966,6 @@ if __name__ == "__main__":
752
 
753
  elif args.mode == "api":
754
  # Run FastAPI only
755
- app = FastAPI()
756
- stream.mount(app)
757
  uvicorn.run(
758
  app,
759
  host=args.host,
@@ -762,12 +974,20 @@ if __name__ == "__main__":
762
  )
763
 
764
  elif args.mode == "both":
765
- # Run both FastRTC UI and API
 
766
  import threading
767
 
 
 
 
 
 
 
 
 
 
768
  def run_fastapi():
769
- app = FastAPI()
770
- stream.mount(app)
771
  uvicorn.run(
772
  app,
773
  host=args.host,
@@ -779,10 +999,5 @@ if __name__ == "__main__":
779
  api_thread = threading.Thread(target=run_fastapi, daemon=True)
780
  api_thread.start()
781
 
782
- # Start FastRTC UI in main thread
783
- stream.ui.launch(
784
- server_name=args.host,
785
- server_port=args.port,
786
- share=True,
787
- show_error=True
788
- )
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, AsyncStreamHandler, WebRTC
14
  import json
15
  import asyncio
16
  import uvicorn
 
329
  except Exception as e:
330
  logger.error(f"Model initialization error: {e}")
331
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  def live_text_detected(self, text):
334
  """Callback for real-time transcription updates"""
335
  with self.transcription_lock:
336
  self.last_transcription = text.strip()
 
 
 
337
 
338
  def process_final_text(self, text):
339
  """Process final transcribed text with speaker embedding"""
 
563
  logger.error(f"Error processing audio chunk: {e}")
564
 
565
 
566
+ # FastRTC Audio Handler
567
+ class DiarizationHandler(AsyncStreamHandler):
568
  def __init__(self, diarization_system):
569
  super().__init__()
570
  self.diarization_system = diarization_system
571
+ self.audio_buffer = []
572
+ self.buffer_size = BUFFER_SIZE
573
 
574
+ def copy(self):
575
+ """Return a fresh handler for each new stream connection"""
576
+ return DiarizationHandler(self.diarization_system)
577
+
578
+ async def emit(self):
579
+ """Not used - we only receive audio"""
580
+ return None
581
+
582
+ async def receive(self, frame):
583
+ """Receive audio data from FastRTC"""
584
  try:
585
+ if not self.diarization_system.is_running:
586
+ return
587
+
588
  # Extract audio data
589
+ audio_data = getattr(frame, 'data', frame)
590
 
591
+ # Check if this is a tuple (sample_rate, audio_array)
592
+ if isinstance(audio_data, tuple) and len(audio_data) >= 2:
593
+ sample_rate, audio_array = audio_data
594
+ else:
595
+ # If not a tuple, assume it's raw audio bytes/array
596
+ sample_rate = SAMPLE_RATE # Use default sample rate
597
+
598
+ # Convert to numpy array
599
+ if isinstance(audio_data, bytes):
600
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
601
+ elif isinstance(audio_data, (list, tuple)):
602
+ audio_array = np.array(audio_data, dtype=np.float32)
603
+ else:
604
+ audio_array = np.array(audio_data, dtype=np.float32)
605
+
606
+ # Ensure 1D
607
+ if len(audio_array.shape) > 1:
608
+ audio_array = audio_array.flatten()
609
+
610
+ # Send audio to recorder for live transcription
611
+ if self.diarization_system.recorder:
612
+ try:
613
+ self.diarization_system.recorder.feed_audio(audio_array)
614
+ logger.info("Fed audio to recorder")
615
+ except Exception as e:
616
+ logger.error(f"Error feeding audio to recorder: {e}")
617
+
618
+ # Buffer audio chunks
619
+ self.audio_buffer.extend(audio_array)
620
+
621
+ # Process in chunks
622
+ while len(self.audio_buffer) >= self.buffer_size:
623
+ chunk = np.array(self.audio_buffer[:self.buffer_size])
624
+ self.audio_buffer = self.audio_buffer[self.buffer_size:]
625
+
626
+ # Process asynchronously
627
+ await self.process_audio_async(chunk)
628
+
629
  except Exception as e:
630
+ logger.error(f"Error in FastRTC receive: {e}")
 
 
 
 
631
 
632
+ async def process_audio_async(self, audio_data):
633
+ """Process audio data asynchronously"""
634
+ try:
635
+ loop = asyncio.get_event_loop()
636
+ await loop.run_in_executor(
637
+ None,
638
+ self.diarization_system.process_audio_chunk,
639
+ audio_data,
640
+ SAMPLE_RATE
641
+ )
642
+ except Exception as e:
643
+ logger.error(f"Error in async audio processing: {e}")
644
+
645
+ async def start_up(self):
646
+ logger.info("DiarizationHandler started")
647
+
648
+ async def shutdown(self):
649
+ logger.info("DiarizationHandler shutdown")
650
 
651
 
652
+ # Global instances
653
  diarization_system = RealtimeSpeakerDiarization()
654
 
655
+ # We'll initialize the stream properly in initialize_system()
656
+
657
  def initialize_system():
658
  """Initialize the diarization system"""
659
+ global stream
660
  try:
661
  success = diarization_system.initialize_models()
662
  if success:
663
+ # Create a DiarizationHandler linked to our system
664
+ handler = DiarizationHandler(diarization_system)
665
+ # Update the Stream's handler
666
+ stream = Stream(
667
+ handler=handler,
668
+ modality="audio",
669
+ mode="send-receive",
670
+ stream_name="audio_stream" # Match the stream_name in WebRTC component
671
+ )
672
  return "✅ System initialized successfully!"
673
  else:
674
  return "❌ Failed to initialize system. Check logs for details."
 
684
  except Exception as e:
685
  return f"❌ Failed to start recording: {str(e)}"
686
 
687
+ def on_start():
688
+ result = start_recording()
689
+ return result, gr.update(interactive=False), gr.update(interactive=True)
690
+
691
  def stop_recording():
692
  """Stop recording and transcription"""
693
  try:
 
726
  except Exception as e:
727
  return f"Error getting status: {str(e)}"
728
 
729
+ # Create Gradio interface
730
+ def create_interface():
731
+ with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
732
+ gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
733
+ gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
 
734
 
735
+ with gr.Row():
736
+ with gr.Column(scale=2):
737
+ # Replace standard Audio with WebRTC component
738
+ audio_component = WebRTC(
739
+ label="Audio Input",
740
+ stream_name="audio_stream",
741
+ modality="audio",
742
+ mode="send-receive"
743
+ )
744
+
745
+ # Conversation display
746
+ conversation_output = gr.HTML(
747
+ value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
748
+ label="Live Conversation"
749
+ )
750
+
751
+ # Control buttons
752
+ with gr.Row():
753
+ init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
754
+ start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
755
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
756
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
757
+
758
+ # Status display
759
+ status_output = gr.Textbox(
760
+ label="System Status",
761
+ value="Ready to initialize...",
762
+ lines=8,
763
+ interactive=False
764
+ )
765
+
766
+ with gr.Column(scale=1):
767
+ # Settings
768
+ gr.Markdown("## ⚙️ Settings")
769
+
770
+ threshold_slider = gr.Slider(
771
+ minimum=0.3,
772
+ maximum=0.9,
773
+ step=0.05,
774
+ value=DEFAULT_CHANGE_THRESHOLD,
775
+ label="Speaker Change Sensitivity",
776
+ info="Lower = more sensitive"
777
+ )
778
+
779
+ max_speakers_slider = gr.Slider(
780
+ minimum=2,
781
+ maximum=ABSOLUTE_MAX_SPEAKERS,
782
+ step=1,
783
+ value=DEFAULT_MAX_SPEAKERS,
784
+ label="Maximum Speakers"
785
+ )
786
+
787
+ update_btn = gr.Button("Update Settings", variant="secondary")
788
+
789
+ # Instructions
790
+ gr.Markdown("""
791
+ ## 📋 Instructions
792
+ 1. **Initialize** the system (loads AI models)
793
+ 2. **Start** recording
794
+ 3. **Speak** - system will transcribe and identify speakers
795
+ 4. **Monitor** real-time results below
796
+
797
+ ## 🎨 Speaker Colors
798
+ - 🔴 Speaker 1 (Red)
799
+ - 🟢 Speaker 2 (Teal)
800
+ - 🔵 Speaker 3 (Blue)
801
+ - 🟡 Speaker 4 (Green)
802
+ - 🟣 Speaker 5 (Yellow)
803
+ - 🟤 Speaker 6 (Plum)
804
+ - 🟫 Speaker 7 (Mint)
805
+ - 🟨 Speaker 8 (Gold)
806
+ """)
807
 
808
+ # Event handlers
809
+ def on_initialize():
810
+ result = initialize_system()
811
+ if "✅" in result:
812
+ return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
813
+ else:
814
+ return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
815
+
816
+ def on_start():
817
+ result = start_recording()
818
+ return result, gr.update(interactive=False), gr.update(interactive=True)
819
+
820
+ def on_stop():
821
+ result = stop_recording()
822
+ return result, gr.update(interactive=True), gr.update(interactive=False)
823
+
824
+ def on_clear():
825
+ result = clear_conversation()
826
+ return result
827
+
828
+ def on_update_settings(threshold, max_speakers):
829
+ result = update_settings(threshold, int(max_speakers))
830
+ return result
831
+
832
+ def refresh_conversation():
833
+ return get_conversation()
834
+
835
+ def refresh_status():
836
+ return get_status()
837
+
838
+ # Button click handlers
839
+ init_btn.click(
840
+ fn=on_initialize,
841
+ outputs=[status_output, start_btn, stop_btn, clear_btn]
842
+ )
843
+
844
+ start_btn.click(
845
+ fn=on_start,
846
+ outputs=[status_output, start_btn, stop_btn]
847
+ )
848
+
849
+ stop_btn.click(
850
+ fn=on_stop,
851
+ outputs=[status_output, start_btn, stop_btn]
852
+ )
853
+
854
+ clear_btn.click(
855
+ fn=on_clear,
856
+ outputs=[status_output]
857
+ )
858
+
859
+ update_btn.click(
860
+ fn=on_update_settings,
861
+ inputs=[threshold_slider, max_speakers_slider],
862
+ outputs=[status_output]
863
+ )
864
+
865
+ # Auto-refresh conversation display every 1 second
866
+ conversation_timer = gr.Timer(1)
867
+ conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
868
+
869
+ # Auto-refresh status every 2 seconds
870
+ status_timer = gr.Timer(2)
871
+ status_timer.tick(refresh_status, outputs=[status_output])
872
+
873
+ return interface
874
+
875
+
876
+ # FastAPI setup for FastRTC integration
877
+ app = FastAPI()
878
+
879
+ # Create a placeholder handler - will be properly initialized later
880
+ class DefaultHandler(AsyncStreamHandler):
881
+ def __init__(self):
882
+ super().__init__()
883
+
884
+ async def receive(self, frame):
885
+ pass
886
+
887
+ async def emit(self):
888
+ return None
889
+
890
+ def copy(self):
891
+ return DefaultHandler()
892
+
893
+ async def shutdown(self):
894
+ pass
895
+
896
+ async def start_up(self):
897
+ pass
898
+
899
+ # Initialize with placeholder handler
900
+ stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
901
+ stream.mount(app)
902
+
903
+ @app.get("/")
904
+ async def root():
905
+ return {"message": "Real-time Speaker Diarization API"}
906
+
907
+ @app.get("/health")
908
+ async def health_check():
909
+ return {"status": "healthy", "system_running": diarization_system.is_running}
910
+
911
+ @app.post("/initialize")
912
+ async def api_initialize():
913
+ result = initialize_system()
914
+ return {"result": result, "success": "✅" in result}
915
+
916
+ @app.post("/start")
917
+ async def api_start():
918
+ result = start_recording()
919
+ return {"result": result, "success": "🎙️" in result}
920
+
921
+ @app.post("/stop")
922
+ async def api_stop():
923
+ result = stop_recording()
924
+ return {"result": result, "success": "⏹️" in result}
925
+
926
+ @app.post("/clear")
927
+ async def api_clear():
928
+ result = clear_conversation()
929
+ return {"result": result}
930
+
931
+ @app.get("/conversation")
932
+ async def api_get_conversation():
933
+ return {"conversation": get_conversation()}
934
+
935
+ @app.get("/status")
936
+ async def api_get_status():
937
+ return {"status": get_status()}
938
+
939
+ @app.post("/settings")
940
+ async def api_update_settings(threshold: float, max_speakers: int):
941
+ result = update_settings(threshold, max_speakers)
942
+ return {"result": result}
943
 
944
  # Main execution
945
  if __name__ == "__main__":
946
  import argparse
 
947
 
948
  parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
949
+ parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
950
+ help="Run mode: gradio interface, API only, or both")
951
  parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
952
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
 
953
  parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
954
 
955
  args = parser.parse_args()
956
 
957
+ if args.mode == "gradio":
958
+ # Run Gradio interface only
959
+ interface = create_interface()
960
+ interface.launch(
 
 
 
961
  server_name=args.host,
962
  server_port=args.port,
963
  share=True,
 
966
 
967
  elif args.mode == "api":
968
  # Run FastAPI only
 
 
969
  uvicorn.run(
970
  app,
971
  host=args.host,
 
974
  )
975
 
976
  elif args.mode == "both":
977
+ # Run both Gradio and FastAPI
978
+ import multiprocessing
979
  import threading
980
 
981
+ def run_gradio():
982
+ interface = create_interface()
983
+ interface.launch(
984
+ server_name=args.host,
985
+ server_port=args.port,
986
+ share=True,
987
+ show_error=True
988
+ )
989
+
990
  def run_fastapi():
 
 
991
  uvicorn.run(
992
  app,
993
  host=args.host,
 
999
  api_thread = threading.Thread(target=run_fastapi, daemon=True)
1000
  api_thread.start()
1001
 
1002
+ # Start Gradio in main thread
1003
+ run_gradio()