Spaces:
Sleeping
Sleeping
Commit
·
2809642
1
Parent(s):
f541218
Check point 4
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
import urllib.request
|
9 |
import torchaudio
|
10 |
from scipy.spatial.distance import cosine
|
|
|
11 |
from RealtimeSTT import AudioToTextRecorder
|
12 |
from fastapi import FastAPI, APIRouter
|
13 |
from fastrtc import Stream, AsyncStreamHandler
|
@@ -419,7 +420,7 @@ class RealtimeSpeakerDiarization:
|
|
419 |
# Setup recorder configuration
|
420 |
recorder_config = {
|
421 |
'spinner': False,
|
422 |
-
'use_microphone': False, #
|
423 |
'model': FINAL_TRANSCRIPTION_MODEL,
|
424 |
'language': TRANSCRIPTION_LANGUAGE,
|
425 |
'silero_sensitivity': SILERO_SENSITIVITY,
|
@@ -456,8 +457,11 @@ class RealtimeSpeakerDiarization:
|
|
456 |
def run_transcription(self):
|
457 |
"""Run the transcription loop"""
|
458 |
try:
|
|
|
459 |
while self.is_running:
|
460 |
-
|
|
|
|
|
461 |
except Exception as e:
|
462 |
logger.error(f"Transcription error: {e}")
|
463 |
|
@@ -559,14 +563,30 @@ class RealtimeSpeakerDiarization:
|
|
559 |
if embedding is not None:
|
560 |
self.speaker_detector.add_embedding(embedding)
|
561 |
|
562 |
-
# Feed audio to
|
563 |
-
if self.recorder:
|
564 |
-
# Convert float32
|
565 |
-
|
566 |
-
|
|
|
|
|
567 |
|
568 |
except Exception as e:
|
569 |
logger.error(f"Error processing audio chunk: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
|
572 |
# FastRTC Audio Handler
|
@@ -598,7 +618,9 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
598 |
if isinstance(audio_data, bytes):
|
599 |
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
600 |
elif isinstance(audio_data, (list, tuple)):
|
601 |
-
audio_array =
|
|
|
|
|
602 |
else:
|
603 |
audio_array = np.array(audio_data, dtype=np.float32)
|
604 |
|
@@ -636,18 +658,7 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
636 |
|
637 |
# Global instances
|
638 |
diarization_system = RealtimeSpeakerDiarization()
|
639 |
-
|
640 |
-
# FastAPI setup for FastRTC integration
|
641 |
-
app = FastAPI()
|
642 |
-
|
643 |
-
# Initialize an empty handler (will be set properly in initialize_system function)
|
644 |
-
audio_handler = DiarizationHandler(diarization_system)
|
645 |
-
|
646 |
-
# Create FastRTC stream
|
647 |
-
stream = Stream(handler=audio_handler)
|
648 |
-
|
649 |
-
# Include FastRTC router in FastAPI app
|
650 |
-
app.include_router(stream.router, prefix="/stream")
|
651 |
|
652 |
def initialize_system():
|
653 |
"""Initialize the diarization system"""
|
@@ -656,8 +667,6 @@ def initialize_system():
|
|
656 |
success = diarization_system.initialize_models()
|
657 |
if success:
|
658 |
audio_handler = DiarizationHandler(diarization_system)
|
659 |
-
# Update the stream's handler
|
660 |
-
stream.handler = audio_handler
|
661 |
return "✅ System initialized successfully!"
|
662 |
else:
|
663 |
return "❌ Failed to initialize system. Check logs for details."
|
@@ -665,13 +674,6 @@ def initialize_system():
|
|
665 |
logger.error(f"Initialization error: {e}")
|
666 |
return f"❌ Initialization error: {str(e)}"
|
667 |
|
668 |
-
# Add startup event to initialize the system
|
669 |
-
@app.on_event("startup")
|
670 |
-
async def startup_event():
|
671 |
-
logger.info("Initializing diarization system on startup...")
|
672 |
-
result = initialize_system()
|
673 |
-
logger.info(f"Initialization result: {result}")
|
674 |
-
|
675 |
def start_recording():
|
676 |
"""Start recording and transcription"""
|
677 |
try:
|
@@ -857,6 +859,9 @@ def create_interface():
|
|
857 |
return interface
|
858 |
|
859 |
|
|
|
|
|
|
|
860 |
@app.get("/")
|
861 |
async def root():
|
862 |
return {"message": "Real-time Speaker Diarization API"}
|
@@ -898,6 +903,12 @@ async def api_update_settings(threshold: float, max_speakers: int):
|
|
898 |
result = update_settings(threshold, max_speakers)
|
899 |
return {"result": result}
|
900 |
|
|
|
|
|
|
|
|
|
|
|
|
|
901 |
# Main execution
|
902 |
if __name__ == "__main__":
|
903 |
import argparse
|
|
|
8 |
import urllib.request
|
9 |
import torchaudio
|
10 |
from scipy.spatial.distance import cosine
|
11 |
+
from scipy.signal import resample
|
12 |
from RealtimeSTT import AudioToTextRecorder
|
13 |
from fastapi import FastAPI, APIRouter
|
14 |
from fastrtc import Stream, AsyncStreamHandler
|
|
|
420 |
# Setup recorder configuration
|
421 |
recorder_config = {
|
422 |
'spinner': False,
|
423 |
+
'use_microphone': False, # Using FastRTC for audio input
|
424 |
'model': FINAL_TRANSCRIPTION_MODEL,
|
425 |
'language': TRANSCRIPTION_LANGUAGE,
|
426 |
'silero_sensitivity': SILERO_SENSITIVITY,
|
|
|
457 |
def run_transcription(self):
|
458 |
"""Run the transcription loop"""
|
459 |
try:
|
460 |
+
logger.info("Starting transcription thread")
|
461 |
while self.is_running:
|
462 |
+
# Just check for final text from recorder, audio is fed externally via FastRTC
|
463 |
+
text = self.recorder.text(self.process_final_text)
|
464 |
+
time.sleep(0.01) # Small sleep to prevent CPU hogging
|
465 |
except Exception as e:
|
466 |
logger.error(f"Transcription error: {e}")
|
467 |
|
|
|
563 |
if embedding is not None:
|
564 |
self.speaker_detector.add_embedding(embedding)
|
565 |
|
566 |
+
# Feed audio to RealtimeSTT recorder
|
567 |
+
if self.recorder and self.is_running:
|
568 |
+
# Convert float32 [-1.0, 1.0] to int16 for RealtimeSTT
|
569 |
+
int16_data = (audio_data * 32768.0).astype(np.int16).tobytes()
|
570 |
+
if sample_rate != 16000:
|
571 |
+
int16_data = self.resample_audio(int16_data, sample_rate, 16000)
|
572 |
+
self.recorder.feed_audio(int16_data)
|
573 |
|
574 |
except Exception as e:
|
575 |
logger.error(f"Error processing audio chunk: {e}")
|
576 |
+
|
577 |
+
def resample_audio(self, audio_bytes, from_rate, to_rate):
|
578 |
+
"""Resample audio to target sample rate"""
|
579 |
+
try:
|
580 |
+
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
581 |
+
num_samples = len(audio_np)
|
582 |
+
num_target_samples = int(num_samples * to_rate / from_rate)
|
583 |
+
|
584 |
+
resampled = resample(audio_np, num_target_samples)
|
585 |
+
|
586 |
+
return resampled.astype(np.int16).tobytes()
|
587 |
+
except Exception as e:
|
588 |
+
logger.error(f"Error resampling audio: {e}")
|
589 |
+
return audio_bytes
|
590 |
|
591 |
|
592 |
# FastRTC Audio Handler
|
|
|
618 |
if isinstance(audio_data, bytes):
|
619 |
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
620 |
elif isinstance(audio_data, (list, tuple)):
|
621 |
+
sample_rate, audio_array = audio_data
|
622 |
+
if isinstance(audio_array, (list, tuple)):
|
623 |
+
audio_array = np.array(audio_array, dtype=np.float32)
|
624 |
else:
|
625 |
audio_array = np.array(audio_data, dtype=np.float32)
|
626 |
|
|
|
658 |
|
659 |
# Global instances
|
660 |
diarization_system = RealtimeSpeakerDiarization()
|
661 |
+
audio_handler = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
|
663 |
def initialize_system():
|
664 |
"""Initialize the diarization system"""
|
|
|
667 |
success = diarization_system.initialize_models()
|
668 |
if success:
|
669 |
audio_handler = DiarizationHandler(diarization_system)
|
|
|
|
|
670 |
return "✅ System initialized successfully!"
|
671 |
else:
|
672 |
return "❌ Failed to initialize system. Check logs for details."
|
|
|
674 |
logger.error(f"Initialization error: {e}")
|
675 |
return f"❌ Initialization error: {str(e)}"
|
676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
def start_recording():
|
678 |
"""Start recording and transcription"""
|
679 |
try:
|
|
|
859 |
return interface
|
860 |
|
861 |
|
862 |
+
# FastAPI setup for FastRTC integration
|
863 |
+
app = FastAPI()
|
864 |
+
|
865 |
@app.get("/")
|
866 |
async def root():
|
867 |
return {"message": "Real-time Speaker Diarization API"}
|
|
|
903 |
result = update_settings(threshold, max_speakers)
|
904 |
return {"result": result}
|
905 |
|
906 |
+
# FastRTC Stream setup
|
907 |
+
if audio_handler:
|
908 |
+
stream = Stream(handler=audio_handler)
|
909 |
+
app.include_router(stream.router, prefix="/stream")
|
910 |
+
|
911 |
+
|
912 |
# Main execution
|
913 |
if __name__ == "__main__":
|
914 |
import argparse
|