Saiyaswanth007 commited on
Commit
9a6051e
·
1 Parent(s): 3466e71

Check point 2

Browse files
Files changed (1) hide show
  1. app.py +458 -920
app.py CHANGED
@@ -10,15 +10,17 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, AsyncStreamHandler, ReplyOnPause, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
14
  import json
15
- import io
16
- import wave
17
  import asyncio
18
  import uvicorn
19
- import socket
20
  from queue import Queue
21
- import time
 
 
 
 
 
22
  # Simplified configuration parameters
23
  SILENCE_THRESHS = [0, 0.4]
24
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
@@ -32,35 +34,31 @@ MIN_LENGTH_OF_RECORDING = 0.7
32
  PRE_RECORDING_BUFFER_DURATION = 0.35
33
 
34
  # Speaker change detection parameters
35
- DEFAULT_CHANGE_THRESHOLD = 0.7
36
  EMBEDDING_HISTORY_SIZE = 5
37
- MIN_SEGMENT_DURATION = 1.0
38
  DEFAULT_MAX_SPEAKERS = 4
39
- ABSOLUTE_MAX_SPEAKERS = 10
40
 
41
  # Global variables
42
- FAST_SENTENCE_END = True
43
  SAMPLE_RATE = 16000
44
- BUFFER_SIZE = 512
45
  CHANNELS = 1
46
 
47
- # Speaker colors
48
  SPEAKER_COLORS = [
49
- "#FFFF00", # Yellow
50
- "#FF0000", # Red
51
- "#00FF00", # Green
52
- "#00FFFF", # Cyan
53
- "#FF00FF", # Magenta
54
- "#0000FF", # Blue
55
- "#FF8000", # Orange
56
- "#00FF80", # Spring Green
57
- "#8000FF", # Purple
58
- "#FFFFFF", # White
59
  ]
60
 
61
  SPEAKER_COLOR_NAMES = [
62
- "Yellow", "Red", "Green", "Cyan", "Magenta",
63
- "Blue", "Orange", "Spring Green", "Purple", "White"
64
  ]
65
 
66
 
@@ -74,24 +72,11 @@ class SpeechBrainEncoder:
74
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
75
  os.makedirs(self.cache_dir, exist_ok=True)
76
 
77
- def _download_model(self):
78
- """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
79
- model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
80
- model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
81
-
82
- if not os.path.exists(model_path):
83
- print(f"Downloading ECAPA-TDNN model to {model_path}...")
84
- urllib.request.urlretrieve(model_url, model_path)
85
-
86
- return model_path
87
-
88
  def load_model(self):
89
  """Load the ECAPA-TDNN model"""
90
  try:
91
  from speechbrain.pretrained import EncoderClassifier
92
 
93
- model_path = self._download_model()
94
-
95
  self.model = EncoderClassifier.from_hparams(
96
  source="speechbrain/spkrec-ecapa-voxceleb",
97
  savedir=self.cache_dir,
@@ -99,9 +84,10 @@ class SpeechBrainEncoder:
99
  )
100
 
101
  self.model_loaded = True
 
102
  return True
103
  except Exception as e:
104
- print(f"Error loading ECAPA-TDNN model: {e}")
105
  return False
106
 
107
  def embed_utterance(self, audio, sr=16000):
@@ -111,10 +97,15 @@ class SpeechBrainEncoder:
111
 
112
  try:
113
  if isinstance(audio, np.ndarray):
114
- waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
 
 
 
 
115
  else:
116
  waveform = audio.unsqueeze(0)
117
 
 
118
  if sr != 16000:
119
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
120
 
@@ -123,7 +114,7 @@ class SpeechBrainEncoder:
123
 
124
  return embedding.squeeze().cpu().numpy()
125
  except Exception as e:
126
- print(f"Error extracting embedding: {e}")
127
  return np.zeros(self.embedding_dim)
128
 
129
 
@@ -131,41 +122,60 @@ class AudioProcessor:
131
  """Processes audio data to extract speaker embeddings"""
132
  def __init__(self, encoder):
133
  self.encoder = encoder
 
 
134
 
135
- def extract_embedding(self, audio_int16):
136
- try:
137
- float_audio = audio_int16.astype(np.float32) / 32768.0
 
 
 
 
 
 
 
 
 
 
138
 
139
- if np.abs(float_audio).max() > 1.0:
140
- float_audio = float_audio / np.abs(float_audio).max()
 
141
 
142
- embedding = self.encoder.embed_utterance(float_audio)
 
 
 
 
143
 
 
144
  return embedding
145
  except Exception as e:
146
- print(f"Embedding extraction error: {e}")
147
- return np.zeros(self.encoder.embedding_dim)
148
 
149
 
150
  class SpeakerChangeDetector:
151
- """Speaker change detector that supports a configurable number of speakers"""
152
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
153
  self.embedding_dim = embedding_dim
154
  self.change_threshold = change_threshold
155
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
156
  self.current_speaker = 0
157
- self.previous_embeddings = []
158
- self.last_change_time = time.time()
159
- self.mean_embeddings = [None] * self.max_speakers
160
  self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
161
- self.last_similarity = 0.0
 
 
162
  self.active_speakers = set([0])
 
163
 
164
  def set_max_speakers(self, max_speakers):
165
  """Update the maximum number of speakers"""
166
  new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
167
 
168
  if new_max < self.max_speakers:
 
169
  for speaker_id in list(self.active_speakers):
170
  if speaker_id >= new_max:
171
  self.active_speakers.discard(speaker_id)
@@ -173,85 +183,85 @@ class SpeakerChangeDetector:
173
  if self.current_speaker >= new_max:
174
  self.current_speaker = 0
175
 
 
176
  if new_max > self.max_speakers:
177
- self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
178
  self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
 
179
  else:
180
- self.mean_embeddings = self.mean_embeddings[:new_max]
181
  self.speaker_embeddings = self.speaker_embeddings[:new_max]
 
182
 
183
  self.max_speakers = new_max
184
 
185
  def set_change_threshold(self, threshold):
186
  """Update the threshold for detecting speaker changes"""
187
- self.change_threshold = max(0.1, min(threshold, 0.99))
188
 
189
  def add_embedding(self, embedding, timestamp=None):
190
- """Add a new embedding and check if there's a speaker change"""
191
  current_time = timestamp or time.time()
192
-
193
- if not self.previous_embeddings:
194
- self.previous_embeddings.append(embedding)
195
- self.speaker_embeddings[self.current_speaker].append(embedding)
196
- if self.mean_embeddings[self.current_speaker] is None:
197
- self.mean_embeddings[self.current_speaker] = embedding.copy()
198
- return self.current_speaker, 1.0
199
-
200
- current_mean = self.mean_embeddings[self.current_speaker]
201
- if current_mean is not None:
202
- similarity = 1.0 - cosine(embedding, current_mean)
 
 
203
  else:
204
- similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
205
 
206
  self.last_similarity = similarity
207
 
 
208
  time_since_last_change = current_time - self.last_change_time
209
- is_speaker_change = False
210
 
211
- if time_since_last_change >= MIN_SEGMENT_DURATION:
212
- if similarity < self.change_threshold:
213
- best_speaker = self.current_speaker
214
- best_similarity = similarity
215
-
216
- for speaker_id in range(self.max_speakers):
217
- if speaker_id == self.current_speaker:
218
- continue
219
-
220
- speaker_mean = self.mean_embeddings[speaker_id]
221
 
222
- if speaker_mean is not None:
223
- speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
224
-
225
- if speaker_similarity > best_similarity:
226
- best_similarity = speaker_similarity
227
- best_speaker = speaker_id
228
-
229
- if best_speaker != self.current_speaker:
230
- is_speaker_change = True
231
- self.current_speaker = best_speaker
232
- elif len(self.active_speakers) < self.max_speakers:
233
- for new_id in range(self.max_speakers):
234
- if new_id not in self.active_speakers:
235
- is_speaker_change = True
236
- self.current_speaker = new_id
237
- self.active_speakers.add(new_id)
238
- break
239
-
240
- if is_speaker_change:
241
- self.last_change_time = current_time
242
-
243
- self.previous_embeddings.append(embedding)
244
- if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
245
- self.previous_embeddings.pop(0)
246
-
247
  self.speaker_embeddings[self.current_speaker].append(embedding)
248
- self.active_speakers.add(self.current_speaker)
249
 
250
- if len(self.speaker_embeddings[self.current_speaker]) > 30:
251
- self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
252
-
 
 
 
253
  if self.speaker_embeddings[self.current_speaker]:
254
- self.mean_embeddings[self.current_speaker] = np.mean(
255
  self.speaker_embeddings[self.current_speaker], axis=0
256
  )
257
 
@@ -264,7 +274,7 @@ class SpeakerChangeDetector:
264
  return "#FFFFFF"
265
 
266
  def get_status_info(self):
267
- """Return status information about the speaker change detector"""
268
  speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
269
 
270
  return {
@@ -273,7 +283,8 @@ class SpeakerChangeDetector:
273
  "active_speakers": len(self.active_speakers),
274
  "max_speakers": self.max_speakers,
275
  "last_similarity": self.last_similarity,
276
- "threshold": self.change_threshold
 
277
  }
278
 
279
 
@@ -283,173 +294,121 @@ class RealtimeSpeakerDiarization:
283
  self.audio_processor = None
284
  self.speaker_detector = None
285
  self.recorder = None
286
- self.sentence_queue = queue.Queue(maxsize=100) # Add maxsize to prevent unlimited growth
287
  self.full_sentences = []
288
  self.sentence_speakers = []
289
  self.pending_sentences = []
290
- self.displayed_text = ""
291
- self.last_realtime_text = ""
292
  self.is_running = False
293
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
294
  self.max_speakers = DEFAULT_MAX_SPEAKERS
295
- self.current_conversation = ""
296
- self.audio_buffer = []
297
- # Add locks for thread safety
298
- self._state_lock = threading.RLock() # Reentrant lock for shared state
299
- self._audio_lock = threading.Lock() # Lock for audio processing
300
 
301
  def initialize_models(self):
302
  """Initialize the speaker encoder model"""
303
  try:
304
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
305
- print(f"Using device: {device_str}")
306
 
307
  self.encoder = SpeechBrainEncoder(device=device_str)
 
308
 
309
- # Try to load model with timeout
310
- import threading
311
- load_success = [False]
312
-
313
- def load_model_thread():
314
- try:
315
- success = self.encoder.load_model()
316
- load_success[0] = success
317
- except Exception as e:
318
- print(f"Error in model loading thread: {e}")
319
-
320
- # Start loading in a thread with timeout
321
- load_thread = threading.Thread(target=load_model_thread)
322
- load_thread.daemon = True
323
- load_thread.start()
324
- load_thread.join(timeout=60) # 60 second timeout for model loading
325
-
326
- if load_success[0]:
327
  self.audio_processor = AudioProcessor(self.encoder)
328
  self.speaker_detector = SpeakerChangeDetector(
329
  embedding_dim=self.encoder.embedding_dim,
330
  change_threshold=self.change_threshold,
331
  max_speakers=self.max_speakers
332
  )
333
- print("ECAPA-TDNN model loaded successfully!")
334
  return True
335
  else:
336
- print("Failed to load ECAPA-TDNN model or timeout occurred")
337
- return self._initialize_fallback()
338
-
339
- except Exception as e:
340
- print(f"Model initialization error: {e}")
341
- import traceback
342
- traceback.print_exc()
343
- return self._initialize_fallback()
344
-
345
- def _initialize_fallback(self):
346
- """Initialize fallback mode when model loading fails"""
347
- try:
348
- print("Initializing fallback mode with simple speaker detection...")
349
- # Create a simple embedding dimension
350
- embedding_dim = 64
351
-
352
- # Create a dummy encoder that produces random embeddings
353
- class DummyEncoder:
354
- def __init__(self):
355
- self.embedding_dim = embedding_dim
356
- self.model_loaded = True
357
-
358
- def embed_utterance(self, audio, sr=16000):
359
- # Simple energy-based pseudo-embedding
360
- if isinstance(audio, np.ndarray):
361
- # Create a simple feature vector (not a real embedding)
362
- energy = np.mean(np.abs(audio))
363
- # Create a pseudo-random but consistent embedding based on audio energy
364
- np.random.seed(int(energy * 1000))
365
- return np.random.rand(embedding_dim)
366
- return np.random.rand(embedding_dim)
367
-
368
- # Set up system with fallback components
369
- self.encoder = DummyEncoder()
370
- self.audio_processor = AudioProcessor(self.encoder)
371
- self.speaker_detector = SpeakerChangeDetector(
372
- embedding_dim=embedding_dim,
373
- change_threshold=self.change_threshold,
374
- max_speakers=2 # Limit speakers in fallback mode
375
- )
376
-
377
- print("Fallback mode initialized - limited functionality!")
378
- return True
379
-
380
  except Exception as e:
381
- print(f"Even fallback initialization failed: {e}")
382
  return False
383
 
384
  def live_text_detected(self, text):
385
  """Callback for real-time transcription updates"""
386
- text = text.strip()
387
- if text:
388
- sentence_delimiters = '.?!。'
389
- prob_sentence_end = (
390
- len(self.last_realtime_text) > 0
391
- and text[-1] in sentence_delimiters
392
- and self.last_realtime_text[-1] in sentence_delimiters
393
- )
394
-
395
- self.last_realtime_text = text
396
-
397
- if prob_sentence_end and FAST_SENTENCE_END:
398
- self.recorder.stop()
399
- elif prob_sentence_end:
400
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
401
- else:
402
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
403
 
404
  def process_final_text(self, text):
405
  """Process final transcribed text with speaker embedding"""
406
  text = text.strip()
407
  if text:
408
  try:
409
- bytes_data = self.recorder.last_transcription_bytes
410
- self.sentence_queue.put((text, bytes_data), timeout=1.0) # Added timeout
411
- with self._state_lock:
412
- self.pending_sentences.append(text)
 
 
 
 
413
  except Exception as e:
414
- print(f"Error processing final text: {e}")
415
 
416
  def process_sentence_queue(self):
417
  """Process sentences in the queue for speaker detection"""
418
  while self.is_running:
419
  try:
420
- text, bytes_data = self.sentence_queue.get(timeout=1)
421
 
422
- # Convert audio data to int16
423
- audio_int16 = np.frombuffer(bytes_data, dtype=np.int16)
424
 
425
- # Extract speaker embedding
426
- speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
427
-
428
- with self._state_lock:
429
- # Store sentence and embedding
430
- self.full_sentences.append((text, speaker_embedding))
431
-
432
- # Fill in missing speaker assignments
433
- while len(self.sentence_speakers) < len(self.full_sentences) - 1:
434
- self.sentence_speakers.append(0)
435
 
436
- # Detect speaker changes
437
- speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
438
- self.sentence_speakers.append(speaker_id)
439
-
440
- # Remove from pending
441
- if text in self.pending_sentences:
442
- self.pending_sentences.remove(text)
443
-
444
- # Update conversation display
445
- self.current_conversation = self.get_formatted_conversation()
446
 
447
  except queue.Empty:
448
  continue
449
  except Exception as e:
450
- print(f"Error processing sentence: {e}")
451
- import traceback
452
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  def start_recording(self):
455
  """Start the recording and transcription process"""
@@ -457,10 +416,10 @@ class RealtimeSpeakerDiarization:
457
  return "Please initialize models first!"
458
 
459
  try:
460
- # Setup recorder configuration for manual audio input
461
  recorder_config = {
462
  'spinner': False,
463
- 'use_microphone': False, # We'll feed audio manually
464
  'model': FINAL_TRANSCRIPTION_MODEL,
465
  'language': TRANSCRIPTION_LANGUAGE,
466
  'silero_sensitivity': SILERO_SENSITIVITY,
@@ -470,45 +429,29 @@ class RealtimeSpeakerDiarization:
470
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
471
  'min_gap_between_recordings': 0,
472
  'enable_realtime_transcription': True,
473
- 'realtime_processing_pause': 0,
474
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
475
  'on_realtime_transcription_update': self.live_text_detected,
476
  'beam_size': FINAL_BEAM_SIZE,
477
  'beam_size_realtime': REALTIME_BEAM_SIZE,
478
- 'buffer_size': BUFFER_SIZE,
479
  'sample_rate': SAMPLE_RATE,
480
- 'external_audio': True, # Signal that we'll provide audio
481
  }
482
 
483
- # Make sure we're not running already
484
- if hasattr(self, 'is_running') and self.is_running:
485
- self.stop_recording()
486
- # Short pause to ensure cleanup completes
487
- time.sleep(0.5)
488
-
489
  self.recorder = AudioToTextRecorder(**recorder_config)
490
 
491
- # Reset state
492
- with self._state_lock:
493
- self.pending_sentences = []
494
- self.last_realtime_text = ""
495
-
496
- # Start sentence processing thread
497
  self.is_running = True
498
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
499
  self.sentence_thread.start()
500
 
501
- # Start transcription thread
502
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
503
  self.transcription_thread.start()
504
 
505
- return "Recording started successfully! FastRTC audio input ready."
506
 
507
  except Exception as e:
508
- self.is_running = False
509
- import traceback
510
- traceback.print_exc()
511
- return f"Error starting recording: {str(e)}"
512
 
513
  def run_transcription(self):
514
  """Run the transcription loop"""
@@ -516,63 +459,21 @@ class RealtimeSpeakerDiarization:
516
  while self.is_running:
517
  self.recorder.text(self.process_final_text)
518
  except Exception as e:
519
- print(f"Transcription error: {e}")
520
 
521
  def stop_recording(self):
522
  """Stop the recording process"""
523
  self.is_running = False
524
  if self.recorder:
525
  self.recorder.stop()
526
-
527
- # Wait for threads to finish
528
- self._cleanup_resources()
529
-
530
  return "Recording stopped!"
531
 
532
- def _cleanup_resources(self):
533
- """Clean up resources and threads"""
534
- try:
535
- # Wait for threads to stop gracefully
536
- if hasattr(self, 'sentence_thread') and self.sentence_thread is not None:
537
- if self.sentence_thread.is_alive():
538
- self.sentence_thread.join(timeout=3.0)
539
-
540
- if hasattr(self, 'transcription_thread') and self.transcription_thread is not None:
541
- if self.transcription_thread.is_alive():
542
- self.transcription_thread.join(timeout=3.0)
543
-
544
- # Clean up memory
545
- with self._state_lock:
546
- # Limit history size to prevent memory leaks
547
- if len(self.full_sentences) > 1000:
548
- self.full_sentences = self.full_sentences[-1000:]
549
- if len(self.sentence_speakers) > 1000:
550
- self.sentence_speakers = self.sentence_speakers[-1000:]
551
-
552
- # Clear audio buffer
553
- with self._audio_lock:
554
- self.audio_buffer = []
555
-
556
- # Clear queue
557
- while not self.sentence_queue.empty():
558
- try:
559
- self.sentence_queue.get_nowait()
560
- except:
561
- pass
562
-
563
- except Exception as e:
564
- print(f"Error during resource cleanup: {e}")
565
- import traceback
566
- traceback.print_exc()
567
-
568
  def clear_conversation(self):
569
  """Clear all conversation data"""
570
- self.full_sentences = []
571
- self.sentence_speakers = []
572
- self.pending_sentences = []
573
- self.displayed_text = ""
574
- self.last_realtime_text = ""
575
- self.current_conversation = "Conversation cleared!"
576
 
577
  if self.speaker_detector:
578
  self.speaker_detector = SpeakerChangeDetector(
@@ -595,36 +496,8 @@ class RealtimeSpeakerDiarization:
595
  return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
596
 
597
  def get_formatted_conversation(self):
598
- """Get the formatted conversation with speaker colors"""
599
- try:
600
- sentences_with_style = []
601
-
602
- # Process completed sentences
603
- for i, sentence in enumerate(self.full_sentences):
604
- sentence_text, _ = sentence
605
- if i >= len(self.sentence_speakers):
606
- color = "#FFFFFF"
607
- speaker_name = "Unknown"
608
- else:
609
- speaker_id = self.sentence_speakers[i]
610
- color = self.speaker_detector.get_color_for_speaker(speaker_id)
611
- speaker_name = f"Speaker {speaker_id + 1}"
612
-
613
- sentences_with_style.append(
614
- f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
615
-
616
- # Add pending sentences
617
- for pending_sentence in self.pending_sentences:
618
- sentences_with_style.append(
619
- f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>')
620
-
621
- if sentences_with_style:
622
- return "<br><br>".join(sentences_with_style)
623
- else:
624
- return "Waiting for speech input..."
625
-
626
- except Exception as e:
627
- return f"Error formatting conversation: {e}"
628
 
629
  def get_status_info(self):
630
  """Get current status information"""
@@ -640,808 +513,473 @@ class RealtimeSpeakerDiarization:
640
  f"**Last Similarity:** {status['last_similarity']:.3f}",
641
  f"**Change Threshold:** {status['threshold']:.2f}",
642
  f"**Total Sentences:** {len(self.full_sentences)}",
 
643
  "",
644
- "**Speaker Segment Counts:**"
645
  ]
646
 
647
  for i in range(status['max_speakers']):
648
  color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
649
- status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
 
 
650
 
651
  return "\n".join(status_lines)
652
 
653
  except Exception as e:
654
  return f"Error getting status: {e}"
655
 
656
- def feed_audio_data(self, audio_data):
657
- """Feed audio data to the recorder"""
658
- if not self.is_running or not self.recorder:
659
- return
660
-
661
- try:
662
- # Ensure audio is in the correct format (16-bit PCM)
663
- if isinstance(audio_data, np.ndarray):
664
- if audio_data.dtype != np.int16:
665
- # Convert float to int16
666
- if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
667
- audio_data = (audio_data * 32767).astype(np.int16)
668
- else:
669
- audio_data = audio_data.astype(np.int16)
670
-
671
- # Convert to bytes
672
- audio_bytes = audio_data.tobytes()
673
- else:
674
- audio_bytes = audio_data
675
-
676
- # Use the recorder's internal buffer mechanism
677
- if hasattr(self.recorder, 'feed_audio') and callable(self.recorder.feed_audio):
678
- self.recorder.feed_audio(audio_bytes)
679
- else:
680
- # Fallback: Direct access to the underlying buffer if the method doesn't exist
681
- self.audio_buffer.append(audio_bytes)
682
- # Process buffered audio when enough is accumulated
683
- if len(self.audio_buffer) > 5: # Process in small batches
684
- combined = b''.join(self.audio_buffer)
685
- if hasattr(self.recorder, '_process_audio'):
686
- self.recorder._process_audio(combined)
687
- self.audio_buffer = []
688
-
689
- except Exception as e:
690
- print(f"Error feeding audio data: {str(e)}")
691
- import traceback
692
- traceback.print_exc()
693
-
694
  def process_audio_chunk(self, audio_data, sample_rate=16000):
695
  """Process audio chunk from FastRTC input"""
696
- if not self.is_running or self.recorder is None:
697
  return
698
 
699
  try:
700
- with self._audio_lock:
701
- # Use the normalized audio function
702
- audio_int16 = self._normalize_audio_format(audio_data, target_dtype=np.int16, target_sample_rate=SAMPLE_RATE)
703
-
704
- # Check if we got valid audio
705
- if audio_int16.size == 0:
706
- print("Warning: Empty audio chunk received")
707
- return
708
-
709
- # Resample if needed
710
- if sample_rate != SAMPLE_RATE:
711
- audio_int16 = self._resample_audio(audio_int16, sample_rate, SAMPLE_RATE)
712
-
713
- # Convert to bytes for feeding to recorder
714
- audio_bytes = audio_int16.tobytes()
715
-
716
- # Feed to recorder
717
- self.feed_audio_data(audio_bytes)
718
-
719
- except Exception as e:
720
- print(f"Error processing audio chunk: {str(e)}")
721
- import traceback
722
- traceback.print_exc()
723
-
724
- def _resample_audio(self, audio, orig_sr, target_sr):
725
- """Resample audio to target sample rate"""
726
- try:
727
- import scipy.signal
728
-
729
- # Get the resampling ratio
730
- ratio = target_sr / orig_sr
731
 
732
- # Calculate the new length
733
- new_length = int(len(audio[0]) * ratio)
 
734
 
735
- # Resample the audio
736
- resampled = scipy.signal.resample(audio[0], new_length)
 
737
 
738
- # Return in the same shape format
739
- return np.expand_dims(resampled, 0)
740
- except Exception as e:
741
- print(f"Error resampling audio: {e}")
742
- return audio
743
-
744
- def _normalize_audio_format(self, audio_data, target_dtype=np.int16, target_sample_rate=SAMPLE_RATE):
745
- """Normalize audio data to consistent format
746
-
747
- Args:
748
- audio_data: Input audio as numpy array or bytes
749
- target_dtype: Target data type (np.int16 or np.float32)
750
- target_sample_rate: Target sample rate
751
 
752
- Returns:
753
- Normalized audio as numpy array in requested format
754
- """
755
- try:
756
- # Convert bytes to numpy if needed
757
- if isinstance(audio_data, bytes):
758
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
759
- elif isinstance(audio_data, (list, tuple)):
760
- audio_array = np.array(audio_data)
761
- else:
762
- audio_array = audio_data
763
-
764
- # Convert data type as needed
765
- if target_dtype == np.int16 and audio_array.dtype != np.int16:
766
- if audio_array.dtype == np.float32 or audio_array.dtype == np.float64:
767
- # Check if normalized to [-1, 1] range
768
- if np.max(np.abs(audio_array)) <= 1.0:
769
- audio_array = (audio_array * 32767).astype(np.int16)
770
- else:
771
- audio_array = audio_array.astype(np.int16)
772
- else:
773
- audio_array = audio_array.astype(np.int16)
774
- elif target_dtype == np.float32 and audio_array.dtype != np.float32:
775
- if audio_array.dtype == np.int16:
776
- audio_array = audio_array.astype(np.float32) / 32768.0
777
- else:
778
- audio_array = audio_array.astype(np.float32)
779
-
780
- # Ensure mono audio
781
- if len(audio_array.shape) > 1 and audio_array.shape[1] > 1:
782
- audio_array = np.mean(audio_array, axis=1)
783
-
784
- # Reshape if needed
785
- if len(audio_array.shape) == 1:
786
- if target_dtype == np.int16:
787
- audio_array = np.expand_dims(audio_array, 0)
788
 
789
- return audio_array
790
-
791
  except Exception as e:
792
- print(f"Error normalizing audio format: {e}")
793
- import traceback
794
- traceback.print_exc()
795
- # Return empty array of correct type as fallback
796
- return np.array([], dtype=target_dtype)
797
-
798
 
799
- # FastRTC Audio Handler for Real-time Diarization
800
 
 
801
  class DiarizationHandler(AsyncStreamHandler):
802
  def __init__(self, diarization_system):
803
  super().__init__()
804
  self.diarization_system = diarization_system
805
- self.audio_queue = asyncio.Queue(maxsize=100) # Use asyncio queue
806
- self.is_processing = False
807
- self.sample_rate = 16000 # Default sample rate
808
- self.processing_task = None
809
 
810
  def copy(self):
811
  """Return a fresh handler for each new stream connection"""
812
  return DiarizationHandler(self.diarization_system)
813
 
814
  async def emit(self):
815
- """Not used in this implementation - we only receive audio"""
816
  return None
817
 
818
  async def receive(self, frame):
819
- """Receive audio data from FastRTC and process it"""
820
  try:
821
  if not self.diarization_system.is_running:
822
  return
823
 
824
- # Extract audio data from frame
825
- if hasattr(frame, 'data') and frame.data is not None:
826
- audio_data = frame.data
827
- elif hasattr(frame, 'audio') and frame.audio is not None:
828
- audio_data = frame.audio
 
 
 
829
  else:
830
- audio_data = frame
831
 
832
- # Get sample rate from frame if available
833
- sample_rate = getattr(frame, 'sample_rate', self.sample_rate)
 
834
 
835
- # Add to queue - non-blocking with timeout
836
- try:
837
- # Use put_nowait with try/except to avoid blocking
838
- await asyncio.wait_for(
839
- self.audio_queue.put((audio_data, sample_rate)),
840
- timeout=0.1
841
- )
842
- except asyncio.TimeoutError:
843
- # Queue is full, drop this chunk
844
- print("Warning: Audio queue full, dropping frame")
845
- return
846
 
847
  except Exception as e:
848
- print(f"Error in FastRTC audio receive: {e}")
849
- import traceback
850
- traceback.print_exc()
851
 
852
- async def _process_audio_loop(self):
853
- """Background task to process audio from queue"""
854
- while self.is_processing:
855
- try:
856
- # Get from queue with timeout to allow checking is_processing flag
857
- try:
858
- audio_data, sample_rate = await asyncio.wait_for(
859
- self.audio_queue.get(),
860
- timeout=0.5
861
- )
862
- except asyncio.TimeoutError:
863
- # No audio available, check if we should keep running
864
- continue
865
-
866
- # Convert to numpy array if needed
867
- if isinstance(audio_data, bytes):
868
- # Convert bytes to numpy array (assuming 16-bit PCM)
869
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
870
- # Normalize to float32 range [-1, 1]
871
- audio_array = audio_array.astype(np.float32) / 32768.0
872
- elif isinstance(audio_data, (list, tuple)):
873
- audio_array = np.array(audio_data, dtype=np.float32)
874
- elif isinstance(audio_data, np.ndarray):
875
- audio_array = audio_array.astype(np.float32)
876
- else:
877
- print(f"Unknown audio data type: {type(audio_data)}")
878
- continue
879
-
880
- # Ensure mono audio
881
- if len(audio_array.shape) > 1 and audio_array.shape[1] > 1:
882
- audio_array = np.mean(audio_array, axis=1)
883
-
884
- # Ensure 1D array
885
- if len(audio_array.shape) > 1:
886
- audio_array = audio_array.flatten()
887
-
888
- # Process audio through thread pool to avoid blocking event loop
889
- await self.process_audio_async(audio_array, sample_rate)
890
-
891
- # Mark as done
892
- self.audio_queue.task_done()
893
-
894
- except Exception as e:
895
- print(f"Error in audio processing loop: {e}")
896
- import traceback
897
- traceback.print_exc()
898
- # Short sleep to avoid tight loop
899
- await asyncio.sleep(0.1)
900
-
901
- async def process_audio_async(self, audio_data, sample_rate=16000):
902
  """Process audio data asynchronously"""
903
  try:
904
- # Run the audio processing in a thread pool to avoid blocking
905
  loop = asyncio.get_event_loop()
906
  await loop.run_in_executor(
907
  None,
908
  self.diarization_system.process_audio_chunk,
909
  audio_data,
910
- sample_rate
911
  )
912
  except Exception as e:
913
- print(f"Error in async audio processing: {e}")
914
-
915
- async def start_up(self) -> None:
916
- """Initialize any resources when the stream starts"""
917
- print("FastRTC stream started")
918
- self.is_processing = True
919
-
920
- # Start background processing task
921
- self.processing_task = asyncio.create_task(self._process_audio_loop())
922
-
923
- async def shutdown(self) -> None:
924
- """Clean up any resources when the stream ends"""
925
- print("FastRTC stream shutting down")
926
- self.is_processing = False
927
-
928
- # Wait for processing task to finish
929
- if self.processing_task:
930
- try:
931
- # Cancel and wait for task
932
- self.processing_task.cancel()
933
- await asyncio.wait([self.processing_task], timeout=2.0)
934
- except (asyncio.CancelledError, Exception) as e:
935
- print(f"Error cancelling audio processing task: {e}")
936
-
937
- # Clear queue
938
- while not self.audio_queue.empty():
939
- try:
940
- self.audio_queue.get_nowait()
941
- self.audio_queue.task_done()
942
- except:
943
- pass
944
 
945
 
946
  # Global instances
947
- diarization_system = None # Will be initialized when RealtimeSpeakerDiarization is available
948
  audio_handler = None
949
 
950
-
951
  def initialize_system():
952
  """Initialize the diarization system"""
953
- global audio_handler, diarization_system
954
  try:
955
- if diarization_system is None:
956
- print("Error: RealtimeSpeakerDiarization not initialized")
957
- return "❌ Diarization system not available. Please ensure RealtimeSpeakerDiarization is properly imported."
958
-
959
  success = diarization_system.initialize_models()
960
  if success:
961
  audio_handler = DiarizationHandler(diarization_system)
962
- return "✅ System initialized successfully! Models loaded and FastRTC handler ready."
963
  else:
964
- return "❌ Failed to initialize system. Please check the logs."
965
  except Exception as e:
966
- print(f"Initialization error: {e}")
967
  return f"❌ Initialization error: {str(e)}"
968
 
969
-
970
  def start_recording():
971
  """Start recording and transcription"""
972
  try:
973
- if diarization_system is None:
974
- return "❌ System not initialized"
975
  result = diarization_system.start_recording()
976
- return f"🎙️ {result} - FastRTC audio streaming is active."
977
  except Exception as e:
978
  return f"❌ Failed to start recording: {str(e)}"
979
 
980
-
981
  def stop_recording():
982
  """Stop recording and transcription"""
983
  try:
984
- if diarization_system is None:
985
- return "❌ System not initialized"
986
  result = diarization_system.stop_recording()
987
  return f"⏹️ {result}"
988
  except Exception as e:
989
  return f"❌ Failed to stop recording: {str(e)}"
990
 
991
-
992
  def clear_conversation():
993
  """Clear the conversation"""
994
  try:
995
- if diarization_system is None:
996
- return "❌ System not initialized"
997
  result = diarization_system.clear_conversation()
998
  return f"🗑️ {result}"
999
  except Exception as e:
1000
  return f"❌ Failed to clear conversation: {str(e)}"
1001
 
1002
-
1003
  def update_settings(threshold, max_speakers):
1004
  """Update system settings"""
1005
  try:
1006
- if diarization_system is None:
1007
- return "❌ System not initialized"
1008
  result = diarization_system.update_settings(threshold, max_speakers)
1009
  return f"⚙️ {result}"
1010
  except Exception as e:
1011
  return f"❌ Failed to update settings: {str(e)}"
1012
 
1013
-
1014
  def get_conversation():
1015
  """Get the current conversation"""
1016
  try:
1017
- if diarization_system is None:
1018
- return "<i>System not initialized</i>"
1019
  return diarization_system.get_formatted_conversation()
1020
  except Exception as e:
1021
  return f"<i>Error getting conversation: {str(e)}</i>"
1022
 
1023
-
1024
  def get_status():
1025
  """Get system status"""
1026
  try:
1027
- if diarization_system is None:
1028
- return "System not initialized"
1029
  return diarization_system.get_status_info()
1030
  except Exception as e:
1031
  return f"Error getting status: {str(e)}"
1032
 
1033
-
1034
  # Create Gradio interface
1035
  def create_interface():
1036
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
1037
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
1038
- gr.Markdown("This app performs real-time speech recognition with automatic speaker identification using FastRTC for low-latency audio streaming.")
1039
 
1040
  with gr.Row():
1041
  with gr.Column(scale=2):
1042
- # Main conversation display
1043
  conversation_output = gr.HTML(
1044
- value="<div style='padding: 20px; background: #f5f5f5; border-radius: 10px;'><i>Click 'Initialize System' to start...</i></div>",
1045
- label="Live Conversation",
1046
- elem_id="conversation_display"
1047
  )
1048
 
1049
  # Control buttons
1050
  with gr.Row():
1051
  init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
1052
- start_btn = gr.Button("🎙️ Start Recording", variant="primary", size="lg", interactive=False)
1053
- stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", size="lg", interactive=False)
1054
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
1055
 
1056
- # FastRTC Stream Interface
1057
- with gr.Row():
1058
- gr.HTML("""
1059
- <div id="fastrtc-container" style="border: 2px solid #ddd; border-radius: 10px; padding: 20px; margin: 10px 0;">
1060
- <h3>🎵 Audio Stream</h3>
1061
- <p>FastRTC audio stream will appear here when recording starts.</p>
1062
- <div id="stream-status" style="padding: 10px; background: #f8f9fa; border-radius: 5px; margin-top: 10px;">
1063
- Status: Waiting for initialization...
1064
- </div>
1065
- </div>
1066
- """)
1067
-
1068
  # Status display
1069
  status_output = gr.Textbox(
1070
  label="System Status",
1071
- value="System not initialized. Please click 'Initialize System' to begin.",
1072
- lines=6,
1073
- interactive=False,
1074
- show_copy_button=True
1075
  )
1076
 
1077
  with gr.Column(scale=1):
1078
- # Settings panel
1079
  gr.Markdown("## ⚙️ Settings")
1080
 
1081
  threshold_slider = gr.Slider(
1082
- minimum=0.1,
1083
- maximum=0.95,
1084
  step=0.05,
1085
- value=0.5, # DEFAULT_CHANGE_THRESHOLD
1086
  label="Speaker Change Sensitivity",
1087
- info="Lower = more sensitive to speaker changes"
1088
  )
1089
 
1090
  max_speakers_slider = gr.Slider(
1091
  minimum=2,
1092
- maximum=10, # ABSOLUTE_MAX_SPEAKERS
1093
  step=1,
1094
- value=4, # DEFAULT_MAX_SPEAKERS
1095
- label="Maximum Number of Speakers"
1096
  )
1097
 
1098
- update_settings_btn = gr.Button("Update Settings", variant="secondary")
1099
-
1100
- # Audio settings
1101
- gr.Markdown("## 🔊 Audio Configuration")
1102
- with gr.Accordion("Advanced Audio Settings", open=False):
1103
- gr.Markdown("""
1104
- **Current Configuration:**
1105
- - Sample Rate: 16kHz
1106
- - Audio Format: 16-bit PCM → Float32 (via AudioProcessor)
1107
- - Channels: Mono (stereo converted automatically)
1108
- - Buffer Size: 1024 samples for real-time processing
1109
- - Processing: Uses existing AudioProcessor.extract_embedding()
1110
- """)
1111
 
1112
  # Instructions
1113
- gr.Markdown("## 📝 How to Use")
1114
  gr.Markdown("""
1115
- 1. **Initialize**: Click "Initialize System" to load AI models
1116
- 2. **Start**: Click "Start Recording" to begin processing
1117
- 3. **Connect**: The FastRTC stream will activate automatically
1118
- 4. **Allow Access**: Grant microphone permissions when prompted
1119
- 5. **Speak**: Talk naturally into your microphone
1120
- 6. **Monitor**: Watch real-time transcription with speaker colors
 
 
 
 
 
 
 
 
 
1121
  """)
1122
-
1123
- # Performance tips
1124
- with gr.Accordion("💡 Performance Tips", open=False):
1125
- gr.Markdown("""
1126
- - Use Chrome/Edge for best FastRTC performance
1127
- - Ensure stable internet connection
1128
- - Use headphones to prevent echo
1129
- - Position microphone 6-12 inches away
1130
- - Minimize background noise
1131
- - Allow browser microphone access
1132
- """)
1133
-
1134
- # Speaker color legend
1135
- gr.Markdown("## 🎨 Speaker Colors")
1136
- speaker_colors = [
1137
- ("#FF6B6B", "Red"),
1138
- ("#4ECDC4", "Teal"),
1139
- ("#45B7D1", "Blue"),
1140
- ("#96CEB4", "Green"),
1141
- ("#FFEAA7", "Yellow"),
1142
- ("#DDA0DD", "Plum"),
1143
- ("#98D8C8", "Mint"),
1144
- ("#F7DC6F", "Gold")
1145
- ]
1146
-
1147
- color_html = ""
1148
- for i, (color, name) in enumerate(speaker_colors[:4]):
1149
- color_html += f'<div style="margin: 3px 0;"><span style="color:{color}; font-size: 16px; font-weight: bold;">●</span> Speaker {i+1} ({name})</div>'
1150
-
1151
- gr.HTML(f"<div style='font-size: 14px;'>{color_html}</div>")
1152
-
1153
- # Auto-refresh conversation and status
1154
- def refresh_display():
1155
- try:
1156
- conversation = get_conversation()
1157
- status = get_status()
1158
- return conversation, status
1159
- except Exception as e:
1160
- error_msg = f"Error refreshing display: {str(e)}"
1161
- return f"<i>{error_msg}</i>", error_msg
1162
 
1163
  # Event handlers
1164
  def on_initialize():
1165
- try:
1166
- result = initialize_system()
1167
- success = "successfully" in result.lower()
1168
-
1169
- conversation, status = refresh_display()
1170
-
1171
- return (
1172
- result, # status_output
1173
- gr.update(interactive=success), # start_btn
1174
- gr.update(interactive=success), # clear_btn
1175
- conversation, # conversation_output
1176
- )
1177
- except Exception as e:
1178
- error_msg = f"❌ Initialization failed: {str(e)}"
1179
- return (
1180
- error_msg,
1181
- gr.update(interactive=False),
1182
- gr.update(interactive=False),
1183
- "<i>System not ready</i>",
1184
- )
1185
 
1186
  def on_start():
1187
- try:
1188
- result = start_recording()
1189
- return (
1190
- result, # status_output
1191
- gr.update(interactive=False), # start_btn
1192
- gr.update(interactive=True), # stop_btn
1193
- )
1194
- except Exception as e:
1195
- error_msg = f"❌ Failed to start: {str(e)}"
1196
- return (
1197
- error_msg,
1198
- gr.update(interactive=True),
1199
- gr.update(interactive=False),
1200
- )
1201
 
1202
  def on_stop():
1203
- try:
1204
- result = stop_recording()
1205
- return (
1206
- result, # status_output
1207
- gr.update(interactive=True), # start_btn
1208
- gr.update(interactive=False), # stop_btn
1209
- )
1210
- except Exception as e:
1211
- error_msg = f"❌ Failed to stop: {str(e)}"
1212
- return (
1213
- error_msg,
1214
- gr.update(interactive=False),
1215
- gr.update(interactive=True),
1216
- )
1217
 
1218
  def on_clear():
1219
- try:
1220
- result = clear_conversation()
1221
- conversation, status = refresh_display()
1222
- return result, conversation
1223
- except Exception as e:
1224
- error_msg = f"❌ Failed to clear: {str(e)}"
1225
- return error_msg, "<i>Error clearing conversation</i>"
1226
 
1227
  def on_update_settings(threshold, max_speakers):
1228
- try:
1229
- result = update_settings(threshold, max_speakers)
1230
- return result
1231
- except Exception as e:
1232
- return f"❌ Failed to update settings: {str(e)}"
1233
 
1234
- # Connect event handlers
 
 
 
 
 
 
1235
  init_btn.click(
1236
- on_initialize,
1237
- outputs=[status_output, start_btn, clear_btn, conversation_output]
 
1238
  )
1239
 
1240
  start_btn.click(
1241
- on_start,
 
1242
  outputs=[status_output, start_btn, stop_btn]
1243
  )
1244
 
1245
  stop_btn.click(
1246
- on_stop,
 
1247
  outputs=[status_output, start_btn, stop_btn]
1248
  )
1249
 
1250
  clear_btn.click(
1251
- on_clear,
1252
- outputs=[status_output, conversation_output]
 
1253
  )
1254
 
1255
- update_settings_btn.click(
1256
- on_update_settings,
1257
  inputs=[threshold_slider, max_speakers_slider],
1258
  outputs=[status_output]
1259
  )
1260
 
1261
- # Auto-refresh every 2 seconds when active
1262
- refresh_timer = gr.Timer(2.0)
1263
- refresh_timer.tick(
1264
- refresh_display,
1265
- outputs=[conversation_output, status_output]
 
1266
  )
1267
 
1268
  return interface
1269
 
1270
 
1271
- # FastAPI setup for API endpoints
1272
  def create_fastapi_app():
1273
- """Create FastAPI app with API endpoints"""
1274
- app = FastAPI(
1275
- title="Real-time Speaker Diarization",
1276
- description="Real-time speech recognition with speaker diarization using FastRTC",
1277
- version="1.0.0"
1278
- )
1279
 
1280
- # API Routes
1281
- router = APIRouter()
 
1282
 
1283
- @router.get("/health")
1284
- async def health_check():
1285
- """Health check endpoint"""
1286
- return {
1287
- "status": "healthy",
1288
- "timestamp": time.time(),
1289
- "system_initialized": diarization_system is not None and hasattr(diarization_system, 'encoder') and diarization_system.encoder is not None,
1290
- "recording_active": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False
1291
- }
 
 
 
 
 
 
 
 
 
1292
 
1293
- @router.get("/api/conversation")
1294
  async def get_conversation_api():
1295
- """Get current conversation"""
1296
  try:
1297
  return {
1298
- "conversation": get_conversation(),
1299
- "status": get_status(),
1300
- "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False,
1301
- "timestamp": time.time()
1302
  }
1303
  except Exception as e:
1304
- return {"error": str(e), "timestamp": time.time()}
1305
 
1306
- @router.post("/api/control/{action}")
1307
- async def control_recording(action: str):
1308
- """Control recording actions"""
1309
  try:
1310
- if action == "start":
1311
- result = start_recording()
1312
- elif action == "stop":
1313
- result = stop_recording()
1314
- elif action == "clear":
1315
- result = clear_conversation()
1316
- elif action == "initialize":
1317
- result = initialize_system()
1318
- else:
1319
- return {"error": "Invalid action. Use: start, stop, clear, or initialize"}
1320
-
1321
- return {
1322
- "result": result,
1323
- "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False,
1324
- "timestamp": time.time()
1325
- }
 
 
 
 
 
 
 
 
 
 
1326
  except Exception as e:
1327
- return {"error": str(e), "timestamp": time.time()}
 
 
 
 
1328
 
1329
- app.include_router(router)
1330
  return app
1331
 
1332
 
1333
- # Function to setup FastRTC stream
1334
- def setup_fastrtc_stream(app):
1335
- """Setup FastRTC stream with proper configuration"""
1336
- try:
1337
- if audio_handler is None:
1338
- print("Warning: Audio handler not initialized. Initialize system first.")
1339
- return None
1340
-
1341
- # Get HuggingFace token for TURN server (optional)
1342
- hf_token = os.environ.get("HF_TOKEN")
1343
-
1344
- # Configure RTC settings
1345
- rtc_config = {
1346
- "iceServers": [
1347
- {"urls": "stun:stun.l.google.com:19302"},
1348
- {"urls": "stun:stun1.l.google.com:19302"}
1349
- ]
1350
- }
1351
-
1352
- # Create FastRTC stream
1353
- stream = Stream(
1354
- handler=audio_handler,
1355
- rtc_configuration=rtc_config,
1356
- modality="audio",
1357
- mode="receive" # We only receive audio, don't send
1358
- )
1359
-
1360
- # Mount the stream
1361
- app.mount("/stream", stream)
1362
- print("✅ FastRTC stream configured successfully!")
1363
- return stream
1364
-
1365
- except Exception as e:
1366
- print(f"⚠️ Warning: Failed to setup FastRTC stream: {e}")
1367
- print("Audio streaming may not work properly.")
1368
- return None
1369
-
1370
-
1371
- # Main application setup
1372
- def create_app(diarization_sys=None):
1373
- """Create the complete application"""
1374
- global diarization_system
1375
-
1376
- # Set the diarization system
1377
- if diarization_sys is not None:
1378
- diarization_system = diarization_sys
1379
 
1380
- # Create FastAPI app
1381
- fastapi_app = create_fastapi_app()
 
 
 
 
 
1382
 
1383
- # Create Gradio interface
1384
- gradio_interface = create_interface()
1385
 
1386
- # Mount Gradio on FastAPI
1387
- app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/")
 
 
 
 
 
 
 
1388
 
1389
- # Setup FastRTC stream
1390
- if diarization_system is not None:
1391
- # Initialize the system if not already done
1392
- if not hasattr(diarization_system, 'encoder') or diarization_system.encoder is None:
1393
- diarization_system.initialize_models()
1394
-
1395
- # Create audio handler if needed
1396
- global audio_handler
1397
- if audio_handler is None:
1398
- audio_handler = DiarizationHandler(diarization_system)
1399
-
1400
- # Setup and mount the FastRTC stream
1401
- setup_fastrtc_stream(app)
1402
 
1403
- return app, gradio_interface
1404
-
1405
-
1406
- # Entry point for HuggingFace Spaces
1407
- if __name__ == "__main__":
1408
- try:
1409
- # Import your diarization system here
1410
- # from your_module import RealtimeSpeakerDiarization
1411
- diarization_system = RealtimeSpeakerDiarization()
1412
 
1413
- # Create the application
1414
- app, interface = create_app()
 
 
 
 
 
1415
 
1416
- # Launch for HuggingFace Spaces
 
1417
  interface.launch(
1418
- server_name="0.0.0.0",
1419
- server_port=7860,
1420
  share=True,
1421
- show_error=True,
1422
- quiet=False
1423
  )
1424
-
 
 
 
 
 
 
 
 
 
 
1425
  except Exception as e:
1426
- print(f"Failed to launch application: {e}")
1427
- import traceback
1428
- traceback.print_exc()
1429
-
1430
- # Fallback - launch just Gradio interface
1431
- try:
1432
- interface = create_interface()
1433
- interface.launch(
1434
- server_name="0.0.0.0",
1435
- server_port=int(os.environ.get("PORT", 7860)),
1436
- share=False
1437
- )
1438
- except Exception as fallback_error:
1439
- print(f"Fallback launch also failed: {fallback_error}")
1440
 
 
 
1441
 
1442
- # Helper function to initialize with your diarization system
1443
- def initialize_with_diarization_system(diarization_sys):
1444
- """Initialize the application with your diarization system"""
1445
- global diarization_system
1446
- diarization_system = diarization_sys
1447
- return create_app(diarization_sys)
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, AsyncStreamHandler
14
  import json
 
 
15
  import asyncio
16
  import uvicorn
 
17
  from queue import Queue
18
+ import logging
19
+
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
  # Simplified configuration parameters
25
  SILENCE_THRESHS = [0, 0.4]
26
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
 
34
  PRE_RECORDING_BUFFER_DURATION = 0.35
35
 
36
  # Speaker change detection parameters
37
+ DEFAULT_CHANGE_THRESHOLD = 0.65
38
  EMBEDDING_HISTORY_SIZE = 5
39
+ MIN_SEGMENT_DURATION = 1.5
40
  DEFAULT_MAX_SPEAKERS = 4
41
+ ABSOLUTE_MAX_SPEAKERS = 8
42
 
43
  # Global variables
 
44
  SAMPLE_RATE = 16000
45
+ BUFFER_SIZE = 1024
46
  CHANNELS = 1
47
 
48
+ # Speaker colors - more distinguishable colors
49
  SPEAKER_COLORS = [
50
+ "#FF6B6B", # Red
51
+ "#4ECDC4", # Teal
52
+ "#45B7D1", # Blue
53
+ "#96CEB4", # Green
54
+ "#FFEAA7", # Yellow
55
+ "#DDA0DD", # Plum
56
+ "#98D8C8", # Mint
57
+ "#F7DC6F", # Gold
 
 
58
  ]
59
 
60
  SPEAKER_COLOR_NAMES = [
61
+ "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold"
 
62
  ]
63
 
64
 
 
72
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
73
  os.makedirs(self.cache_dir, exist_ok=True)
74
 
 
 
 
 
 
 
 
 
 
 
 
75
  def load_model(self):
76
  """Load the ECAPA-TDNN model"""
77
  try:
78
  from speechbrain.pretrained import EncoderClassifier
79
 
 
 
80
  self.model = EncoderClassifier.from_hparams(
81
  source="speechbrain/spkrec-ecapa-voxceleb",
82
  savedir=self.cache_dir,
 
84
  )
85
 
86
  self.model_loaded = True
87
+ logger.info("ECAPA-TDNN model loaded successfully!")
88
  return True
89
  except Exception as e:
90
+ logger.error(f"Error loading ECAPA-TDNN model: {e}")
91
  return False
92
 
93
  def embed_utterance(self, audio, sr=16000):
 
97
 
98
  try:
99
  if isinstance(audio, np.ndarray):
100
+ # Ensure audio is float32 and properly normalized
101
+ audio = audio.astype(np.float32)
102
+ if np.max(np.abs(audio)) > 1.0:
103
+ audio = audio / np.max(np.abs(audio))
104
+ waveform = torch.tensor(audio).unsqueeze(0)
105
  else:
106
  waveform = audio.unsqueeze(0)
107
 
108
+ # Resample if necessary
109
  if sr != 16000:
110
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
111
 
 
114
 
115
  return embedding.squeeze().cpu().numpy()
116
  except Exception as e:
117
+ logger.error(f"Error extracting embedding: {e}")
118
  return np.zeros(self.embedding_dim)
119
 
120
 
 
122
  """Processes audio data to extract speaker embeddings"""
123
  def __init__(self, encoder):
124
  self.encoder = encoder
125
+ self.audio_buffer = []
126
+ self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio
127
 
128
+ def add_audio_chunk(self, audio_chunk):
129
+ """Add audio chunk to buffer"""
130
+ self.audio_buffer.extend(audio_chunk)
131
+
132
+ # Keep buffer from getting too large
133
+ max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max
134
+ if len(self.audio_buffer) > max_buffer_size:
135
+ self.audio_buffer = self.audio_buffer[-max_buffer_size:]
136
+
137
+ def extract_embedding_from_buffer(self):
138
+ """Extract embedding from current audio buffer"""
139
+ if len(self.audio_buffer) < self.min_audio_length:
140
+ return None
141
 
142
+ try:
143
+ # Use the last portion of the buffer for embedding
144
+ audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32)
145
 
146
+ # Normalize audio
147
+ if np.max(np.abs(audio_segment)) > 0:
148
+ audio_segment = audio_segment / np.max(np.abs(audio_segment))
149
+ else:
150
+ return None
151
 
152
+ embedding = self.encoder.embed_utterance(audio_segment)
153
  return embedding
154
  except Exception as e:
155
+ logger.error(f"Embedding extraction error: {e}")
156
+ return None
157
 
158
 
159
  class SpeakerChangeDetector:
160
+ """Improved speaker change detector"""
161
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
162
  self.embedding_dim = embedding_dim
163
  self.change_threshold = change_threshold
164
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
165
  self.current_speaker = 0
 
 
 
166
  self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
167
+ self.speaker_centroids = [None] * self.max_speakers
168
+ self.last_change_time = time.time()
169
+ self.last_similarity = 1.0
170
  self.active_speakers = set([0])
171
+ self.segment_counter = 0
172
 
173
  def set_max_speakers(self, max_speakers):
174
  """Update the maximum number of speakers"""
175
  new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
176
 
177
  if new_max < self.max_speakers:
178
+ # Remove speakers beyond the new limit
179
  for speaker_id in list(self.active_speakers):
180
  if speaker_id >= new_max:
181
  self.active_speakers.discard(speaker_id)
 
183
  if self.current_speaker >= new_max:
184
  self.current_speaker = 0
185
 
186
+ # Resize arrays
187
  if new_max > self.max_speakers:
 
188
  self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
189
+ self.speaker_centroids.extend([None] * (new_max - self.max_speakers))
190
  else:
 
191
  self.speaker_embeddings = self.speaker_embeddings[:new_max]
192
+ self.speaker_centroids = self.speaker_centroids[:new_max]
193
 
194
  self.max_speakers = new_max
195
 
196
  def set_change_threshold(self, threshold):
197
  """Update the threshold for detecting speaker changes"""
198
+ self.change_threshold = max(0.1, min(threshold, 0.95))
199
 
200
  def add_embedding(self, embedding, timestamp=None):
201
+ """Add a new embedding and detect speaker changes"""
202
  current_time = timestamp or time.time()
203
+ self.segment_counter += 1
204
+
205
+ # Initialize first speaker
206
+ if not self.speaker_embeddings[0]:
207
+ self.speaker_embeddings[0].append(embedding)
208
+ self.speaker_centroids[0] = embedding.copy()
209
+ self.active_speakers.add(0)
210
+ return 0, 1.0
211
+
212
+ # Calculate similarity with current speaker
213
+ current_centroid = self.speaker_centroids[self.current_speaker]
214
+ if current_centroid is not None:
215
+ similarity = 1.0 - cosine(embedding, current_centroid)
216
  else:
217
+ similarity = 0.5
218
 
219
  self.last_similarity = similarity
220
 
221
+ # Check for speaker change
222
  time_since_last_change = current_time - self.last_change_time
223
+ speaker_changed = False
224
 
225
+ if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold:
226
+ # Find best matching speaker
227
+ best_speaker = self.current_speaker
228
+ best_similarity = similarity
229
+
230
+ for speaker_id in self.active_speakers:
231
+ if speaker_id == self.current_speaker:
232
+ continue
 
 
233
 
234
+ centroid = self.speaker_centroids[speaker_id]
235
+ if centroid is not None:
236
+ speaker_similarity = 1.0 - cosine(embedding, centroid)
237
+ if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold:
238
+ best_similarity = speaker_similarity
239
+ best_speaker = speaker_id
240
+
241
+ # If no good match found and we can add a new speaker
242
+ if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers:
243
+ for new_id in range(self.max_speakers):
244
+ if new_id not in self.active_speakers:
245
+ best_speaker = new_id
246
+ self.active_speakers.add(new_id)
247
+ break
248
+
249
+ if best_speaker != self.current_speaker:
250
+ self.current_speaker = best_speaker
251
+ self.last_change_time = current_time
252
+ speaker_changed = True
253
+
254
+ # Update speaker embeddings and centroids
 
 
 
 
255
  self.speaker_embeddings[self.current_speaker].append(embedding)
 
256
 
257
+ # Keep only recent embeddings (sliding window)
258
+ max_embeddings = 20
259
+ if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings:
260
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:]
261
+
262
+ # Update centroid
263
  if self.speaker_embeddings[self.current_speaker]:
264
+ self.speaker_centroids[self.current_speaker] = np.mean(
265
  self.speaker_embeddings[self.current_speaker], axis=0
266
  )
267
 
 
274
  return "#FFFFFF"
275
 
276
  def get_status_info(self):
277
+ """Return status information"""
278
  speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
279
 
280
  return {
 
283
  "active_speakers": len(self.active_speakers),
284
  "max_speakers": self.max_speakers,
285
  "last_similarity": self.last_similarity,
286
+ "threshold": self.change_threshold,
287
+ "segment_counter": self.segment_counter
288
  }
289
 
290
 
 
294
  self.audio_processor = None
295
  self.speaker_detector = None
296
  self.recorder = None
297
+ self.sentence_queue = queue.Queue()
298
  self.full_sentences = []
299
  self.sentence_speakers = []
300
  self.pending_sentences = []
301
+ self.current_conversation = ""
 
302
  self.is_running = False
303
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
304
  self.max_speakers = DEFAULT_MAX_SPEAKERS
305
+ self.last_transcription = ""
306
+ self.transcription_lock = threading.Lock()
 
 
 
307
 
308
  def initialize_models(self):
309
  """Initialize the speaker encoder model"""
310
  try:
311
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
312
+ logger.info(f"Using device: {device_str}")
313
 
314
  self.encoder = SpeechBrainEncoder(device=device_str)
315
+ success = self.encoder.load_model()
316
 
317
+ if success:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  self.audio_processor = AudioProcessor(self.encoder)
319
  self.speaker_detector = SpeakerChangeDetector(
320
  embedding_dim=self.encoder.embedding_dim,
321
  change_threshold=self.change_threshold,
322
  max_speakers=self.max_speakers
323
  )
324
+ logger.info("Models initialized successfully!")
325
  return True
326
  else:
327
+ logger.error("Failed to load models")
328
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  except Exception as e:
330
+ logger.error(f"Model initialization error: {e}")
331
  return False
332
 
333
  def live_text_detected(self, text):
334
  """Callback for real-time transcription updates"""
335
+ with self.transcription_lock:
336
+ self.last_transcription = text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  def process_final_text(self, text):
339
  """Process final transcribed text with speaker embedding"""
340
  text = text.strip()
341
  if text:
342
  try:
343
+ # Get audio data for this transcription
344
+ audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None)
345
+ if audio_bytes:
346
+ self.sentence_queue.put((text, audio_bytes))
347
+ else:
348
+ # If no audio bytes, use current speaker
349
+ self.sentence_queue.put((text, None))
350
+
351
  except Exception as e:
352
+ logger.error(f"Error processing final text: {e}")
353
 
354
  def process_sentence_queue(self):
355
  """Process sentences in the queue for speaker detection"""
356
  while self.is_running:
357
  try:
358
+ text, audio_bytes = self.sentence_queue.get(timeout=1)
359
 
360
+ current_speaker = self.speaker_detector.current_speaker
 
361
 
362
+ if audio_bytes:
363
+ # Convert audio data and extract embedding
364
+ audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
365
+ audio_float = audio_int16.astype(np.float32) / 32768.0
 
 
 
 
 
 
366
 
367
+ # Extract embedding
368
+ embedding = self.audio_processor.encoder.embed_utterance(audio_float)
369
+ if embedding is not None:
370
+ current_speaker, similarity = self.speaker_detector.add_embedding(embedding)
371
+
372
+ # Store sentence with speaker
373
+ with self.transcription_lock:
374
+ self.full_sentences.append((text, current_speaker))
375
+ self.update_conversation_display()
 
376
 
377
  except queue.Empty:
378
  continue
379
  except Exception as e:
380
+ logger.error(f"Error processing sentence: {e}")
381
+
382
+ def update_conversation_display(self):
383
+ """Update the conversation display"""
384
+ try:
385
+ sentences_with_style = []
386
+
387
+ for sentence_text, speaker_id in self.full_sentences:
388
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
389
+ speaker_name = f"Speaker {speaker_id + 1}"
390
+ sentences_with_style.append(
391
+ f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> '
392
+ f'<span style="color:#333333;">{sentence_text}</span>'
393
+ )
394
+
395
+ # Add current transcription if available
396
+ if self.last_transcription:
397
+ current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker)
398
+ current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}"
399
+ sentences_with_style.append(
400
+ f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> '
401
+ f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>'
402
+ )
403
+
404
+ if sentences_with_style:
405
+ self.current_conversation = "<br><br>".join(sentences_with_style)
406
+ else:
407
+ self.current_conversation = "<i>Waiting for speech input...</i>"
408
+
409
+ except Exception as e:
410
+ logger.error(f"Error updating conversation display: {e}")
411
+ self.current_conversation = f"<i>Error: {str(e)}</i>"
412
 
413
  def start_recording(self):
414
  """Start the recording and transcription process"""
 
416
  return "Please initialize models first!"
417
 
418
  try:
419
+ # Setup recorder configuration
420
  recorder_config = {
421
  'spinner': False,
422
+ 'use_microphone': True, # Changed to True for direct microphone input
423
  'model': FINAL_TRANSCRIPTION_MODEL,
424
  'language': TRANSCRIPTION_LANGUAGE,
425
  'silero_sensitivity': SILERO_SENSITIVITY,
 
429
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
430
  'min_gap_between_recordings': 0,
431
  'enable_realtime_transcription': True,
432
+ 'realtime_processing_pause': 0.1,
433
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
434
  'on_realtime_transcription_update': self.live_text_detected,
435
  'beam_size': FINAL_BEAM_SIZE,
436
  'beam_size_realtime': REALTIME_BEAM_SIZE,
 
437
  'sample_rate': SAMPLE_RATE,
 
438
  }
439
 
 
 
 
 
 
 
440
  self.recorder = AudioToTextRecorder(**recorder_config)
441
 
442
+ # Start processing threads
 
 
 
 
 
443
  self.is_running = True
444
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
445
  self.sentence_thread.start()
446
 
 
447
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
448
  self.transcription_thread.start()
449
 
450
+ return "Recording started successfully!"
451
 
452
  except Exception as e:
453
+ logger.error(f"Error starting recording: {e}")
454
+ return f"Error starting recording: {e}"
 
 
455
 
456
  def run_transcription(self):
457
  """Run the transcription loop"""
 
459
  while self.is_running:
460
  self.recorder.text(self.process_final_text)
461
  except Exception as e:
462
+ logger.error(f"Transcription error: {e}")
463
 
464
  def stop_recording(self):
465
  """Stop the recording process"""
466
  self.is_running = False
467
  if self.recorder:
468
  self.recorder.stop()
 
 
 
 
469
  return "Recording stopped!"
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def clear_conversation(self):
472
  """Clear all conversation data"""
473
+ with self.transcription_lock:
474
+ self.full_sentences = []
475
+ self.last_transcription = ""
476
+ self.current_conversation = "Conversation cleared!"
 
 
477
 
478
  if self.speaker_detector:
479
  self.speaker_detector = SpeakerChangeDetector(
 
496
  return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
497
 
498
  def get_formatted_conversation(self):
499
+ """Get the formatted conversation"""
500
+ return self.current_conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  def get_status_info(self):
503
  """Get current status information"""
 
513
  f"**Last Similarity:** {status['last_similarity']:.3f}",
514
  f"**Change Threshold:** {status['threshold']:.2f}",
515
  f"**Total Sentences:** {len(self.full_sentences)}",
516
+ f"**Segments Processed:** {status['segment_counter']}",
517
  "",
518
+ "**Speaker Activity:**"
519
  ]
520
 
521
  for i in range(status['max_speakers']):
522
  color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
523
+ count = status['speaker_counts'][i]
524
+ active = "🟢" if count > 0 else "⚫"
525
+ status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments")
526
 
527
  return "\n".join(status_lines)
528
 
529
  except Exception as e:
530
  return f"Error getting status: {e}"
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  def process_audio_chunk(self, audio_data, sample_rate=16000):
533
  """Process audio chunk from FastRTC input"""
534
+ if not self.is_running or self.audio_processor is None:
535
  return
536
 
537
  try:
538
+ # Ensure audio is float32
539
+ if isinstance(audio_data, np.ndarray):
540
+ if audio_data.dtype != np.float32:
541
+ audio_data = audio_data.astype(np.float32)
542
+ else:
543
+ audio_data = np.array(audio_data, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
+ # Ensure mono
546
+ if len(audio_data.shape) > 1:
547
+ audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten()
548
 
549
+ # Normalize if needed
550
+ if np.max(np.abs(audio_data)) > 1.0:
551
+ audio_data = audio_data / np.max(np.abs(audio_data))
552
 
553
+ # Add to audio processor buffer for speaker detection
554
+ self.audio_processor.add_audio_chunk(audio_data)
 
 
 
 
 
 
 
 
 
 
 
555
 
556
+ # Periodically extract embeddings for speaker detection
557
+ if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: # Every 0.5 seconds
558
+ embedding = self.audio_processor.extract_embedding_from_buffer()
559
+ if embedding is not None:
560
+ self.speaker_detector.add_embedding(embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
 
 
 
562
  except Exception as e:
563
+ logger.error(f"Error processing audio chunk: {e}")
 
 
 
 
 
564
 
 
565
 
566
+ # FastRTC Audio Handler
567
  class DiarizationHandler(AsyncStreamHandler):
568
  def __init__(self, diarization_system):
569
  super().__init__()
570
  self.diarization_system = diarization_system
571
+ self.audio_buffer = []
572
+ self.buffer_size = BUFFER_SIZE
 
 
573
 
574
  def copy(self):
575
  """Return a fresh handler for each new stream connection"""
576
  return DiarizationHandler(self.diarization_system)
577
 
578
  async def emit(self):
579
+ """Not used - we only receive audio"""
580
  return None
581
 
582
  async def receive(self, frame):
583
+ """Receive audio data from FastRTC"""
584
  try:
585
  if not self.diarization_system.is_running:
586
  return
587
 
588
+ # Extract audio data
589
+ audio_data = getattr(frame, 'data', frame)
590
+
591
+ # Convert to numpy array
592
+ if isinstance(audio_data, bytes):
593
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
594
+ elif isinstance(audio_data, (list, tuple)):
595
+ audio_array = np.array(audio_data, dtype=np.float32)
596
  else:
597
+ audio_array = np.array(audio_data, dtype=np.float32)
598
 
599
+ # Ensure 1D
600
+ if len(audio_array.shape) > 1:
601
+ audio_array = audio_array.flatten()
602
 
603
+ # Buffer audio chunks
604
+ self.audio_buffer.extend(audio_array)
605
+
606
+ # Process in chunks
607
+ while len(self.audio_buffer) >= self.buffer_size:
608
+ chunk = np.array(self.audio_buffer[:self.buffer_size])
609
+ self.audio_buffer = self.audio_buffer[self.buffer_size:]
610
+
611
+ # Process asynchronously
612
+ await self.process_audio_async(chunk)
 
613
 
614
  except Exception as e:
615
+ logger.error(f"Error in FastRTC receive: {e}")
 
 
616
 
617
+ async def process_audio_async(self, audio_data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  """Process audio data asynchronously"""
619
  try:
 
620
  loop = asyncio.get_event_loop()
621
  await loop.run_in_executor(
622
  None,
623
  self.diarization_system.process_audio_chunk,
624
  audio_data,
625
+ SAMPLE_RATE
626
  )
627
  except Exception as e:
628
+ logger.error(f"Error in async audio processing: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
 
631
  # Global instances
632
+ diarization_system = RealtimeSpeakerDiarization()
633
  audio_handler = None
634
 
 
635
  def initialize_system():
636
  """Initialize the diarization system"""
637
+ global audio_handler
638
  try:
 
 
 
 
639
  success = diarization_system.initialize_models()
640
  if success:
641
  audio_handler = DiarizationHandler(diarization_system)
642
+ return "✅ System initialized successfully!"
643
  else:
644
+ return "❌ Failed to initialize system. Check logs for details."
645
  except Exception as e:
646
+ logger.error(f"Initialization error: {e}")
647
  return f"❌ Initialization error: {str(e)}"
648
 
 
649
  def start_recording():
650
  """Start recording and transcription"""
651
  try:
 
 
652
  result = diarization_system.start_recording()
653
+ return f"🎙️ {result}"
654
  except Exception as e:
655
  return f"❌ Failed to start recording: {str(e)}"
656
 
 
657
  def stop_recording():
658
  """Stop recording and transcription"""
659
  try:
 
 
660
  result = diarization_system.stop_recording()
661
  return f"⏹️ {result}"
662
  except Exception as e:
663
  return f"❌ Failed to stop recording: {str(e)}"
664
 
 
665
  def clear_conversation():
666
  """Clear the conversation"""
667
  try:
 
 
668
  result = diarization_system.clear_conversation()
669
  return f"🗑️ {result}"
670
  except Exception as e:
671
  return f"❌ Failed to clear conversation: {str(e)}"
672
 
 
673
  def update_settings(threshold, max_speakers):
674
  """Update system settings"""
675
  try:
 
 
676
  result = diarization_system.update_settings(threshold, max_speakers)
677
  return f"⚙️ {result}"
678
  except Exception as e:
679
  return f"❌ Failed to update settings: {str(e)}"
680
 
 
681
  def get_conversation():
682
  """Get the current conversation"""
683
  try:
 
 
684
  return diarization_system.get_formatted_conversation()
685
  except Exception as e:
686
  return f"<i>Error getting conversation: {str(e)}</i>"
687
 
 
688
  def get_status():
689
  """Get system status"""
690
  try:
 
 
691
  return diarization_system.get_status_info()
692
  except Exception as e:
693
  return f"Error getting status: {str(e)}"
694
 
 
695
  # Create Gradio interface
696
  def create_interface():
697
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
698
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
699
+ gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
700
 
701
  with gr.Row():
702
  with gr.Column(scale=2):
703
+ # Conversation display
704
  conversation_output = gr.HTML(
705
+ value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
706
+ label="Live Conversation"
 
707
  )
708
 
709
  # Control buttons
710
  with gr.Row():
711
  init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
712
+ start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
713
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
714
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
715
 
 
 
 
 
 
 
 
 
 
 
 
 
716
  # Status display
717
  status_output = gr.Textbox(
718
  label="System Status",
719
+ value="Ready to initialize...",
720
+ lines=8,
721
+ interactive=False
 
722
  )
723
 
724
  with gr.Column(scale=1):
725
+ # Settings
726
  gr.Markdown("## ⚙️ Settings")
727
 
728
  threshold_slider = gr.Slider(
729
+ minimum=0.3,
730
+ maximum=0.9,
731
  step=0.05,
732
+ value=DEFAULT_CHANGE_THRESHOLD,
733
  label="Speaker Change Sensitivity",
734
+ info="Lower = more sensitive"
735
  )
736
 
737
  max_speakers_slider = gr.Slider(
738
  minimum=2,
739
+ maximum=ABSOLUTE_MAX_SPEAKERS,
740
  step=1,
741
+ value=DEFAULT_MAX_SPEAKERS,
742
+ label="Maximum Speakers"
743
  )
744
 
745
+ update_btn = gr.Button("Update Settings", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
  # Instructions
 
748
  gr.Markdown("""
749
+ ## 📋 Instructions
750
+ 1. **Initialize** the system (loads AI models)
751
+ 2. **Start** recording
752
+ 3. **Speak** - system will transcribe and identify speakers
753
+ 4. **Monitor** real-time results below
754
+
755
+ ## 🎨 Speaker Colors
756
+ - 🔴 Speaker 1 (Red)
757
+ - 🟢 Speaker 2 (Teal)
758
+ - 🔵 Speaker 3 (Blue)
759
+ - 🟡 Speaker 4 (Green)
760
+ - 🟣 Speaker 5 (Yellow)
761
+ - 🟤 Speaker 6 (Plum)
762
+ - 🟫 Speaker 7 (Mint)
763
+ - 🟨 Speaker 8 (Gold)
764
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
 
766
  # Event handlers
767
  def on_initialize():
768
+ result = initialize_system()
769
+ if "✅" in result:
770
+ return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
771
+ else:
772
+ return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
 
774
  def on_start():
775
+ result = start_recording()
776
+ return result, gr.update(interactive=False), gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
777
 
778
  def on_stop():
779
+ result = stop_recording()
780
+ return result, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
  def on_clear():
783
+ result = clear_conversation()
784
+ return result
 
 
 
 
 
785
 
786
  def on_update_settings(threshold, max_speakers):
787
+ result = update_settings(threshold, int(max_speakers))
788
+ return result
 
 
 
789
 
790
+ def update_display():
791
+ """Continuously update the conversation display"""
792
+ conversation = get_conversation()
793
+ status = get_status()
794
+ return conversation, status
795
+
796
+ # Button event bindings
797
  init_btn.click(
798
+ fn=on_initialize,
799
+ inputs=[],
800
+ outputs=[status_output, start_btn, stop_btn, clear_btn]
801
  )
802
 
803
  start_btn.click(
804
+ fn=on_start,
805
+ inputs=[],
806
  outputs=[status_output, start_btn, stop_btn]
807
  )
808
 
809
  stop_btn.click(
810
+ fn=on_stop,
811
+ inputs=[],
812
  outputs=[status_output, start_btn, stop_btn]
813
  )
814
 
815
  clear_btn.click(
816
+ fn=on_clear,
817
+ inputs=[],
818
+ outputs=[status_output]
819
  )
820
 
821
+ update_btn.click(
822
+ fn=on_update_settings,
823
  inputs=[threshold_slider, max_speakers_slider],
824
  outputs=[status_output]
825
  )
826
 
827
+ # Auto-refresh conversation display every 500ms
828
+ interface.load(
829
+ fn=update_display,
830
+ inputs=[],
831
+ outputs=[conversation_output, status_output],
832
+ every=0.5
833
  )
834
 
835
  return interface
836
 
837
 
838
+ # FastAPI integration for FastRTC
839
  def create_fastapi_app():
840
+ """Create FastAPI app with FastRTC integration"""
841
+ app = FastAPI(title="Real-time Speaker Diarization API")
 
 
 
 
842
 
843
+ @app.get("/")
844
+ async def root():
845
+ return {"message": "Real-time Speaker Diarization API"}
846
 
847
+ @app.get("/status")
848
+ async def api_status():
849
+ try:
850
+ if diarization_system.speaker_detector:
851
+ status = diarization_system.speaker_detector.get_status_info()
852
+ return {
853
+ "initialized": True,
854
+ "running": diarization_system.is_running,
855
+ "current_speaker": status["current_speaker"],
856
+ "active_speakers": status["active_speakers"],
857
+ "max_speakers": status["max_speakers"],
858
+ "last_similarity": status["last_similarity"],
859
+ "threshold": status["threshold"]
860
+ }
861
+ else:
862
+ return {"initialized": False, "running": False}
863
+ except Exception as e:
864
+ return {"error": str(e)}
865
 
866
+ @app.get("/conversation")
867
  async def get_conversation_api():
 
868
  try:
869
  return {
870
+ "conversation": diarization_system.get_formatted_conversation(),
871
+ "sentences": len(diarization_system.full_sentences)
 
 
872
  }
873
  except Exception as e:
874
+ return {"error": str(e)}
875
 
876
+ @app.post("/initialize")
877
+ async def initialize_api():
 
878
  try:
879
+ result = initialize_system()
880
+ return {"message": result, "success": "✅" in result}
881
+ except Exception as e:
882
+ return {"error": str(e), "success": False}
883
+
884
+ @app.post("/start")
885
+ async def start_api():
886
+ try:
887
+ result = start_recording()
888
+ return {"message": result, "success": "🎙️" in result}
889
+ except Exception as e:
890
+ return {"error": str(e), "success": False}
891
+
892
+ @app.post("/stop")
893
+ async def stop_api():
894
+ try:
895
+ result = stop_recording()
896
+ return {"message": result, "success": "⏹️" in result}
897
+ except Exception as e:
898
+ return {"error": str(e), "success": False}
899
+
900
+ @app.post("/clear")
901
+ async def clear_api():
902
+ try:
903
+ result = clear_conversation()
904
+ return {"message": result, "success": True}
905
  except Exception as e:
906
+ return {"error": str(e), "success": False}
907
+
908
+ # FastRTC stream endpoint
909
+ if audio_handler:
910
+ app.add_websocket_route("/stream", Stream(audio_handler))
911
 
 
912
  return app
913
 
914
 
915
+ # Main execution
916
+ if __name__ == "__main__":
917
+ import argparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
 
919
+ parser = argparse.ArgumentParser(description='Real-time Speaker Diarization System')
920
+ parser.add_argument('--mode', choices=['gradio', 'api', 'both'], default='gradio',
921
+ help='Run mode: gradio interface, API only, or both')
922
+ parser.add_argument('--port', type=int, default=7860,
923
+ help='Port to run on (default: 7860)')
924
+ parser.add_argument('--host', type=str, default='0.0.0.0',
925
+ help='Host to bind to (default: 0.0.0.0)')
926
 
927
+ args = parser.parse_args()
 
928
 
929
+ if args.mode == 'gradio':
930
+ # Run Gradio interface only
931
+ interface = create_interface()
932
+ interface.launch(
933
+ server_name=args.host,
934
+ server_port=args.port,
935
+ share=True,
936
+ show_error=True
937
+ )
938
 
939
+ elif args.mode == 'api':
940
+ # Run FastAPI only
941
+ app = create_fastapi_app()
942
+ uvicorn.run(app, host=args.host, port=args.port)
 
 
 
 
 
 
 
 
 
943
 
944
+ elif args.mode == 'both':
945
+ # Run both Gradio and FastAPI
946
+ import threading
 
 
 
 
 
 
947
 
948
+ # Start FastAPI in a separate thread
949
+ app = create_fastapi_app()
950
+ api_thread = threading.Thread(
951
+ target=lambda: uvicorn.run(app, host=args.host, port=args.port + 1),
952
+ daemon=True
953
+ )
954
+ api_thread.start()
955
 
956
+ # Start Gradio interface
957
+ interface = create_interface()
958
  interface.launch(
959
+ server_name=args.host,
960
+ server_port=args.port,
961
  share=True,
962
+ show_error=True
 
963
  )
964
+
965
+
966
+ # Additional utility functions for Hugging Face Spaces
967
+ def setup_for_huggingface():
968
+ """Setup function specifically for Hugging Face Spaces"""
969
+ # Auto-initialize when running on HF Spaces
970
+ try:
971
+ if os.environ.get('SPACE_ID'): # Running on HF Spaces
972
+ logger.info("Running on Hugging Face Spaces - Auto-initializing...")
973
+ initialize_system()
974
+ logger.info("System ready for Hugging Face Spaces!")
975
  except Exception as e:
976
+ logger.error(f"HF Spaces setup error: {e}")
977
+
978
+ # Call setup for HF Spaces
979
+ setup_for_huggingface()
 
 
 
 
 
 
 
 
 
 
980
 
981
+ # For Hugging Face Spaces, create and launch interface directly
982
+ interface = create_interface()
983
 
984
+ # Export the interface for HF Spaces
985
+ demo = interface