Saiyaswanth007 commited on
Commit
a5c083c
·
1 Parent(s): a3ec320

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +69 -325
app.py CHANGED
@@ -10,13 +10,12 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, AsyncStreamHandler
14
  import json
15
  import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
19
- from gradio_webrtc import WebRTC
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO)
@@ -330,7 +329,7 @@ class RealtimeSpeakerDiarization:
330
  except Exception as e:
331
  logger.error(f"Model initialization error: {e}")
332
  return False
333
-
334
  def feed_audio(self, audio_data):
335
  """Feed audio data directly to the recorder for live transcription"""
336
  if not self.is_running or not self.recorder:
@@ -601,117 +600,47 @@ class RealtimeSpeakerDiarization:
601
  logger.error(f"Error processing audio chunk: {e}")
602
 
603
 
604
- # FastRTC Audio Handler
605
- class DiarizationHandler(AsyncStreamHandler):
606
  def __init__(self, diarization_system):
607
  super().__init__()
608
  self.diarization_system = diarization_system
609
- self.audio_buffer = []
610
- self.buffer_size = BUFFER_SIZE
611
 
612
- def copy(self):
613
- """Return a fresh handler for each new stream connection"""
614
- return DiarizationHandler(self.diarization_system)
615
-
616
- async def emit(self):
617
- """Not used - we only receive audio"""
618
- return None
619
-
620
- async def receive(self, frame):
621
- """Receive audio data from FastRTC"""
622
  try:
623
- if not self.diarization_system.is_running:
624
- return
625
-
626
  # Extract audio data
627
- audio_data = getattr(frame, 'data', frame)
628
 
629
- # Check if this is a tuple (sample_rate, audio_array)
630
- if isinstance(audio_data, tuple) and len(audio_data) >= 2:
631
- sample_rate, audio_array = audio_data
632
- else:
633
- # If not a tuple, assume it's raw audio bytes/array
634
- sample_rate = SAMPLE_RATE # Use default sample rate
635
-
636
- # Convert to numpy array
637
- if isinstance(audio_data, bytes):
638
- audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
639
- elif isinstance(audio_data, (list, tuple)):
640
- audio_array = np.array(audio_data, dtype=np.float32)
641
- else:
642
- audio_array = np.array(audio_data, dtype=np.float32)
643
-
644
- # Ensure 1D
645
- if len(audio_array.shape) > 1:
646
- audio_array = audio_array.flatten()
647
-
648
- # Send audio to recorder for live transcription
649
- if self.diarization_system.recorder:
650
- try:
651
- self.diarization_system.recorder.feed_audio(audio_array)
652
- logger.info("Fed audio to recorder")
653
- except Exception as e:
654
- logger.error(f"Error feeding audio to recorder: {e}")
655
-
656
- # Buffer audio chunks
657
- self.audio_buffer.extend(audio_array)
658
-
659
- # Process in chunks
660
- while len(self.audio_buffer) >= self.buffer_size:
661
- chunk = np.array(self.audio_buffer[:self.buffer_size])
662
- self.audio_buffer = self.audio_buffer[self.buffer_size:]
663
-
664
- # Process asynchronously
665
- await self.process_audio_async(chunk)
666
-
667
  except Exception as e:
668
- logger.error(f"Error in FastRTC receive: {e}")
669
 
670
- async def process_audio_async(self, audio_data):
671
- """Process audio data asynchronously"""
672
- try:
673
- loop = asyncio.get_event_loop()
674
- await loop.run_in_executor(
675
- None,
676
- self.diarization_system.process_audio_chunk,
677
- audio_data,
678
- SAMPLE_RATE
679
- )
680
- except Exception as e:
681
- logger.error(f"Error in async audio processing: {e}")
682
-
683
- async def start_up(self):
684
- logger.info("DiarizationHandler started")
685
-
686
- async def shutdown(self):
687
- logger.info("DiarizationHandler shutdown")
688
 
689
 
690
- # Global instances
691
  diarization_system = RealtimeSpeakerDiarization()
692
 
693
- # We'll initialize the stream in initialize_system()
694
- # For now, just create a placeholder
695
- stream = None
696
-
697
  def initialize_system():
698
  """Initialize the diarization system"""
699
- global stream
700
  try:
701
  success = diarization_system.initialize_models()
702
  if success:
703
- # Create a DiarizationHandler linked to our system
704
- handler = DiarizationHandler(diarization_system)
705
- # Update the Stream's handler
706
- stream = Stream(
707
- handler=handler,
708
- modality="audio",
709
- mode="send-receive"
710
- )
711
-
712
- # Mount the stream to the FastAPI app
713
- stream.mount(app)
714
-
715
  return "✅ System initialized successfully!"
716
  else:
717
  return "❌ Failed to initialize system. Check logs for details."
@@ -727,10 +656,6 @@ def start_recording():
727
  except Exception as e:
728
  return f"❌ Failed to start recording: {str(e)}"
729
 
730
- def on_start():
731
- result = start_recording()
732
- return result, gr.update(interactive=False), gr.update(interactive=True)
733
-
734
  def stop_recording():
735
  """Stop recording and transcription"""
736
  try:
@@ -769,232 +694,52 @@ def get_status():
769
  except Exception as e:
770
  return f"Error getting status: {str(e)}"
771
 
772
- # Create Gradio interface
773
- def create_interface():
774
- with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
775
- gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
776
- gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
777
-
778
- with gr.Row():
779
- with gr.Column(scale=2):
780
- # Replace standard Audio with WebRTC component
781
- audio_component = WebRTC(
782
- label="Audio Input",
783
- modality="audio",
784
- mode="send-receive"
785
- )
786
-
787
- # Conversation display
788
- conversation_output = gr.HTML(
789
- value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
790
- label="Live Conversation"
791
- )
792
-
793
- # Control buttons
794
- with gr.Row():
795
- init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
796
- start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
797
- stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
798
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
799
-
800
- # Status display
801
- status_output = gr.Textbox(
802
- label="System Status",
803
- value="Ready to initialize...",
804
- lines=8,
805
- interactive=False
806
- )
807
-
808
- with gr.Column(scale=1):
809
- # Settings
810
- gr.Markdown("## ⚙️ Settings")
811
-
812
- threshold_slider = gr.Slider(
813
- minimum=0.3,
814
- maximum=0.9,
815
- step=0.05,
816
- value=DEFAULT_CHANGE_THRESHOLD,
817
- label="Speaker Change Sensitivity",
818
- info="Lower = more sensitive"
819
- )
820
-
821
- max_speakers_slider = gr.Slider(
822
- minimum=2,
823
- maximum=ABSOLUTE_MAX_SPEAKERS,
824
- step=1,
825
- value=DEFAULT_MAX_SPEAKERS,
826
- label="Maximum Speakers"
827
- )
828
-
829
- update_btn = gr.Button("Update Settings", variant="secondary")
830
-
831
- # Instructions
832
- gr.Markdown("""
833
- ## 📋 Instructions
834
- 1. **Initialize** the system (loads AI models)
835
- 2. **Start** recording
836
- 3. **Speak** - system will transcribe and identify speakers
837
- 4. **Monitor** real-time results below
838
-
839
- ## 🎨 Speaker Colors
840
- - 🔴 Speaker 1 (Red)
841
- - 🟢 Speaker 2 (Teal)
842
- - 🔵 Speaker 3 (Blue)
843
- - 🟡 Speaker 4 (Green)
844
- - 🟣 Speaker 5 (Yellow)
845
- - 🟤 Speaker 6 (Plum)
846
- - 🟫 Speaker 7 (Mint)
847
- - 🟨 Speaker 8 (Gold)
848
- """)
849
-
850
- # Event handlers
851
- def on_initialize():
852
- result = initialize_system()
853
- if "✅" in result:
854
- return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
855
- else:
856
- return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
857
-
858
- def on_start():
859
- result = start_recording()
860
- return result, gr.update(interactive=False), gr.update(interactive=True)
861
-
862
- def on_stop():
863
- result = stop_recording()
864
- return result, gr.update(interactive=True), gr.update(interactive=False)
865
-
866
- def on_clear():
867
- result = clear_conversation()
868
- return result
869
-
870
- def on_update_settings(threshold, max_speakers):
871
- result = update_settings(threshold, int(max_speakers))
872
- return result
873
-
874
- def refresh_conversation():
875
- return get_conversation()
876
-
877
- def refresh_status():
878
- return get_status()
879
-
880
- # Button click handlers
881
- init_btn.click(
882
- fn=on_initialize,
883
- outputs=[status_output, start_btn, stop_btn, clear_btn]
884
- )
885
-
886
- start_btn.click(
887
- fn=on_start,
888
- outputs=[status_output, start_btn, stop_btn]
889
- )
890
-
891
- stop_btn.click(
892
- fn=on_stop,
893
- outputs=[status_output, start_btn, stop_btn]
894
- )
895
-
896
- clear_btn.click(
897
- fn=on_clear,
898
- outputs=[status_output]
899
- )
900
-
901
- update_btn.click(
902
- fn=on_update_settings,
903
- inputs=[threshold_slider, max_speakers_slider],
904
- outputs=[status_output]
905
- )
906
-
907
- # Auto-refresh conversation display every 1 second
908
- conversation_timer = gr.Timer(1)
909
- conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
910
 
911
- # Auto-refresh status every 2 seconds
912
- status_timer = gr.Timer(2)
913
- status_timer.tick(refresh_status, outputs=[status_output])
 
914
 
915
- # Connect the WebRTC component to our processing function
916
- def process_webrtc_audio(audio_data):
917
- if audio_data is not None and diarization_system.is_running:
918
- try:
919
- # Feed audio to our diarization system
920
- diarization_system.feed_audio(audio_data)
921
- except Exception as e:
922
- logger.error(f"Error processing WebRTC audio: {e}")
923
- return get_conversation()
924
-
925
- audio_component.stream(
926
- fn=process_webrtc_audio,
927
- outputs=[conversation_output]
928
- )
929
-
930
- return interface
931
-
932
-
933
- # FastAPI setup for FastRTC integration
934
- app = FastAPI()
935
-
936
- # We'll initialize the stream in initialize_system()
937
- # For now, just create a placeholder
938
- stream = None
939
-
940
- @app.get("/")
941
- async def root():
942
- return {"message": "Real-time Speaker Diarization API"}
943
-
944
- @app.get("/health")
945
- async def health_check():
946
- return {"status": "healthy", "system_running": diarization_system.is_running}
947
-
948
- @app.post("/initialize")
949
- async def api_initialize():
950
- result = initialize_system()
951
- return {"result": result, "success": "✅" in result}
952
-
953
- @app.post("/start")
954
- async def api_start():
955
- result = start_recording()
956
- return {"result": result, "success": "🎙️" in result}
957
-
958
- @app.post("/stop")
959
- async def api_stop():
960
- result = stop_recording()
961
- return {"result": result, "success": "⏹️" in result}
962
-
963
- @app.post("/clear")
964
- async def api_clear():
965
- result = clear_conversation()
966
- return {"result": result}
967
-
968
- @app.get("/conversation")
969
- async def api_get_conversation():
970
- return {"conversation": get_conversation()}
971
-
972
- @app.get("/status")
973
- async def api_get_status():
974
- return {"status": get_status()}
975
-
976
- @app.post("/settings")
977
- async def api_update_settings(threshold: float, max_speakers: int):
978
- result = update_settings(threshold, max_speakers)
979
- return {"result": result}
980
 
981
  # Main execution
982
  if __name__ == "__main__":
983
  import argparse
984
 
985
  parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
986
- parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
987
- help="Run mode: gradio interface, API only, or both")
988
  parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
989
  parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
990
  parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
991
 
992
  args = parser.parse_args()
993
 
994
- if args.mode == "gradio":
995
- # Run Gradio interface only
996
- interface = create_interface()
997
- interface.launch(
 
 
 
998
  server_name=args.host,
999
  server_port=args.port,
1000
  share=True,
@@ -1003,6 +748,8 @@ if __name__ == "__main__":
1003
 
1004
  elif args.mode == "api":
1005
  # Run FastAPI only
 
 
1006
  uvicorn.run(
1007
  app,
1008
  host=args.host,
@@ -1011,20 +758,12 @@ if __name__ == "__main__":
1011
  )
1012
 
1013
  elif args.mode == "both":
1014
- # Run both Gradio and FastAPI
1015
- import multiprocessing
1016
  import threading
1017
 
1018
- def run_gradio():
1019
- interface = create_interface()
1020
- interface.launch(
1021
- server_name=args.host,
1022
- server_port=args.port,
1023
- share=True,
1024
- show_error=True
1025
- )
1026
-
1027
  def run_fastapi():
 
 
1028
  uvicorn.run(
1029
  app,
1030
  host=args.host,
@@ -1036,5 +775,10 @@ if __name__ == "__main__":
1036
  api_thread = threading.Thread(target=run_fastapi, daemon=True)
1037
  api_thread.start()
1038
 
1039
- # Start Gradio in main thread
1040
- run_gradio()
 
 
 
 
 
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, ReplyOnPause, AudioStreamHandler
14
  import json
15
  import asyncio
16
  import uvicorn
17
  from queue import Queue
18
  import logging
 
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO)
 
329
  except Exception as e:
330
  logger.error(f"Model initialization error: {e}")
331
  return False
332
+
333
  def feed_audio(self, audio_data):
334
  """Feed audio data directly to the recorder for live transcription"""
335
  if not self.is_running or not self.recorder:
 
600
  logger.error(f"Error processing audio chunk: {e}")
601
 
602
 
603
+ # Create diarization handler for FastRTC
604
+ class DiarizationAudioHandler(AudioStreamHandler):
605
  def __init__(self, diarization_system):
606
  super().__init__()
607
  self.diarization_system = diarization_system
 
 
608
 
609
+ def receive(self, frame):
610
+ """Process incoming audio frame"""
611
+ if not self.diarization_system.is_running:
612
+ return
613
+
 
 
 
 
 
614
  try:
 
 
 
615
  # Extract audio data
616
+ sample_rate, audio_array = frame
617
 
618
+ # Send audio to diarization system for processing
619
+ self.diarization_system.feed_audio(audio_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  except Exception as e:
621
+ logger.error(f"Error processing FastRTC audio: {e}")
622
 
623
+ def copy(self):
624
+ """Return a fresh handler instance"""
625
+ return DiarizationAudioHandler(self.diarization_system)
626
+
627
+ def shutdown(self):
628
+ """Clean up resources"""
629
+ pass
630
+
631
+ def start_up(self):
632
+ """Initialize resources"""
633
+ logger.info("DiarizationAudioHandler started")
 
 
 
 
 
 
 
634
 
635
 
636
+ # Global diarization system instance
637
  diarization_system = RealtimeSpeakerDiarization()
638
 
 
 
 
 
639
  def initialize_system():
640
  """Initialize the diarization system"""
 
641
  try:
642
  success = diarization_system.initialize_models()
643
  if success:
 
 
 
 
 
 
 
 
 
 
 
 
644
  return "✅ System initialized successfully!"
645
  else:
646
  return "❌ Failed to initialize system. Check logs for details."
 
656
  except Exception as e:
657
  return f"❌ Failed to start recording: {str(e)}"
658
 
 
 
 
 
659
  def stop_recording():
660
  """Stop recording and transcription"""
661
  try:
 
694
  except Exception as e:
695
  return f"Error getting status: {str(e)}"
696
 
697
+ # Create handler wrapper function for FastRTC
698
+ def diarization_handler(audio_data):
699
+ """Handler function for FastRTC stream"""
700
+ try:
701
+ # Process the audio data
702
+ diarization_system.process_audio_chunk(audio_data[1], audio_data[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
 
704
+ # Just yield the original audio back (echo)
705
+ # This can be changed to just return None since we don't need echo
706
+ # This can be changed to just return None since we don't need echo
707
+ yield audio_data
708
 
709
+ except Exception as e:
710
+ logger.error(f"Error in diarization handler: {e}")
711
+
712
+ # Create FastRTC stream with ReplyOnPause pattern
713
+ stream = Stream(
714
+ handler=ReplyOnPause(diarization_handler),
715
+ modality="audio",
716
+ mode="send-receive",
717
+ ui_args={
718
+ "title": "Real-time Speaker Diarization",
719
+ "description": "Live transcription with automatic speaker identification"
720
+ }
721
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  # Main execution
724
  if __name__ == "__main__":
725
  import argparse
726
 
727
  parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
728
+ parser.add_argument("--mode", choices=["ui", "api", "both"], default="ui",
729
+ help="Run mode: FastRTC UI, API only, or both")
730
  parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
731
  parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
732
  parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
733
 
734
  args = parser.parse_args()
735
 
736
+ # Initialize the system before running anything
737
+ initialize_system()
738
+ start_recording()
739
+
740
+ if args.mode == "ui":
741
+ # Launch the FastRTC built-in UI
742
+ stream.ui.launch(
743
  server_name=args.host,
744
  server_port=args.port,
745
  share=True,
 
748
 
749
  elif args.mode == "api":
750
  # Run FastAPI only
751
+ app = FastAPI()
752
+ stream.mount(app)
753
  uvicorn.run(
754
  app,
755
  host=args.host,
 
758
  )
759
 
760
  elif args.mode == "both":
761
+ # Run both FastRTC UI and API
 
762
  import threading
763
 
 
 
 
 
 
 
 
 
 
764
  def run_fastapi():
765
+ app = FastAPI()
766
+ stream.mount(app)
767
  uvicorn.run(
768
  app,
769
  host=args.host,
 
775
  api_thread = threading.Thread(target=run_fastapi, daemon=True)
776
  api_thread.start()
777
 
778
+ # Start FastRTC UI in main thread
779
+ stream.ui.launch(
780
+ server_name=args.host,
781
+ server_port=args.port,
782
+ share=True,
783
+ show_error=True
784
+ )