Spaces:
Running
Running
Commit
·
f4275bf
1
Parent(s):
bd39e10
Transcript
Browse files- inference.py +19 -0
- shared.py +57 -8
- test_websocket.py +1 -0
- ui.py +10 -2
inference.py
CHANGED
@@ -10,6 +10,16 @@ import time
|
|
10 |
from typing import Set, Dict, Any
|
11 |
import traceback
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Set up logging
|
14 |
logging.basicConfig(
|
15 |
level=logging.INFO,
|
@@ -185,6 +195,15 @@ async def shutdown_event():
|
|
185 |
try:
|
186 |
diart.stop_recording()
|
187 |
logger.info("Recording stopped")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
except Exception as e:
|
189 |
logger.error(f"Error stopping recording: {e}")
|
190 |
|
|
|
10 |
from typing import Set, Dict, Any
|
11 |
import traceback
|
12 |
|
13 |
+
# Check for RealtimeSTT and install if needed
|
14 |
+
try:
|
15 |
+
from RealtimeSTT import AudioToTextRecorder
|
16 |
+
except ImportError:
|
17 |
+
import subprocess
|
18 |
+
import sys
|
19 |
+
print("Installing RealtimeSTT dependency...")
|
20 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "RealtimeSTT"])
|
21 |
+
from RealtimeSTT import AudioToTextRecorder
|
22 |
+
|
23 |
# Set up logging
|
24 |
logging.basicConfig(
|
25 |
level=logging.INFO,
|
|
|
195 |
try:
|
196 |
diart.stop_recording()
|
197 |
logger.info("Recording stopped")
|
198 |
+
|
199 |
+
# Shutdown RealtimeSTT properly if available
|
200 |
+
if hasattr(diart, 'recorder') and diart.recorder:
|
201 |
+
try:
|
202 |
+
diart.recorder.shutdown()
|
203 |
+
logger.info("Transcription model shut down")
|
204 |
+
except Exception as e:
|
205 |
+
logger.error(f"Error shutting down transcription model: {e}")
|
206 |
+
|
207 |
except Exception as e:
|
208 |
logger.error(f"Error stopping recording: {e}")
|
209 |
|
shared.py
CHANGED
@@ -8,6 +8,9 @@ import torchaudio
|
|
8 |
from scipy.spatial.distance import cosine
|
9 |
from scipy.signal import resample
|
10 |
import logging
|
|
|
|
|
|
|
11 |
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -64,12 +67,26 @@ class SpeechBrainEncoder:
|
|
64 |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
|
65 |
os.makedirs(self.cache_dir, exist_ok=True)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def load_model(self):
|
68 |
"""Load the ECAPA-TDNN model"""
|
69 |
try:
|
70 |
# Import SpeechBrain
|
71 |
from speechbrain.pretrained import EncoderClassifier
|
72 |
|
|
|
|
|
|
|
73 |
# Load the pre-trained model
|
74 |
self.model = EncoderClassifier.from_hparams(
|
75 |
source="speechbrain/spkrec-ecapa-voxceleb",
|
@@ -286,7 +303,7 @@ class RealtimeSpeakerDiarization:
|
|
286 |
self.encoder = None
|
287 |
self.audio_processor = None
|
288 |
self.speaker_detector = None
|
289 |
-
self.recorder = None
|
290 |
self.sentence_queue = queue.Queue()
|
291 |
self.full_sentences = []
|
292 |
self.sentence_speakers = []
|
@@ -314,6 +331,25 @@ class RealtimeSpeakerDiarization:
|
|
314 |
change_threshold=self.change_threshold,
|
315 |
max_speakers=self.max_speakers
|
316 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
logger.info("Models initialized successfully!")
|
318 |
return True
|
319 |
else:
|
@@ -416,6 +452,11 @@ class RealtimeSpeakerDiarization:
|
|
416 |
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
|
417 |
self.sentence_thread.start()
|
418 |
|
|
|
|
|
|
|
|
|
|
|
419 |
return "Recording started successfully!"
|
420 |
|
421 |
except Exception as e:
|
@@ -425,6 +466,15 @@ class RealtimeSpeakerDiarization:
|
|
425 |
def stop_recording(self):
|
426 |
"""Stop the recording process"""
|
427 |
self.is_running = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
return "Recording stopped!"
|
429 |
|
430 |
def clear_conversation(self):
|
@@ -573,6 +623,12 @@ class RealtimeSpeakerDiarization:
|
|
573 |
# Add to audio processor buffer for speaker detection
|
574 |
self.audio_processor.add_audio_chunk(audio_data)
|
575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
576 |
# Periodically extract embeddings for speaker detection
|
577 |
embedding = None
|
578 |
speaker_id = self.speaker_detector.current_speaker
|
@@ -582,12 +638,6 @@ class RealtimeSpeakerDiarization:
|
|
582 |
embedding = self.audio_processor.extract_embedding_from_buffer()
|
583 |
if embedding is not None:
|
584 |
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
|
585 |
-
|
586 |
-
# Add a simulated sentence for demo purposes
|
587 |
-
if similarity < 0.5:
|
588 |
-
with self.transcription_lock:
|
589 |
-
self.full_sentences.append((f"[Audio segment {self.speaker_detector.segment_counter}]", speaker_id))
|
590 |
-
self.update_conversation_display()
|
591 |
|
592 |
# Return processing result
|
593 |
return {
|
@@ -595,7 +645,6 @@ class RealtimeSpeakerDiarization:
|
|
595 |
"buffer_size": len(self.audio_processor.audio_buffer),
|
596 |
"speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
|
597 |
"similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
|
598 |
-
"latest_sentence": f"[Audio segment {self.speaker_detector.segment_counter}]" if similarity < 0.5 else None,
|
599 |
"conversation_html": self.current_conversation
|
600 |
}
|
601 |
|
|
|
8 |
from scipy.spatial.distance import cosine
|
9 |
from scipy.signal import resample
|
10 |
import logging
|
11 |
+
import urllib.request
|
12 |
+
# Import RealtimeSTT for transcription
|
13 |
+
from RealtimeSTT import AudioToTextRecorder
|
14 |
|
15 |
# Set up logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
|
|
67 |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
|
68 |
os.makedirs(self.cache_dir, exist_ok=True)
|
69 |
|
70 |
+
def _download_model(self):
|
71 |
+
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
|
72 |
+
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
|
73 |
+
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
|
74 |
+
|
75 |
+
if not os.path.exists(model_path):
|
76 |
+
print(f"Downloading ECAPA-TDNN model to {model_path}...")
|
77 |
+
urllib.request.urlretrieve(model_url, model_path)
|
78 |
+
|
79 |
+
return model_path
|
80 |
+
|
81 |
def load_model(self):
|
82 |
"""Load the ECAPA-TDNN model"""
|
83 |
try:
|
84 |
# Import SpeechBrain
|
85 |
from speechbrain.pretrained import EncoderClassifier
|
86 |
|
87 |
+
# Get model path
|
88 |
+
model_path = self._download_model()
|
89 |
+
|
90 |
# Load the pre-trained model
|
91 |
self.model = EncoderClassifier.from_hparams(
|
92 |
source="speechbrain/spkrec-ecapa-voxceleb",
|
|
|
303 |
self.encoder = None
|
304 |
self.audio_processor = None
|
305 |
self.speaker_detector = None
|
306 |
+
self.recorder = None # RealtimeSTT recorder
|
307 |
self.sentence_queue = queue.Queue()
|
308 |
self.full_sentences = []
|
309 |
self.sentence_speakers = []
|
|
|
331 |
change_threshold=self.change_threshold,
|
332 |
max_speakers=self.max_speakers
|
333 |
)
|
334 |
+
|
335 |
+
# Initialize RealtimeSTT transcription model
|
336 |
+
self.recorder = AudioToTextRecorder(
|
337 |
+
spinner=False,
|
338 |
+
use_microphone=False,
|
339 |
+
model=FINAL_TRANSCRIPTION_MODEL,
|
340 |
+
language=TRANSCRIPTION_LANGUAGE,
|
341 |
+
silero_sensitivity=SILERO_SENSITIVITY,
|
342 |
+
webrtc_sensitivity=WEBRTC_SENSITIVITY,
|
343 |
+
post_speech_silence_duration=0.7,
|
344 |
+
min_length_of_recording=MIN_LENGTH_OF_RECORDING,
|
345 |
+
pre_recording_buffer_duration=PRE_RECORDING_BUFFER_DURATION,
|
346 |
+
enable_realtime_transcription=True,
|
347 |
+
realtime_processing_pause=0,
|
348 |
+
realtime_model_type=REALTIME_TRANSCRIPTION_MODEL,
|
349 |
+
on_realtime_transcription_stabilized=self.live_text_detected,
|
350 |
+
on_recording_complete=self.process_final_text
|
351 |
+
)
|
352 |
+
|
353 |
logger.info("Models initialized successfully!")
|
354 |
return True
|
355 |
else:
|
|
|
452 |
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
|
453 |
self.sentence_thread.start()
|
454 |
|
455 |
+
# Start the RealtimeSTT recorder if not already started
|
456 |
+
if self.recorder and not getattr(self.recorder, '_is_running', False):
|
457 |
+
self.recorder.start()
|
458 |
+
logger.info("RealtimeSTT recorder started")
|
459 |
+
|
460 |
return "Recording started successfully!"
|
461 |
|
462 |
except Exception as e:
|
|
|
466 |
def stop_recording(self):
|
467 |
"""Stop the recording process"""
|
468 |
self.is_running = False
|
469 |
+
|
470 |
+
# Stop the RealtimeSTT recorder
|
471 |
+
if self.recorder:
|
472 |
+
try:
|
473 |
+
self.recorder.stop()
|
474 |
+
logger.info("RealtimeSTT recorder stopped")
|
475 |
+
except Exception as e:
|
476 |
+
logger.error(f"Error stopping recorder: {e}")
|
477 |
+
|
478 |
return "Recording stopped!"
|
479 |
|
480 |
def clear_conversation(self):
|
|
|
623 |
# Add to audio processor buffer for speaker detection
|
624 |
self.audio_processor.add_audio_chunk(audio_data)
|
625 |
|
626 |
+
# Feed to RealtimeSTT for transcription
|
627 |
+
if self.recorder:
|
628 |
+
# Convert to int16 for RealtimeSTT
|
629 |
+
audio_int16 = (audio_data * 32768).astype(np.int16)
|
630 |
+
self.recorder.feed_audio(audio_int16.tobytes())
|
631 |
+
|
632 |
# Periodically extract embeddings for speaker detection
|
633 |
embedding = None
|
634 |
speaker_id = self.speaker_detector.current_speaker
|
|
|
638 |
embedding = self.audio_processor.extract_embedding_from_buffer()
|
639 |
if embedding is not None:
|
640 |
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
|
642 |
# Return processing result
|
643 |
return {
|
|
|
645 |
"buffer_size": len(self.audio_processor.audio_buffer),
|
646 |
"speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
|
647 |
"similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
|
|
|
648 |
"conversation_html": self.current_conversation
|
649 |
}
|
650 |
|
test_websocket.py
CHANGED
@@ -15,6 +15,7 @@ async def test_ws():
|
|
15 |
audio = (np.random.randn(3200) * 3000).astype(np.int16)
|
16 |
await websocket.send(audio.tobytes())
|
17 |
print(f"Sent audio chunk {i+1}/20")
|
|
|
18 |
|
19 |
try:
|
20 |
while True:
|
|
|
15 |
audio = (np.random.randn(3200) * 3000).astype(np.int16)
|
16 |
await websocket.send(audio.tobytes())
|
17 |
print(f"Sent audio chunk {i+1}/20")
|
18 |
+
await asyncio.sleep(0.05)
|
19 |
|
20 |
try:
|
21 |
while True:
|
ui.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from fastapi import FastAPI
|
3 |
-
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS
|
4 |
print(gr.__version__)
|
5 |
# Connection configuration (separate signaling server from model server)
|
6 |
# These will be replaced at deployment time with the correct URLs
|
@@ -23,7 +23,10 @@ def build_ui():
|
|
23 |
|
24 |
# Header and description
|
25 |
gr.Markdown("# 🎤 Live Speaker Diarization")
|
26 |
-
gr.Markdown("Real-time speech recognition with automatic speaker identification")
|
|
|
|
|
|
|
27 |
|
28 |
# Status indicator
|
29 |
connection_status = gr.HTML(
|
@@ -459,6 +462,11 @@ def build_ui():
|
|
459 |
<li>Threshold: ${threshold}</li>
|
460 |
<li>Max Speakers: ${maxSpeakers}</li>
|
461 |
</ul>
|
|
|
|
|
|
|
|
|
|
|
462 |
`;
|
463 |
}
|
464 |
});
|
|
|
1 |
import gradio as gr
|
2 |
from fastapi import FastAPI
|
3 |
+
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL
|
4 |
print(gr.__version__)
|
5 |
# Connection configuration (separate signaling server from model server)
|
6 |
# These will be replaced at deployment time with the correct URLs
|
|
|
23 |
|
24 |
# Header and description
|
25 |
gr.Markdown("# 🎤 Live Speaker Diarization")
|
26 |
+
gr.Markdown(f"Real-time speech recognition with automatic speaker identification")
|
27 |
+
|
28 |
+
# Add transcription model info
|
29 |
+
gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}")
|
30 |
|
31 |
# Status indicator
|
32 |
connection_status = gr.HTML(
|
|
|
462 |
<li>Threshold: ${threshold}</li>
|
463 |
<li>Max Speakers: ${maxSpeakers}</li>
|
464 |
</ul>
|
465 |
+
<p>Transcription Models:</p>
|
466 |
+
<ul>
|
467 |
+
<li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li>
|
468 |
+
<li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li>
|
469 |
+
</ul>
|
470 |
`;
|
471 |
}
|
472 |
});
|