Saiyaswanth007 commited on
Commit
f09b373
·
1 Parent(s): 27be9ef

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +416 -685
app.py CHANGED
@@ -10,15 +10,17 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, AsyncStreamHandler, ReplyOnPause, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
14
  import json
15
- import io
16
- import wave
17
  import asyncio
18
  import uvicorn
19
- import socket
20
  from queue import Queue
21
- import time
 
 
 
 
 
22
  # Simplified configuration parameters
23
  SILENCE_THRESHS = [0, 0.4]
24
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
@@ -32,35 +34,31 @@ MIN_LENGTH_OF_RECORDING = 0.7
32
  PRE_RECORDING_BUFFER_DURATION = 0.35
33
 
34
  # Speaker change detection parameters
35
- DEFAULT_CHANGE_THRESHOLD = 0.7
36
  EMBEDDING_HISTORY_SIZE = 5
37
- MIN_SEGMENT_DURATION = 1.0
38
  DEFAULT_MAX_SPEAKERS = 4
39
- ABSOLUTE_MAX_SPEAKERS = 10
40
 
41
  # Global variables
42
- FAST_SENTENCE_END = True
43
  SAMPLE_RATE = 16000
44
- BUFFER_SIZE = 512
45
  CHANNELS = 1
46
 
47
- # Speaker colors
48
  SPEAKER_COLORS = [
49
- "#FFFF00", # Yellow
50
- "#FF0000", # Red
51
- "#00FF00", # Green
52
- "#00FFFF", # Cyan
53
- "#FF00FF", # Magenta
54
- "#0000FF", # Blue
55
- "#FF8000", # Orange
56
- "#00FF80", # Spring Green
57
- "#8000FF", # Purple
58
- "#FFFFFF", # White
59
  ]
60
 
61
  SPEAKER_COLOR_NAMES = [
62
- "Yellow", "Red", "Green", "Cyan", "Magenta",
63
- "Blue", "Orange", "Spring Green", "Purple", "White"
64
  ]
65
 
66
 
@@ -74,24 +72,11 @@ class SpeechBrainEncoder:
74
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
75
  os.makedirs(self.cache_dir, exist_ok=True)
76
 
77
- def _download_model(self):
78
- """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
79
- model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
80
- model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
81
-
82
- if not os.path.exists(model_path):
83
- print(f"Downloading ECAPA-TDNN model to {model_path}...")
84
- urllib.request.urlretrieve(model_url, model_path)
85
-
86
- return model_path
87
-
88
  def load_model(self):
89
  """Load the ECAPA-TDNN model"""
90
  try:
91
  from speechbrain.pretrained import EncoderClassifier
92
 
93
- model_path = self._download_model()
94
-
95
  self.model = EncoderClassifier.from_hparams(
96
  source="speechbrain/spkrec-ecapa-voxceleb",
97
  savedir=self.cache_dir,
@@ -99,9 +84,10 @@ class SpeechBrainEncoder:
99
  )
100
 
101
  self.model_loaded = True
 
102
  return True
103
  except Exception as e:
104
- print(f"Error loading ECAPA-TDNN model: {e}")
105
  return False
106
 
107
  def embed_utterance(self, audio, sr=16000):
@@ -111,10 +97,15 @@ class SpeechBrainEncoder:
111
 
112
  try:
113
  if isinstance(audio, np.ndarray):
114
- waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
 
 
 
 
115
  else:
116
  waveform = audio.unsqueeze(0)
117
 
 
118
  if sr != 16000:
119
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
120
 
@@ -123,7 +114,7 @@ class SpeechBrainEncoder:
123
 
124
  return embedding.squeeze().cpu().numpy()
125
  except Exception as e:
126
- print(f"Error extracting embedding: {e}")
127
  return np.zeros(self.embedding_dim)
128
 
129
 
@@ -131,41 +122,60 @@ class AudioProcessor:
131
  """Processes audio data to extract speaker embeddings"""
132
  def __init__(self, encoder):
133
  self.encoder = encoder
 
 
134
 
135
- def extract_embedding(self, audio_int16):
136
- try:
137
- float_audio = audio_int16.astype(np.float32) / 32768.0
 
 
 
 
 
 
 
 
 
 
138
 
139
- if np.abs(float_audio).max() > 1.0:
140
- float_audio = float_audio / np.abs(float_audio).max()
 
141
 
142
- embedding = self.encoder.embed_utterance(float_audio)
 
 
 
 
143
 
 
144
  return embedding
145
  except Exception as e:
146
- print(f"Embedding extraction error: {e}")
147
- return np.zeros(self.encoder.embedding_dim)
148
 
149
 
150
  class SpeakerChangeDetector:
151
- """Speaker change detector that supports a configurable number of speakers"""
152
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
153
  self.embedding_dim = embedding_dim
154
  self.change_threshold = change_threshold
155
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
156
  self.current_speaker = 0
157
- self.previous_embeddings = []
158
- self.last_change_time = time.time()
159
- self.mean_embeddings = [None] * self.max_speakers
160
  self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
161
- self.last_similarity = 0.0
 
 
162
  self.active_speakers = set([0])
 
163
 
164
  def set_max_speakers(self, max_speakers):
165
  """Update the maximum number of speakers"""
166
  new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
167
 
168
  if new_max < self.max_speakers:
 
169
  for speaker_id in list(self.active_speakers):
170
  if speaker_id >= new_max:
171
  self.active_speakers.discard(speaker_id)
@@ -173,85 +183,85 @@ class SpeakerChangeDetector:
173
  if self.current_speaker >= new_max:
174
  self.current_speaker = 0
175
 
 
176
  if new_max > self.max_speakers:
177
- self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
178
  self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
 
179
  else:
180
- self.mean_embeddings = self.mean_embeddings[:new_max]
181
  self.speaker_embeddings = self.speaker_embeddings[:new_max]
 
182
 
183
  self.max_speakers = new_max
184
 
185
  def set_change_threshold(self, threshold):
186
  """Update the threshold for detecting speaker changes"""
187
- self.change_threshold = max(0.1, min(threshold, 0.99))
188
 
189
  def add_embedding(self, embedding, timestamp=None):
190
- """Add a new embedding and check if there's a speaker change"""
191
  current_time = timestamp or time.time()
192
-
193
- if not self.previous_embeddings:
194
- self.previous_embeddings.append(embedding)
195
- self.speaker_embeddings[self.current_speaker].append(embedding)
196
- if self.mean_embeddings[self.current_speaker] is None:
197
- self.mean_embeddings[self.current_speaker] = embedding.copy()
198
- return self.current_speaker, 1.0
199
-
200
- current_mean = self.mean_embeddings[self.current_speaker]
201
- if current_mean is not None:
202
- similarity = 1.0 - cosine(embedding, current_mean)
 
 
203
  else:
204
- similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
205
 
206
  self.last_similarity = similarity
207
 
 
208
  time_since_last_change = current_time - self.last_change_time
209
- is_speaker_change = False
210
 
211
- if time_since_last_change >= MIN_SEGMENT_DURATION:
212
- if similarity < self.change_threshold:
213
- best_speaker = self.current_speaker
214
- best_similarity = similarity
215
-
216
- for speaker_id in range(self.max_speakers):
217
- if speaker_id == self.current_speaker:
218
- continue
219
-
220
- speaker_mean = self.mean_embeddings[speaker_id]
221
 
222
- if speaker_mean is not None:
223
- speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
224
-
225
- if speaker_similarity > best_similarity:
226
- best_similarity = speaker_similarity
227
- best_speaker = speaker_id
228
-
229
- if best_speaker != self.current_speaker:
230
- is_speaker_change = True
231
- self.current_speaker = best_speaker
232
- elif len(self.active_speakers) < self.max_speakers:
233
- for new_id in range(self.max_speakers):
234
- if new_id not in self.active_speakers:
235
- is_speaker_change = True
236
- self.current_speaker = new_id
237
- self.active_speakers.add(new_id)
238
- break
239
-
240
- if is_speaker_change:
241
- self.last_change_time = current_time
242
-
243
- self.previous_embeddings.append(embedding)
244
- if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
245
- self.previous_embeddings.pop(0)
246
-
247
  self.speaker_embeddings[self.current_speaker].append(embedding)
248
- self.active_speakers.add(self.current_speaker)
249
 
250
- if len(self.speaker_embeddings[self.current_speaker]) > 30:
251
- self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
252
-
 
 
 
253
  if self.speaker_embeddings[self.current_speaker]:
254
- self.mean_embeddings[self.current_speaker] = np.mean(
255
  self.speaker_embeddings[self.current_speaker], axis=0
256
  )
257
 
@@ -264,7 +274,7 @@ class SpeakerChangeDetector:
264
  return "#FFFFFF"
265
 
266
  def get_status_info(self):
267
- """Return status information about the speaker change detector"""
268
  speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
269
 
270
  return {
@@ -273,7 +283,8 @@ class SpeakerChangeDetector:
273
  "active_speakers": len(self.active_speakers),
274
  "max_speakers": self.max_speakers,
275
  "last_similarity": self.last_similarity,
276
- "threshold": self.change_threshold
 
277
  }
278
 
279
 
@@ -287,19 +298,18 @@ class RealtimeSpeakerDiarization:
287
  self.full_sentences = []
288
  self.sentence_speakers = []
289
  self.pending_sentences = []
290
- self.displayed_text = ""
291
- self.last_realtime_text = ""
292
  self.is_running = False
293
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
294
  self.max_speakers = DEFAULT_MAX_SPEAKERS
295
- self.current_conversation = ""
296
- self.audio_buffer = []
297
 
298
  def initialize_models(self):
299
  """Initialize the speaker encoder model"""
300
  try:
301
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
302
- print(f"Using device: {device_str}")
303
 
304
  self.encoder = SpeechBrainEncoder(device=device_str)
305
  success = self.encoder.load_model()
@@ -311,80 +321,94 @@ class RealtimeSpeakerDiarization:
311
  change_threshold=self.change_threshold,
312
  max_speakers=self.max_speakers
313
  )
314
- print("ECAPA-TDNN model loaded successfully!")
315
  return True
316
  else:
317
- print("Failed to load ECAPA-TDNN model")
318
  return False
319
  except Exception as e:
320
- print(f"Model initialization error: {e}")
321
  return False
322
 
323
  def live_text_detected(self, text):
324
  """Callback for real-time transcription updates"""
325
- text = text.strip()
326
- if text:
327
- sentence_delimiters = '.?!。'
328
- prob_sentence_end = (
329
- len(self.last_realtime_text) > 0
330
- and text[-1] in sentence_delimiters
331
- and self.last_realtime_text[-1] in sentence_delimiters
332
- )
333
-
334
- self.last_realtime_text = text
335
-
336
- if prob_sentence_end and FAST_SENTENCE_END:
337
- self.recorder.stop()
338
- elif prob_sentence_end:
339
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
340
- else:
341
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
342
 
343
  def process_final_text(self, text):
344
  """Process final transcribed text with speaker embedding"""
345
  text = text.strip()
346
  if text:
347
  try:
348
- bytes_data = self.recorder.last_transcription_bytes
349
- self.sentence_queue.put((text, bytes_data))
350
- self.pending_sentences.append(text)
 
 
 
 
 
351
  except Exception as e:
352
- print(f"Error processing final text: {e}")
353
 
354
  def process_sentence_queue(self):
355
  """Process sentences in the queue for speaker detection"""
356
  while self.is_running:
357
  try:
358
- text, bytes_data = self.sentence_queue.get(timeout=1)
359
-
360
- # Convert audio data to int16
361
- audio_int16 = np.frombuffer(bytes_data, dtype=np.int16)
362
-
363
- # Extract speaker embedding
364
- speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
365
 
366
- # Store sentence and embedding
367
- self.full_sentences.append((text, speaker_embedding))
368
 
369
- # Fill in missing speaker assignments
370
- while len(self.sentence_speakers) < len(self.full_sentences) - 1:
371
- self.sentence_speakers.append(0)
372
-
373
- # Detect speaker changes
374
- speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
375
- self.sentence_speakers.append(speaker_id)
376
-
377
- # Remove from pending
378
- if text in self.pending_sentences:
379
- self.pending_sentences.remove(text)
380
 
381
- # Update conversation display
382
- self.current_conversation = self.get_formatted_conversation()
 
 
383
 
384
  except queue.Empty:
385
  continue
386
  except Exception as e:
387
- print(f"Error processing sentence: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  def start_recording(self):
390
  """Start the recording and transcription process"""
@@ -392,10 +416,10 @@ class RealtimeSpeakerDiarization:
392
  return "Please initialize models first!"
393
 
394
  try:
395
- # Setup recorder configuration for manual audio input
396
  recorder_config = {
397
  'spinner': False,
398
- 'use_microphone': False, # We'll feed audio manually
399
  'model': FINAL_TRANSCRIPTION_MODEL,
400
  'language': TRANSCRIPTION_LANGUAGE,
401
  'silero_sensitivity': SILERO_SENSITIVITY,
@@ -405,29 +429,28 @@ class RealtimeSpeakerDiarization:
405
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
406
  'min_gap_between_recordings': 0,
407
  'enable_realtime_transcription': True,
408
- 'realtime_processing_pause': 0,
409
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
410
  'on_realtime_transcription_update': self.live_text_detected,
411
  'beam_size': FINAL_BEAM_SIZE,
412
  'beam_size_realtime': REALTIME_BEAM_SIZE,
413
- 'buffer_size': BUFFER_SIZE,
414
  'sample_rate': SAMPLE_RATE,
415
  }
416
 
417
  self.recorder = AudioToTextRecorder(**recorder_config)
418
 
419
- # Start sentence processing thread
420
  self.is_running = True
421
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
422
  self.sentence_thread.start()
423
 
424
- # Start transcription thread
425
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
426
  self.transcription_thread.start()
427
 
428
- return "Recording started successfully! FastRTC audio input ready."
429
 
430
  except Exception as e:
 
431
  return f"Error starting recording: {e}"
432
 
433
  def run_transcription(self):
@@ -436,7 +459,7 @@ class RealtimeSpeakerDiarization:
436
  while self.is_running:
437
  self.recorder.text(self.process_final_text)
438
  except Exception as e:
439
- print(f"Transcription error: {e}")
440
 
441
  def stop_recording(self):
442
  """Stop the recording process"""
@@ -447,12 +470,10 @@ class RealtimeSpeakerDiarization:
447
 
448
  def clear_conversation(self):
449
  """Clear all conversation data"""
450
- self.full_sentences = []
451
- self.sentence_speakers = []
452
- self.pending_sentences = []
453
- self.displayed_text = ""
454
- self.last_realtime_text = ""
455
- self.current_conversation = "Conversation cleared!"
456
 
457
  if self.speaker_detector:
458
  self.speaker_detector = SpeakerChangeDetector(
@@ -475,36 +496,8 @@ class RealtimeSpeakerDiarization:
475
  return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
476
 
477
  def get_formatted_conversation(self):
478
- """Get the formatted conversation with speaker colors"""
479
- try:
480
- sentences_with_style = []
481
-
482
- # Process completed sentences
483
- for i, sentence in enumerate(self.full_sentences):
484
- sentence_text, _ = sentence
485
- if i >= len(self.sentence_speakers):
486
- color = "#FFFFFF"
487
- speaker_name = "Unknown"
488
- else:
489
- speaker_id = self.sentence_speakers[i]
490
- color = self.speaker_detector.get_color_for_speaker(speaker_id)
491
- speaker_name = f"Speaker {speaker_id + 1}"
492
-
493
- sentences_with_style.append(
494
- f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
495
-
496
- # Add pending sentences
497
- for pending_sentence in self.pending_sentences:
498
- sentences_with_style.append(
499
- f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>')
500
-
501
- if sentences_with_style:
502
- return "<br><br>".join(sentences_with_style)
503
- else:
504
- return "Waiting for speech input..."
505
-
506
- except Exception as e:
507
- return f"Error formatting conversation: {e}"
508
 
509
  def get_status_info(self):
510
  """Get current status information"""
@@ -520,689 +513,427 @@ class RealtimeSpeakerDiarization:
520
  f"**Last Similarity:** {status['last_similarity']:.3f}",
521
  f"**Change Threshold:** {status['threshold']:.2f}",
522
  f"**Total Sentences:** {len(self.full_sentences)}",
 
523
  "",
524
- "**Speaker Segment Counts:**"
525
  ]
526
 
527
  for i in range(status['max_speakers']):
528
  color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
529
- status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
 
 
530
 
531
  return "\n".join(status_lines)
532
 
533
  except Exception as e:
534
  return f"Error getting status: {e}"
535
 
536
- def feed_audio_data(self, audio_data):
537
- """Feed audio data to the recorder"""
538
- if not self.is_running or not self.recorder:
539
- return
540
-
541
- try:
542
- # Ensure audio is in the correct format (16-bit PCM)
543
- if isinstance(audio_data, np.ndarray):
544
- if audio_data.dtype != np.int16:
545
- # Convert float to int16
546
- if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
547
- audio_data = (audio_data * 32767).astype(np.int16)
548
- else:
549
- audio_data = audio_data.astype(np.int16)
550
-
551
- # Convert to bytes
552
- audio_bytes = audio_data.tobytes()
553
- else:
554
- audio_bytes = audio_data
555
-
556
- # Feed to recorder
557
- self.recorder.feed_audio(audio_bytes)
558
-
559
- except Exception as e:
560
- print(f"Error feeding audio data: {e}")
561
-
562
  def process_audio_chunk(self, audio_data, sample_rate=16000):
563
  """Process audio chunk from FastRTC input"""
564
- if not self.is_running or self.recorder is None:
565
  return
566
 
567
  try:
568
- # Convert float audio to int16 for the recorder
569
- if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
570
- if np.max(np.abs(audio_data)) <= 1.0:
571
- # Float audio is normalized to [-1, 1], convert to int16
572
- audio_int16 = (audio_data * 32767).astype(np.int16)
573
- else:
574
- # Audio is already in higher range
575
- audio_int16 = audio_data.astype(np.int16)
576
  else:
577
- audio_int16 = audio_data
578
-
579
- # Ensure correct shape (1, N) for the recorder
580
- if len(audio_int16.shape) == 1:
581
- audio_int16 = np.expand_dims(audio_int16, 0)
582
-
583
- # Resample if needed
584
- if sample_rate != SAMPLE_RATE:
585
- audio_int16 = self._resample_audio(audio_int16, sample_rate, SAMPLE_RATE)
586
-
587
- # Convert to bytes for feeding to recorder
588
- audio_bytes = audio_int16.tobytes()
589
-
590
- # Feed to recorder
591
- self.feed_audio_data(audio_bytes)
592
 
593
- except Exception as e:
594
- print(f"Error processing audio chunk: {e}")
595
-
596
- def _resample_audio(self, audio, orig_sr, target_sr):
597
- """Resample audio to target sample rate"""
598
- try:
599
- import scipy.signal
600
 
601
- # Get the resampling ratio
602
- ratio = target_sr / orig_sr
 
603
 
604
- # Calculate the new length
605
- new_length = int(len(audio[0]) * ratio)
606
 
607
- # Resample the audio
608
- resampled = scipy.signal.resample(audio[0], new_length)
609
-
610
- # Return in the same shape format
611
- return np.expand_dims(resampled, 0)
 
612
  except Exception as e:
613
- print(f"Error resampling audio: {e}")
614
- return audio
615
-
616
 
617
- # FastRTC Audio Handler for Real-time Diarization
618
 
 
619
  class DiarizationHandler(AsyncStreamHandler):
620
  def __init__(self, diarization_system):
621
  super().__init__()
622
  self.diarization_system = diarization_system
623
- self.audio_queue = Queue()
624
- self.is_processing = False
625
- self.sample_rate = 16000 # Default sample rate
626
 
627
  def copy(self):
628
  """Return a fresh handler for each new stream connection"""
629
  return DiarizationHandler(self.diarization_system)
630
 
631
  async def emit(self):
632
- """Not used in this implementation - we only receive audio"""
633
  return None
634
 
635
  async def receive(self, frame):
636
- """Receive audio data from FastRTC and process it"""
637
  try:
638
  if not self.diarization_system.is_running:
639
  return
640
 
641
- # Extract audio data from frame
642
- if hasattr(frame, 'data') and frame.data is not None:
643
- audio_data = frame.data
644
- elif hasattr(frame, 'audio') and frame.audio is not None:
645
- audio_data = frame.audio
646
- else:
647
- audio_data = frame
648
 
649
- # Convert to numpy array if needed
650
  if isinstance(audio_data, bytes):
651
- # Convert bytes to numpy array (assuming 16-bit PCM)
652
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
653
- # Normalize to float32 range [-1, 1]
654
- audio_array = audio_array.astype(np.float32) / 32768.0
655
  elif isinstance(audio_data, (list, tuple)):
656
  audio_array = np.array(audio_data, dtype=np.float32)
657
- elif isinstance(audio_data, np.ndarray):
658
- audio_array = audio_data.astype(np.float32)
659
  else:
660
- print(f"Unknown audio data type: {type(audio_data)}")
661
- return
662
-
663
- # Ensure mono audio
664
- if len(audio_array.shape) > 1 and audio_array.shape[1] > 1:
665
- audio_array = np.mean(audio_array, axis=1)
666
 
667
- # Ensure 1D array
668
  if len(audio_array.shape) > 1:
669
  audio_array = audio_array.flatten()
670
 
671
- # Get sample rate from frame if available
672
- sample_rate = getattr(frame, 'sample_rate', self.sample_rate)
673
 
674
- # Process audio asynchronously to avoid blocking
675
- await self.process_audio_async(audio_array, sample_rate)
 
 
 
 
 
676
 
677
  except Exception as e:
678
- print(f"Error in FastRTC audio receive: {e}")
679
- import traceback
680
- traceback.print_exc()
681
 
682
- async def process_audio_async(self, audio_data, sample_rate=16000):
683
  """Process audio data asynchronously"""
684
  try:
685
- # Run the audio processing in a thread pool to avoid blocking
686
  loop = asyncio.get_event_loop()
687
  await loop.run_in_executor(
688
  None,
689
  self.diarization_system.process_audio_chunk,
690
  audio_data,
691
- sample_rate
692
  )
693
  except Exception as e:
694
- print(f"Error in async audio processing: {e}")
695
-
696
- async def start_up(self) -> None:
697
- """Initialize any resources when the stream starts"""
698
- print("FastRTC stream started")
699
- self.is_processing = True
700
-
701
- async def shutdown(self) -> None:
702
- """Clean up any resources when the stream ends"""
703
- print("FastRTC stream shutting down")
704
- self.is_processing = False
705
 
706
 
707
  # Global instances
708
- diarization_system = None # Will be initialized when RealtimeSpeakerDiarization is available
709
  audio_handler = None
710
 
711
-
712
  def initialize_system():
713
  """Initialize the diarization system"""
714
- global audio_handler, diarization_system
715
  try:
716
- if diarization_system is None:
717
- print("Error: RealtimeSpeakerDiarization not initialized")
718
- return "❌ Diarization system not available. Please ensure RealtimeSpeakerDiarization is properly imported."
719
-
720
  success = diarization_system.initialize_models()
721
  if success:
722
  audio_handler = DiarizationHandler(diarization_system)
723
- return "✅ System initialized successfully! Models loaded and FastRTC handler ready."
724
  else:
725
- return "❌ Failed to initialize system. Please check the logs."
726
  except Exception as e:
727
- print(f"Initialization error: {e}")
728
  return f"❌ Initialization error: {str(e)}"
729
 
730
-
731
  def start_recording():
732
  """Start recording and transcription"""
733
  try:
734
- if diarization_system is None:
735
- return "❌ System not initialized"
736
  result = diarization_system.start_recording()
737
- return f"🎙️ {result} - FastRTC audio streaming is active."
738
  except Exception as e:
739
  return f"❌ Failed to start recording: {str(e)}"
740
 
741
-
742
  def stop_recording():
743
  """Stop recording and transcription"""
744
  try:
745
- if diarization_system is None:
746
- return "❌ System not initialized"
747
  result = diarization_system.stop_recording()
748
  return f"⏹️ {result}"
749
  except Exception as e:
750
  return f"❌ Failed to stop recording: {str(e)}"
751
 
752
-
753
  def clear_conversation():
754
  """Clear the conversation"""
755
  try:
756
- if diarization_system is None:
757
- return "❌ System not initialized"
758
  result = diarization_system.clear_conversation()
759
  return f"🗑️ {result}"
760
  except Exception as e:
761
  return f"❌ Failed to clear conversation: {str(e)}"
762
 
763
-
764
  def update_settings(threshold, max_speakers):
765
  """Update system settings"""
766
  try:
767
- if diarization_system is None:
768
- return "❌ System not initialized"
769
  result = diarization_system.update_settings(threshold, max_speakers)
770
  return f"⚙️ {result}"
771
  except Exception as e:
772
  return f"❌ Failed to update settings: {str(e)}"
773
 
774
-
775
  def get_conversation():
776
  """Get the current conversation"""
777
  try:
778
- if diarization_system is None:
779
- return "<i>System not initialized</i>"
780
  return diarization_system.get_formatted_conversation()
781
  except Exception as e:
782
  return f"<i>Error getting conversation: {str(e)}</i>"
783
 
784
-
785
  def get_status():
786
  """Get system status"""
787
  try:
788
- if diarization_system is None:
789
- return "System not initialized"
790
  return diarization_system.get_status_info()
791
  except Exception as e:
792
  return f"Error getting status: {str(e)}"
793
 
794
-
795
  # Create Gradio interface
796
  def create_interface():
797
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
798
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
799
- gr.Markdown("This app performs real-time speech recognition with automatic speaker identification using FastRTC for low-latency audio streaming.")
800
 
801
  with gr.Row():
802
  with gr.Column(scale=2):
803
- # Main conversation display
804
  conversation_output = gr.HTML(
805
- value="<div style='padding: 20px; background: #f5f5f5; border-radius: 10px;'><i>Click 'Initialize System' to start...</i></div>",
806
- label="Live Conversation",
807
- elem_id="conversation_display"
808
  )
809
 
810
  # Control buttons
811
  with gr.Row():
812
  init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
813
- start_btn = gr.Button("🎙️ Start Recording", variant="primary", size="lg", interactive=False)
814
- stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", size="lg", interactive=False)
815
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
816
 
817
- # FastRTC Stream Interface
818
- with gr.Row():
819
- gr.HTML("""
820
- <div id="fastrtc-container" style="border: 2px solid #ddd; border-radius: 10px; padding: 20px; margin: 10px 0;">
821
- <h3>🎵 Audio Stream</h3>
822
- <p>FastRTC audio stream will appear here when recording starts.</p>
823
- <div id="stream-status" style="padding: 10px; background: #f8f9fa; border-radius: 5px; margin-top: 10px;">
824
- Status: Waiting for initialization...
825
- </div>
826
- </div>
827
- """)
828
-
829
  # Status display
830
  status_output = gr.Textbox(
831
  label="System Status",
832
- value="System not initialized. Please click 'Initialize System' to begin.",
833
- lines=6,
834
- interactive=False,
835
- show_copy_button=True
836
  )
837
 
838
  with gr.Column(scale=1):
839
- # Settings panel
840
  gr.Markdown("## ⚙️ Settings")
841
 
842
  threshold_slider = gr.Slider(
843
- minimum=0.1,
844
- maximum=0.95,
845
  step=0.05,
846
- value=0.5, # DEFAULT_CHANGE_THRESHOLD
847
  label="Speaker Change Sensitivity",
848
- info="Lower = more sensitive to speaker changes"
849
  )
850
 
851
  max_speakers_slider = gr.Slider(
852
  minimum=2,
853
- maximum=10, # ABSOLUTE_MAX_SPEAKERS
854
  step=1,
855
- value=4, # DEFAULT_MAX_SPEAKERS
856
- label="Maximum Number of Speakers"
857
  )
858
 
859
- update_settings_btn = gr.Button("Update Settings", variant="secondary")
860
-
861
- # Audio settings
862
- gr.Markdown("## 🔊 Audio Configuration")
863
- with gr.Accordion("Advanced Audio Settings", open=False):
864
- gr.Markdown("""
865
- **Current Configuration:**
866
- - Sample Rate: 16kHz
867
- - Audio Format: 16-bit PCM → Float32 (via AudioProcessor)
868
- - Channels: Mono (stereo converted automatically)
869
- - Buffer Size: 1024 samples for real-time processing
870
- - Processing: Uses existing AudioProcessor.extract_embedding()
871
- """)
872
 
873
  # Instructions
874
- gr.Markdown("## 📝 How to Use")
875
  gr.Markdown("""
876
- 1. **Initialize**: Click "Initialize System" to load AI models
877
- 2. **Start**: Click "Start Recording" to begin processing
878
- 3. **Connect**: The FastRTC stream will activate automatically
879
- 4. **Allow Access**: Grant microphone permissions when prompted
880
- 5. **Speak**: Talk naturally into your microphone
881
- 6. **Monitor**: Watch real-time transcription with speaker colors
882
- """)
883
-
884
- # Performance tips
885
- with gr.Accordion("💡 Performance Tips", open=False):
886
- gr.Markdown("""
887
- - Use Chrome/Edge for best FastRTC performance
888
- - Ensure stable internet connection
889
- - Use headphones to prevent echo
890
- - Position microphone 6-12 inches away
891
- - Minimize background noise
892
- - Allow browser microphone access
893
- """)
894
 
895
- # Speaker color legend
896
- gr.Markdown("## 🎨 Speaker Colors")
897
- speaker_colors = [
898
- ("#FF6B6B", "Red"),
899
- ("#4ECDC4", "Teal"),
900
- ("#45B7D1", "Blue"),
901
- ("#96CEB4", "Green"),
902
- ("#FFEAA7", "Yellow"),
903
- ("#DDA0DD", "Plum"),
904
- ("#98D8C8", "Mint"),
905
- ("#F7DC6F", "Gold")
906
- ]
907
-
908
- color_html = ""
909
- for i, (color, name) in enumerate(speaker_colors[:4]):
910
- color_html += f'<div style="margin: 3px 0;"><span style="color:{color}; font-size: 16px; font-weight: bold;">●</span> Speaker {i+1} ({name})</div>'
911
-
912
- gr.HTML(f"<div style='font-size: 14px;'>{color_html}</div>")
913
-
914
- # Auto-refresh conversation and status
915
- def refresh_display():
916
- try:
917
- conversation = get_conversation()
918
- status = get_status()
919
- return conversation, status
920
- except Exception as e:
921
- error_msg = f"Error refreshing display: {str(e)}"
922
- return f"<i>{error_msg}</i>", error_msg
923
 
924
  # Event handlers
925
  def on_initialize():
926
- try:
927
- result = initialize_system()
928
- success = "successfully" in result.lower()
929
-
930
- conversation, status = refresh_display()
931
-
932
- return (
933
- result, # status_output
934
- gr.update(interactive=success), # start_btn
935
- gr.update(interactive=success), # clear_btn
936
- conversation, # conversation_output
937
- )
938
- except Exception as e:
939
- error_msg = f"❌ Initialization failed: {str(e)}"
940
- return (
941
- error_msg,
942
- gr.update(interactive=False),
943
- gr.update(interactive=False),
944
- "<i>System not ready</i>",
945
- )
946
 
947
  def on_start():
948
- try:
949
- result = start_recording()
950
- return (
951
- result, # status_output
952
- gr.update(interactive=False), # start_btn
953
- gr.update(interactive=True), # stop_btn
954
- )
955
- except Exception as e:
956
- error_msg = f"❌ Failed to start: {str(e)}"
957
- return (
958
- error_msg,
959
- gr.update(interactive=True),
960
- gr.update(interactive=False),
961
- )
962
 
963
  def on_stop():
964
- try:
965
- result = stop_recording()
966
- return (
967
- result, # status_output
968
- gr.update(interactive=True), # start_btn
969
- gr.update(interactive=False), # stop_btn
970
- )
971
- except Exception as e:
972
- error_msg = f"❌ Failed to stop: {str(e)}"
973
- return (
974
- error_msg,
975
- gr.update(interactive=False),
976
- gr.update(interactive=True),
977
- )
978
 
979
  def on_clear():
980
- try:
981
- result = clear_conversation()
982
- conversation, status = refresh_display()
983
- return result, conversation
984
- except Exception as e:
985
- error_msg = f"❌ Failed to clear: {str(e)}"
986
- return error_msg, "<i>Error clearing conversation</i>"
987
 
988
  def on_update_settings(threshold, max_speakers):
989
- try:
990
- result = update_settings(threshold, max_speakers)
991
- return result
992
- except Exception as e:
993
- return f"❌ Failed to update settings: {str(e)}"
994
 
995
- # Connect event handlers
 
 
 
996
  init_btn.click(
997
- on_initialize,
998
- outputs=[status_output, start_btn, clear_btn, conversation_output]
999
  )
1000
 
1001
  start_btn.click(
1002
- on_start,
1003
  outputs=[status_output, start_btn, stop_btn]
1004
  )
1005
 
1006
  stop_btn.click(
1007
- on_stop,
1008
  outputs=[status_output, start_btn, stop_btn]
1009
  )
1010
 
1011
  clear_btn.click(
1012
- on_clear,
1013
- outputs=[status_output, conversation_output]
1014
  )
1015
 
1016
- update_settings_btn.click(
1017
- on_update_settings,
1018
  inputs=[threshold_slider, max_speakers_slider],
1019
  outputs=[status_output]
1020
  )
1021
 
1022
- # Auto-refresh every 2 seconds when active
1023
- refresh_timer = gr.Timer(2.0)
1024
- refresh_timer.tick(
1025
- refresh_display,
1026
- outputs=[conversation_output, status_output]
1027
- )
1028
 
1029
  return interface
1030
 
1031
 
1032
- # FastAPI setup for API endpoints
1033
- def create_fastapi_app():
1034
- """Create FastAPI app with API endpoints"""
1035
- app = FastAPI(
1036
- title="Real-time Speaker Diarization",
1037
- description="Real-time speech recognition with speaker diarization using FastRTC",
1038
- version="1.0.0"
1039
- )
1040
-
1041
- # API Routes
1042
- router = APIRouter()
1043
-
1044
- @router.get("/health")
1045
- async def health_check():
1046
- """Health check endpoint"""
1047
- return {
1048
- "status": "healthy",
1049
- "timestamp": time.time(),
1050
- "system_initialized": diarization_system is not None and hasattr(diarization_system, 'encoder') and diarization_system.encoder is not None,
1051
- "recording_active": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False
1052
- }
1053
-
1054
- @router.get("/api/conversation")
1055
- async def get_conversation_api():
1056
- """Get current conversation"""
1057
- try:
1058
- return {
1059
- "conversation": get_conversation(),
1060
- "status": get_status(),
1061
- "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False,
1062
- "timestamp": time.time()
1063
- }
1064
- except Exception as e:
1065
- return {"error": str(e), "timestamp": time.time()}
1066
-
1067
- @router.post("/api/control/{action}")
1068
- async def control_recording(action: str):
1069
- """Control recording actions"""
1070
- try:
1071
- if action == "start":
1072
- result = start_recording()
1073
- elif action == "stop":
1074
- result = stop_recording()
1075
- elif action == "clear":
1076
- result = clear_conversation()
1077
- elif action == "initialize":
1078
- result = initialize_system()
1079
- else:
1080
- return {"error": "Invalid action. Use: start, stop, clear, or initialize"}
1081
-
1082
- return {
1083
- "result": result,
1084
- "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False,
1085
- "timestamp": time.time()
1086
- }
1087
- except Exception as e:
1088
- return {"error": str(e), "timestamp": time.time()}
1089
-
1090
- app.include_router(router)
1091
- return app
1092
 
 
 
 
1093
 
1094
- # Function to setup FastRTC stream
1095
- def setup_fastrtc_stream(app):
1096
- """Setup FastRTC stream with proper configuration"""
1097
- try:
1098
- if audio_handler is None:
1099
- print("Warning: Audio handler not initialized. Initialize system first.")
1100
- return None
1101
-
1102
- # Get HuggingFace token for TURN server (optional)
1103
- hf_token = os.environ.get("HF_TOKEN")
1104
-
1105
- # Configure RTC settings
1106
- rtc_config = {
1107
- "iceServers": [
1108
- {"urls": "stun:stun.l.google.com:19302"},
1109
- {"urls": "stun:stun1.l.google.com:19302"}
1110
- ]
1111
- }
1112
-
1113
- # Create FastRTC stream
1114
- stream = Stream(
1115
- handler=audio_handler,
1116
- rtc_configuration=rtc_config,
1117
- modality="audio",
1118
- mode="receive" # We only receive audio, don't send
1119
- )
1120
-
1121
- # Mount the stream
1122
- app.mount("/stream", stream)
1123
- print("✅ FastRTC stream configured successfully!")
1124
- return stream
1125
-
1126
- except Exception as e:
1127
- print(f"⚠️ Warning: Failed to setup FastRTC stream: {e}")
1128
- print("Audio streaming may not work properly.")
1129
- return None
1130
 
 
 
 
 
1131
 
1132
- # Main application setup
1133
- def create_app(diarization_sys=None):
1134
- """Create the complete application"""
1135
- global diarization_system
1136
-
1137
- # Set the diarization system
1138
- if diarization_sys is not None:
1139
- diarization_system = diarization_sys
1140
-
1141
- # Create FastAPI app
1142
- fastapi_app = create_fastapi_app()
1143
-
1144
- # Create Gradio interface
1145
- gradio_interface = create_interface()
1146
-
1147
- # Mount Gradio on FastAPI
1148
- app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/")
1149
-
1150
- # Setup FastRTC stream
1151
- if diarization_system is not None:
1152
- # Initialize the system if not already done
1153
- if not hasattr(diarization_system, 'encoder') or diarization_system.encoder is None:
1154
- diarization_system.initialize_models()
1155
-
1156
- # Create audio handler if needed
1157
- global audio_handler
1158
- if audio_handler is None:
1159
- audio_handler = DiarizationHandler(diarization_system)
1160
-
1161
- # Setup and mount the FastRTC stream
1162
- setup_fastrtc_stream(app)
1163
-
1164
- return app, gradio_interface
1165
 
1166
 
1167
- # Entry point for HuggingFace Spaces
1168
  if __name__ == "__main__":
1169
- try:
1170
- # Import your diarization system here
1171
- # from your_module import RealtimeSpeakerDiarization
1172
- diarization_system = RealtimeSpeakerDiarization()
1173
-
1174
- # Create the application
1175
- app, interface = create_app()
1176
-
1177
- # Launch for HuggingFace Spaces
 
 
 
 
 
1178
  interface.launch(
1179
- server_name="0.0.0.0",
1180
- server_port=7860,
1181
  share=True,
1182
- show_error=True,
1183
- quiet=False
1184
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
 
1186
- except Exception as e:
1187
- print(f"Failed to launch application: {e}")
1188
- import traceback
1189
- traceback.print_exc()
1190
-
1191
- # Fallback - launch just Gradio interface
1192
- try:
1193
  interface = create_interface()
1194
  interface.launch(
1195
- server_name="0.0.0.0",
1196
- server_port=int(os.environ.get("PORT", 7860)),
1197
- share=False
 
1198
  )
1199
- except Exception as fallback_error:
1200
- print(f"Fallback launch also failed: {fallback_error}")
1201
-
1202
-
1203
- # Helper function to initialize with your diarization system
1204
- def initialize_with_diarization_system(diarization_sys):
1205
- """Initialize the application with your diarization system"""
1206
- global diarization_system
1207
- diarization_system = diarization_sys
1208
- return create_app(diarization_sys)
 
 
 
 
 
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, AsyncStreamHandler
14
  import json
 
 
15
  import asyncio
16
  import uvicorn
 
17
  from queue import Queue
18
+ import logging
19
+
20
+ # Set up logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
  # Simplified configuration parameters
25
  SILENCE_THRESHS = [0, 0.4]
26
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
 
34
  PRE_RECORDING_BUFFER_DURATION = 0.35
35
 
36
  # Speaker change detection parameters
37
+ DEFAULT_CHANGE_THRESHOLD = 0.65
38
  EMBEDDING_HISTORY_SIZE = 5
39
+ MIN_SEGMENT_DURATION = 1.5
40
  DEFAULT_MAX_SPEAKERS = 4
41
+ ABSOLUTE_MAX_SPEAKERS = 8
42
 
43
  # Global variables
 
44
  SAMPLE_RATE = 16000
45
+ BUFFER_SIZE = 1024
46
  CHANNELS = 1
47
 
48
+ # Speaker colors - more distinguishable colors
49
  SPEAKER_COLORS = [
50
+ "#FF6B6B", # Red
51
+ "#4ECDC4", # Teal
52
+ "#45B7D1", # Blue
53
+ "#96CEB4", # Green
54
+ "#FFEAA7", # Yellow
55
+ "#DDA0DD", # Plum
56
+ "#98D8C8", # Mint
57
+ "#F7DC6F", # Gold
 
 
58
  ]
59
 
60
  SPEAKER_COLOR_NAMES = [
61
+ "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold"
 
62
  ]
63
 
64
 
 
72
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
73
  os.makedirs(self.cache_dir, exist_ok=True)
74
 
 
 
 
 
 
 
 
 
 
 
 
75
  def load_model(self):
76
  """Load the ECAPA-TDNN model"""
77
  try:
78
  from speechbrain.pretrained import EncoderClassifier
79
 
 
 
80
  self.model = EncoderClassifier.from_hparams(
81
  source="speechbrain/spkrec-ecapa-voxceleb",
82
  savedir=self.cache_dir,
 
84
  )
85
 
86
  self.model_loaded = True
87
+ logger.info("ECAPA-TDNN model loaded successfully!")
88
  return True
89
  except Exception as e:
90
+ logger.error(f"Error loading ECAPA-TDNN model: {e}")
91
  return False
92
 
93
  def embed_utterance(self, audio, sr=16000):
 
97
 
98
  try:
99
  if isinstance(audio, np.ndarray):
100
+ # Ensure audio is float32 and properly normalized
101
+ audio = audio.astype(np.float32)
102
+ if np.max(np.abs(audio)) > 1.0:
103
+ audio = audio / np.max(np.abs(audio))
104
+ waveform = torch.tensor(audio).unsqueeze(0)
105
  else:
106
  waveform = audio.unsqueeze(0)
107
 
108
+ # Resample if necessary
109
  if sr != 16000:
110
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
111
 
 
114
 
115
  return embedding.squeeze().cpu().numpy()
116
  except Exception as e:
117
+ logger.error(f"Error extracting embedding: {e}")
118
  return np.zeros(self.embedding_dim)
119
 
120
 
 
122
  """Processes audio data to extract speaker embeddings"""
123
  def __init__(self, encoder):
124
  self.encoder = encoder
125
+ self.audio_buffer = []
126
+ self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio
127
 
128
+ def add_audio_chunk(self, audio_chunk):
129
+ """Add audio chunk to buffer"""
130
+ self.audio_buffer.extend(audio_chunk)
131
+
132
+ # Keep buffer from getting too large
133
+ max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max
134
+ if len(self.audio_buffer) > max_buffer_size:
135
+ self.audio_buffer = self.audio_buffer[-max_buffer_size:]
136
+
137
+ def extract_embedding_from_buffer(self):
138
+ """Extract embedding from current audio buffer"""
139
+ if len(self.audio_buffer) < self.min_audio_length:
140
+ return None
141
 
142
+ try:
143
+ # Use the last portion of the buffer for embedding
144
+ audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32)
145
 
146
+ # Normalize audio
147
+ if np.max(np.abs(audio_segment)) > 0:
148
+ audio_segment = audio_segment / np.max(np.abs(audio_segment))
149
+ else:
150
+ return None
151
 
152
+ embedding = self.encoder.embed_utterance(audio_segment)
153
  return embedding
154
  except Exception as e:
155
+ logger.error(f"Embedding extraction error: {e}")
156
+ return None
157
 
158
 
159
  class SpeakerChangeDetector:
160
+ """Improved speaker change detector"""
161
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
162
  self.embedding_dim = embedding_dim
163
  self.change_threshold = change_threshold
164
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
165
  self.current_speaker = 0
 
 
 
166
  self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
167
+ self.speaker_centroids = [None] * self.max_speakers
168
+ self.last_change_time = time.time()
169
+ self.last_similarity = 1.0
170
  self.active_speakers = set([0])
171
+ self.segment_counter = 0
172
 
173
  def set_max_speakers(self, max_speakers):
174
  """Update the maximum number of speakers"""
175
  new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
176
 
177
  if new_max < self.max_speakers:
178
+ # Remove speakers beyond the new limit
179
  for speaker_id in list(self.active_speakers):
180
  if speaker_id >= new_max:
181
  self.active_speakers.discard(speaker_id)
 
183
  if self.current_speaker >= new_max:
184
  self.current_speaker = 0
185
 
186
+ # Resize arrays
187
  if new_max > self.max_speakers:
 
188
  self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
189
+ self.speaker_centroids.extend([None] * (new_max - self.max_speakers))
190
  else:
 
191
  self.speaker_embeddings = self.speaker_embeddings[:new_max]
192
+ self.speaker_centroids = self.speaker_centroids[:new_max]
193
 
194
  self.max_speakers = new_max
195
 
196
  def set_change_threshold(self, threshold):
197
  """Update the threshold for detecting speaker changes"""
198
+ self.change_threshold = max(0.1, min(threshold, 0.95))
199
 
200
  def add_embedding(self, embedding, timestamp=None):
201
+ """Add a new embedding and detect speaker changes"""
202
  current_time = timestamp or time.time()
203
+ self.segment_counter += 1
204
+
205
+ # Initialize first speaker
206
+ if not self.speaker_embeddings[0]:
207
+ self.speaker_embeddings[0].append(embedding)
208
+ self.speaker_centroids[0] = embedding.copy()
209
+ self.active_speakers.add(0)
210
+ return 0, 1.0
211
+
212
+ # Calculate similarity with current speaker
213
+ current_centroid = self.speaker_centroids[self.current_speaker]
214
+ if current_centroid is not None:
215
+ similarity = 1.0 - cosine(embedding, current_centroid)
216
  else:
217
+ similarity = 0.5
218
 
219
  self.last_similarity = similarity
220
 
221
+ # Check for speaker change
222
  time_since_last_change = current_time - self.last_change_time
223
+ speaker_changed = False
224
 
225
+ if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold:
226
+ # Find best matching speaker
227
+ best_speaker = self.current_speaker
228
+ best_similarity = similarity
229
+
230
+ for speaker_id in self.active_speakers:
231
+ if speaker_id == self.current_speaker:
232
+ continue
 
 
233
 
234
+ centroid = self.speaker_centroids[speaker_id]
235
+ if centroid is not None:
236
+ speaker_similarity = 1.0 - cosine(embedding, centroid)
237
+ if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold:
238
+ best_similarity = speaker_similarity
239
+ best_speaker = speaker_id
240
+
241
+ # If no good match found and we can add a new speaker
242
+ if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers:
243
+ for new_id in range(self.max_speakers):
244
+ if new_id not in self.active_speakers:
245
+ best_speaker = new_id
246
+ self.active_speakers.add(new_id)
247
+ break
248
+
249
+ if best_speaker != self.current_speaker:
250
+ self.current_speaker = best_speaker
251
+ self.last_change_time = current_time
252
+ speaker_changed = True
253
+
254
+ # Update speaker embeddings and centroids
 
 
 
 
255
  self.speaker_embeddings[self.current_speaker].append(embedding)
 
256
 
257
+ # Keep only recent embeddings (sliding window)
258
+ max_embeddings = 20
259
+ if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings:
260
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:]
261
+
262
+ # Update centroid
263
  if self.speaker_embeddings[self.current_speaker]:
264
+ self.speaker_centroids[self.current_speaker] = np.mean(
265
  self.speaker_embeddings[self.current_speaker], axis=0
266
  )
267
 
 
274
  return "#FFFFFF"
275
 
276
  def get_status_info(self):
277
+ """Return status information"""
278
  speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
279
 
280
  return {
 
283
  "active_speakers": len(self.active_speakers),
284
  "max_speakers": self.max_speakers,
285
  "last_similarity": self.last_similarity,
286
+ "threshold": self.change_threshold,
287
+ "segment_counter": self.segment_counter
288
  }
289
 
290
 
 
298
  self.full_sentences = []
299
  self.sentence_speakers = []
300
  self.pending_sentences = []
301
+ self.current_conversation = ""
 
302
  self.is_running = False
303
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
304
  self.max_speakers = DEFAULT_MAX_SPEAKERS
305
+ self.last_transcription = ""
306
+ self.transcription_lock = threading.Lock()
307
 
308
  def initialize_models(self):
309
  """Initialize the speaker encoder model"""
310
  try:
311
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
312
+ logger.info(f"Using device: {device_str}")
313
 
314
  self.encoder = SpeechBrainEncoder(device=device_str)
315
  success = self.encoder.load_model()
 
321
  change_threshold=self.change_threshold,
322
  max_speakers=self.max_speakers
323
  )
324
+ logger.info("Models initialized successfully!")
325
  return True
326
  else:
327
+ logger.error("Failed to load models")
328
  return False
329
  except Exception as e:
330
+ logger.error(f"Model initialization error: {e}")
331
  return False
332
 
333
  def live_text_detected(self, text):
334
  """Callback for real-time transcription updates"""
335
+ with self.transcription_lock:
336
+ self.last_transcription = text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  def process_final_text(self, text):
339
  """Process final transcribed text with speaker embedding"""
340
  text = text.strip()
341
  if text:
342
  try:
343
+ # Get audio data for this transcription
344
+ audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None)
345
+ if audio_bytes:
346
+ self.sentence_queue.put((text, audio_bytes))
347
+ else:
348
+ # If no audio bytes, use current speaker
349
+ self.sentence_queue.put((text, None))
350
+
351
  except Exception as e:
352
+ logger.error(f"Error processing final text: {e}")
353
 
354
  def process_sentence_queue(self):
355
  """Process sentences in the queue for speaker detection"""
356
  while self.is_running:
357
  try:
358
+ text, audio_bytes = self.sentence_queue.get(timeout=1)
 
 
 
 
 
 
359
 
360
+ current_speaker = self.speaker_detector.current_speaker
 
361
 
362
+ if audio_bytes:
363
+ # Convert audio data and extract embedding
364
+ audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
365
+ audio_float = audio_int16.astype(np.float32) / 32768.0
366
+
367
+ # Extract embedding
368
+ embedding = self.audio_processor.encoder.embed_utterance(audio_float)
369
+ if embedding is not None:
370
+ current_speaker, similarity = self.speaker_detector.add_embedding(embedding)
 
 
371
 
372
+ # Store sentence with speaker
373
+ with self.transcription_lock:
374
+ self.full_sentences.append((text, current_speaker))
375
+ self.update_conversation_display()
376
 
377
  except queue.Empty:
378
  continue
379
  except Exception as e:
380
+ logger.error(f"Error processing sentence: {e}")
381
+
382
+ def update_conversation_display(self):
383
+ """Update the conversation display"""
384
+ try:
385
+ sentences_with_style = []
386
+
387
+ for sentence_text, speaker_id in self.full_sentences:
388
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
389
+ speaker_name = f"Speaker {speaker_id + 1}"
390
+ sentences_with_style.append(
391
+ f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> '
392
+ f'<span style="color:#333333;">{sentence_text}</span>'
393
+ )
394
+
395
+ # Add current transcription if available
396
+ if self.last_transcription:
397
+ current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker)
398
+ current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}"
399
+ sentences_with_style.append(
400
+ f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> '
401
+ f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>'
402
+ )
403
+
404
+ if sentences_with_style:
405
+ self.current_conversation = "<br><br>".join(sentences_with_style)
406
+ else:
407
+ self.current_conversation = "<i>Waiting for speech input...</i>"
408
+
409
+ except Exception as e:
410
+ logger.error(f"Error updating conversation display: {e}")
411
+ self.current_conversation = f"<i>Error: {str(e)}</i>"
412
 
413
  def start_recording(self):
414
  """Start the recording and transcription process"""
 
416
  return "Please initialize models first!"
417
 
418
  try:
419
+ # Setup recorder configuration
420
  recorder_config = {
421
  'spinner': False,
422
+ 'use_microphone': True, # Changed to True for direct microphone input
423
  'model': FINAL_TRANSCRIPTION_MODEL,
424
  'language': TRANSCRIPTION_LANGUAGE,
425
  'silero_sensitivity': SILERO_SENSITIVITY,
 
429
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
430
  'min_gap_between_recordings': 0,
431
  'enable_realtime_transcription': True,
432
+ 'realtime_processing_pause': 0.1,
433
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
434
  'on_realtime_transcription_update': self.live_text_detected,
435
  'beam_size': FINAL_BEAM_SIZE,
436
  'beam_size_realtime': REALTIME_BEAM_SIZE,
 
437
  'sample_rate': SAMPLE_RATE,
438
  }
439
 
440
  self.recorder = AudioToTextRecorder(**recorder_config)
441
 
442
+ # Start processing threads
443
  self.is_running = True
444
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
445
  self.sentence_thread.start()
446
 
 
447
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
448
  self.transcription_thread.start()
449
 
450
+ return "Recording started successfully!"
451
 
452
  except Exception as e:
453
+ logger.error(f"Error starting recording: {e}")
454
  return f"Error starting recording: {e}"
455
 
456
  def run_transcription(self):
 
459
  while self.is_running:
460
  self.recorder.text(self.process_final_text)
461
  except Exception as e:
462
+ logger.error(f"Transcription error: {e}")
463
 
464
  def stop_recording(self):
465
  """Stop the recording process"""
 
470
 
471
  def clear_conversation(self):
472
  """Clear all conversation data"""
473
+ with self.transcription_lock:
474
+ self.full_sentences = []
475
+ self.last_transcription = ""
476
+ self.current_conversation = "Conversation cleared!"
 
 
477
 
478
  if self.speaker_detector:
479
  self.speaker_detector = SpeakerChangeDetector(
 
496
  return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
497
 
498
  def get_formatted_conversation(self):
499
+ """Get the formatted conversation"""
500
+ return self.current_conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  def get_status_info(self):
503
  """Get current status information"""
 
513
  f"**Last Similarity:** {status['last_similarity']:.3f}",
514
  f"**Change Threshold:** {status['threshold']:.2f}",
515
  f"**Total Sentences:** {len(self.full_sentences)}",
516
+ f"**Segments Processed:** {status['segment_counter']}",
517
  "",
518
+ "**Speaker Activity:**"
519
  ]
520
 
521
  for i in range(status['max_speakers']):
522
  color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
523
+ count = status['speaker_counts'][i]
524
+ active = "🟢" if count > 0 else "⚫"
525
+ status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments")
526
 
527
  return "\n".join(status_lines)
528
 
529
  except Exception as e:
530
  return f"Error getting status: {e}"
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  def process_audio_chunk(self, audio_data, sample_rate=16000):
533
  """Process audio chunk from FastRTC input"""
534
+ if not self.is_running or self.audio_processor is None:
535
  return
536
 
537
  try:
538
+ # Ensure audio is float32
539
+ if isinstance(audio_data, np.ndarray):
540
+ if audio_data.dtype != np.float32:
541
+ audio_data = audio_data.astype(np.float32)
 
 
 
 
542
  else:
543
+ audio_data = np.array(audio_data, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
+ # Ensure mono
546
+ if len(audio_data.shape) > 1:
547
+ audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten()
 
 
 
 
548
 
549
+ # Normalize if needed
550
+ if np.max(np.abs(audio_data)) > 1.0:
551
+ audio_data = audio_data / np.max(np.abs(audio_data))
552
 
553
+ # Add to audio processor buffer for speaker detection
554
+ self.audio_processor.add_audio_chunk(audio_data)
555
 
556
+ # Periodically extract embeddings for speaker detection
557
+ if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: # Every 0.5 seconds
558
+ embedding = self.audio_processor.extract_embedding_from_buffer()
559
+ if embedding is not None:
560
+ self.speaker_detector.add_embedding(embedding)
561
+
562
  except Exception as e:
563
+ logger.error(f"Error processing audio chunk: {e}")
 
 
564
 
 
565
 
566
+ # FastRTC Audio Handler
567
  class DiarizationHandler(AsyncStreamHandler):
568
  def __init__(self, diarization_system):
569
  super().__init__()
570
  self.diarization_system = diarization_system
571
+ self.audio_buffer = []
572
+ self.buffer_size = BUFFER_SIZE
 
573
 
574
  def copy(self):
575
  """Return a fresh handler for each new stream connection"""
576
  return DiarizationHandler(self.diarization_system)
577
 
578
  async def emit(self):
579
+ """Not used - we only receive audio"""
580
  return None
581
 
582
  async def receive(self, frame):
583
+ """Receive audio data from FastRTC"""
584
  try:
585
  if not self.diarization_system.is_running:
586
  return
587
 
588
+ # Extract audio data
589
+ audio_data = getattr(frame, 'data', frame)
 
 
 
 
 
590
 
591
+ # Convert to numpy array
592
  if isinstance(audio_data, bytes):
593
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
 
 
 
594
  elif isinstance(audio_data, (list, tuple)):
595
  audio_array = np.array(audio_data, dtype=np.float32)
 
 
596
  else:
597
+ audio_array = np.array(audio_data, dtype=np.float32)
 
 
 
 
 
598
 
599
+ # Ensure 1D
600
  if len(audio_array.shape) > 1:
601
  audio_array = audio_array.flatten()
602
 
603
+ # Buffer audio chunks
604
+ self.audio_buffer.extend(audio_array)
605
 
606
+ # Process in chunks
607
+ while len(self.audio_buffer) >= self.buffer_size:
608
+ chunk = np.array(self.audio_buffer[:self.buffer_size])
609
+ self.audio_buffer = self.audio_buffer[self.buffer_size:]
610
+
611
+ # Process asynchronously
612
+ await self.process_audio_async(chunk)
613
 
614
  except Exception as e:
615
+ logger.error(f"Error in FastRTC receive: {e}")
 
 
616
 
617
+ async def process_audio_async(self, audio_data):
618
  """Process audio data asynchronously"""
619
  try:
 
620
  loop = asyncio.get_event_loop()
621
  await loop.run_in_executor(
622
  None,
623
  self.diarization_system.process_audio_chunk,
624
  audio_data,
625
+ SAMPLE_RATE
626
  )
627
  except Exception as e:
628
+ logger.error(f"Error in async audio processing: {e}")
 
 
 
 
 
 
 
 
 
 
629
 
630
 
631
  # Global instances
632
+ diarization_system = RealtimeSpeakerDiarization()
633
  audio_handler = None
634
 
 
635
  def initialize_system():
636
  """Initialize the diarization system"""
637
+ global audio_handler
638
  try:
 
 
 
 
639
  success = diarization_system.initialize_models()
640
  if success:
641
  audio_handler = DiarizationHandler(diarization_system)
642
+ return "✅ System initialized successfully!"
643
  else:
644
+ return "❌ Failed to initialize system. Check logs for details."
645
  except Exception as e:
646
+ logger.error(f"Initialization error: {e}")
647
  return f"❌ Initialization error: {str(e)}"
648
 
 
649
  def start_recording():
650
  """Start recording and transcription"""
651
  try:
 
 
652
  result = diarization_system.start_recording()
653
+ return f"🎙️ {result}"
654
  except Exception as e:
655
  return f"❌ Failed to start recording: {str(e)}"
656
 
 
657
  def stop_recording():
658
  """Stop recording and transcription"""
659
  try:
 
 
660
  result = diarization_system.stop_recording()
661
  return f"⏹️ {result}"
662
  except Exception as e:
663
  return f"❌ Failed to stop recording: {str(e)}"
664
 
 
665
  def clear_conversation():
666
  """Clear the conversation"""
667
  try:
 
 
668
  result = diarization_system.clear_conversation()
669
  return f"🗑️ {result}"
670
  except Exception as e:
671
  return f"❌ Failed to clear conversation: {str(e)}"
672
 
 
673
  def update_settings(threshold, max_speakers):
674
  """Update system settings"""
675
  try:
 
 
676
  result = diarization_system.update_settings(threshold, max_speakers)
677
  return f"⚙️ {result}"
678
  except Exception as e:
679
  return f"❌ Failed to update settings: {str(e)}"
680
 
 
681
  def get_conversation():
682
  """Get the current conversation"""
683
  try:
 
 
684
  return diarization_system.get_formatted_conversation()
685
  except Exception as e:
686
  return f"<i>Error getting conversation: {str(e)}</i>"
687
 
 
688
  def get_status():
689
  """Get system status"""
690
  try:
 
 
691
  return diarization_system.get_status_info()
692
  except Exception as e:
693
  return f"Error getting status: {str(e)}"
694
 
 
695
  # Create Gradio interface
696
  def create_interface():
697
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
698
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
699
+ gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
700
 
701
  with gr.Row():
702
  with gr.Column(scale=2):
703
+ # Conversation display
704
  conversation_output = gr.HTML(
705
+ value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
706
+ label="Live Conversation"
 
707
  )
708
 
709
  # Control buttons
710
  with gr.Row():
711
  init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
712
+ start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
713
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
714
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
715
 
 
 
 
 
 
 
 
 
 
 
 
 
716
  # Status display
717
  status_output = gr.Textbox(
718
  label="System Status",
719
+ value="Ready to initialize...",
720
+ lines=8,
721
+ interactive=False
 
722
  )
723
 
724
  with gr.Column(scale=1):
725
+ # Settings
726
  gr.Markdown("## ⚙️ Settings")
727
 
728
  threshold_slider = gr.Slider(
729
+ minimum=0.3,
730
+ maximum=0.9,
731
  step=0.05,
732
+ value=DEFAULT_CHANGE_THRESHOLD,
733
  label="Speaker Change Sensitivity",
734
+ info="Lower = more sensitive"
735
  )
736
 
737
  max_speakers_slider = gr.Slider(
738
  minimum=2,
739
+ maximum=ABSOLUTE_MAX_SPEAKERS,
740
  step=1,
741
+ value=DEFAULT_MAX_SPEAKERS,
742
+ label="Maximum Speakers"
743
  )
744
 
745
+ update_btn = gr.Button("Update Settings", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
  # Instructions
 
748
  gr.Markdown("""
749
+ ## 📋 Instructions
750
+ 1. **Initialize** the system (loads AI models)
751
+ 2. **Start** recording
752
+ 3. **Speak** - system will transcribe and identify speakers
753
+ 4. **Monitor** real-time results below
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
+ ## 🎨 Speaker Colors
756
+ - 🔴 Speaker 1 (Red)
757
+ - 🟢 Speaker 2 (Teal)
758
+ - 🔵 Speaker 3 (Blue)
759
+ - 🟡 Speaker 4 (Green)
760
+ - 🟣 Speaker 5 (Yellow)
761
+ - 🟤 Speaker 6 (Plum)
762
+ - 🟫 Speaker 7 (Mint)
763
+ - 🟨 Speaker 8 (Gold)
764
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
 
766
  # Event handlers
767
  def on_initialize():
768
+ result = initialize_system()
769
+ if "✅" in result:
770
+ return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
771
+ else:
772
+ return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
 
774
  def on_start():
775
+ result = start_recording()
776
+ return result, gr.update(interactive=False), gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
777
 
778
  def on_stop():
779
+ result = stop_recording()
780
+ return result, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
  def on_clear():
783
+ result = clear_conversation()
784
+ return result
 
 
 
 
 
785
 
786
  def on_update_settings(threshold, max_speakers):
787
+ result = update_settings(threshold, int(max_speakers))
788
+ return result
789
+
790
+ def refresh_conversation():
791
+ return get_conversation()
792
 
793
+ def refresh_status():
794
+ return get_status()
795
+
796
+ # Button click handlers
797
  init_btn.click(
798
+ fn=on_initialize,
799
+ outputs=[status_output, start_btn, stop_btn, clear_btn]
800
  )
801
 
802
  start_btn.click(
803
+ fn=on_start,
804
  outputs=[status_output, start_btn, stop_btn]
805
  )
806
 
807
  stop_btn.click(
808
+ fn=on_stop,
809
  outputs=[status_output, start_btn, stop_btn]
810
  )
811
 
812
  clear_btn.click(
813
+ fn=on_clear,
814
+ outputs=[status_output]
815
  )
816
 
817
+ update_btn.click(
818
+ fn=on_update_settings,
819
  inputs=[threshold_slider, max_speakers_slider],
820
  outputs=[status_output]
821
  )
822
 
823
+ # Create timers for auto-refreshing components
824
+ conversation_timer = gr.Timer(1, refresh_conversation, outputs=[conversation_output])
825
+ status_timer = gr.Timer(2, refresh_status, outputs=[status_output])
 
 
 
826
 
827
  return interface
828
 
829
 
830
+ # FastAPI setup for FastRTC integration
831
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
832
 
833
+ @app.get("/")
834
+ async def root():
835
+ return {"message": "Real-time Speaker Diarization API"}
836
 
837
+ @app.get("/health")
838
+ async def health_check():
839
+ return {"status": "healthy", "system_running": diarization_system.is_running}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
 
841
+ @app.post("/initialize")
842
+ async def api_initialize():
843
+ result = initialize_system()
844
+ return {"result": result, "success": "✅" in result}
845
 
846
+ @app.post("/start")
847
+ async def api_start():
848
+ result = start_recording()
849
+ return {"result": result, "success": "🎙️" in result}
850
+
851
+ @app.post("/stop")
852
+ async def api_stop():
853
+ result = stop_recording()
854
+ return {"result": result, "success": "⏹️" in result}
855
+
856
+ @app.post("/clear")
857
+ async def api_clear():
858
+ result = clear_conversation()
859
+ return {"result": result}
860
+
861
+ @app.get("/conversation")
862
+ async def api_get_conversation():
863
+ return {"conversation": get_conversation()}
864
+
865
+ @app.get("/status")
866
+ async def api_get_status():
867
+ return {"status": get_status()}
868
+
869
+ @app.post("/settings")
870
+ async def api_update_settings(threshold: float, max_speakers: int):
871
+ result = update_settings(threshold, max_speakers)
872
+ return {"result": result}
873
+
874
+ # FastRTC Stream setup
875
+ if audio_handler:
876
+ stream = Stream(handler=audio_handler)
877
+ app.include_router(stream.router, prefix="/stream")
 
878
 
879
 
880
+ # Main execution
881
  if __name__ == "__main__":
882
+ import argparse
883
+
884
+ parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
885
+ parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
886
+ help="Run mode: gradio interface, API only, or both")
887
+ parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
888
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
889
+ parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
890
+
891
+ args = parser.parse_args()
892
+
893
+ if args.mode == "gradio":
894
+ # Run Gradio interface only
895
+ interface = create_interface()
896
  interface.launch(
897
+ server_name=args.host,
898
+ server_port=args.port,
899
  share=True,
900
+ show_error=True
 
901
  )
902
+
903
+ elif args.mode == "api":
904
+ # Run FastAPI only
905
+ uvicorn.run(
906
+ app,
907
+ host=args.host,
908
+ port=args.port,
909
+ log_level="info"
910
+ )
911
+
912
+ elif args.mode == "both":
913
+ # Run both Gradio and FastAPI
914
+ import multiprocessing
915
+ import threading
916
 
917
+ def run_gradio():
 
 
 
 
 
 
918
  interface = create_interface()
919
  interface.launch(
920
+ server_name=args.host,
921
+ server_port=args.port,
922
+ share=True,
923
+ show_error=True
924
  )
925
+
926
+ def run_fastapi():
927
+ uvicorn.run(
928
+ app,
929
+ host=args.host,
930
+ port=args.api_port,
931
+ log_level="info"
932
+ )
933
+
934
+ # Start FastAPI in a separate thread
935
+ api_thread = threading.Thread(target=run_fastapi, daemon=True)
936
+ api_thread.start()
937
+
938
+ # Start Gradio in main thread
939
+ run_gradio()