Saiyaswanth007 commited on
Commit
25dcfd9
ยท
1 Parent(s): 9e0d933

Changed gradio approx real-stream to FastRTC

Browse files
Files changed (1) hide show
  1. app.py +341 -128
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import urllib.request
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
 
11
  import json
12
  import io
13
  import wave
@@ -57,9 +58,6 @@ SPEAKER_COLOR_NAMES = [
57
  ]
58
 
59
 
60
-
61
-
62
-
63
  class SpeechBrainEncoder:
64
  """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
65
  def __init__(self, device="cpu"):
@@ -70,11 +68,24 @@ class SpeechBrainEncoder:
70
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
71
  os.makedirs(self.cache_dir, exist_ok=True)
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  def load_model(self):
74
  """Load the ECAPA-TDNN model"""
75
  try:
76
  from speechbrain.pretrained import EncoderClassifier
77
 
 
 
78
  self.model = EncoderClassifier.from_hparams(
79
  source="speechbrain/spkrec-ecapa-voxceleb",
80
  savedir=self.cache_dir,
@@ -82,10 +93,9 @@ class SpeechBrainEncoder:
82
  )
83
 
84
  self.model_loaded = True
85
- print("ECAPA-TDNN model loaded successfully!")
86
  return True
87
  except Exception as e:
88
- print(f"SpeechBrain not available: {e}")
89
  return False
90
 
91
  def embed_utterance(self, audio, sr=16000):
@@ -116,21 +126,16 @@ class AudioProcessor:
116
  def __init__(self, encoder):
117
  self.encoder = encoder
118
 
119
- def extract_embedding(self, audio_data, sample_rate=16000):
120
  try:
121
- # Ensure audio is float32 and normalized
122
- if audio_data.dtype == np.int16:
123
- float_audio = audio_data.astype(np.float32) / 32768.0
124
- else:
125
- float_audio = audio_data.astype(np.float32)
126
 
127
- # Normalize if needed
128
  if np.abs(float_audio).max() > 1.0:
129
  float_audio = float_audio / np.abs(float_audio).max()
130
 
131
- embedding = self.encoder.embed_utterance(float_audio, sample_rate)
132
- return embedding
133
 
 
134
  except Exception as e:
135
  print(f"Embedding extraction error: {e}")
136
  return np.zeros(self.encoder.embedding_dim)
@@ -266,14 +271,68 @@ class SpeakerChangeDetector:
266
  }
267
 
268
 
269
- class GradioSpeakerDiarization:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  def __init__(self):
271
  self.encoder = None
272
  self.audio_processor = None
273
  self.speaker_detector = None
 
 
 
274
  self.full_sentences = []
275
  self.sentence_speakers = []
276
- self.is_initialized = False
 
 
 
277
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
278
  self.max_speakers = DEFAULT_MAX_SPEAKERS
279
 
@@ -283,7 +342,6 @@ class GradioSpeakerDiarization:
283
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
284
  print(f"Using device: {device_str}")
285
 
286
- # Load SpeechBrain encoder
287
  self.encoder = SpeechBrainEncoder(device=device_str)
288
  success = self.encoder.load_model()
289
 
@@ -294,62 +352,145 @@ class GradioSpeakerDiarization:
294
  change_threshold=self.change_threshold,
295
  max_speakers=self.max_speakers
296
  )
297
- self.is_initialized = True
 
298
  return True
299
  else:
 
300
  return False
301
-
302
  except Exception as e:
303
  print(f"Model initialization error: {e}")
304
  return False
305
 
306
- def transcribe_audio(self, audio_input):
307
- """Process audio input and perform transcription with speaker diarization"""
308
- if not self.is_initialized:
309
- return "โŒ Please initialize the system first!", self.get_formatted_conversation(), self.get_status_info()
310
-
311
- if audio_input is None:
312
- return "No audio received", self.get_formatted_conversation(), self.get_status_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  try:
315
- # Handle different audio input formats
316
- if isinstance(audio_input, tuple):
317
- sample_rate, audio_data = audio_input
318
- else:
319
- # Assume it's a file path
320
- import librosa
321
- audio_data, sample_rate = librosa.load(audio_input, sr=16000)
322
-
323
- # Ensure audio is in the right format
324
- if len(audio_data.shape) > 1:
325
- audio_data = audio_data.mean(axis=1) # Convert to mono
326
-
327
- # Perform simple transcription (placeholder - you'd want to integrate with Whisper or similar)
328
- # For now, we'll just do speaker diarization
329
- transcription = f"Audio segment {len(self.full_sentences) + 1} (duration: {len(audio_data)/sample_rate:.1f}s)"
330
-
331
- # Extract speaker embedding
332
- speaker_embedding = self.audio_processor.extract_embedding(audio_data, sample_rate)
333
-
334
- # Store sentence and embedding
335
- self.full_sentences.append((transcription, speaker_embedding))
 
 
336
 
337
- # Detect speaker changes
338
- speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
339
- self.sentence_speakers.append(speaker_id)
 
340
 
341
- status_msg = f"โœ… Processed audio segment. Detected as Speaker {speaker_id + 1} (similarity: {similarity:.3f})"
 
 
342
 
343
- return status_msg, self.get_formatted_conversation(), self.get_status_info()
344
 
345
  except Exception as e:
346
- error_msg = f"โŒ Error processing audio: {str(e)}"
347
- return error_msg, self.get_formatted_conversation(), self.get_status_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  def clear_conversation(self):
350
  """Clear all conversation data"""
351
  self.full_sentences = []
352
  self.sentence_speakers = []
 
 
 
353
 
354
  if self.speaker_detector:
355
  self.speaker_detector = SpeakerChangeDetector(
@@ -358,7 +499,7 @@ class GradioSpeakerDiarization:
358
  max_speakers=self.max_speakers
359
  )
360
 
361
- return "Conversation cleared!", self.get_formatted_conversation(), self.get_status_info()
362
 
363
  def update_settings(self, threshold, max_speakers):
364
  """Update speaker detection settings"""
@@ -369,22 +510,18 @@ class GradioSpeakerDiarization:
369
  self.speaker_detector.set_change_threshold(threshold)
370
  self.speaker_detector.set_max_speakers(max_speakers)
371
 
372
- status_msg = f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
373
- return status_msg, self.get_formatted_conversation(), self.get_status_info()
374
 
375
  def get_formatted_conversation(self):
376
  """Get the formatted conversation with speaker colors"""
377
  try:
378
- if not self.full_sentences:
379
- return "No audio processed yet. Upload an audio file or record using the microphone."
380
-
381
  sentences_with_style = []
382
 
 
383
  for i, sentence in enumerate(self.full_sentences):
384
  sentence_text, _ = sentence
385
  if i >= len(self.sentence_speakers):
386
  color = "#FFFFFF"
387
- speaker_name = "Unknown"
388
  else:
389
  speaker_id = self.sentence_speakers[i]
390
  color = self.speaker_detector.get_color_for_speaker(speaker_id)
@@ -393,7 +530,15 @@ class GradioSpeakerDiarization:
393
  sentences_with_style.append(
394
  f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
395
 
396
- return "<br><br>".join(sentences_with_style)
 
 
 
 
 
 
 
 
397
 
398
  except Exception as e:
399
  return f"Error formatting conversation: {e}"
@@ -411,7 +556,7 @@ class GradioSpeakerDiarization:
411
  f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
412
  f"**Last Similarity:** {status['last_similarity']:.3f}",
413
  f"**Change Threshold:** {status['threshold']:.2f}",
414
- f"**Total Segments:** {len(self.full_sentences)}",
415
  "",
416
  "**Speaker Segment Counts:**"
417
  ]
@@ -427,21 +572,26 @@ class GradioSpeakerDiarization:
427
 
428
 
429
  # Global instance
430
- diarization_system = GradioSpeakerDiarization()
431
 
432
 
433
  def initialize_system():
434
  """Initialize the diarization system"""
435
  success = diarization_system.initialize_models()
436
  if success:
437
- return "โœ… System initialized successfully! Models loaded.", "", ""
438
  else:
439
- return "โŒ Failed to initialize system. Please check the logs.", "", ""
 
 
 
 
 
440
 
441
 
442
- def process_audio(audio):
443
- """Process uploaded or recorded audio"""
444
- return diarization_system.transcribe_audio(audio)
445
 
446
 
447
  def clear_conversation():
@@ -454,52 +604,59 @@ def update_settings(threshold, max_speakers):
454
  return diarization_system.update_settings(threshold, max_speakers)
455
 
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  # Create Gradio interface
458
  def create_interface():
459
- with gr.Blocks(title="Speaker Diarization", theme=gr.themes.Soft()) as app:
460
- gr.Markdown("# ๐ŸŽค Audio Speaker Diarization")
461
- gr.Markdown("Upload audio files or record directly to identify different speakers using voice characteristics.")
462
 
463
  with gr.Row():
464
  with gr.Column(scale=2):
465
- # Initialize button
466
- with gr.Row():
467
- init_btn = gr.Button("๐Ÿ”ง Initialize System", variant="primary", size="lg")
468
-
469
- # Audio input options
470
- gr.Markdown("### ๐Ÿ“ Audio Input")
471
- with gr.Tab("Upload Audio File"):
472
- audio_file = gr.Audio(
473
- label="Upload Audio File",
474
- type="filepath",
475
- sources=["upload"]
476
- )
477
- process_file_btn = gr.Button("Process Audio File", variant="secondary")
478
-
479
- with gr.Tab("Record Audio"):
480
- audio_mic = gr.Audio(
481
- label="Record Audio",
482
- type="numpy",
483
- sources=["microphone"]
484
- )
485
- process_mic_btn = gr.Button("Process Recording", variant="secondary")
486
-
487
- # Results display
488
- status_output = gr.Textbox(
489
- label="Status",
490
- value="Click 'Initialize System' to start...",
491
- lines=2,
492
- interactive=False
493
  )
494
 
 
495
  conversation_output = gr.HTML(
496
- value="<i>System not initialized...</i>",
497
- label="Speaker Analysis Results"
498
  )
499
 
500
  # Control buttons
501
  with gr.Row():
502
- clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Results", variant="stop")
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  with gr.Column(scale=1):
505
  # Settings panel
@@ -511,7 +668,7 @@ def create_interface():
511
  step=0.05,
512
  value=DEFAULT_CHANGE_THRESHOLD,
513
  label="Speaker Change Sensitivity",
514
- info="Lower = more sensitive to speaker changes"
515
  )
516
 
517
  max_speakers_slider = gr.Slider(
@@ -522,51 +679,106 @@ def create_interface():
522
  label="Maximum Number of Speakers"
523
  )
524
 
525
- update_settings_btn = gr.Button("Update Settings", variant="secondary")
526
 
527
- # System status
528
- system_status = gr.Textbox(
529
- label="System Status",
530
- value="System not initialized",
531
- lines=12,
532
- interactive=False
533
- )
 
 
 
534
 
535
  # Speaker color legend
536
  gr.Markdown("## ๐ŸŽจ Speaker Colors")
537
  color_info = []
538
- for i, (color, name) in enumerate(zip(SPEAKER_COLORS[:DEFAULT_MAX_SPEAKERS], SPEAKER_COLOR_NAMES[:DEFAULT_MAX_SPEAKERS])):
539
- color_info.append(f'<span style="color:{color};">โ—</span> Speaker {i+1} ({name})')
540
 
541
- gr.HTML("<br>".join(color_info))
 
 
 
 
542
 
543
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  init_btn.click(
545
- initialize_system,
546
- outputs=[status_output, conversation_output, system_status]
547
  )
548
 
549
- process_file_btn.click(
550
- process_audio,
551
- inputs=[audio_file],
552
- outputs=[status_output, conversation_output, system_status]
553
  )
554
 
555
- process_mic_btn.click(
556
- process_audio,
557
- inputs=[audio_mic],
558
- outputs=[status_output, conversation_output, system_status]
559
  )
560
 
561
  clear_btn.click(
562
  clear_conversation,
563
- outputs=[status_output, conversation_output, system_status]
564
  )
565
 
566
  update_settings_btn.click(
567
  update_settings,
568
  inputs=[threshold_slider, max_speakers_slider],
569
- outputs=[status_output, conversation_output, system_status]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  )
571
 
572
  return app
@@ -576,5 +788,6 @@ if __name__ == "__main__":
576
  app = create_interface()
577
  app.launch(
578
  server_name="0.0.0.0",
579
- server_port=7860
 
580
  )
 
8
  import urllib.request
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
11
+ from RealtimeSTT import AudioToTextRecorder
12
  import json
13
  import io
14
  import wave
 
58
  ]
59
 
60
 
 
 
 
61
  class SpeechBrainEncoder:
62
  """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
63
  def __init__(self, device="cpu"):
 
68
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
69
  os.makedirs(self.cache_dir, exist_ok=True)
70
 
71
+ def _download_model(self):
72
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
73
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
74
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
75
+
76
+ if not os.path.exists(model_path):
77
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
78
+ urllib.request.urlretrieve(model_url, model_path)
79
+
80
+ return model_path
81
+
82
  def load_model(self):
83
  """Load the ECAPA-TDNN model"""
84
  try:
85
  from speechbrain.pretrained import EncoderClassifier
86
 
87
+ model_path = self._download_model()
88
+
89
  self.model = EncoderClassifier.from_hparams(
90
  source="speechbrain/spkrec-ecapa-voxceleb",
91
  savedir=self.cache_dir,
 
93
  )
94
 
95
  self.model_loaded = True
 
96
  return True
97
  except Exception as e:
98
+ print(f"Error loading ECAPA-TDNN model: {e}")
99
  return False
100
 
101
  def embed_utterance(self, audio, sr=16000):
 
126
  def __init__(self, encoder):
127
  self.encoder = encoder
128
 
129
+ def extract_embedding(self, audio_int16):
130
  try:
131
+ float_audio = audio_int16.astype(np.float32) / 32768.0
 
 
 
 
132
 
 
133
  if np.abs(float_audio).max() > 1.0:
134
  float_audio = float_audio / np.abs(float_audio).max()
135
 
136
+ embedding = self.encoder.embed_utterance(float_audio)
 
137
 
138
+ return embedding
139
  except Exception as e:
140
  print(f"Embedding extraction error: {e}")
141
  return np.zeros(self.encoder.embedding_dim)
 
271
  }
272
 
273
 
274
+ class WebRTCAudioProcessor:
275
+ """Processes WebRTC audio streams for speaker diarization"""
276
+ def __init__(self, diarization_system):
277
+ self.diarization_system = diarization_system
278
+ self.audio_buffer = []
279
+ self.buffer_lock = threading.Lock()
280
+ self.processing_thread = None
281
+ self.is_processing = False
282
+
283
+ def process_audio(self, audio_data, sample_rate):
284
+ """Process incoming audio data from WebRTC"""
285
+ try:
286
+ # Convert audio data to numpy array if needed
287
+ if isinstance(audio_data, bytes):
288
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
289
+ elif isinstance(audio_data, tuple):
290
+ # Handle tuple format (sample_rate, audio_array)
291
+ sample_rate, audio_array = audio_data
292
+ if isinstance(audio_array, np.ndarray):
293
+ if audio_array.dtype != np.int16:
294
+ audio_array = (audio_array * 32767).astype(np.int16)
295
+ else:
296
+ audio_array = np.array(audio_array, dtype=np.int16)
297
+ else:
298
+ audio_array = np.array(audio_data, dtype=np.int16)
299
+
300
+ # Ensure mono audio
301
+ if len(audio_array.shape) > 1:
302
+ audio_array = audio_array[:, 0]
303
+
304
+ # Add to buffer
305
+ with self.buffer_lock:
306
+ self.audio_buffer.extend(audio_array)
307
+
308
+ # Process buffer when it's large enough (1 second of audio)
309
+ if len(self.audio_buffer) >= sample_rate:
310
+ buffer_to_process = np.array(self.audio_buffer[:sample_rate])
311
+ self.audio_buffer = self.audio_buffer[sample_rate//2:] # Keep 50% overlap
312
+
313
+ # Feed to recorder in separate thread
314
+ if self.diarization_system.recorder:
315
+ audio_bytes = buffer_to_process.tobytes()
316
+ self.diarization_system.recorder.feed_audio(audio_bytes)
317
+
318
+ except Exception as e:
319
+ print(f"Error processing WebRTC audio: {e}")
320
+
321
+
322
+ class RealtimeSpeakerDiarization:
323
  def __init__(self):
324
  self.encoder = None
325
  self.audio_processor = None
326
  self.speaker_detector = None
327
+ self.recorder = None
328
+ self.webrtc_processor = None
329
+ self.sentence_queue = queue.Queue()
330
  self.full_sentences = []
331
  self.sentence_speakers = []
332
+ self.pending_sentences = []
333
+ self.displayed_text = ""
334
+ self.last_realtime_text = ""
335
+ self.is_running = False
336
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
337
  self.max_speakers = DEFAULT_MAX_SPEAKERS
338
 
 
342
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
343
  print(f"Using device: {device_str}")
344
 
 
345
  self.encoder = SpeechBrainEncoder(device=device_str)
346
  success = self.encoder.load_model()
347
 
 
352
  change_threshold=self.change_threshold,
353
  max_speakers=self.max_speakers
354
  )
355
+ self.webrtc_processor = WebRTCAudioProcessor(self)
356
+ print("ECAPA-TDNN model loaded successfully!")
357
  return True
358
  else:
359
+ print("Failed to load ECAPA-TDNN model")
360
  return False
 
361
  except Exception as e:
362
  print(f"Model initialization error: {e}")
363
  return False
364
 
365
+ def live_text_detected(self, text):
366
+ """Callback for real-time transcription updates"""
367
+ text = text.strip()
368
+ if text:
369
+ sentence_delimiters = '.?!ใ€‚'
370
+ prob_sentence_end = (
371
+ len(self.last_realtime_text) > 0
372
+ and text[-1] in sentence_delimiters
373
+ and self.last_realtime_text[-1] in sentence_delimiters
374
+ )
375
+
376
+ self.last_realtime_text = text
377
+
378
+ if prob_sentence_end and FAST_SENTENCE_END:
379
+ self.recorder.stop()
380
+ elif prob_sentence_end:
381
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
382
+ else:
383
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
384
+
385
+ def process_final_text(self, text):
386
+ """Process final transcribed text with speaker embedding"""
387
+ text = text.strip()
388
+ if text:
389
+ try:
390
+ bytes_data = self.recorder.last_transcription_bytes
391
+ self.sentence_queue.put((text, bytes_data))
392
+ self.pending_sentences.append(text)
393
+ except Exception as e:
394
+ print(f"Error processing final text: {e}")
395
+
396
+ def process_sentence_queue(self):
397
+ """Process sentences in the queue for speaker detection"""
398
+ while self.is_running:
399
+ try:
400
+ text, bytes_data = self.sentence_queue.get(timeout=1)
401
+
402
+ # Convert audio data to int16
403
+ audio_int16 = np.int16(bytes_data * 32767)
404
+
405
+ # Extract speaker embedding
406
+ speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
407
+
408
+ # Store sentence and embedding
409
+ self.full_sentences.append((text, speaker_embedding))
410
+
411
+ # Fill in missing speaker assignments
412
+ while len(self.sentence_speakers) < len(self.full_sentences) - 1:
413
+ self.sentence_speakers.append(0)
414
+
415
+ # Detect speaker changes
416
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
417
+ self.sentence_speakers.append(speaker_id)
418
+
419
+ # Remove from pending
420
+ if text in self.pending_sentences:
421
+ self.pending_sentences.remove(text)
422
+
423
+ except queue.Empty:
424
+ continue
425
+ except Exception as e:
426
+ print(f"Error processing sentence: {e}")
427
+
428
+ def start_recording(self):
429
+ """Start the recording and transcription process"""
430
+ if self.encoder is None:
431
+ return "Please initialize models first!"
432
 
433
  try:
434
+ # Setup recorder configuration for WebRTC input
435
+ recorder_config = {
436
+ 'spinner': False,
437
+ 'use_microphone': False, # We'll feed audio manually
438
+ 'model': FINAL_TRANSCRIPTION_MODEL,
439
+ 'language': TRANSCRIPTION_LANGUAGE,
440
+ 'silero_sensitivity': SILERO_SENSITIVITY,
441
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
442
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
443
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
444
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
445
+ 'min_gap_between_recordings': 0,
446
+ 'enable_realtime_transcription': True,
447
+ 'realtime_processing_pause': 0,
448
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
449
+ 'on_realtime_transcription_update': self.live_text_detected,
450
+ 'beam_size': FINAL_BEAM_SIZE,
451
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
452
+ 'buffer_size': BUFFER_SIZE,
453
+ 'sample_rate': SAMPLE_RATE,
454
+ }
455
+
456
+ self.recorder = AudioToTextRecorder(**recorder_config)
457
 
458
+ # Start sentence processing thread
459
+ self.is_running = True
460
+ self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
461
+ self.sentence_thread.start()
462
 
463
+ # Start transcription thread
464
+ self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
465
+ self.transcription_thread.start()
466
 
467
+ return "Recording started successfully! WebRTC audio input ready."
468
 
469
  except Exception as e:
470
+ return f"Error starting recording: {e}"
471
+
472
+ def run_transcription(self):
473
+ """Run the transcription loop"""
474
+ try:
475
+ while self.is_running:
476
+ self.recorder.text(self.process_final_text)
477
+ except Exception as e:
478
+ print(f"Transcription error: {e}")
479
+
480
+ def stop_recording(self):
481
+ """Stop the recording process"""
482
+ self.is_running = False
483
+ if self.recorder:
484
+ self.recorder.stop()
485
+ return "Recording stopped!"
486
 
487
  def clear_conversation(self):
488
  """Clear all conversation data"""
489
  self.full_sentences = []
490
  self.sentence_speakers = []
491
+ self.pending_sentences = []
492
+ self.displayed_text = ""
493
+ self.last_realtime_text = ""
494
 
495
  if self.speaker_detector:
496
  self.speaker_detector = SpeakerChangeDetector(
 
499
  max_speakers=self.max_speakers
500
  )
501
 
502
+ return "Conversation cleared!"
503
 
504
  def update_settings(self, threshold, max_speakers):
505
  """Update speaker detection settings"""
 
510
  self.speaker_detector.set_change_threshold(threshold)
511
  self.speaker_detector.set_max_speakers(max_speakers)
512
 
513
+ return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
 
514
 
515
  def get_formatted_conversation(self):
516
  """Get the formatted conversation with speaker colors"""
517
  try:
 
 
 
518
  sentences_with_style = []
519
 
520
+ # Process completed sentences
521
  for i, sentence in enumerate(self.full_sentences):
522
  sentence_text, _ = sentence
523
  if i >= len(self.sentence_speakers):
524
  color = "#FFFFFF"
 
525
  else:
526
  speaker_id = self.sentence_speakers[i]
527
  color = self.speaker_detector.get_color_for_speaker(speaker_id)
 
530
  sentences_with_style.append(
531
  f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
532
 
533
+ # Add pending sentences
534
+ for pending_sentence in self.pending_sentences:
535
+ sentences_with_style.append(
536
+ f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>')
537
+
538
+ if sentences_with_style:
539
+ return "<br><br>".join(sentences_with_style)
540
+ else:
541
+ return "Waiting for speech input..."
542
 
543
  except Exception as e:
544
  return f"Error formatting conversation: {e}"
 
556
  f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
557
  f"**Last Similarity:** {status['last_similarity']:.3f}",
558
  f"**Change Threshold:** {status['threshold']:.2f}",
559
+ f"**Total Sentences:** {len(self.full_sentences)}",
560
  "",
561
  "**Speaker Segment Counts:**"
562
  ]
 
572
 
573
 
574
  # Global instance
575
+ diarization_system = RealtimeSpeakerDiarization()
576
 
577
 
578
  def initialize_system():
579
  """Initialize the diarization system"""
580
  success = diarization_system.initialize_models()
581
  if success:
582
+ return "โœ… System initialized successfully! Models loaded."
583
  else:
584
+ return "โŒ Failed to initialize system. Please check the logs."
585
+
586
+
587
+ def start_recording():
588
+ """Start recording and transcription"""
589
+ return diarization_system.start_recording()
590
 
591
 
592
+ def stop_recording():
593
+ """Stop recording and transcription"""
594
+ return diarization_system.stop_recording()
595
 
596
 
597
  def clear_conversation():
 
604
  return diarization_system.update_settings(threshold, max_speakers)
605
 
606
 
607
+ def get_conversation():
608
+ """Get the current conversation"""
609
+ return diarization_system.get_formatted_conversation()
610
+
611
+
612
+ def get_status():
613
+ """Get system status"""
614
+ return diarization_system.get_status_info()
615
+
616
+
617
+ def process_audio_stream(audio):
618
+ """Process audio stream from WebRTC"""
619
+ if diarization_system.webrtc_processor and diarization_system.is_running:
620
+ diarization_system.webrtc_processor.process_audio(audio, SAMPLE_RATE)
621
+ return None
622
+
623
+
624
  # Create Gradio interface
625
  def create_interface():
626
+ with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Dark()) as app:
627
+ gr.Markdown("# ๐ŸŽค Real-time Speech Recognition with Speaker Diarization")
628
+ gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding using WebRTC.")
629
 
630
  with gr.Row():
631
  with gr.Column(scale=2):
632
+ # WebRTC Audio Input
633
+ audio_input = gr.Audio(
634
+ sources=["microphone"],
635
+ streaming=True,
636
+ label="๐ŸŽ™๏ธ Microphone Input",
637
+ type="numpy"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  )
639
 
640
+ # Main conversation display
641
  conversation_output = gr.HTML(
642
+ value="<i>Click 'Initialize System' to start...</i>",
643
+ label="Live Conversation"
644
  )
645
 
646
  # Control buttons
647
  with gr.Row():
648
+ init_btn = gr.Button("๐Ÿ”ง Initialize System", variant="secondary")
649
+ start_btn = gr.Button("๐ŸŽ™๏ธ Start Recording", variant="primary", interactive=False)
650
+ stop_btn = gr.Button("โน๏ธ Stop Recording", variant="stop", interactive=False)
651
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Conversation", interactive=False)
652
+
653
+ # Status display
654
+ status_output = gr.Textbox(
655
+ label="System Status",
656
+ value="System not initialized",
657
+ lines=8,
658
+ interactive=False
659
+ )
660
 
661
  with gr.Column(scale=1):
662
  # Settings panel
 
668
  step=0.05,
669
  value=DEFAULT_CHANGE_THRESHOLD,
670
  label="Speaker Change Sensitivity",
671
+ info="Lower values = more sensitive to speaker changes"
672
  )
673
 
674
  max_speakers_slider = gr.Slider(
 
679
  label="Maximum Number of Speakers"
680
  )
681
 
682
+ update_settings_btn = gr.Button("Update Settings")
683
 
684
+ # Instructions
685
+ gr.Markdown("## ๐Ÿ“ Instructions")
686
+ gr.Markdown("""
687
+ 1. Click **Initialize System** to load models
688
+ 2. Click **Start Recording** to begin processing
689
+ 3. Allow microphone access when prompted
690
+ 4. Speak into your microphone
691
+ 5. Watch real-time transcription with speaker labels
692
+ 6. Adjust settings as needed
693
+ """)
694
 
695
  # Speaker color legend
696
  gr.Markdown("## ๐ŸŽจ Speaker Colors")
697
  color_info = []
698
+ for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)):
699
+ color_info.append(f'<span style="color:{color};">โ– </span> Speaker {i+1} ({name})')
700
 
701
+ gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
702
+
703
+ # Auto-refresh conversation and status
704
+ def refresh_display():
705
+ return get_conversation(), get_status()
706
 
707
  # Event handlers
708
+ def on_initialize():
709
+ result = initialize_system()
710
+ if "successfully" in result:
711
+ return (
712
+ result,
713
+ gr.update(interactive=True), # start_btn
714
+ gr.update(interactive=True), # clear_btn
715
+ get_conversation(),
716
+ get_status()
717
+ )
718
+ else:
719
+ return (
720
+ result,
721
+ gr.update(interactive=False), # start_btn
722
+ gr.update(interactive=False), # clear_btn
723
+ get_conversation(),
724
+ get_status()
725
+ )
726
+
727
+ def on_start():
728
+ result = start_recording()
729
+ return (
730
+ result,
731
+ gr.update(interactive=False), # start_btn
732
+ gr.update(interactive=True), # stop_btn
733
+ )
734
+
735
+ def on_stop():
736
+ result = stop_recording()
737
+ return (
738
+ result,
739
+ gr.update(interactive=True), # start_btn
740
+ gr.update(interactive=False), # stop_btn
741
+ )
742
+
743
+ # Connect event handlers
744
  init_btn.click(
745
+ on_initialize,
746
+ outputs=[status_output, start_btn, clear_btn, conversation_output, status_output]
747
  )
748
 
749
+ start_btn.click(
750
+ on_start,
751
+ outputs=[status_output, start_btn, stop_btn]
 
752
  )
753
 
754
+ stop_btn.click(
755
+ on_stop,
756
+ outputs=[status_output, start_btn, stop_btn]
 
757
  )
758
 
759
  clear_btn.click(
760
  clear_conversation,
761
+ outputs=[status_output]
762
  )
763
 
764
  update_settings_btn.click(
765
  update_settings,
766
  inputs=[threshold_slider, max_speakers_slider],
767
+ outputs=[status_output]
768
+ )
769
+
770
+ # Connect WebRTC audio stream to processing
771
+ audio_input.stream(
772
+ process_audio_stream,
773
+ inputs=[audio_input],
774
+ outputs=[]
775
+ )
776
+
777
+ # Auto-refresh every 2 seconds when recording
778
+ refresh_timer = gr.Timer(2.0)
779
+ refresh_timer.tick(
780
+ refresh_display,
781
+ outputs=[conversation_output, status_output]
782
  )
783
 
784
  return app
 
788
  app = create_interface()
789
  app.launch(
790
  server_name="0.0.0.0",
791
+ server_port=7860,
792
+ share=True
793
  )