Saiyaswanth007 commited on
Commit
fd289b1
·
1 Parent(s): 7208f76

Fixing Real-time

Browse files
Files changed (1) hide show
  1. app.py +324 -293
app.py CHANGED
@@ -5,16 +5,13 @@ import torchaudio
5
  import time
6
  import os
7
  import urllib.request
8
- from scipy.spatial.distance import cosine
9
- import threading
10
  import queue
11
- from collections import deque
12
- import asyncio
13
- from typing import Generator, Tuple, List, Optional
14
- import whisper
15
- from transformers import pipeline
16
 
17
- # Configuration parameters (keeping original models)
 
18
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
19
  FINAL_BEAM_SIZE = 5
20
  REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
@@ -31,11 +28,24 @@ EMBEDDING_HISTORY_SIZE = 5
31
  MIN_SEGMENT_DURATION = 1.0
32
  DEFAULT_MAX_SPEAKERS = 4
33
  ABSOLUTE_MAX_SPEAKERS = 10
 
 
 
34
  SAMPLE_RATE = 16000
35
- CHUNK_DURATION = 2.0 # Process audio in 2-second chunks
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Speaker labels
38
- SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)]
39
 
40
  class SpeechBrainEncoder:
41
  """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
@@ -47,11 +57,24 @@ class SpeechBrainEncoder:
47
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
48
  os.makedirs(self.cache_dir, exist_ok=True)
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  def load_model(self):
51
  """Load the ECAPA-TDNN model"""
52
  try:
53
  from speechbrain.pretrained import EncoderClassifier
54
 
 
 
55
  self.model = EncoderClassifier.from_hparams(
56
  source="speechbrain/spkrec-ecapa-voxceleb",
57
  savedir=self.cache_dir,
@@ -87,8 +110,28 @@ class SpeechBrainEncoder:
87
  return np.zeros(self.embedding_dim)
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class SpeakerChangeDetector:
91
- """Speaker change detector that supports configurable number of speakers"""
92
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
93
  self.embedding_dim = embedding_dim
94
  self.change_threshold = change_threshold
@@ -196,373 +239,361 @@ class SpeakerChangeDetector:
196
  )
197
 
198
  return self.current_speaker, similarity
199
-
200
-
201
- class AudioProcessor:
202
- """Processes audio data to extract speaker embeddings"""
203
- def __init__(self, encoder):
204
- self.encoder = encoder
205
 
206
- def extract_embedding(self, audio_data):
207
- try:
208
- # Ensure audio is float32 and normalized
209
- if audio_data.dtype != np.float32:
210
- audio_data = audio_data.astype(np.float32)
211
-
212
- # Normalize if needed
213
- if np.abs(audio_data).max() > 1.0:
214
- audio_data = audio_data / np.abs(audio_data).max()
215
-
216
- # Extract embedding using the loaded encoder
217
- embedding = self.encoder.embed_utterance(audio_data)
218
-
219
- return embedding
220
- except Exception as e:
221
- print(f"Embedding extraction error: {e}")
222
- return np.zeros(self.encoder.embedding_dim)
223
 
224
 
225
- class RealTimeSpeakerDiarization:
226
- """Main class for real-time speaker diarization with FastRTC"""
227
- def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
228
  self.encoder = None
229
  self.audio_processor = None
230
  self.speaker_detector = None
231
- self.transcription_pipeline = None
232
- self.change_threshold = change_threshold
233
- self.max_speakers = max_speakers
234
- self.transcript_history = []
235
- self.is_initialized = False
 
 
 
 
236
 
237
- # Audio processing
238
- self.audio_buffer = deque(maxlen=int(SAMPLE_RATE * 10)) # 10 second buffer
239
- self.processing_queue = queue.Queue()
240
- self.last_processed_time = 0
241
- self.current_transcript = ""
242
 
243
- def initialize(self):
244
- """Initialize the speaker diarization system"""
245
- if self.is_initialized:
246
- return True
247
-
248
  try:
249
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
250
- print(f"Initializing models on {device_str}...")
251
 
252
- # Initialize speaker encoder
253
  self.encoder = SpeechBrainEncoder(device=device_str)
254
  success = self.encoder.load_model()
255
 
256
- if not success:
257
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- # Initialize transcription pipeline
260
- self.transcription_pipeline = pipeline(
261
- "automatic-speech-recognition",
262
- model=f"openai/whisper-{REALTIME_TRANSCRIPTION_MODEL}",
263
- device=0 if torch.cuda.is_available() else -1,
264
- return_timestamps=True
265
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- self.audio_processor = AudioProcessor(self.encoder)
268
- self.speaker_detector = SpeakerChangeDetector(
269
- embedding_dim=self.encoder.embedding_dim,
270
- change_threshold=self.change_threshold,
271
- max_speakers=self.max_speakers
272
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- self.is_initialized = True
275
- print("Speaker diarization system initialized successfully!")
276
  return True
277
 
278
  except Exception as e:
279
- print(f"Initialization error: {e}")
280
  return False
281
 
282
- def update_settings(self, change_threshold, max_speakers):
283
- """Update diarization settings"""
284
- self.change_threshold = change_threshold
285
- self.max_speakers = max_speakers
 
 
 
 
 
 
 
 
286
 
287
- if self.speaker_detector:
288
- self.speaker_detector.set_change_threshold(change_threshold)
289
- self.speaker_detector.set_max_speakers(max_speakers)
 
 
 
 
 
 
290
 
291
- def process_audio_stream(self, audio_chunk, sample_rate):
292
- """Process real-time audio stream from FastRTC"""
293
- if not self.is_initialized:
294
- return self.get_current_transcript(), "System not initialized"
295
-
 
296
  try:
297
- # Convert to numpy array if needed
298
- if hasattr(audio_chunk, 'numpy'):
299
- audio_data = audio_chunk.numpy()
300
  else:
301
- audio_data = np.array(audio_chunk)
 
302
 
303
- # Handle different audio formats
304
- if len(audio_data.shape) > 1:
305
- audio_data = audio_data.mean(axis=1) # Convert to mono
 
 
 
306
 
307
- # Resample if needed
308
- if sample_rate != SAMPLE_RATE:
309
- audio_data = torchaudio.functional.resample(
310
- torch.tensor(audio_data), sample_rate, SAMPLE_RATE
311
- ).numpy()
312
 
313
- # Add to buffer
314
- self.audio_buffer.extend(audio_data)
 
 
 
 
 
315
 
316
- # Process if we have enough audio
317
- current_time = time.time()
318
- if (current_time - self.last_processed_time) >= CHUNK_DURATION:
319
- self.process_buffered_audio()
320
- self.last_processed_time = current_time
321
 
322
- return self.get_current_transcript(), f"Processing... Buffer: {len(self.audio_buffer)} samples"
323
 
324
  except Exception as e:
325
- error_msg = f"Error processing audio stream: {str(e)}"
326
- print(error_msg)
327
- return self.get_current_transcript(), error_msg
328
 
329
- def process_buffered_audio(self):
330
- """Process buffered audio for transcription and speaker diarization"""
331
- if len(self.audio_buffer) < int(SAMPLE_RATE * MIN_LENGTH_OF_RECORDING):
332
- return
333
-
334
  try:
335
- # Get audio data from buffer
336
- audio_data = np.array(list(self.audio_buffer))
337
-
338
- # Transcribe audio
339
- if len(audio_data) > 0:
340
- result = self.transcription_pipeline(
341
- audio_data,
342
- return_timestamps=True,
343
- generate_kwargs={"language": TRANSCRIPTION_LANGUAGE}
344
- )
345
-
346
- transcription = result["text"].strip()
347
-
348
- if transcription and len(transcription) > 0:
349
- # Extract speaker embedding
350
- embedding = self.audio_processor.extract_embedding(audio_data)
351
-
352
- # Detect speaker
353
- speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
354
-
355
- # Format text with speaker label
356
- speaker_label = SPEAKER_LABELS[speaker_id]
357
- formatted_text = f"{speaker_label}: {transcription}"
358
-
359
- # Add to transcript
360
- self.add_to_transcript(formatted_text)
361
-
362
- print(f"Transcribed: {formatted_text} (Similarity: {similarity:.3f})")
363
 
364
- # Clear part of the buffer to prevent memory issues
365
- if len(self.audio_buffer) > SAMPLE_RATE * 5: # Keep last 5 seconds
366
- self.audio_buffer = deque(list(self.audio_buffer)[-SAMPLE_RATE * 3:], maxlen=int(SAMPLE_RATE * 10))
367
-
368
  except Exception as e:
369
- print(f"Error in process_buffered_audio: {e}")
 
370
 
371
- def get_current_transcript(self):
372
- """Get the current transcript"""
373
- return "\n".join(self.transcript_history) if self.transcript_history else "Listening..."
374
-
375
- def add_to_transcript(self, formatted_text: str):
376
- """Add formatted text to transcript history"""
377
- self.transcript_history.append(formatted_text)
378
 
379
- # Keep only last 50 entries to prevent memory issues
380
- if len(self.transcript_history) > 50:
381
- self.transcript_history = self.transcript_history[-50:]
382
 
383
  def clear_transcript(self):
384
- """Clear transcript history and reset speaker detector"""
385
- self.transcript_history = []
386
- self.audio_buffer.clear()
 
 
 
387
  if self.speaker_detector:
388
  self.speaker_detector = SpeakerChangeDetector(
389
  embedding_dim=self.encoder.embedding_dim,
390
  change_threshold=self.change_threshold,
391
  max_speakers=self.max_speakers
392
  )
393
-
394
- def get_status(self):
395
- """Get current system status"""
396
- if not self.is_initialized:
397
- return "System not initialized"
398
-
399
- if self.speaker_detector:
400
- active_speakers = len(self.speaker_detector.active_speakers)
401
- current_speaker = self.speaker_detector.current_speaker + 1
402
- similarity = self.speaker_detector.last_similarity
403
- return f"Active: {active_speakers} speakers | Current: Speaker {current_speaker} | Similarity: {similarity:.3f}"
404
-
405
- return "Ready"
406
 
407
 
408
  # Global instance
409
- diarization_system = RealTimeSpeakerDiarization()
410
 
411
 
412
- def initialize_system():
413
- """Initialize the diarization system"""
414
- success = diarization_system.initialize()
415
- if success:
416
- return "✅ Speaker diarization system initialized successfully!"
417
- else:
418
- return "❌ Failed to initialize speaker diarization system. Please check your setup."
419
-
420
-
421
- def process_realtime_audio(audio_stream, change_threshold, max_speakers):
422
- """Process real-time audio stream from FastRTC"""
423
- if not diarization_system.is_initialized:
424
- return "Please initialize the system first.", "System not ready"
425
-
426
- # Update settings
427
- diarization_system.update_settings(change_threshold, max_speakers)
428
-
429
- if audio_stream is None:
430
- return diarization_system.get_current_transcript(), diarization_system.get_status()
431
 
432
- # Process the audio stream
433
- transcript, status = diarization_system.process_audio_stream(audio_stream, SAMPLE_RATE)
434
 
435
- return transcript, diarization_system.get_status()
436
 
437
 
438
- def clear_conversation():
439
- """Clear the conversation transcript"""
440
- diarization_system.clear_transcript()
441
- return "Conversation cleared. Listening...", "Ready"
442
 
443
 
444
- def create_gradio_interface():
445
- """Create and return the Gradio interface with FastRTC"""
446
- with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
447
- gr.Markdown("# 🎙️ Real-time Speaker Diarization with FastRTC")
448
- gr.Markdown("Speak into your microphone for real-time speaker diarization and transcription.")
449
-
450
- # Initialization section
451
- with gr.Row():
452
- init_btn = gr.Button("🚀 Initialize System", variant="primary", scale=1)
453
- init_status = gr.Textbox(label="System Status", interactive=False, scale=2)
454
 
455
- # Settings section
456
  with gr.Row():
457
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  change_threshold = gr.Slider(
459
- minimum=0.1,
460
- maximum=0.9,
461
  value=DEFAULT_CHANGE_THRESHOLD,
462
  step=0.05,
463
  label="Speaker Change Threshold",
464
  info="Lower values = more sensitive to speaker changes"
465
  )
466
- with gr.Column():
 
467
  max_speakers = gr.Slider(
468
  minimum=2,
469
  maximum=ABSOLUTE_MAX_SPEAKERS,
470
  value=DEFAULT_MAX_SPEAKERS,
471
  step=1,
472
- label="Maximum Number of Speakers",
473
  info="Maximum number of speakers to detect"
474
  )
475
-
476
- # FastRTC Audio Input
477
- with gr.Row():
478
- with gr.Column():
479
- # FastRTC component for real-time audio
480
- audio_input = gr.FastRTC(
481
- audio=True,
482
- video=False,
483
- label="🎤 Real-time Audio Input",
484
- audio_sample_rate=SAMPLE_RATE,
485
- audio_channels=1
486
- )
487
 
488
- clear_btn = gr.Button("🗑️ Clear Conversation", variant="stop")
489
-
490
- with gr.Column():
491
- current_status = gr.Textbox(
492
- label="Current Status",
493
- interactive=False,
494
- value="Click Initialize to start"
495
- )
496
-
497
- # Output section
498
- transcript_output = gr.Textbox(
499
- label="🔴 Live Transcript with Speaker Labels",
500
- lines=15,
501
- max_lines=25,
502
- interactive=False,
503
- value="Click Initialize, then start speaking...",
504
- autoscroll=True
505
- )
506
-
507
- # Event handlers
508
- init_btn.click(
509
- fn=initialize_system,
510
- outputs=[init_status]
511
- )
512
 
513
- # FastRTC stream processing
514
  audio_input.stream(
515
- fn=process_realtime_audio,
516
  inputs=[audio_input, change_threshold, max_speakers],
517
- outputs=[transcript_output, current_status],
518
- time_limit=30 # Process in 30-second chunks
519
  )
520
 
 
521
  clear_btn.click(
522
- fn=clear_conversation,
523
- outputs=[transcript_output, current_status]
524
  )
525
 
526
- # Instructions
527
- with gr.Accordion("📋 Instructions", open=False):
528
- gr.Markdown("""
529
- ## How to Use:
530
-
531
- 1. **Initialize**: Click "🚀 Initialize System" to load the AI models (this may take a moment)
532
- 2. **Allow Microphone**: Your browser will ask for microphone permission - please allow it
533
- 3. **Adjust Settings**:
534
- - **Speaker Change Threshold**:
535
- - Lower (0.3-0.5) for speakers with different voices
536
- - Higher (0.6-0.8) for speakers with similar voices
537
- - **Max Speakers**: Set expected number of speakers (2-10)
538
- 4. **Start Speaking**: The system will automatically transcribe and identify speakers
539
- 5. **View Results**: See real-time transcript with speaker labels (Speaker 1, Speaker 2, etc.)
540
- 6. **Clear**: Use "Clear Conversation" to reset and start fresh
541
-
542
- ## Features:
543
- - ✅ Real-time audio processing via FastRTC
544
- - ✅ Automatic speech recognition with Whisper
545
- - ✅ Speaker diarization with ECAPA-TDNN
546
- - ✅ Live transcript with speaker labels
547
- - ✅ Configurable sensitivity settings
548
- - ✅ Support for up to 10 speakers
549
-
550
- ## Tips:
551
- - Speak clearly and allow brief pauses between speakers
552
- - The system learns speaker characteristics over time
553
- - Better results with distinct speaker voices
554
- - Ensure good microphone quality for best performance
555
- """)
556
 
557
- return demo
558
 
559
 
560
  if __name__ == "__main__":
561
- # Create and launch the Gradio interface
562
- demo = create_gradio_interface()
563
- demo.launch(
564
- share=True,
565
  server_name="0.0.0.0",
566
  server_port=7860,
567
- show_error=True
568
  )
 
5
  import time
6
  import os
7
  import urllib.request
 
 
8
  import queue
9
+ import threading
10
+ from scipy.spatial.distance import cosine
11
+ from RealtimeSTT import AudioToTextRecorder
 
 
12
 
13
+ # Configuration parameters (kept same as original)
14
+ SILENCE_THRESHS = [0, 0.4]
15
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
16
  FINAL_BEAM_SIZE = 5
17
  REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
 
28
  MIN_SEGMENT_DURATION = 1.0
29
  DEFAULT_MAX_SPEAKERS = 4
30
  ABSOLUTE_MAX_SPEAKERS = 10
31
+
32
+ # Audio parameters
33
+ FAST_SENTENCE_END = True
34
  SAMPLE_RATE = 16000
35
+ BUFFER_SIZE = 512
36
+ CHANNELS = 1
37
+
38
+ # Speaker colors for HTML display
39
+ SPEAKER_COLORS = [
40
+ "#FFFF00", "#FF0000", "#00FF00", "#00FFFF", "#FF00FF",
41
+ "#0000FF", "#FF8000", "#00FF80", "#8000FF", "#FFFFFF"
42
+ ]
43
+
44
+ SPEAKER_COLOR_NAMES = [
45
+ "Yellow", "Red", "Green", "Cyan", "Magenta",
46
+ "Blue", "Orange", "Spring Green", "Purple", "White"
47
+ ]
48
 
 
 
49
 
50
  class SpeechBrainEncoder:
51
  """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
 
57
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
58
  os.makedirs(self.cache_dir, exist_ok=True)
59
 
60
+ def _download_model(self):
61
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
62
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
63
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
64
+
65
+ if not os.path.exists(model_path):
66
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
67
+ urllib.request.urlretrieve(model_url, model_path)
68
+
69
+ return model_path
70
+
71
  def load_model(self):
72
  """Load the ECAPA-TDNN model"""
73
  try:
74
  from speechbrain.pretrained import EncoderClassifier
75
 
76
+ model_path = self._download_model()
77
+
78
  self.model = EncoderClassifier.from_hparams(
79
  source="speechbrain/spkrec-ecapa-voxceleb",
80
  savedir=self.cache_dir,
 
110
  return np.zeros(self.embedding_dim)
111
 
112
 
113
+ class AudioProcessor:
114
+ """Processes audio data to extract speaker embeddings"""
115
+ def __init__(self, encoder):
116
+ self.encoder = encoder
117
+
118
+ def extract_embedding(self, audio_int16):
119
+ try:
120
+ float_audio = audio_int16.astype(np.float32) / 32768.0
121
+
122
+ if np.abs(float_audio).max() > 1.0:
123
+ float_audio = float_audio / np.abs(float_audio).max()
124
+
125
+ embedding = self.encoder.embed_utterance(float_audio)
126
+
127
+ return embedding
128
+ except Exception as e:
129
+ print(f"Embedding extraction error: {e}")
130
+ return np.zeros(self.encoder.embedding_dim)
131
+
132
+
133
  class SpeakerChangeDetector:
134
+ """Speaker change detector with configurable number of speakers"""
135
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
136
  self.embedding_dim = embedding_dim
137
  self.change_threshold = change_threshold
 
239
  )
240
 
241
  return self.current_speaker, similarity
 
 
 
 
 
 
242
 
243
+ def get_color_for_speaker(self, speaker_id):
244
+ """Return color for speaker ID"""
245
+ if 0 <= speaker_id < len(SPEAKER_COLORS):
246
+ return SPEAKER_COLORS[speaker_id]
247
+ return "#FFFFFF"
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
+ class RealtimeASRDiarization:
251
+ """Main class for real-time ASR with speaker diarization"""
252
+ def __init__(self):
253
  self.encoder = None
254
  self.audio_processor = None
255
  self.speaker_detector = None
256
+ self.recorder = None
257
+ self.is_recording = False
258
+ self.full_sentences = []
259
+ self.sentence_speakers = []
260
+ self.pending_sentences = []
261
+ self.last_realtime_text = ""
262
+ self.sentence_queue = queue.Queue()
263
+ self.change_threshold = DEFAULT_CHANGE_THRESHOLD
264
+ self.max_speakers = DEFAULT_MAX_SPEAKERS
265
 
266
+ # Initialize model
267
+ self.initialize_model()
 
 
 
268
 
269
+ def initialize_model(self):
270
+ """Initialize the speaker encoder model"""
 
 
 
271
  try:
272
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
273
+ print(f"Using device: {device_str}")
274
 
 
275
  self.encoder = SpeechBrainEncoder(device=device_str)
276
  success = self.encoder.load_model()
277
 
278
+ if success:
279
+ print("ECAPA-TDNN model loaded successfully!")
280
+ self.audio_processor = AudioProcessor(self.encoder)
281
+ self.speaker_detector = SpeakerChangeDetector(
282
+ embedding_dim=self.encoder.embedding_dim,
283
+ change_threshold=self.change_threshold,
284
+ max_speakers=self.max_speakers
285
+ )
286
+
287
+ # Start sentence processing thread
288
+ self.sentence_thread = threading.Thread(target=self.process_sentences, daemon=True)
289
+ self.sentence_thread.start()
290
+
291
+ else:
292
+ print("Failed to load ECAPA-TDNN model")
293
+
294
+ except Exception as e:
295
+ print(f"Model initialization error: {e}")
296
+
297
+ def process_sentences(self):
298
+ """Process sentences in background thread"""
299
+ while True:
300
+ try:
301
+ text, audio_bytes = self.sentence_queue.get(timeout=1)
302
+ self.process_sentence(text, audio_bytes)
303
+ except queue.Empty:
304
+ continue
305
+
306
+ def process_sentence(self, text, audio_bytes):
307
+ """Process a sentence with speaker diarization"""
308
+ if self.audio_processor is None or self.speaker_detector is None:
309
+ return
310
 
311
+ try:
312
+ # Convert audio data to int16
313
+ audio_int16 = np.int16(audio_bytes * 32767)
314
+
315
+ # Extract speaker embedding
316
+ speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
317
+
318
+ # Store sentence and embedding
319
+ self.full_sentences.append((text, speaker_embedding))
320
+
321
+ # Fill in any missing speaker assignments
322
+ while len(self.sentence_speakers) < len(self.full_sentences) - 1:
323
+ self.sentence_speakers.append(0)
324
+
325
+ # Detect speaker changes
326
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
327
+ self.sentence_speakers.append(speaker_id)
328
+
329
+ # Remove from pending
330
+ if text in self.pending_sentences:
331
+ self.pending_sentences.remove(text)
332
 
333
+ except Exception as e:
334
+ print(f"Error processing sentence: {e}")
335
+
336
+ def setup_recorder(self):
337
+ """Setup the audio recorder"""
338
+ try:
339
+ recorder_config = {
340
+ 'spinner': False,
341
+ 'use_microphone': False,
342
+ 'model': FINAL_TRANSCRIPTION_MODEL,
343
+ 'language': TRANSCRIPTION_LANGUAGE,
344
+ 'silero_sensitivity': SILERO_SENSITIVITY,
345
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
346
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
347
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
348
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
349
+ 'min_gap_between_recordings': 0,
350
+ 'enable_realtime_transcription': True,
351
+ 'realtime_processing_pause': 0,
352
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
353
+ 'on_realtime_transcription_update': self.live_text_detected,
354
+ 'beam_size': FINAL_BEAM_SIZE,
355
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
356
+ 'buffer_size': BUFFER_SIZE,
357
+ 'sample_rate': SAMPLE_RATE,
358
+ }
359
 
360
+ self.recorder = AudioToTextRecorder(**recorder_config)
 
361
  return True
362
 
363
  except Exception as e:
364
+ print(f"Error setting up recorder: {e}")
365
  return False
366
 
367
+ def live_text_detected(self, text):
368
+ """Handle live text detection"""
369
+ text = text.strip()
370
+ if not text:
371
+ return
372
+
373
+ sentence_delimiters = '.?!。'
374
+ prob_sentence_end = (
375
+ len(self.last_realtime_text) > 0
376
+ and text[-1] in sentence_delimiters
377
+ and self.last_realtime_text[-1] in sentence_delimiters
378
+ )
379
 
380
+ self.last_realtime_text = text
381
+
382
+ if prob_sentence_end:
383
+ if FAST_SENTENCE_END:
384
+ self.recorder.stop()
385
+ else:
386
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
387
+ else:
388
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
389
 
390
+ def process_audio_chunk(self, audio_chunk):
391
+ """Process incoming audio chunk from FastRTC"""
392
+ if self.recorder is None:
393
+ if not self.setup_recorder():
394
+ return "Failed to setup recorder"
395
+
396
  try:
397
+ # Convert audio to the format expected by the recorder
398
+ if isinstance(audio_chunk, tuple):
399
+ sample_rate, audio_data = audio_chunk
400
  else:
401
+ audio_data = audio_chunk
402
+ sample_rate = SAMPLE_RATE
403
 
404
+ # Ensure audio is in the right format
405
+ if audio_data.dtype != np.int16:
406
+ if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
407
+ audio_data = (audio_data * 32767).astype(np.int16)
408
+ else:
409
+ audio_data = audio_data.astype(np.int16)
410
 
411
+ # Convert to bytes and feed to recorder
412
+ audio_bytes = audio_data.tobytes()
413
+ self.recorder.feed_audio(audio_bytes)
 
 
414
 
415
+ # Process final text if available
416
+ def process_final_text(text):
417
+ text = text.strip()
418
+ if text:
419
+ self.pending_sentences.append(text)
420
+ audio_bytes = self.recorder.last_transcription_bytes
421
+ self.sentence_queue.put((text, audio_bytes))
422
 
423
+ # Get transcription
424
+ self.recorder.text(process_final_text)
 
 
 
425
 
426
+ return self.get_formatted_transcript()
427
 
428
  except Exception as e:
429
+ print(f"Error processing audio: {e}")
430
+ return f"Error: {e}"
 
431
 
432
+ def get_formatted_transcript(self):
433
+ """Get formatted transcript with speaker labels"""
 
 
 
434
  try:
435
+ transcript_parts = []
436
+
437
+ # Add completed sentences with speaker labels
438
+ for i, (sentence_text, _) in enumerate(self.full_sentences):
439
+ if i < len(self.sentence_speakers):
440
+ speaker_id = self.sentence_speakers[i]
441
+ speaker_label = f"Speaker {speaker_id + 1}"
442
+ transcript_parts.append(f"{speaker_label}: {sentence_text}")
443
+
444
+ # Add pending sentences
445
+ for pending in self.pending_sentences:
446
+ transcript_parts.append(f"[Processing]: {pending}")
447
+
448
+ # Add current live text
449
+ if self.last_realtime_text:
450
+ transcript_parts.append(f"[Live]: {self.last_realtime_text}")
451
+
452
+ return "\n".join(transcript_parts)
 
 
 
 
 
 
 
 
 
 
453
 
 
 
 
 
454
  except Exception as e:
455
+ print(f"Error formatting transcript: {e}")
456
+ return "Error formatting transcript"
457
 
458
+ def update_settings(self, change_threshold, max_speakers):
459
+ """Update diarization settings"""
460
+ self.change_threshold = change_threshold
461
+ self.max_speakers = max_speakers
 
 
 
462
 
463
+ if self.speaker_detector:
464
+ self.speaker_detector.set_change_threshold(change_threshold)
465
+ self.speaker_detector.set_max_speakers(max_speakers)
466
 
467
  def clear_transcript(self):
468
+ """Clear all transcript data"""
469
+ self.full_sentences = []
470
+ self.sentence_speakers = []
471
+ self.pending_sentences = []
472
+ self.last_realtime_text = ""
473
+
474
  if self.speaker_detector:
475
  self.speaker_detector = SpeakerChangeDetector(
476
  embedding_dim=self.encoder.embedding_dim,
477
  change_threshold=self.change_threshold,
478
  max_speakers=self.max_speakers
479
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
 
482
  # Global instance
483
+ asr_diarization = RealtimeASRDiarization()
484
 
485
 
486
+ def process_audio_stream(audio_chunk, change_threshold, max_speakers):
487
+ """Process audio stream and return transcript"""
488
+ # Update settings if changed
489
+ asr_diarization.update_settings(change_threshold, max_speakers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
+ # Process audio
492
+ transcript = asr_diarization.process_audio_chunk(audio_chunk)
493
 
494
+ return transcript
495
 
496
 
497
+ def clear_transcript():
498
+ """Clear the transcript"""
499
+ asr_diarization.clear_transcript()
500
+ return "Transcript cleared. Ready for new input..."
501
 
502
 
503
+ def create_interface():
504
+ """Create Gradio interface with FastRTC"""
505
+
506
+ with gr.Blocks(title="Real-time Speaker Diarization") as iface:
507
+ gr.Markdown("# Real-time ASR with Speaker Diarization")
508
+ gr.Markdown("Speak into your microphone to see real-time transcription with speaker labels!")
 
 
 
 
509
 
 
510
  with gr.Row():
511
+ with gr.Column(scale=3):
512
+ # Audio input with FastRTC
513
+ audio_input = gr.Audio(
514
+ sources=["microphone"],
515
+ streaming=True,
516
+ label="Microphone Input"
517
+ )
518
+
519
+ # Transcript output
520
+ transcript_output = gr.Textbox(
521
+ label="Live Transcript with Speaker Labels",
522
+ lines=15,
523
+ max_lines=20,
524
+ value="Ready to start transcription...",
525
+ interactive=False
526
+ )
527
+
528
+ with gr.Column(scale=1):
529
+ gr.Markdown("### Settings")
530
+
531
+ # Speaker change threshold
532
  change_threshold = gr.Slider(
533
+ minimum=0.1,
534
+ maximum=0.95,
535
  value=DEFAULT_CHANGE_THRESHOLD,
536
  step=0.05,
537
  label="Speaker Change Threshold",
538
  info="Lower values = more sensitive to speaker changes"
539
  )
540
+
541
+ # Max speakers
542
  max_speakers = gr.Slider(
543
  minimum=2,
544
  maximum=ABSOLUTE_MAX_SPEAKERS,
545
  value=DEFAULT_MAX_SPEAKERS,
546
  step=1,
547
+ label="Maximum Speakers",
548
  info="Maximum number of speakers to detect"
549
  )
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
+ # Clear button
552
+ clear_btn = gr.Button("Clear Transcript", variant="secondary")
553
+
554
+ gr.Markdown("### Speaker Colors")
555
+ color_info = "\\n".join([
556
+ f"Speaker {i+1}: {SPEAKER_COLOR_NAMES[i]}"
557
+ for i in range(min(DEFAULT_MAX_SPEAKERS, len(SPEAKER_COLOR_NAMES)))
558
+ ])
559
+ gr.Markdown(color_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
+ # Set up streaming
562
  audio_input.stream(
563
+ fn=process_audio_stream,
564
  inputs=[audio_input, change_threshold, max_speakers],
565
+ outputs=[transcript_output],
566
+ show_progress=False
567
  )
568
 
569
+ # Clear button functionality
570
  clear_btn.click(
571
+ fn=clear_transcript,
572
+ outputs=[transcript_output]
573
  )
574
 
575
+ gr.Markdown("""
576
+ ### Instructions:
577
+ 1. Allow microphone access when prompted
578
+ 2. Start speaking - transcription will appear in real-time
579
+ 3. Different speakers will be automatically detected and labeled
580
+ 4. Adjust the threshold if speaker changes aren't detected properly
581
+ 5. Use the clear button to reset the transcript
582
+
583
+ ### Notes:
584
+ - The system works best with clear audio and distinct speakers
585
+ - It may take a moment to load the speaker recognition model on first use
586
+ - Lower threshold values make the system more sensitive to speaker changes
587
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
+ return iface
590
 
591
 
592
  if __name__ == "__main__":
593
+ # Create and launch the interface
594
+ iface = create_interface()
595
+ iface.launch(
 
596
  server_name="0.0.0.0",
597
  server_port=7860,
598
+ share=True
599
  )