Spaces:
Sleeping
Sleeping
Commit
·
534a53d
1
Parent(s):
2b9c901
Check point 4
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ import torchaudio
|
|
10 |
from scipy.spatial.distance import cosine
|
11 |
from RealtimeSTT import AudioToTextRecorder
|
12 |
from fastapi import FastAPI, APIRouter
|
13 |
-
from fastrtc import Stream,
|
14 |
import json
|
15 |
import asyncio
|
16 |
import uvicorn
|
@@ -329,48 +329,11 @@ class RealtimeSpeakerDiarization:
|
|
329 |
except Exception as e:
|
330 |
logger.error(f"Model initialization error: {e}")
|
331 |
return False
|
332 |
-
|
333 |
-
def feed_audio(self, audio_data):
|
334 |
-
"""Feed audio data directly to the recorder for live transcription"""
|
335 |
-
if not self.is_running or not self.recorder:
|
336 |
-
return
|
337 |
-
|
338 |
-
try:
|
339 |
-
# Normalize if needed
|
340 |
-
if isinstance(audio_data, np.ndarray):
|
341 |
-
if audio_data.dtype != np.float32:
|
342 |
-
audio_data = audio_data.astype(np.float32)
|
343 |
-
|
344 |
-
# Convert to int16 for the recorder
|
345 |
-
audio_int16 = (audio_data * 32767).astype(np.int16)
|
346 |
-
audio_bytes = audio_int16.tobytes()
|
347 |
-
|
348 |
-
# Feed to recorder
|
349 |
-
self.recorder.feed_audio(audio_bytes)
|
350 |
-
|
351 |
-
# Also process for speaker detection
|
352 |
-
self.process_audio_chunk(audio_data)
|
353 |
-
|
354 |
-
elif isinstance(audio_data, bytes):
|
355 |
-
# Feed raw bytes directly
|
356 |
-
self.recorder.feed_audio(audio_data)
|
357 |
-
|
358 |
-
# Convert to float for speaker detection
|
359 |
-
audio_int16 = np.frombuffer(audio_data, dtype=np.int16)
|
360 |
-
audio_float = audio_int16.astype(np.float32) / 32768.0
|
361 |
-
self.process_audio_chunk(audio_float)
|
362 |
-
|
363 |
-
logger.debug("Audio fed to recorder")
|
364 |
-
except Exception as e:
|
365 |
-
logger.error(f"Error feeding audio: {e}")
|
366 |
|
367 |
def live_text_detected(self, text):
|
368 |
"""Callback for real-time transcription updates"""
|
369 |
with self.transcription_lock:
|
370 |
self.last_transcription = text.strip()
|
371 |
-
|
372 |
-
# Update the display immediately on new transcription
|
373 |
-
self.update_conversation_display()
|
374 |
|
375 |
def process_final_text(self, text):
|
376 |
"""Process final transcribed text with speaker embedding"""
|
@@ -600,47 +563,112 @@ class RealtimeSpeakerDiarization:
|
|
600 |
logger.error(f"Error processing audio chunk: {e}")
|
601 |
|
602 |
|
603 |
-
#
|
604 |
-
class
|
605 |
def __init__(self, diarization_system):
|
606 |
super().__init__()
|
607 |
self.diarization_system = diarization_system
|
|
|
|
|
608 |
|
609 |
-
def
|
610 |
-
"""
|
611 |
-
|
612 |
-
|
613 |
-
|
|
|
|
|
|
|
|
|
|
|
614 |
try:
|
|
|
|
|
|
|
615 |
# Extract audio data
|
616 |
-
|
617 |
|
618 |
-
#
|
619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
except Exception as e:
|
621 |
-
logger.error(f"Error
|
622 |
-
|
623 |
-
def copy(self):
|
624 |
-
"""Return a fresh handler instance"""
|
625 |
-
return DiarizationAudioHandler(self.diarization_system)
|
626 |
|
627 |
-
def
|
628 |
-
"""
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
|
636 |
-
# Global
|
637 |
diarization_system = RealtimeSpeakerDiarization()
|
638 |
|
|
|
|
|
639 |
def initialize_system():
|
640 |
"""Initialize the diarization system"""
|
|
|
641 |
try:
|
642 |
success = diarization_system.initialize_models()
|
643 |
if success:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
return "✅ System initialized successfully!"
|
645 |
else:
|
646 |
return "❌ Failed to initialize system. Check logs for details."
|
@@ -656,6 +684,10 @@ def start_recording():
|
|
656 |
except Exception as e:
|
657 |
return f"❌ Failed to start recording: {str(e)}"
|
658 |
|
|
|
|
|
|
|
|
|
659 |
def stop_recording():
|
660 |
"""Stop recording and transcription"""
|
661 |
try:
|
@@ -694,56 +726,238 @@ def get_status():
|
|
694 |
except Exception as e:
|
695 |
return f"Error getting status: {str(e)}"
|
696 |
|
697 |
-
# Create
|
698 |
-
def
|
699 |
-
""
|
700 |
-
|
701 |
-
|
702 |
-
diarization_system.process_audio_chunk(audio_data[1], audio_data[0])
|
703 |
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
724 |
|
725 |
# Main execution
|
726 |
if __name__ == "__main__":
|
727 |
import argparse
|
728 |
-
import os
|
729 |
|
730 |
parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
|
731 |
-
parser.add_argument("--mode", choices=["
|
732 |
-
help="Run mode:
|
733 |
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
|
734 |
-
parser.add_argument("--port", type=int, default=
|
735 |
-
help="Port to bind to")
|
736 |
parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
|
737 |
|
738 |
args = parser.parse_args()
|
739 |
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
if args.mode == "ui":
|
745 |
-
# Launch the FastRTC built-in UI
|
746 |
-
stream.ui.launch(
|
747 |
server_name=args.host,
|
748 |
server_port=args.port,
|
749 |
share=True,
|
@@ -752,8 +966,6 @@ if __name__ == "__main__":
|
|
752 |
|
753 |
elif args.mode == "api":
|
754 |
# Run FastAPI only
|
755 |
-
app = FastAPI()
|
756 |
-
stream.mount(app)
|
757 |
uvicorn.run(
|
758 |
app,
|
759 |
host=args.host,
|
@@ -762,12 +974,20 @@ if __name__ == "__main__":
|
|
762 |
)
|
763 |
|
764 |
elif args.mode == "both":
|
765 |
-
# Run both
|
|
|
766 |
import threading
|
767 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
def run_fastapi():
|
769 |
-
app = FastAPI()
|
770 |
-
stream.mount(app)
|
771 |
uvicorn.run(
|
772 |
app,
|
773 |
host=args.host,
|
@@ -779,10 +999,5 @@ if __name__ == "__main__":
|
|
779 |
api_thread = threading.Thread(target=run_fastapi, daemon=True)
|
780 |
api_thread.start()
|
781 |
|
782 |
-
# Start
|
783 |
-
|
784 |
-
server_name=args.host,
|
785 |
-
server_port=args.port,
|
786 |
-
share=True,
|
787 |
-
show_error=True
|
788 |
-
)
|
|
|
10 |
from scipy.spatial.distance import cosine
|
11 |
from RealtimeSTT import AudioToTextRecorder
|
12 |
from fastapi import FastAPI, APIRouter
|
13 |
+
from fastrtc import Stream, AsyncStreamHandler, WebRTC
|
14 |
import json
|
15 |
import asyncio
|
16 |
import uvicorn
|
|
|
329 |
except Exception as e:
|
330 |
logger.error(f"Model initialization error: {e}")
|
331 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
def live_text_detected(self, text):
|
334 |
"""Callback for real-time transcription updates"""
|
335 |
with self.transcription_lock:
|
336 |
self.last_transcription = text.strip()
|
|
|
|
|
|
|
337 |
|
338 |
def process_final_text(self, text):
|
339 |
"""Process final transcribed text with speaker embedding"""
|
|
|
563 |
logger.error(f"Error processing audio chunk: {e}")
|
564 |
|
565 |
|
566 |
+
# FastRTC Audio Handler
|
567 |
+
class DiarizationHandler(AsyncStreamHandler):
|
568 |
def __init__(self, diarization_system):
|
569 |
super().__init__()
|
570 |
self.diarization_system = diarization_system
|
571 |
+
self.audio_buffer = []
|
572 |
+
self.buffer_size = BUFFER_SIZE
|
573 |
|
574 |
+
def copy(self):
|
575 |
+
"""Return a fresh handler for each new stream connection"""
|
576 |
+
return DiarizationHandler(self.diarization_system)
|
577 |
+
|
578 |
+
async def emit(self):
|
579 |
+
"""Not used - we only receive audio"""
|
580 |
+
return None
|
581 |
+
|
582 |
+
async def receive(self, frame):
|
583 |
+
"""Receive audio data from FastRTC"""
|
584 |
try:
|
585 |
+
if not self.diarization_system.is_running:
|
586 |
+
return
|
587 |
+
|
588 |
# Extract audio data
|
589 |
+
audio_data = getattr(frame, 'data', frame)
|
590 |
|
591 |
+
# Check if this is a tuple (sample_rate, audio_array)
|
592 |
+
if isinstance(audio_data, tuple) and len(audio_data) >= 2:
|
593 |
+
sample_rate, audio_array = audio_data
|
594 |
+
else:
|
595 |
+
# If not a tuple, assume it's raw audio bytes/array
|
596 |
+
sample_rate = SAMPLE_RATE # Use default sample rate
|
597 |
+
|
598 |
+
# Convert to numpy array
|
599 |
+
if isinstance(audio_data, bytes):
|
600 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
601 |
+
elif isinstance(audio_data, (list, tuple)):
|
602 |
+
audio_array = np.array(audio_data, dtype=np.float32)
|
603 |
+
else:
|
604 |
+
audio_array = np.array(audio_data, dtype=np.float32)
|
605 |
+
|
606 |
+
# Ensure 1D
|
607 |
+
if len(audio_array.shape) > 1:
|
608 |
+
audio_array = audio_array.flatten()
|
609 |
+
|
610 |
+
# Send audio to recorder for live transcription
|
611 |
+
if self.diarization_system.recorder:
|
612 |
+
try:
|
613 |
+
self.diarization_system.recorder.feed_audio(audio_array)
|
614 |
+
logger.info("Fed audio to recorder")
|
615 |
+
except Exception as e:
|
616 |
+
logger.error(f"Error feeding audio to recorder: {e}")
|
617 |
+
|
618 |
+
# Buffer audio chunks
|
619 |
+
self.audio_buffer.extend(audio_array)
|
620 |
+
|
621 |
+
# Process in chunks
|
622 |
+
while len(self.audio_buffer) >= self.buffer_size:
|
623 |
+
chunk = np.array(self.audio_buffer[:self.buffer_size])
|
624 |
+
self.audio_buffer = self.audio_buffer[self.buffer_size:]
|
625 |
+
|
626 |
+
# Process asynchronously
|
627 |
+
await self.process_audio_async(chunk)
|
628 |
+
|
629 |
except Exception as e:
|
630 |
+
logger.error(f"Error in FastRTC receive: {e}")
|
|
|
|
|
|
|
|
|
631 |
|
632 |
+
async def process_audio_async(self, audio_data):
|
633 |
+
"""Process audio data asynchronously"""
|
634 |
+
try:
|
635 |
+
loop = asyncio.get_event_loop()
|
636 |
+
await loop.run_in_executor(
|
637 |
+
None,
|
638 |
+
self.diarization_system.process_audio_chunk,
|
639 |
+
audio_data,
|
640 |
+
SAMPLE_RATE
|
641 |
+
)
|
642 |
+
except Exception as e:
|
643 |
+
logger.error(f"Error in async audio processing: {e}")
|
644 |
+
|
645 |
+
async def start_up(self):
|
646 |
+
logger.info("DiarizationHandler started")
|
647 |
+
|
648 |
+
async def shutdown(self):
|
649 |
+
logger.info("DiarizationHandler shutdown")
|
650 |
|
651 |
|
652 |
+
# Global instances
|
653 |
diarization_system = RealtimeSpeakerDiarization()
|
654 |
|
655 |
+
# We'll initialize the stream properly in initialize_system()
|
656 |
+
|
657 |
def initialize_system():
|
658 |
"""Initialize the diarization system"""
|
659 |
+
global stream
|
660 |
try:
|
661 |
success = diarization_system.initialize_models()
|
662 |
if success:
|
663 |
+
# Create a DiarizationHandler linked to our system
|
664 |
+
handler = DiarizationHandler(diarization_system)
|
665 |
+
# Update the Stream's handler
|
666 |
+
stream = Stream(
|
667 |
+
handler=handler,
|
668 |
+
modality="audio",
|
669 |
+
mode="send-receive",
|
670 |
+
stream_name="audio_stream" # Match the stream_name in WebRTC component
|
671 |
+
)
|
672 |
return "✅ System initialized successfully!"
|
673 |
else:
|
674 |
return "❌ Failed to initialize system. Check logs for details."
|
|
|
684 |
except Exception as e:
|
685 |
return f"❌ Failed to start recording: {str(e)}"
|
686 |
|
687 |
+
def on_start():
|
688 |
+
result = start_recording()
|
689 |
+
return result, gr.update(interactive=False), gr.update(interactive=True)
|
690 |
+
|
691 |
def stop_recording():
|
692 |
"""Stop recording and transcription"""
|
693 |
try:
|
|
|
726 |
except Exception as e:
|
727 |
return f"Error getting status: {str(e)}"
|
728 |
|
729 |
+
# Create Gradio interface
|
730 |
+
def create_interface():
|
731 |
+
with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
|
732 |
+
gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
|
733 |
+
gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
|
|
|
734 |
|
735 |
+
with gr.Row():
|
736 |
+
with gr.Column(scale=2):
|
737 |
+
# Replace standard Audio with WebRTC component
|
738 |
+
audio_component = WebRTC(
|
739 |
+
label="Audio Input",
|
740 |
+
stream_name="audio_stream",
|
741 |
+
modality="audio",
|
742 |
+
mode="send-receive"
|
743 |
+
)
|
744 |
+
|
745 |
+
# Conversation display
|
746 |
+
conversation_output = gr.HTML(
|
747 |
+
value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
|
748 |
+
label="Live Conversation"
|
749 |
+
)
|
750 |
+
|
751 |
+
# Control buttons
|
752 |
+
with gr.Row():
|
753 |
+
init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
|
754 |
+
start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
|
755 |
+
stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
|
756 |
+
clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
|
757 |
+
|
758 |
+
# Status display
|
759 |
+
status_output = gr.Textbox(
|
760 |
+
label="System Status",
|
761 |
+
value="Ready to initialize...",
|
762 |
+
lines=8,
|
763 |
+
interactive=False
|
764 |
+
)
|
765 |
+
|
766 |
+
with gr.Column(scale=1):
|
767 |
+
# Settings
|
768 |
+
gr.Markdown("## ⚙️ Settings")
|
769 |
+
|
770 |
+
threshold_slider = gr.Slider(
|
771 |
+
minimum=0.3,
|
772 |
+
maximum=0.9,
|
773 |
+
step=0.05,
|
774 |
+
value=DEFAULT_CHANGE_THRESHOLD,
|
775 |
+
label="Speaker Change Sensitivity",
|
776 |
+
info="Lower = more sensitive"
|
777 |
+
)
|
778 |
+
|
779 |
+
max_speakers_slider = gr.Slider(
|
780 |
+
minimum=2,
|
781 |
+
maximum=ABSOLUTE_MAX_SPEAKERS,
|
782 |
+
step=1,
|
783 |
+
value=DEFAULT_MAX_SPEAKERS,
|
784 |
+
label="Maximum Speakers"
|
785 |
+
)
|
786 |
+
|
787 |
+
update_btn = gr.Button("Update Settings", variant="secondary")
|
788 |
+
|
789 |
+
# Instructions
|
790 |
+
gr.Markdown("""
|
791 |
+
## 📋 Instructions
|
792 |
+
1. **Initialize** the system (loads AI models)
|
793 |
+
2. **Start** recording
|
794 |
+
3. **Speak** - system will transcribe and identify speakers
|
795 |
+
4. **Monitor** real-time results below
|
796 |
+
|
797 |
+
## 🎨 Speaker Colors
|
798 |
+
- 🔴 Speaker 1 (Red)
|
799 |
+
- 🟢 Speaker 2 (Teal)
|
800 |
+
- 🔵 Speaker 3 (Blue)
|
801 |
+
- 🟡 Speaker 4 (Green)
|
802 |
+
- 🟣 Speaker 5 (Yellow)
|
803 |
+
- 🟤 Speaker 6 (Plum)
|
804 |
+
- 🟫 Speaker 7 (Mint)
|
805 |
+
- 🟨 Speaker 8 (Gold)
|
806 |
+
""")
|
807 |
|
808 |
+
# Event handlers
|
809 |
+
def on_initialize():
|
810 |
+
result = initialize_system()
|
811 |
+
if "✅" in result:
|
812 |
+
return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
|
813 |
+
else:
|
814 |
+
return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
|
815 |
+
|
816 |
+
def on_start():
|
817 |
+
result = start_recording()
|
818 |
+
return result, gr.update(interactive=False), gr.update(interactive=True)
|
819 |
+
|
820 |
+
def on_stop():
|
821 |
+
result = stop_recording()
|
822 |
+
return result, gr.update(interactive=True), gr.update(interactive=False)
|
823 |
+
|
824 |
+
def on_clear():
|
825 |
+
result = clear_conversation()
|
826 |
+
return result
|
827 |
+
|
828 |
+
def on_update_settings(threshold, max_speakers):
|
829 |
+
result = update_settings(threshold, int(max_speakers))
|
830 |
+
return result
|
831 |
+
|
832 |
+
def refresh_conversation():
|
833 |
+
return get_conversation()
|
834 |
+
|
835 |
+
def refresh_status():
|
836 |
+
return get_status()
|
837 |
+
|
838 |
+
# Button click handlers
|
839 |
+
init_btn.click(
|
840 |
+
fn=on_initialize,
|
841 |
+
outputs=[status_output, start_btn, stop_btn, clear_btn]
|
842 |
+
)
|
843 |
+
|
844 |
+
start_btn.click(
|
845 |
+
fn=on_start,
|
846 |
+
outputs=[status_output, start_btn, stop_btn]
|
847 |
+
)
|
848 |
+
|
849 |
+
stop_btn.click(
|
850 |
+
fn=on_stop,
|
851 |
+
outputs=[status_output, start_btn, stop_btn]
|
852 |
+
)
|
853 |
+
|
854 |
+
clear_btn.click(
|
855 |
+
fn=on_clear,
|
856 |
+
outputs=[status_output]
|
857 |
+
)
|
858 |
+
|
859 |
+
update_btn.click(
|
860 |
+
fn=on_update_settings,
|
861 |
+
inputs=[threshold_slider, max_speakers_slider],
|
862 |
+
outputs=[status_output]
|
863 |
+
)
|
864 |
+
|
865 |
+
# Auto-refresh conversation display every 1 second
|
866 |
+
conversation_timer = gr.Timer(1)
|
867 |
+
conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
|
868 |
+
|
869 |
+
# Auto-refresh status every 2 seconds
|
870 |
+
status_timer = gr.Timer(2)
|
871 |
+
status_timer.tick(refresh_status, outputs=[status_output])
|
872 |
+
|
873 |
+
return interface
|
874 |
+
|
875 |
+
|
876 |
+
# FastAPI setup for FastRTC integration
|
877 |
+
app = FastAPI()
|
878 |
+
|
879 |
+
# Create a placeholder handler - will be properly initialized later
|
880 |
+
class DefaultHandler(AsyncStreamHandler):
|
881 |
+
def __init__(self):
|
882 |
+
super().__init__()
|
883 |
+
|
884 |
+
async def receive(self, frame):
|
885 |
+
pass
|
886 |
+
|
887 |
+
async def emit(self):
|
888 |
+
return None
|
889 |
+
|
890 |
+
def copy(self):
|
891 |
+
return DefaultHandler()
|
892 |
+
|
893 |
+
async def shutdown(self):
|
894 |
+
pass
|
895 |
+
|
896 |
+
async def start_up(self):
|
897 |
+
pass
|
898 |
+
|
899 |
+
# Initialize with placeholder handler
|
900 |
+
stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
|
901 |
+
stream.mount(app)
|
902 |
+
|
903 |
+
@app.get("/")
|
904 |
+
async def root():
|
905 |
+
return {"message": "Real-time Speaker Diarization API"}
|
906 |
+
|
907 |
+
@app.get("/health")
|
908 |
+
async def health_check():
|
909 |
+
return {"status": "healthy", "system_running": diarization_system.is_running}
|
910 |
+
|
911 |
+
@app.post("/initialize")
|
912 |
+
async def api_initialize():
|
913 |
+
result = initialize_system()
|
914 |
+
return {"result": result, "success": "✅" in result}
|
915 |
+
|
916 |
+
@app.post("/start")
|
917 |
+
async def api_start():
|
918 |
+
result = start_recording()
|
919 |
+
return {"result": result, "success": "🎙️" in result}
|
920 |
+
|
921 |
+
@app.post("/stop")
|
922 |
+
async def api_stop():
|
923 |
+
result = stop_recording()
|
924 |
+
return {"result": result, "success": "⏹️" in result}
|
925 |
+
|
926 |
+
@app.post("/clear")
|
927 |
+
async def api_clear():
|
928 |
+
result = clear_conversation()
|
929 |
+
return {"result": result}
|
930 |
+
|
931 |
+
@app.get("/conversation")
|
932 |
+
async def api_get_conversation():
|
933 |
+
return {"conversation": get_conversation()}
|
934 |
+
|
935 |
+
@app.get("/status")
|
936 |
+
async def api_get_status():
|
937 |
+
return {"status": get_status()}
|
938 |
+
|
939 |
+
@app.post("/settings")
|
940 |
+
async def api_update_settings(threshold: float, max_speakers: int):
|
941 |
+
result = update_settings(threshold, max_speakers)
|
942 |
+
return {"result": result}
|
943 |
|
944 |
# Main execution
|
945 |
if __name__ == "__main__":
|
946 |
import argparse
|
|
|
947 |
|
948 |
parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
|
949 |
+
parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
|
950 |
+
help="Run mode: gradio interface, API only, or both")
|
951 |
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
|
952 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
|
|
953 |
parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
|
954 |
|
955 |
args = parser.parse_args()
|
956 |
|
957 |
+
if args.mode == "gradio":
|
958 |
+
# Run Gradio interface only
|
959 |
+
interface = create_interface()
|
960 |
+
interface.launch(
|
|
|
|
|
|
|
961 |
server_name=args.host,
|
962 |
server_port=args.port,
|
963 |
share=True,
|
|
|
966 |
|
967 |
elif args.mode == "api":
|
968 |
# Run FastAPI only
|
|
|
|
|
969 |
uvicorn.run(
|
970 |
app,
|
971 |
host=args.host,
|
|
|
974 |
)
|
975 |
|
976 |
elif args.mode == "both":
|
977 |
+
# Run both Gradio and FastAPI
|
978 |
+
import multiprocessing
|
979 |
import threading
|
980 |
|
981 |
+
def run_gradio():
|
982 |
+
interface = create_interface()
|
983 |
+
interface.launch(
|
984 |
+
server_name=args.host,
|
985 |
+
server_port=args.port,
|
986 |
+
share=True,
|
987 |
+
show_error=True
|
988 |
+
)
|
989 |
+
|
990 |
def run_fastapi():
|
|
|
|
|
991 |
uvicorn.run(
|
992 |
app,
|
993 |
host=args.host,
|
|
|
999 |
api_thread = threading.Thread(target=run_fastapi, daemon=True)
|
1000 |
api_thread.start()
|
1001 |
|
1002 |
+
# Start Gradio in main thread
|
1003 |
+
run_gradio()
|
|
|
|
|
|
|
|
|
|