Spaces:
Sleeping
Sleeping
Commit
·
af84a93
1
Parent(s):
98185ad
Check point 4
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ import asyncio
|
|
16 |
import uvicorn
|
17 |
from queue import Queue
|
18 |
import logging
|
19 |
-
from
|
20 |
|
21 |
# Set up logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
@@ -331,47 +331,10 @@ class RealtimeSpeakerDiarization:
|
|
331 |
logger.error(f"Model initialization error: {e}")
|
332 |
return False
|
333 |
|
334 |
-
def feed_audio(self, audio_data):
|
335 |
-
"""Feed audio data directly to the recorder for live transcription"""
|
336 |
-
if not self.is_running or not self.recorder:
|
337 |
-
return
|
338 |
-
|
339 |
-
try:
|
340 |
-
# Normalize if needed
|
341 |
-
if isinstance(audio_data, np.ndarray):
|
342 |
-
if audio_data.dtype != np.float32:
|
343 |
-
audio_data = audio_data.astype(np.float32)
|
344 |
-
|
345 |
-
# Convert to int16 for the recorder
|
346 |
-
audio_int16 = (audio_data * 32767).astype(np.int16)
|
347 |
-
audio_bytes = audio_int16.tobytes()
|
348 |
-
|
349 |
-
# Feed to recorder
|
350 |
-
self.recorder.feed_audio(audio_bytes)
|
351 |
-
|
352 |
-
# Also process for speaker detection
|
353 |
-
self.process_audio_chunk(audio_data)
|
354 |
-
|
355 |
-
elif isinstance(audio_data, bytes):
|
356 |
-
# Feed raw bytes directly
|
357 |
-
self.recorder.feed_audio(audio_data)
|
358 |
-
|
359 |
-
# Convert to float for speaker detection
|
360 |
-
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
361 |
-
audio_float = audio_int16.astype(np.float32) / 32768.0
|
362 |
-
self.process_audio_chunk(audio_float)
|
363 |
-
|
364 |
-
logger.debug("Audio fed to recorder")
|
365 |
-
except Exception as e:
|
366 |
-
logger.error(f"Error feeding audio: {e}")
|
367 |
-
|
368 |
def live_text_detected(self, text):
|
369 |
"""Callback for real-time transcription updates"""
|
370 |
with self.transcription_lock:
|
371 |
self.last_transcription = text.strip()
|
372 |
-
|
373 |
-
# Update the display immediately on new transcription
|
374 |
-
self.update_conversation_display()
|
375 |
|
376 |
def process_final_text(self, text):
|
377 |
"""Process final transcribed text with speaker embedding"""
|
@@ -626,33 +589,18 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
626 |
# Extract audio data
|
627 |
audio_data = getattr(frame, 'data', frame)
|
628 |
|
629 |
-
#
|
630 |
-
if isinstance(audio_data,
|
631 |
-
|
|
|
|
|
632 |
else:
|
633 |
-
|
634 |
-
sample_rate = SAMPLE_RATE # Use default sample rate
|
635 |
-
|
636 |
-
# Convert to numpy array
|
637 |
-
if isinstance(audio_data, bytes):
|
638 |
-
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
639 |
-
elif isinstance(audio_data, (list, tuple)):
|
640 |
-
audio_array = np.array(audio_data, dtype=np.float32)
|
641 |
-
else:
|
642 |
-
audio_array = np.array(audio_data, dtype=np.float32)
|
643 |
|
644 |
# Ensure 1D
|
645 |
if len(audio_array.shape) > 1:
|
646 |
audio_array = audio_array.flatten()
|
647 |
|
648 |
-
# Send audio to recorder for live transcription
|
649 |
-
if self.diarization_system.recorder:
|
650 |
-
try:
|
651 |
-
self.diarization_system.recorder.feed_audio(audio_array)
|
652 |
-
logger.info("Fed audio to recorder")
|
653 |
-
except Exception as e:
|
654 |
-
logger.error(f"Error feeding audio to recorder: {e}")
|
655 |
-
|
656 |
# Buffer audio chunks
|
657 |
self.audio_buffer.extend(audio_array)
|
658 |
|
@@ -679,20 +627,11 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
679 |
)
|
680 |
except Exception as e:
|
681 |
logger.error(f"Error in async audio processing: {e}")
|
682 |
-
|
683 |
-
async def start_up(self):
|
684 |
-
logger.info("DiarizationHandler started")
|
685 |
-
|
686 |
-
async def shutdown(self):
|
687 |
-
logger.info("DiarizationHandler shutdown")
|
688 |
|
689 |
|
690 |
# Global instances
|
691 |
diarization_system = RealtimeSpeakerDiarization()
|
692 |
-
|
693 |
-
# We'll initialize the stream in initialize_system()
|
694 |
-
# For now, just create a placeholder
|
695 |
-
stream = None
|
696 |
|
697 |
def initialize_system():
|
698 |
"""Initialize the diarization system"""
|
@@ -700,18 +639,8 @@ def initialize_system():
|
|
700 |
try:
|
701 |
success = diarization_system.initialize_models()
|
702 |
if success:
|
703 |
-
#
|
704 |
-
handler = DiarizationHandler(diarization_system)
|
705 |
-
# Update the Stream's handler
|
706 |
-
stream = Stream(
|
707 |
-
handler=handler,
|
708 |
-
modality="audio",
|
709 |
-
mode="send-receive"
|
710 |
-
)
|
711 |
-
|
712 |
-
# Mount the stream to the FastAPI app
|
713 |
-
stream.mount(app)
|
714 |
-
|
715 |
return "✅ System initialized successfully!"
|
716 |
else:
|
717 |
return "❌ Failed to initialize system. Check logs for details."
|
@@ -777,11 +706,11 @@ def create_interface():
|
|
777 |
|
778 |
with gr.Row():
|
779 |
with gr.Column(scale=2):
|
780 |
-
# Replace standard
|
781 |
-
audio_component =
|
782 |
label="Audio Input",
|
783 |
-
|
784 |
-
|
785 |
)
|
786 |
|
787 |
# Conversation display
|
@@ -912,15 +841,50 @@ def create_interface():
|
|
912 |
status_timer = gr.Timer(2)
|
913 |
status_timer.tick(refresh_status, outputs=[status_output])
|
914 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
915 |
return interface
|
916 |
|
917 |
|
918 |
# FastAPI setup for FastRTC integration
|
919 |
app = FastAPI()
|
920 |
|
921 |
-
#
|
922 |
-
|
923 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
924 |
|
925 |
@app.get("/")
|
926 |
async def root():
|
|
|
16 |
import uvicorn
|
17 |
from queue import Queue
|
18 |
import logging
|
19 |
+
from fastrtc import WebRTC
|
20 |
|
21 |
# Set up logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
|
|
331 |
logger.error(f"Model initialization error: {e}")
|
332 |
return False
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
def live_text_detected(self, text):
|
335 |
"""Callback for real-time transcription updates"""
|
336 |
with self.transcription_lock:
|
337 |
self.last_transcription = text.strip()
|
|
|
|
|
|
|
338 |
|
339 |
def process_final_text(self, text):
|
340 |
"""Process final transcribed text with speaker embedding"""
|
|
|
589 |
# Extract audio data
|
590 |
audio_data = getattr(frame, 'data', frame)
|
591 |
|
592 |
+
# Convert to numpy array
|
593 |
+
if isinstance(audio_data, bytes):
|
594 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
595 |
+
elif isinstance(audio_data, (list, tuple)):
|
596 |
+
audio_array = np.array(audio_data, dtype=np.float32)
|
597 |
else:
|
598 |
+
audio_array = np.array(audio_data, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
|
600 |
# Ensure 1D
|
601 |
if len(audio_array.shape) > 1:
|
602 |
audio_array = audio_array.flatten()
|
603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
# Buffer audio chunks
|
605 |
self.audio_buffer.extend(audio_array)
|
606 |
|
|
|
627 |
)
|
628 |
except Exception as e:
|
629 |
logger.error(f"Error in async audio processing: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
|
631 |
|
632 |
# Global instances
|
633 |
diarization_system = RealtimeSpeakerDiarization()
|
634 |
+
audio_handler = None
|
|
|
|
|
|
|
635 |
|
636 |
def initialize_system():
|
637 |
"""Initialize the diarization system"""
|
|
|
639 |
try:
|
640 |
success = diarization_system.initialize_models()
|
641 |
if success:
|
642 |
+
# Update the Stream's handler to use our DiarizationHandler
|
643 |
+
stream.handler = DiarizationHandler(diarization_system)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
return "✅ System initialized successfully!"
|
645 |
else:
|
646 |
return "❌ Failed to initialize system. Check logs for details."
|
|
|
706 |
|
707 |
with gr.Row():
|
708 |
with gr.Column(scale=2):
|
709 |
+
# Replace WebRTC with standard Gradio audio component
|
710 |
+
audio_component = gr.Audio(
|
711 |
label="Audio Input",
|
712 |
+
sources=["microphone"],
|
713 |
+
streaming=True
|
714 |
)
|
715 |
|
716 |
# Conversation display
|
|
|
841 |
status_timer = gr.Timer(2)
|
842 |
status_timer.tick(refresh_status, outputs=[status_output])
|
843 |
|
844 |
+
# Process audio from Gradio component
|
845 |
+
def process_audio_input(audio_data):
|
846 |
+
if audio_data is not None and diarization_system.is_running:
|
847 |
+
# Extract audio data
|
848 |
+
if isinstance(audio_data, tuple) and len(audio_data) >= 2:
|
849 |
+
sample_rate, audio_array = audio_data[0], audio_data[1]
|
850 |
+
diarization_system.process_audio_chunk(audio_array, sample_rate)
|
851 |
+
return get_conversation()
|
852 |
+
|
853 |
+
# Connect audio component to processing function
|
854 |
+
audio_component.stream(
|
855 |
+
fn=process_audio_input,
|
856 |
+
outputs=[conversation_output]
|
857 |
+
)
|
858 |
+
|
859 |
return interface
|
860 |
|
861 |
|
862 |
# FastAPI setup for FastRTC integration
|
863 |
app = FastAPI()
|
864 |
|
865 |
+
# Create a placeholder handler - will be properly initialized later
|
866 |
+
class DefaultHandler(AsyncStreamHandler):
|
867 |
+
def __init__(self):
|
868 |
+
super().__init__()
|
869 |
+
|
870 |
+
async def receive(self, frame):
|
871 |
+
pass
|
872 |
+
|
873 |
+
async def emit(self):
|
874 |
+
return None
|
875 |
+
|
876 |
+
def copy(self):
|
877 |
+
return DefaultHandler()
|
878 |
+
|
879 |
+
async def shutdown(self):
|
880 |
+
pass
|
881 |
+
|
882 |
+
async def start_up(self):
|
883 |
+
pass
|
884 |
+
|
885 |
+
# Initialize with placeholder handler
|
886 |
+
stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
|
887 |
+
stream.mount(app)
|
888 |
|
889 |
@app.get("/")
|
890 |
async def root():
|