Spaces:
Sleeping
Sleeping
Commit
·
876be23
1
Parent(s):
534a53d
Check point 4
Browse files
app.py
CHANGED
@@ -10,12 +10,13 @@ import torchaudio
|
|
10 |
from scipy.spatial.distance import cosine
|
11 |
from RealtimeSTT import AudioToTextRecorder
|
12 |
from fastapi import FastAPI, APIRouter
|
13 |
-
from fastrtc import Stream, AsyncStreamHandler
|
14 |
import json
|
15 |
import asyncio
|
16 |
import uvicorn
|
17 |
from queue import Queue
|
18 |
import logging
|
|
|
19 |
|
20 |
# Set up logging
|
21 |
logging.basicConfig(level=logging.INFO)
|
@@ -330,10 +331,47 @@ class RealtimeSpeakerDiarization:
|
|
330 |
logger.error(f"Model initialization error: {e}")
|
331 |
return False
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
def live_text_detected(self, text):
|
334 |
"""Callback for real-time transcription updates"""
|
335 |
with self.transcription_lock:
|
336 |
self.last_transcription = text.strip()
|
|
|
|
|
|
|
337 |
|
338 |
def process_final_text(self, text):
|
339 |
"""Process final transcribed text with speaker embedding"""
|
@@ -652,7 +690,9 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
652 |
# Global instances
|
653 |
diarization_system = RealtimeSpeakerDiarization()
|
654 |
|
655 |
-
# We'll initialize the stream
|
|
|
|
|
656 |
|
657 |
def initialize_system():
|
658 |
"""Initialize the diarization system"""
|
@@ -666,9 +706,12 @@ def initialize_system():
|
|
666 |
stream = Stream(
|
667 |
handler=handler,
|
668 |
modality="audio",
|
669 |
-
mode="send-receive"
|
670 |
-
stream_name="audio_stream" # Match the stream_name in WebRTC component
|
671 |
)
|
|
|
|
|
|
|
|
|
672 |
return "✅ System initialized successfully!"
|
673 |
else:
|
674 |
return "❌ Failed to initialize system. Check logs for details."
|
@@ -737,7 +780,6 @@ def create_interface():
|
|
737 |
# Replace standard Audio with WebRTC component
|
738 |
audio_component = WebRTC(
|
739 |
label="Audio Input",
|
740 |
-
stream_name="audio_stream",
|
741 |
modality="audio",
|
742 |
mode="send-receive"
|
743 |
)
|
@@ -869,6 +911,21 @@ def create_interface():
|
|
869 |
# Auto-refresh status every 2 seconds
|
870 |
status_timer = gr.Timer(2)
|
871 |
status_timer.tick(refresh_status, outputs=[status_output])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
872 |
|
873 |
return interface
|
874 |
|
@@ -876,29 +933,9 @@ def create_interface():
|
|
876 |
# FastAPI setup for FastRTC integration
|
877 |
app = FastAPI()
|
878 |
|
879 |
-
#
|
880 |
-
|
881 |
-
|
882 |
-
super().__init__()
|
883 |
-
|
884 |
-
async def receive(self, frame):
|
885 |
-
pass
|
886 |
-
|
887 |
-
async def emit(self):
|
888 |
-
return None
|
889 |
-
|
890 |
-
def copy(self):
|
891 |
-
return DefaultHandler()
|
892 |
-
|
893 |
-
async def shutdown(self):
|
894 |
-
pass
|
895 |
-
|
896 |
-
async def start_up(self):
|
897 |
-
pass
|
898 |
-
|
899 |
-
# Initialize with placeholder handler
|
900 |
-
stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
|
901 |
-
stream.mount(app)
|
902 |
|
903 |
@app.get("/")
|
904 |
async def root():
|
|
|
10 |
from scipy.spatial.distance import cosine
|
11 |
from RealtimeSTT import AudioToTextRecorder
|
12 |
from fastapi import FastAPI, APIRouter
|
13 |
+
from fastrtc import Stream, AsyncStreamHandler
|
14 |
import json
|
15 |
import asyncio
|
16 |
import uvicorn
|
17 |
from queue import Queue
|
18 |
import logging
|
19 |
+
from gradio_webrtc import WebRTC
|
20 |
|
21 |
# Set up logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
|
|
331 |
logger.error(f"Model initialization error: {e}")
|
332 |
return False
|
333 |
|
334 |
+
def feed_audio(self, audio_data):
|
335 |
+
"""Feed audio data directly to the recorder for live transcription"""
|
336 |
+
if not self.is_running or not self.recorder:
|
337 |
+
return
|
338 |
+
|
339 |
+
try:
|
340 |
+
# Normalize if needed
|
341 |
+
if isinstance(audio_data, np.ndarray):
|
342 |
+
if audio_data.dtype != np.float32:
|
343 |
+
audio_data = audio_data.astype(np.float32)
|
344 |
+
|
345 |
+
# Convert to int16 for the recorder
|
346 |
+
audio_int16 = (audio_data * 32767).astype(np.int16)
|
347 |
+
audio_bytes = audio_int16.tobytes()
|
348 |
+
|
349 |
+
# Feed to recorder
|
350 |
+
self.recorder.feed_audio(audio_bytes)
|
351 |
+
|
352 |
+
# Also process for speaker detection
|
353 |
+
self.process_audio_chunk(audio_data)
|
354 |
+
|
355 |
+
elif isinstance(audio_data, bytes):
|
356 |
+
# Feed raw bytes directly
|
357 |
+
self.recorder.feed_audio(audio_data)
|
358 |
+
|
359 |
+
# Convert to float for speaker detection
|
360 |
+
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
361 |
+
audio_float = audio_int16.astype(np.float32) / 32768.0
|
362 |
+
self.process_audio_chunk(audio_float)
|
363 |
+
|
364 |
+
logger.debug("Audio fed to recorder")
|
365 |
+
except Exception as e:
|
366 |
+
logger.error(f"Error feeding audio: {e}")
|
367 |
+
|
368 |
def live_text_detected(self, text):
|
369 |
"""Callback for real-time transcription updates"""
|
370 |
with self.transcription_lock:
|
371 |
self.last_transcription = text.strip()
|
372 |
+
|
373 |
+
# Update the display immediately on new transcription
|
374 |
+
self.update_conversation_display()
|
375 |
|
376 |
def process_final_text(self, text):
|
377 |
"""Process final transcribed text with speaker embedding"""
|
|
|
690 |
# Global instances
|
691 |
diarization_system = RealtimeSpeakerDiarization()
|
692 |
|
693 |
+
# We'll initialize the stream in initialize_system()
|
694 |
+
# For now, just create a placeholder
|
695 |
+
stream = None
|
696 |
|
697 |
def initialize_system():
|
698 |
"""Initialize the diarization system"""
|
|
|
706 |
stream = Stream(
|
707 |
handler=handler,
|
708 |
modality="audio",
|
709 |
+
mode="send-receive"
|
|
|
710 |
)
|
711 |
+
|
712 |
+
# Mount the stream to the FastAPI app
|
713 |
+
stream.mount(app)
|
714 |
+
|
715 |
return "✅ System initialized successfully!"
|
716 |
else:
|
717 |
return "❌ Failed to initialize system. Check logs for details."
|
|
|
780 |
# Replace standard Audio with WebRTC component
|
781 |
audio_component = WebRTC(
|
782 |
label="Audio Input",
|
|
|
783 |
modality="audio",
|
784 |
mode="send-receive"
|
785 |
)
|
|
|
911 |
# Auto-refresh status every 2 seconds
|
912 |
status_timer = gr.Timer(2)
|
913 |
status_timer.tick(refresh_status, outputs=[status_output])
|
914 |
+
|
915 |
+
# Connect the WebRTC component to our processing function
|
916 |
+
def process_webrtc_audio(audio_data):
|
917 |
+
if audio_data is not None and diarization_system.is_running:
|
918 |
+
try:
|
919 |
+
# Feed audio to our diarization system
|
920 |
+
diarization_system.feed_audio(audio_data)
|
921 |
+
except Exception as e:
|
922 |
+
logger.error(f"Error processing WebRTC audio: {e}")
|
923 |
+
return get_conversation()
|
924 |
+
|
925 |
+
audio_component.stream(
|
926 |
+
fn=process_webrtc_audio,
|
927 |
+
outputs=[conversation_output]
|
928 |
+
)
|
929 |
|
930 |
return interface
|
931 |
|
|
|
933 |
# FastAPI setup for FastRTC integration
|
934 |
app = FastAPI()
|
935 |
|
936 |
+
# We'll initialize the stream in initialize_system()
|
937 |
+
# For now, just create a placeholder
|
938 |
+
stream = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
939 |
|
940 |
@app.get("/")
|
941 |
async def root():
|