Saiyaswanth007 commited on
Commit
c263c26
ยท
1 Parent(s): 29b89b3

Code error correction

Browse files
Files changed (1) hide show
  1. app.py +213 -170
app.py CHANGED
@@ -9,8 +9,13 @@ import urllib.request
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
11
  import json
12
- import io
13
- import wave
 
 
 
 
 
14
 
15
  # Simplified configuration parameters
16
  SILENCE_THRESHS = [0, 0.4]
@@ -34,8 +39,9 @@ ABSOLUTE_MAX_SPEAKERS = 10
34
  # Global variables
35
  FAST_SENTENCE_END = True
36
  SAMPLE_RATE = 16000
37
- BUFFER_SIZE = 512
38
  CHANNELS = 1
 
39
 
40
  # Speaker colors
41
  SPEAKER_COLORS = [
@@ -73,7 +79,7 @@ class SpeechBrainEncoder:
73
  model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
74
 
75
  if not os.path.exists(model_path):
76
- print(f"Downloading ECAPA-TDNN model to {model_path}...")
77
  urllib.request.urlretrieve(model_url, model_path)
78
 
79
  return model_path
@@ -94,7 +100,7 @@ class SpeechBrainEncoder:
94
  self.model_loaded = True
95
  return True
96
  except Exception as e:
97
- print(f"Error loading ECAPA-TDNN model: {e}")
98
  return False
99
 
100
  def embed_utterance(self, audio, sr=16000):
@@ -116,7 +122,7 @@ class SpeechBrainEncoder:
116
 
117
  return embedding.squeeze().cpu().numpy()
118
  except Exception as e:
119
- print(f"Error extracting embedding: {e}")
120
  return np.zeros(self.embedding_dim)
121
 
122
 
@@ -135,7 +141,7 @@ class AudioProcessor:
135
 
136
  return embedding
137
  except Exception as e:
138
- print(f"Embedding extraction error: {e}")
139
  return np.zeros(self.encoder.embedding_dim)
140
 
141
 
@@ -270,83 +276,105 @@ class SpeakerChangeDetector:
270
 
271
 
272
  class WhisperTranscriber:
273
- """Simple Whisper transcriber for audio chunks"""
274
  def __init__(self, model_name="distil-large-v3"):
275
  self.model = None
276
  self.processor = None
277
  self.model_name = model_name
278
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
279
 
280
  def load_model(self):
281
  """Load Whisper model"""
282
  try:
283
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
284
 
285
- self.processor = WhisperProcessor.from_pretrained(f"distil-whisper/{self.model_name}")
286
- self.model = WhisperForConditionalGeneration.from_pretrained(f"distil-whisper/{self.model_name}")
287
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
288
 
 
289
  return True
290
  except Exception as e:
291
- print(f"Error loading Whisper model: {e}")
292
  return False
293
 
294
  def transcribe(self, audio_array, sample_rate=16000):
295
  """Transcribe audio array"""
 
 
 
296
  try:
297
- if self.model is None:
 
298
  return ""
299
 
300
- # Ensure audio is the right sample rate
301
  if sample_rate != 16000:
302
- audio_array = torchaudio.functional.resample(
303
- torch.tensor(audio_array).float(),
304
- orig_freq=sample_rate,
305
- new_freq=16000
306
- ).numpy()
 
 
 
 
 
 
 
307
 
308
- # Process audio
309
- inputs = self.processor(audio_array, sampling_rate=16000, return_tensors="pt")
310
- inputs = inputs.to(self.device)
311
 
312
- # Generate transcription
313
  with torch.no_grad():
314
- predicted_ids = self.model.generate(inputs["input_features"])
315
-
316
- # Decode transcription
317
- transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
 
 
 
318
 
319
- return transcription[0] if transcription else ""
320
 
 
321
  except Exception as e:
322
- print(f"Transcription error: {e}")
323
  return ""
324
 
325
 
326
- class RealtimeSpeakerDiarization:
327
  def __init__(self):
328
  self.encoder = None
329
  self.audio_processor = None
330
  self.speaker_detector = None
331
  self.transcriber = None
332
- self.audio_buffer = []
333
  self.processing_thread = None
334
- self.sentence_queue = queue.Queue()
335
  self.full_sentences = []
336
  self.sentence_speakers = []
337
- self.pending_sentences = []
338
- self.displayed_text = ""
339
  self.is_running = False
340
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
341
  self.max_speakers = DEFAULT_MAX_SPEAKERS
342
- self.audio_chunks = []
343
- self.chunk_counter = 0
 
 
344
 
345
  def initialize_models(self):
346
  """Initialize the speaker encoder and transcription models"""
347
  try:
348
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
349
- print(f"Using device: {device_str}")
350
 
351
  # Initialize speaker encoder
352
  self.encoder = SpeechBrainEncoder(device=device_str)
@@ -363,124 +391,131 @@ class RealtimeSpeakerDiarization:
363
  change_threshold=self.change_threshold,
364
  max_speakers=self.max_speakers
365
  )
366
- print("Models loaded successfully!")
367
  return True
368
  else:
369
- print("Failed to load models")
370
  return False
371
  except Exception as e:
372
- print(f"Model initialization error: {e}")
373
  return False
374
 
375
- def process_audio_stream(self, audio_data, sample_rate):
376
- """Process incoming audio stream data"""
377
- if not self.is_running or self.encoder is None:
378
  return
379
 
380
  try:
381
- # Convert audio data to numpy array if needed
382
- if isinstance(audio_data, tuple):
383
- sample_rate, audio_array = audio_data
384
- else:
385
- audio_array = audio_data
386
-
387
- # Ensure audio is float32 and normalized
388
- if audio_array.dtype != np.float32:
389
- if audio_array.dtype == np.int16:
390
- audio_array = audio_array.astype(np.float32) / 32768.0
391
- else:
392
- audio_array = audio_array.astype(np.float32)
393
-
394
- # Ensure mono audio
395
- if len(audio_array.shape) > 1 and audio_array.shape[1] > 1:
396
- audio_array = np.mean(audio_array, axis=1)
397
-
398
- # Add to buffer
399
- self.audio_buffer.extend(audio_array.flatten())
400
-
401
- # Process when we have enough audio (about 2 seconds)
402
- target_length = int(sample_rate * 2.0)
403
- if len(self.audio_buffer) >= target_length:
404
- self.process_audio_chunk()
405
 
406
- except Exception as e:
407
- print(f"Error processing audio stream: {e}")
408
-
409
- def process_audio_chunk(self):
410
- """Process accumulated audio chunk"""
411
- try:
412
- if len(self.audio_buffer) < SAMPLE_RATE: # Need at least 1 second
413
- return
414
-
415
- # Get audio chunk
416
- audio_chunk = np.array(self.audio_buffer[:int(SAMPLE_RATE * 2)])
417
- self.audio_buffer = self.audio_buffer[int(SAMPLE_RATE * 1.5):] # Keep some overlap
418
-
419
- # Transcribe audio
420
- transcription = self.transcriber.transcribe(audio_chunk, SAMPLE_RATE)
421
-
422
- if transcription.strip():
423
- # Extract speaker embedding
424
- speaker_embedding = self.audio_processor.extract_embedding(audio_chunk)
425
 
426
- # Add to queue for processing
427
- self.sentence_queue.put((transcription.strip(), speaker_embedding))
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  except Exception as e:
430
- print(f"Error processing audio chunk: {e}")
431
 
432
- def process_sentence_queue(self):
433
- """Process sentences in the queue for speaker detection"""
434
  while self.is_running:
435
  try:
436
- text, speaker_embedding = self.sentence_queue.get(timeout=1)
437
-
438
- # Store sentence and embedding
439
- self.full_sentences.append((text, speaker_embedding))
440
 
441
- # Fill in missing speaker assignments
442
- while len(self.sentence_speakers) < len(self.full_sentences) - 1:
443
- self.sentence_speakers.append(0)
444
 
445
- # Detect speaker changes
446
- speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
447
- self.sentence_speakers.append(speaker_id)
448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  except queue.Empty:
450
  continue
451
  except Exception as e:
452
- print(f"Error processing sentence: {e}")
453
 
454
  def start_recording(self):
455
- """Start the recording and transcription process"""
456
- if self.encoder is None:
457
  return "Please initialize models first!"
458
 
459
  try:
460
- # Start sentence processing thread
461
  self.is_running = True
462
- self.processing_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
 
 
 
 
 
 
 
 
 
 
 
463
  self.processing_thread.start()
464
 
465
- return "Recording started successfully! Start speaking into your microphone."
 
466
 
467
  except Exception as e:
 
468
  return f"Error starting recording: {e}"
469
 
470
  def stop_recording(self):
471
  """Stop the recording process"""
472
  self.is_running = False
473
- self.audio_buffer = []
474
  return "Recording stopped!"
475
 
476
  def clear_conversation(self):
477
  """Clear all conversation data"""
478
  self.full_sentences = []
479
  self.sentence_speakers = []
480
- self.pending_sentences = []
481
- self.displayed_text = ""
482
  self.audio_buffer = []
483
 
 
 
 
 
 
 
 
484
  if self.speaker_detector:
485
  self.speaker_detector = SpeakerChangeDetector(
486
  embedding_dim=self.encoder.embedding_dim,
@@ -504,26 +539,24 @@ class RealtimeSpeakerDiarization:
504
  def get_formatted_conversation(self):
505
  """Get the formatted conversation with speaker colors"""
506
  try:
 
 
 
507
  sentences_with_style = []
508
 
509
- # Process completed sentences
510
- for i, sentence in enumerate(self.full_sentences):
511
- sentence_text, _ = sentence
512
  if i >= len(self.sentence_speakers):
513
  color = "#FFFFFF"
514
- speaker_name = "Speaker ?"
515
  else:
516
- speaker_id = self.sentence_speakers[i]
517
  color = self.speaker_detector.get_color_for_speaker(speaker_id)
518
  speaker_name = f"Speaker {speaker_id + 1}"
519
-
520
  sentences_with_style.append(
521
- f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
522
 
523
- if sentences_with_style:
524
- return "<br><br>".join(sentences_with_style)
525
- else:
526
- return "Waiting for speech input..."
527
 
528
  except Exception as e:
529
  return f"Error formatting conversation: {e}"
@@ -535,6 +568,7 @@ class RealtimeSpeakerDiarization:
535
 
536
  try:
537
  status = self.speaker_detector.get_status_info()
 
538
 
539
  status_lines = [
540
  f"**Current Speaker:** {status['current_speaker'] + 1}",
@@ -542,7 +576,8 @@ class RealtimeSpeakerDiarization:
542
  f"**Last Similarity:** {status['last_similarity']:.3f}",
543
  f"**Change Threshold:** {status['threshold']:.2f}",
544
  f"**Total Sentences:** {len(self.full_sentences)}",
545
- f"**Audio Buffer Size:** {len(self.audio_buffer)}",
 
546
  "",
547
  "**Speaker Segment Counts:**"
548
  ]
@@ -558,7 +593,7 @@ class RealtimeSpeakerDiarization:
558
 
559
 
560
  # Global instance
561
- diarization_system = RealtimeSpeakerDiarization()
562
 
563
 
564
  def initialize_system():
@@ -600,49 +635,56 @@ def get_status():
600
  return diarization_system.get_status_info()
601
 
602
 
603
- def process_audio(audio_data):
604
- """Process audio from Gradio audio input"""
605
- if audio_data is not None:
606
- sample_rate, audio_array = audio_data
607
- diarization_system.process_audio_stream(audio_array, sample_rate)
 
608
  return get_conversation(), get_status()
609
 
610
 
611
- # Create Gradio interface
612
  def create_interface():
613
- with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome()) as app:
614
- gr.Markdown("# ๐ŸŽค Real-time Speech Recognition with Speaker Diarization")
615
- gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding using your browser's microphone.")
616
 
617
  with gr.Row():
618
  with gr.Column(scale=2):
619
- # Audio input
620
  audio_input = gr.Audio(
621
- source="microphone",
622
  type="numpy",
623
  streaming=True,
624
- label="๐ŸŽ™๏ธ Microphone Input"
 
 
 
 
625
  )
626
 
627
  # Main conversation display
628
  conversation_output = gr.HTML(
629
- value="<i>Click 'Initialize System' to start...</i>",
630
- label="Live Conversation"
 
631
  )
632
 
633
  # Control buttons
634
  with gr.Row():
635
- init_btn = gr.Button("๐Ÿ”ง Initialize System", variant="secondary")
636
- start_btn = gr.Button("๐ŸŽ™๏ธ Start Recording", variant="primary", interactive=False)
637
- stop_btn = gr.Button("โน๏ธ Stop Recording", variant="stop", interactive=False)
638
- clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Conversation", interactive=False)
639
 
640
  # Status display
641
  status_output = gr.Textbox(
642
  label="System Status",
643
  value="System not initialized",
644
  lines=10,
645
- interactive=False
 
646
  )
647
 
648
  with gr.Column(scale=1):
@@ -655,7 +697,7 @@ def create_interface():
655
  step=0.05,
656
  value=DEFAULT_CHANGE_THRESHOLD,
657
  label="Speaker Change Sensitivity",
658
- info="Lower values = more sensitive to speaker changes"
659
  )
660
 
661
  max_speakers_slider = gr.Slider(
@@ -666,26 +708,23 @@ def create_interface():
666
  label="Maximum Number of Speakers"
667
  )
668
 
669
- update_settings_btn = gr.Button("Update Settings")
670
 
671
  # Speaker color legend
672
  gr.Markdown("## ๐ŸŽจ Speaker Colors")
673
  color_info = []
674
  for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)):
675
- color_info.append(f'<span style="color:{color};">โ– </span> Speaker {i+1} ({name})')
676
 
677
  gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
678
 
679
- # Instructions
 
680
  gr.Markdown("""
681
- ## ๐Ÿ“‹ Instructions
682
- 1. **Initialize System** - Load AI models
683
- 2. **Allow microphone access** when prompted
684
- 3. **Start Recording** - Begin real-time processing
685
- 4. **Speak naturally** - The system will detect different speakers
686
- 5. **Stop Recording** when done
687
-
688
- **Note:** Processing happens in real-time with ~2 second chunks for better accuracy.
689
  """)
690
 
691
  # Event handlers
@@ -693,25 +732,25 @@ def create_interface():
693
  result = initialize_system()
694
  if "successfully" in result:
695
  return (
696
- result,
697
  gr.update(interactive=True), # start_btn
698
  gr.update(interactive=True), # clear_btn
699
- get_conversation(),
700
- get_status()
701
  )
702
  else:
703
  return (
704
- result,
705
  gr.update(interactive=False), # start_btn
706
  gr.update(interactive=False), # clear_btn
707
- get_conversation(),
708
- get_status()
709
  )
710
 
711
  def on_start():
712
  result = start_recording()
713
  return (
714
- result,
715
  gr.update(interactive=False), # start_btn
716
  gr.update(interactive=True), # stop_btn
717
  )
@@ -719,11 +758,15 @@ def create_interface():
719
  def on_stop():
720
  result = stop_recording()
721
  return (
722
- result,
723
  gr.update(interactive=True), # start_btn
724
  gr.update(interactive=False), # stop_btn
725
  )
726
 
 
 
 
 
727
  # Connect event handlers
728
  init_btn.click(
729
  on_initialize,
@@ -751,19 +794,19 @@ def create_interface():
751
  outputs=[status_output]
752
  )
753
 
754
- # Process streaming audio
755
  audio_input.stream(
756
- process_audio,
757
  inputs=[audio_input],
758
  outputs=[conversation_output, status_output],
759
- time_limit=60,
760
- stream_every=0.5
761
  )
762
 
763
- # Auto-refresh every 3 seconds
764
- refresh_timer = gr.Timer(3.0)
765
  refresh_timer.tick(
766
- lambda: (get_conversation(), get_status()),
767
  outputs=[conversation_output, status_output]
768
  )
769
 
 
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
11
  import json
12
+ import asyncio
13
+ from typing import Iterator
14
+ import logging
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
 
20
  # Simplified configuration parameters
21
  SILENCE_THRESHS = [0, 0.4]
 
39
  # Global variables
40
  FAST_SENTENCE_END = True
41
  SAMPLE_RATE = 16000
42
+ BUFFER_SIZE = 1024
43
  CHANNELS = 1
44
+ CHUNK_DURATION_MS = 100 # 100ms chunks for FastRTC
45
 
46
  # Speaker colors
47
  SPEAKER_COLORS = [
 
79
  model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
80
 
81
  if not os.path.exists(model_path):
82
+ logger.info(f"Downloading ECAPA-TDNN model to {model_path}...")
83
  urllib.request.urlretrieve(model_url, model_path)
84
 
85
  return model_path
 
100
  self.model_loaded = True
101
  return True
102
  except Exception as e:
103
+ logger.error(f"Error loading ECAPA-TDNN model: {e}")
104
  return False
105
 
106
  def embed_utterance(self, audio, sr=16000):
 
122
 
123
  return embedding.squeeze().cpu().numpy()
124
  except Exception as e:
125
+ logger.error(f"Error extracting embedding: {e}")
126
  return np.zeros(self.embedding_dim)
127
 
128
 
 
141
 
142
  return embedding
143
  except Exception as e:
144
+ logger.error(f"Embedding extraction error: {e}")
145
  return np.zeros(self.encoder.embedding_dim)
146
 
147
 
 
276
 
277
 
278
  class WhisperTranscriber:
279
+ """Whisper transcriber using transformers with FastRTC optimization"""
280
  def __init__(self, model_name="distil-large-v3"):
281
  self.model = None
282
  self.processor = None
283
  self.model_name = model_name
284
+ self.model_loaded = False
285
 
286
  def load_model(self):
287
  """Load Whisper model"""
288
  try:
289
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
290
 
291
+ model_id = f"distil-whisper/distil-{self.model_name}" if "distil" in self.model_name else f"openai/whisper-{self.model_name}"
292
+
293
+ self.processor = WhisperProcessor.from_pretrained(model_id)
294
+ self.model = WhisperForConditionalGeneration.from_pretrained(
295
+ model_id,
296
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
297
+ low_cpu_mem_usage=True,
298
+ use_safetensors=True
299
+ )
300
+
301
+ if torch.cuda.is_available():
302
+ self.model = self.model.cuda()
303
 
304
+ self.model_loaded = True
305
  return True
306
  except Exception as e:
307
+ logger.error(f"Error loading Whisper model: {e}")
308
  return False
309
 
310
  def transcribe(self, audio_array, sample_rate=16000):
311
  """Transcribe audio array"""
312
+ if not self.model_loaded:
313
+ return ""
314
+
315
  try:
316
+ # Ensure audio is the right length and format
317
+ if len(audio_array) < 1600: # Less than 0.1 seconds
318
  return ""
319
 
320
+ # Resample if needed
321
  if sample_rate != 16000:
322
+ import torchaudio.functional as F
323
+ audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
324
+ audio_array = F.resample(audio_tensor, sample_rate, 16000).numpy()
325
+
326
+ # Process with Whisper
327
+ inputs = self.processor(
328
+ audio_array,
329
+ sampling_rate=16000,
330
+ return_tensors="pt",
331
+ truncation=False,
332
+ padding=True
333
+ )
334
 
335
+ if torch.cuda.is_available():
336
+ inputs = {k: v.cuda() for k, v in inputs.items()}
 
337
 
 
338
  with torch.no_grad():
339
+ predicted_ids = self.model.generate(
340
+ inputs["input_features"],
341
+ max_length=448,
342
+ num_beams=1,
343
+ do_sample=False,
344
+ use_cache=True
345
+ )
346
 
347
+ transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
348
 
349
+ return transcription.strip()
350
  except Exception as e:
351
+ logger.error(f"Transcription error: {e}")
352
  return ""
353
 
354
 
355
+ class FastRTCSpeakerDiarization:
356
  def __init__(self):
357
  self.encoder = None
358
  self.audio_processor = None
359
  self.speaker_detector = None
360
  self.transcriber = None
361
+ self.audio_queue = queue.Queue(maxsize=100)
362
  self.processing_thread = None
 
363
  self.full_sentences = []
364
  self.sentence_speakers = []
 
 
365
  self.is_running = False
366
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
367
  self.max_speakers = DEFAULT_MAX_SPEAKERS
368
+ self.audio_buffer = []
369
+ self.buffer_duration = 3.0 # seconds
370
+ self.last_transcription_time = time.time()
371
+ self.chunk_size = int(SAMPLE_RATE * CHUNK_DURATION_MS / 1000)
372
 
373
  def initialize_models(self):
374
  """Initialize the speaker encoder and transcription models"""
375
  try:
376
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
377
+ logger.info(f"Using device: {device_str}")
378
 
379
  # Initialize speaker encoder
380
  self.encoder = SpeechBrainEncoder(device=device_str)
 
391
  change_threshold=self.change_threshold,
392
  max_speakers=self.max_speakers
393
  )
394
+ logger.info("Models loaded successfully!")
395
  return True
396
  else:
397
+ logger.error("Failed to load models")
398
  return False
399
  except Exception as e:
400
+ logger.error(f"Model initialization error: {e}")
401
  return False
402
 
403
+ def process_audio_chunk(self, audio_chunk: np.ndarray, sample_rate: int):
404
+ """Process individual audio chunk from FastRTC"""
405
+ if not self.is_running or audio_chunk is None:
406
  return
407
 
408
  try:
409
+ # Ensure audio chunk is in correct format
410
+ if isinstance(audio_chunk, np.ndarray):
411
+ # Ensure mono audio
412
+ if len(audio_chunk.shape) > 1:
413
+ audio_chunk = audio_chunk.mean(axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ # Normalize audio
416
+ if audio_chunk.dtype != np.float32:
417
+ audio_chunk = audio_chunk.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
+ if np.abs(audio_chunk).max() > 1.0:
420
+ audio_chunk = audio_chunk / np.abs(audio_chunk).max()
421
 
422
+ # Add to buffer
423
+ self.audio_buffer.extend(audio_chunk)
424
+
425
+ # Keep buffer to specified duration
426
+ max_buffer_length = int(self.buffer_duration * sample_rate)
427
+ if len(self.audio_buffer) > max_buffer_length:
428
+ self.audio_buffer = self.audio_buffer[-max_buffer_length:]
429
+
430
+ # Process if enough audio accumulated and enough time passed
431
+ current_time = time.time()
432
+ if (current_time - self.last_transcription_time > 1.5 and
433
+ len(self.audio_buffer) > sample_rate * 0.8): # At least 0.8 seconds
434
+
435
+ if not self.audio_queue.full():
436
+ self.audio_queue.put((np.array(self.audio_buffer[-int(sample_rate * 2):]), sample_rate))
437
+ self.last_transcription_time = current_time
438
+
439
  except Exception as e:
440
+ logger.error(f"Audio chunk processing error: {e}")
441
 
442
+ def process_audio_queue(self):
443
+ """Process audio from the queue"""
444
  while self.is_running:
445
  try:
446
+ audio_data, sample_rate = self.audio_queue.get(timeout=1)
 
 
 
447
 
448
+ if len(audio_data) < 1600: # Skip very short audio
449
+ continue
 
450
 
451
+ # Transcribe audio
452
+ transcription = self.transcriber.transcribe(audio_data, sample_rate)
 
453
 
454
+ if transcription and len(transcription.strip()) > 0:
455
+ # Extract speaker embedding
456
+ speaker_embedding = self.audio_processor.extract_embedding(audio_data)
457
+
458
+ # Detect speaker
459
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
460
+
461
+ # Store results
462
+ self.full_sentences.append(transcription.strip())
463
+ self.sentence_speakers.append(speaker_id)
464
+
465
+ logger.info(f"Processed: Speaker {speaker_id + 1}: {transcription.strip()[:50]}...")
466
+
467
  except queue.Empty:
468
  continue
469
  except Exception as e:
470
+ logger.error(f"Error processing audio queue: {e}")
471
 
472
  def start_recording(self):
473
+ """Start the recording and processing"""
474
+ if self.encoder is None or self.transcriber is None:
475
  return "Please initialize models first!"
476
 
477
  try:
 
478
  self.is_running = True
479
+ self.audio_buffer = []
480
+ self.last_transcription_time = time.time()
481
+
482
+ # Clear the queue
483
+ while not self.audio_queue.empty():
484
+ try:
485
+ self.audio_queue.get_nowait()
486
+ except queue.Empty:
487
+ break
488
+
489
+ # Start processing thread
490
+ self.processing_thread = threading.Thread(target=self.process_audio_queue, daemon=True)
491
  self.processing_thread.start()
492
 
493
+ logger.info("Recording started successfully!")
494
+ return "Recording started successfully!"
495
 
496
  except Exception as e:
497
+ logger.error(f"Error starting recording: {e}")
498
  return f"Error starting recording: {e}"
499
 
500
  def stop_recording(self):
501
  """Stop the recording process"""
502
  self.is_running = False
503
+ logger.info("Recording stopped!")
504
  return "Recording stopped!"
505
 
506
  def clear_conversation(self):
507
  """Clear all conversation data"""
508
  self.full_sentences = []
509
  self.sentence_speakers = []
 
 
510
  self.audio_buffer = []
511
 
512
+ # Clear the queue
513
+ while not self.audio_queue.empty():
514
+ try:
515
+ self.audio_queue.get_nowait()
516
+ except queue.Empty:
517
+ break
518
+
519
  if self.speaker_detector:
520
  self.speaker_detector = SpeakerChangeDetector(
521
  embedding_dim=self.encoder.embedding_dim,
 
539
  def get_formatted_conversation(self):
540
  """Get the formatted conversation with speaker colors"""
541
  try:
542
+ if not self.full_sentences:
543
+ return "Waiting for speech input... ๐ŸŽค"
544
+
545
  sentences_with_style = []
546
 
547
+ for i, sentence in enumerate(self.full_sentences[-10:]): # Show last 10 sentences
 
 
548
  if i >= len(self.sentence_speakers):
549
  color = "#FFFFFF"
550
+ speaker_name = "Unknown"
551
  else:
552
+ speaker_id = self.sentence_speakers[-(10-i) if len(self.sentence_speakers) >= 10 else i]
553
  color = self.speaker_detector.get_color_for_speaker(speaker_id)
554
  speaker_name = f"Speaker {speaker_id + 1}"
555
+
556
  sentences_with_style.append(
557
+ f'<p><span style="color:{color}; font-weight: bold;">{speaker_name}:</span> {sentence}</p>')
558
 
559
+ return "".join(sentences_with_style)
 
 
 
560
 
561
  except Exception as e:
562
  return f"Error formatting conversation: {e}"
 
568
 
569
  try:
570
  status = self.speaker_detector.get_status_info()
571
+ queue_size = self.audio_queue.qsize()
572
 
573
  status_lines = [
574
  f"**Current Speaker:** {status['current_speaker'] + 1}",
 
576
  f"**Last Similarity:** {status['last_similarity']:.3f}",
577
  f"**Change Threshold:** {status['threshold']:.2f}",
578
  f"**Total Sentences:** {len(self.full_sentences)}",
579
+ f"**Buffer Length:** {len(self.audio_buffer)} samples",
580
+ f"**Queue Size:** {queue_size}",
581
  "",
582
  "**Speaker Segment Counts:**"
583
  ]
 
593
 
594
 
595
  # Global instance
596
+ diarization_system = FastRTCSpeakerDiarization()
597
 
598
 
599
  def initialize_system():
 
635
  return diarization_system.get_status_info()
636
 
637
 
638
+ def process_audio_stream(audio_stream):
639
+ """Process streaming audio from FastRTC"""
640
+ if audio_stream is not None and diarization_system.is_running:
641
+ sample_rate, audio_data = audio_stream
642
+ diarization_system.process_audio_chunk(audio_data, sample_rate)
643
+
644
  return get_conversation(), get_status()
645
 
646
 
647
+ # Create Gradio interface with FastRTC
648
  def create_interface():
649
+ with gr.Blocks(title="FastRTC Real-time Speaker Diarization", theme=gr.themes.Soft()) as app:
650
+ gr.Markdown("# ๐ŸŽค FastRTC Real-time Speech Recognition with Speaker Diarization")
651
+ gr.Markdown("This app uses Hugging Face FastRTC for real-time audio streaming with automatic speaker identification and color-coding.")
652
 
653
  with gr.Row():
654
  with gr.Column(scale=2):
655
+ # FastRTC Audio input for real-time streaming
656
  audio_input = gr.Audio(
657
+ sources=["microphone"],
658
  type="numpy",
659
  streaming=True,
660
+ label="๐ŸŽ™๏ธ FastRTC Microphone Input",
661
+ format="wav",
662
+ show_download_button=False,
663
+ container=True,
664
+ elem_id="fastrtc_audio"
665
  )
666
 
667
  # Main conversation display
668
  conversation_output = gr.HTML(
669
+ value="<i>Click 'Initialize System' and then 'Start Recording' to begin...</i>",
670
+ label="Live Conversation",
671
+ elem_id="conversation_display"
672
  )
673
 
674
  # Control buttons
675
  with gr.Row():
676
+ init_btn = gr.Button("๐Ÿ”ง Initialize System", variant="secondary", size="lg")
677
+ start_btn = gr.Button("๐ŸŽ™๏ธ Start Recording", variant="primary", interactive=False, size="lg")
678
+ stop_btn = gr.Button("โน๏ธ Stop Recording", variant="stop", interactive=False, size="lg")
679
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear", interactive=False, size="lg")
680
 
681
  # Status display
682
  status_output = gr.Textbox(
683
  label="System Status",
684
  value="System not initialized",
685
  lines=10,
686
+ interactive=False,
687
+ show_copy_button=True
688
  )
689
 
690
  with gr.Column(scale=1):
 
697
  step=0.05,
698
  value=DEFAULT_CHANGE_THRESHOLD,
699
  label="Speaker Change Sensitivity",
700
+ info="Lower = more sensitive to changes"
701
  )
702
 
703
  max_speakers_slider = gr.Slider(
 
708
  label="Maximum Number of Speakers"
709
  )
710
 
711
+ update_settings_btn = gr.Button("Update Settings", variant="secondary")
712
 
713
  # Speaker color legend
714
  gr.Markdown("## ๐ŸŽจ Speaker Colors")
715
  color_info = []
716
  for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)):
717
+ color_info.append(f'<span style="color:{color}; font-size: 16px;">โ—</span> Speaker {i+1} ({name})')
718
 
719
  gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
720
 
721
+ # Performance info
722
+ gr.Markdown("## ๐Ÿ“Š Performance")
723
  gr.Markdown("""
724
+ - **FastRTC**: Low-latency audio streaming
725
+ - **Whisper**: distil-large-v3 for transcription
726
+ - **ECAPA-TDNN**: Speaker embeddings
727
+ - **Real-time**: ~100ms processing chunks
 
 
 
 
728
  """)
729
 
730
  # Event handlers
 
732
  result = initialize_system()
733
  if "successfully" in result:
734
  return (
735
+ result, # status_output
736
  gr.update(interactive=True), # start_btn
737
  gr.update(interactive=True), # clear_btn
738
+ get_conversation(), # conversation_output
739
+ get_status() # status_output update
740
  )
741
  else:
742
  return (
743
+ result, # status_output
744
  gr.update(interactive=False), # start_btn
745
  gr.update(interactive=False), # clear_btn
746
+ get_conversation(), # conversation_output
747
+ get_status() # status_output update
748
  )
749
 
750
  def on_start():
751
  result = start_recording()
752
  return (
753
+ result, # status_output
754
  gr.update(interactive=False), # start_btn
755
  gr.update(interactive=True), # stop_btn
756
  )
 
758
  def on_stop():
759
  result = stop_recording()
760
  return (
761
+ result, # status_output
762
  gr.update(interactive=True), # start_btn
763
  gr.update(interactive=False), # stop_btn
764
  )
765
 
766
+ # Auto-refresh function
767
+ def refresh_display():
768
+ return get_conversation(), get_status()
769
+
770
  # Connect event handlers
771
  init_btn.click(
772
  on_initialize,
 
794
  outputs=[status_output]
795
  )
796
 
797
+ # FastRTC streaming audio processing
798
  audio_input.stream(
799
+ process_audio_stream,
800
  inputs=[audio_input],
801
  outputs=[conversation_output, status_output],
802
+ stream_every=0.1, # Process every 100ms
803
+ time_limit=None
804
  )
805
 
806
+ # Auto-refresh timer
807
+ refresh_timer = gr.Timer(2.0)
808
  refresh_timer.tick(
809
+ refresh_display,
810
  outputs=[conversation_output, status_output]
811
  )
812