Saiyaswanth007 commited on
Commit
f4275bf
·
1 Parent(s): bd39e10

Transcript

Browse files
Files changed (4) hide show
  1. inference.py +19 -0
  2. shared.py +57 -8
  3. test_websocket.py +1 -0
  4. ui.py +10 -2
inference.py CHANGED
@@ -10,6 +10,16 @@ import time
10
  from typing import Set, Dict, Any
11
  import traceback
12
 
 
 
 
 
 
 
 
 
 
 
13
  # Set up logging
14
  logging.basicConfig(
15
  level=logging.INFO,
@@ -185,6 +195,15 @@ async def shutdown_event():
185
  try:
186
  diart.stop_recording()
187
  logger.info("Recording stopped")
 
 
 
 
 
 
 
 
 
188
  except Exception as e:
189
  logger.error(f"Error stopping recording: {e}")
190
 
 
10
  from typing import Set, Dict, Any
11
  import traceback
12
 
13
+ # Check for RealtimeSTT and install if needed
14
+ try:
15
+ from RealtimeSTT import AudioToTextRecorder
16
+ except ImportError:
17
+ import subprocess
18
+ import sys
19
+ print("Installing RealtimeSTT dependency...")
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "RealtimeSTT"])
21
+ from RealtimeSTT import AudioToTextRecorder
22
+
23
  # Set up logging
24
  logging.basicConfig(
25
  level=logging.INFO,
 
195
  try:
196
  diart.stop_recording()
197
  logger.info("Recording stopped")
198
+
199
+ # Shutdown RealtimeSTT properly if available
200
+ if hasattr(diart, 'recorder') and diart.recorder:
201
+ try:
202
+ diart.recorder.shutdown()
203
+ logger.info("Transcription model shut down")
204
+ except Exception as e:
205
+ logger.error(f"Error shutting down transcription model: {e}")
206
+
207
  except Exception as e:
208
  logger.error(f"Error stopping recording: {e}")
209
 
shared.py CHANGED
@@ -8,6 +8,9 @@ import torchaudio
8
  from scipy.spatial.distance import cosine
9
  from scipy.signal import resample
10
  import logging
 
 
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -64,12 +67,26 @@ class SpeechBrainEncoder:
64
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
65
  os.makedirs(self.cache_dir, exist_ok=True)
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  def load_model(self):
68
  """Load the ECAPA-TDNN model"""
69
  try:
70
  # Import SpeechBrain
71
  from speechbrain.pretrained import EncoderClassifier
72
 
 
 
 
73
  # Load the pre-trained model
74
  self.model = EncoderClassifier.from_hparams(
75
  source="speechbrain/spkrec-ecapa-voxceleb",
@@ -286,7 +303,7 @@ class RealtimeSpeakerDiarization:
286
  self.encoder = None
287
  self.audio_processor = None
288
  self.speaker_detector = None
289
- self.recorder = None
290
  self.sentence_queue = queue.Queue()
291
  self.full_sentences = []
292
  self.sentence_speakers = []
@@ -314,6 +331,25 @@ class RealtimeSpeakerDiarization:
314
  change_threshold=self.change_threshold,
315
  max_speakers=self.max_speakers
316
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  logger.info("Models initialized successfully!")
318
  return True
319
  else:
@@ -416,6 +452,11 @@ class RealtimeSpeakerDiarization:
416
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
417
  self.sentence_thread.start()
418
 
 
 
 
 
 
419
  return "Recording started successfully!"
420
 
421
  except Exception as e:
@@ -425,6 +466,15 @@ class RealtimeSpeakerDiarization:
425
  def stop_recording(self):
426
  """Stop the recording process"""
427
  self.is_running = False
 
 
 
 
 
 
 
 
 
428
  return "Recording stopped!"
429
 
430
  def clear_conversation(self):
@@ -573,6 +623,12 @@ class RealtimeSpeakerDiarization:
573
  # Add to audio processor buffer for speaker detection
574
  self.audio_processor.add_audio_chunk(audio_data)
575
 
 
 
 
 
 
 
576
  # Periodically extract embeddings for speaker detection
577
  embedding = None
578
  speaker_id = self.speaker_detector.current_speaker
@@ -582,12 +638,6 @@ class RealtimeSpeakerDiarization:
582
  embedding = self.audio_processor.extract_embedding_from_buffer()
583
  if embedding is not None:
584
  speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
585
-
586
- # Add a simulated sentence for demo purposes
587
- if similarity < 0.5:
588
- with self.transcription_lock:
589
- self.full_sentences.append((f"[Audio segment {self.speaker_detector.segment_counter}]", speaker_id))
590
- self.update_conversation_display()
591
 
592
  # Return processing result
593
  return {
@@ -595,7 +645,6 @@ class RealtimeSpeakerDiarization:
595
  "buffer_size": len(self.audio_processor.audio_buffer),
596
  "speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
597
  "similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
598
- "latest_sentence": f"[Audio segment {self.speaker_detector.segment_counter}]" if similarity < 0.5 else None,
599
  "conversation_html": self.current_conversation
600
  }
601
 
 
8
  from scipy.spatial.distance import cosine
9
  from scipy.signal import resample
10
  import logging
11
+ import urllib.request
12
+ # Import RealtimeSTT for transcription
13
+ from RealtimeSTT import AudioToTextRecorder
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
 
67
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
68
  os.makedirs(self.cache_dir, exist_ok=True)
69
 
70
+ def _download_model(self):
71
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
72
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
73
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
74
+
75
+ if not os.path.exists(model_path):
76
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
77
+ urllib.request.urlretrieve(model_url, model_path)
78
+
79
+ return model_path
80
+
81
  def load_model(self):
82
  """Load the ECAPA-TDNN model"""
83
  try:
84
  # Import SpeechBrain
85
  from speechbrain.pretrained import EncoderClassifier
86
 
87
+ # Get model path
88
+ model_path = self._download_model()
89
+
90
  # Load the pre-trained model
91
  self.model = EncoderClassifier.from_hparams(
92
  source="speechbrain/spkrec-ecapa-voxceleb",
 
303
  self.encoder = None
304
  self.audio_processor = None
305
  self.speaker_detector = None
306
+ self.recorder = None # RealtimeSTT recorder
307
  self.sentence_queue = queue.Queue()
308
  self.full_sentences = []
309
  self.sentence_speakers = []
 
331
  change_threshold=self.change_threshold,
332
  max_speakers=self.max_speakers
333
  )
334
+
335
+ # Initialize RealtimeSTT transcription model
336
+ self.recorder = AudioToTextRecorder(
337
+ spinner=False,
338
+ use_microphone=False,
339
+ model=FINAL_TRANSCRIPTION_MODEL,
340
+ language=TRANSCRIPTION_LANGUAGE,
341
+ silero_sensitivity=SILERO_SENSITIVITY,
342
+ webrtc_sensitivity=WEBRTC_SENSITIVITY,
343
+ post_speech_silence_duration=0.7,
344
+ min_length_of_recording=MIN_LENGTH_OF_RECORDING,
345
+ pre_recording_buffer_duration=PRE_RECORDING_BUFFER_DURATION,
346
+ enable_realtime_transcription=True,
347
+ realtime_processing_pause=0,
348
+ realtime_model_type=REALTIME_TRANSCRIPTION_MODEL,
349
+ on_realtime_transcription_stabilized=self.live_text_detected,
350
+ on_recording_complete=self.process_final_text
351
+ )
352
+
353
  logger.info("Models initialized successfully!")
354
  return True
355
  else:
 
452
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
453
  self.sentence_thread.start()
454
 
455
+ # Start the RealtimeSTT recorder if not already started
456
+ if self.recorder and not getattr(self.recorder, '_is_running', False):
457
+ self.recorder.start()
458
+ logger.info("RealtimeSTT recorder started")
459
+
460
  return "Recording started successfully!"
461
 
462
  except Exception as e:
 
466
  def stop_recording(self):
467
  """Stop the recording process"""
468
  self.is_running = False
469
+
470
+ # Stop the RealtimeSTT recorder
471
+ if self.recorder:
472
+ try:
473
+ self.recorder.stop()
474
+ logger.info("RealtimeSTT recorder stopped")
475
+ except Exception as e:
476
+ logger.error(f"Error stopping recorder: {e}")
477
+
478
  return "Recording stopped!"
479
 
480
  def clear_conversation(self):
 
623
  # Add to audio processor buffer for speaker detection
624
  self.audio_processor.add_audio_chunk(audio_data)
625
 
626
+ # Feed to RealtimeSTT for transcription
627
+ if self.recorder:
628
+ # Convert to int16 for RealtimeSTT
629
+ audio_int16 = (audio_data * 32768).astype(np.int16)
630
+ self.recorder.feed_audio(audio_int16.tobytes())
631
+
632
  # Periodically extract embeddings for speaker detection
633
  embedding = None
634
  speaker_id = self.speaker_detector.current_speaker
 
638
  embedding = self.audio_processor.extract_embedding_from_buffer()
639
  if embedding is not None:
640
  speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
 
 
 
 
 
 
641
 
642
  # Return processing result
643
  return {
 
645
  "buffer_size": len(self.audio_processor.audio_buffer),
646
  "speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
647
  "similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
 
648
  "conversation_html": self.current_conversation
649
  }
650
 
test_websocket.py CHANGED
@@ -15,6 +15,7 @@ async def test_ws():
15
  audio = (np.random.randn(3200) * 3000).astype(np.int16)
16
  await websocket.send(audio.tobytes())
17
  print(f"Sent audio chunk {i+1}/20")
 
18
 
19
  try:
20
  while True:
 
15
  audio = (np.random.randn(3200) * 3000).astype(np.int16)
16
  await websocket.send(audio.tobytes())
17
  print(f"Sent audio chunk {i+1}/20")
18
+ await asyncio.sleep(0.05)
19
 
20
  try:
21
  while True:
ui.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from fastapi import FastAPI
3
- from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS
4
  print(gr.__version__)
5
  # Connection configuration (separate signaling server from model server)
6
  # These will be replaced at deployment time with the correct URLs
@@ -23,7 +23,10 @@ def build_ui():
23
 
24
  # Header and description
25
  gr.Markdown("# 🎤 Live Speaker Diarization")
26
- gr.Markdown("Real-time speech recognition with automatic speaker identification")
 
 
 
27
 
28
  # Status indicator
29
  connection_status = gr.HTML(
@@ -459,6 +462,11 @@ def build_ui():
459
  <li>Threshold: ${threshold}</li>
460
  <li>Max Speakers: ${maxSpeakers}</li>
461
  </ul>
 
 
 
 
 
462
  `;
463
  }
464
  });
 
1
  import gradio as gr
2
  from fastapi import FastAPI
3
+ from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL
4
  print(gr.__version__)
5
  # Connection configuration (separate signaling server from model server)
6
  # These will be replaced at deployment time with the correct URLs
 
23
 
24
  # Header and description
25
  gr.Markdown("# 🎤 Live Speaker Diarization")
26
+ gr.Markdown(f"Real-time speech recognition with automatic speaker identification")
27
+
28
+ # Add transcription model info
29
+ gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}")
30
 
31
  # Status indicator
32
  connection_status = gr.HTML(
 
462
  <li>Threshold: ${threshold}</li>
463
  <li>Max Speakers: ${maxSpeakers}</li>
464
  </ul>
465
+ <p>Transcription Models:</p>
466
+ <ul>
467
+ <li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li>
468
+ <li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li>
469
+ </ul>
470
  `;
471
  }
472
  });