Saiyaswanth007 commited on
Commit
27be9ef
·
1 Parent(s): af84a93

Check point 4

Browse files
Files changed (1) hide show
  1. app.py +683 -464
app.py CHANGED
@@ -10,18 +10,15 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
- from fastrtc import Stream, AsyncStreamHandler
14
  import json
 
 
15
  import asyncio
16
  import uvicorn
 
17
  from queue import Queue
18
- import logging
19
- from fastrtc import WebRTC
20
-
21
- # Set up logging
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
-
25
  # Simplified configuration parameters
26
  SILENCE_THRESHS = [0, 0.4]
27
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
@@ -35,31 +32,35 @@ MIN_LENGTH_OF_RECORDING = 0.7
35
  PRE_RECORDING_BUFFER_DURATION = 0.35
36
 
37
  # Speaker change detection parameters
38
- DEFAULT_CHANGE_THRESHOLD = 0.65
39
  EMBEDDING_HISTORY_SIZE = 5
40
- MIN_SEGMENT_DURATION = 1.5
41
  DEFAULT_MAX_SPEAKERS = 4
42
- ABSOLUTE_MAX_SPEAKERS = 8
43
 
44
  # Global variables
 
45
  SAMPLE_RATE = 16000
46
- BUFFER_SIZE = 1024
47
  CHANNELS = 1
48
 
49
- # Speaker colors - more distinguishable colors
50
  SPEAKER_COLORS = [
51
- "#FF6B6B", # Red
52
- "#4ECDC4", # Teal
53
- "#45B7D1", # Blue
54
- "#96CEB4", # Green
55
- "#FFEAA7", # Yellow
56
- "#DDA0DD", # Plum
57
- "#98D8C8", # Mint
58
- "#F7DC6F", # Gold
 
 
59
  ]
60
 
61
  SPEAKER_COLOR_NAMES = [
62
- "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold"
 
63
  ]
64
 
65
 
@@ -73,11 +74,24 @@ class SpeechBrainEncoder:
73
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
74
  os.makedirs(self.cache_dir, exist_ok=True)
75
 
 
 
 
 
 
 
 
 
 
 
 
76
  def load_model(self):
77
  """Load the ECAPA-TDNN model"""
78
  try:
79
  from speechbrain.pretrained import EncoderClassifier
80
 
 
 
81
  self.model = EncoderClassifier.from_hparams(
82
  source="speechbrain/spkrec-ecapa-voxceleb",
83
  savedir=self.cache_dir,
@@ -85,10 +99,9 @@ class SpeechBrainEncoder:
85
  )
86
 
87
  self.model_loaded = True
88
- logger.info("ECAPA-TDNN model loaded successfully!")
89
  return True
90
  except Exception as e:
91
- logger.error(f"Error loading ECAPA-TDNN model: {e}")
92
  return False
93
 
94
  def embed_utterance(self, audio, sr=16000):
@@ -98,15 +111,10 @@ class SpeechBrainEncoder:
98
 
99
  try:
100
  if isinstance(audio, np.ndarray):
101
- # Ensure audio is float32 and properly normalized
102
- audio = audio.astype(np.float32)
103
- if np.max(np.abs(audio)) > 1.0:
104
- audio = audio / np.max(np.abs(audio))
105
- waveform = torch.tensor(audio).unsqueeze(0)
106
  else:
107
  waveform = audio.unsqueeze(0)
108
 
109
- # Resample if necessary
110
  if sr != 16000:
111
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
112
 
@@ -115,7 +123,7 @@ class SpeechBrainEncoder:
115
 
116
  return embedding.squeeze().cpu().numpy()
117
  except Exception as e:
118
- logger.error(f"Error extracting embedding: {e}")
119
  return np.zeros(self.embedding_dim)
120
 
121
 
@@ -123,60 +131,41 @@ class AudioProcessor:
123
  """Processes audio data to extract speaker embeddings"""
124
  def __init__(self, encoder):
125
  self.encoder = encoder
126
- self.audio_buffer = []
127
- self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio
128
-
129
- def add_audio_chunk(self, audio_chunk):
130
- """Add audio chunk to buffer"""
131
- self.audio_buffer.extend(audio_chunk)
132
-
133
- # Keep buffer from getting too large
134
- max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max
135
- if len(self.audio_buffer) > max_buffer_size:
136
- self.audio_buffer = self.audio_buffer[-max_buffer_size:]
137
 
138
- def extract_embedding_from_buffer(self):
139
- """Extract embedding from current audio buffer"""
140
- if len(self.audio_buffer) < self.min_audio_length:
141
- return None
142
-
143
  try:
144
- # Use the last portion of the buffer for embedding
145
- audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32)
146
 
147
- # Normalize audio
148
- if np.max(np.abs(audio_segment)) > 0:
149
- audio_segment = audio_segment / np.max(np.abs(audio_segment))
150
- else:
151
- return None
152
 
153
- embedding = self.encoder.embed_utterance(audio_segment)
154
  return embedding
155
  except Exception as e:
156
- logger.error(f"Embedding extraction error: {e}")
157
- return None
158
 
159
 
160
  class SpeakerChangeDetector:
161
- """Improved speaker change detector"""
162
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
163
  self.embedding_dim = embedding_dim
164
  self.change_threshold = change_threshold
165
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
166
  self.current_speaker = 0
167
- self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
168
- self.speaker_centroids = [None] * self.max_speakers
169
  self.last_change_time = time.time()
170
- self.last_similarity = 1.0
 
 
171
  self.active_speakers = set([0])
172
- self.segment_counter = 0
173
 
174
  def set_max_speakers(self, max_speakers):
175
  """Update the maximum number of speakers"""
176
  new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
177
 
178
  if new_max < self.max_speakers:
179
- # Remove speakers beyond the new limit
180
  for speaker_id in list(self.active_speakers):
181
  if speaker_id >= new_max:
182
  self.active_speakers.discard(speaker_id)
@@ -184,85 +173,85 @@ class SpeakerChangeDetector:
184
  if self.current_speaker >= new_max:
185
  self.current_speaker = 0
186
 
187
- # Resize arrays
188
  if new_max > self.max_speakers:
 
189
  self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
190
- self.speaker_centroids.extend([None] * (new_max - self.max_speakers))
191
  else:
 
192
  self.speaker_embeddings = self.speaker_embeddings[:new_max]
193
- self.speaker_centroids = self.speaker_centroids[:new_max]
194
 
195
  self.max_speakers = new_max
196
 
197
  def set_change_threshold(self, threshold):
198
  """Update the threshold for detecting speaker changes"""
199
- self.change_threshold = max(0.1, min(threshold, 0.95))
200
 
201
  def add_embedding(self, embedding, timestamp=None):
202
- """Add a new embedding and detect speaker changes"""
203
  current_time = timestamp or time.time()
204
- self.segment_counter += 1
205
-
206
- # Initialize first speaker
207
- if not self.speaker_embeddings[0]:
208
- self.speaker_embeddings[0].append(embedding)
209
- self.speaker_centroids[0] = embedding.copy()
210
- self.active_speakers.add(0)
211
- return 0, 1.0
212
-
213
- # Calculate similarity with current speaker
214
- current_centroid = self.speaker_centroids[self.current_speaker]
215
- if current_centroid is not None:
216
- similarity = 1.0 - cosine(embedding, current_centroid)
217
  else:
218
- similarity = 0.5
219
 
220
  self.last_similarity = similarity
221
 
222
- # Check for speaker change
223
  time_since_last_change = current_time - self.last_change_time
224
- speaker_changed = False
225
 
226
- if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold:
227
- # Find best matching speaker
228
- best_speaker = self.current_speaker
229
- best_similarity = similarity
230
-
231
- for speaker_id in self.active_speakers:
232
- if speaker_id == self.current_speaker:
233
- continue
 
 
234
 
235
- centroid = self.speaker_centroids[speaker_id]
236
- if centroid is not None:
237
- speaker_similarity = 1.0 - cosine(embedding, centroid)
238
- if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold:
239
- best_similarity = speaker_similarity
240
- best_speaker = speaker_id
241
-
242
- # If no good match found and we can add a new speaker
243
- if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers:
244
- for new_id in range(self.max_speakers):
245
- if new_id not in self.active_speakers:
246
- best_speaker = new_id
247
- self.active_speakers.add(new_id)
248
- break
249
-
250
- if best_speaker != self.current_speaker:
251
- self.current_speaker = best_speaker
252
- self.last_change_time = current_time
253
- speaker_changed = True
254
-
255
- # Update speaker embeddings and centroids
256
- self.speaker_embeddings[self.current_speaker].append(embedding)
257
 
258
- # Keep only recent embeddings (sliding window)
259
- max_embeddings = 20
260
- if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings:
261
- self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:]
262
 
263
- # Update centroid
 
 
 
 
 
264
  if self.speaker_embeddings[self.current_speaker]:
265
- self.speaker_centroids[self.current_speaker] = np.mean(
266
  self.speaker_embeddings[self.current_speaker], axis=0
267
  )
268
 
@@ -275,7 +264,7 @@ class SpeakerChangeDetector:
275
  return "#FFFFFF"
276
 
277
  def get_status_info(self):
278
- """Return status information"""
279
  speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
280
 
281
  return {
@@ -284,8 +273,7 @@ class SpeakerChangeDetector:
284
  "active_speakers": len(self.active_speakers),
285
  "max_speakers": self.max_speakers,
286
  "last_similarity": self.last_similarity,
287
- "threshold": self.change_threshold,
288
- "segment_counter": self.segment_counter
289
  }
290
 
291
 
@@ -299,18 +287,19 @@ class RealtimeSpeakerDiarization:
299
  self.full_sentences = []
300
  self.sentence_speakers = []
301
  self.pending_sentences = []
302
- self.current_conversation = ""
 
303
  self.is_running = False
304
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
305
  self.max_speakers = DEFAULT_MAX_SPEAKERS
306
- self.last_transcription = ""
307
- self.transcription_lock = threading.Lock()
308
 
309
  def initialize_models(self):
310
  """Initialize the speaker encoder model"""
311
  try:
312
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
313
- logger.info(f"Using device: {device_str}")
314
 
315
  self.encoder = SpeechBrainEncoder(device=device_str)
316
  success = self.encoder.load_model()
@@ -322,94 +311,80 @@ class RealtimeSpeakerDiarization:
322
  change_threshold=self.change_threshold,
323
  max_speakers=self.max_speakers
324
  )
325
- logger.info("Models initialized successfully!")
326
  return True
327
  else:
328
- logger.error("Failed to load models")
329
  return False
330
  except Exception as e:
331
- logger.error(f"Model initialization error: {e}")
332
  return False
333
 
334
  def live_text_detected(self, text):
335
  """Callback for real-time transcription updates"""
336
- with self.transcription_lock:
337
- self.last_transcription = text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  def process_final_text(self, text):
340
  """Process final transcribed text with speaker embedding"""
341
  text = text.strip()
342
  if text:
343
  try:
344
- # Get audio data for this transcription
345
- audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None)
346
- if audio_bytes:
347
- self.sentence_queue.put((text, audio_bytes))
348
- else:
349
- # If no audio bytes, use current speaker
350
- self.sentence_queue.put((text, None))
351
-
352
  except Exception as e:
353
- logger.error(f"Error processing final text: {e}")
354
 
355
  def process_sentence_queue(self):
356
  """Process sentences in the queue for speaker detection"""
357
  while self.is_running:
358
  try:
359
- text, audio_bytes = self.sentence_queue.get(timeout=1)
360
 
361
- current_speaker = self.speaker_detector.current_speaker
 
362
 
363
- if audio_bytes:
364
- # Convert audio data and extract embedding
365
- audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
366
- audio_float = audio_int16.astype(np.float32) / 32768.0
367
-
368
- # Extract embedding
369
- embedding = self.audio_processor.encoder.embed_utterance(audio_float)
370
- if embedding is not None:
371
- current_speaker, similarity = self.speaker_detector.add_embedding(embedding)
372
 
373
- # Store sentence with speaker
374
- with self.transcription_lock:
375
- self.full_sentences.append((text, current_speaker))
376
- self.update_conversation_display()
 
 
 
 
 
 
 
 
 
 
377
 
378
  except queue.Empty:
379
  continue
380
  except Exception as e:
381
- logger.error(f"Error processing sentence: {e}")
382
-
383
- def update_conversation_display(self):
384
- """Update the conversation display"""
385
- try:
386
- sentences_with_style = []
387
-
388
- for sentence_text, speaker_id in self.full_sentences:
389
- color = self.speaker_detector.get_color_for_speaker(speaker_id)
390
- speaker_name = f"Speaker {speaker_id + 1}"
391
- sentences_with_style.append(
392
- f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> '
393
- f'<span style="color:#333333;">{sentence_text}</span>'
394
- )
395
-
396
- # Add current transcription if available
397
- if self.last_transcription:
398
- current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker)
399
- current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}"
400
- sentences_with_style.append(
401
- f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> '
402
- f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>'
403
- )
404
-
405
- if sentences_with_style:
406
- self.current_conversation = "<br><br>".join(sentences_with_style)
407
- else:
408
- self.current_conversation = "<i>Waiting for speech input...</i>"
409
-
410
- except Exception as e:
411
- logger.error(f"Error updating conversation display: {e}")
412
- self.current_conversation = f"<i>Error: {str(e)}</i>"
413
 
414
  def start_recording(self):
415
  """Start the recording and transcription process"""
@@ -417,10 +392,10 @@ class RealtimeSpeakerDiarization:
417
  return "Please initialize models first!"
418
 
419
  try:
420
- # Setup recorder configuration
421
  recorder_config = {
422
  'spinner': False,
423
- 'use_microphone': False, # Change to False for Hugging Face Spaces
424
  'model': FINAL_TRANSCRIPTION_MODEL,
425
  'language': TRANSCRIPTION_LANGUAGE,
426
  'silero_sensitivity': SILERO_SENSITIVITY,
@@ -430,28 +405,29 @@ class RealtimeSpeakerDiarization:
430
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
431
  'min_gap_between_recordings': 0,
432
  'enable_realtime_transcription': True,
433
- 'realtime_processing_pause': 0.1,
434
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
435
  'on_realtime_transcription_update': self.live_text_detected,
436
  'beam_size': FINAL_BEAM_SIZE,
437
  'beam_size_realtime': REALTIME_BEAM_SIZE,
 
438
  'sample_rate': SAMPLE_RATE,
439
  }
440
 
441
  self.recorder = AudioToTextRecorder(**recorder_config)
442
 
443
- # Start processing threads
444
  self.is_running = True
445
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
446
  self.sentence_thread.start()
447
 
 
448
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
449
  self.transcription_thread.start()
450
 
451
- return "Recording started successfully!"
452
 
453
  except Exception as e:
454
- logger.error(f"Error starting recording: {e}")
455
  return f"Error starting recording: {e}"
456
 
457
  def run_transcription(self):
@@ -460,7 +436,7 @@ class RealtimeSpeakerDiarization:
460
  while self.is_running:
461
  self.recorder.text(self.process_final_text)
462
  except Exception as e:
463
- logger.error(f"Transcription error: {e}")
464
 
465
  def stop_recording(self):
466
  """Stop the recording process"""
@@ -471,10 +447,12 @@ class RealtimeSpeakerDiarization:
471
 
472
  def clear_conversation(self):
473
  """Clear all conversation data"""
474
- with self.transcription_lock:
475
- self.full_sentences = []
476
- self.last_transcription = ""
477
- self.current_conversation = "Conversation cleared!"
 
 
478
 
479
  if self.speaker_detector:
480
  self.speaker_detector = SpeakerChangeDetector(
@@ -497,8 +475,36 @@ class RealtimeSpeakerDiarization:
497
  return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
498
 
499
  def get_formatted_conversation(self):
500
- """Get the formatted conversation"""
501
- return self.current_conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  def get_status_info(self):
504
  """Get current status information"""
@@ -514,476 +520,689 @@ class RealtimeSpeakerDiarization:
514
  f"**Last Similarity:** {status['last_similarity']:.3f}",
515
  f"**Change Threshold:** {status['threshold']:.2f}",
516
  f"**Total Sentences:** {len(self.full_sentences)}",
517
- f"**Segments Processed:** {status['segment_counter']}",
518
  "",
519
- "**Speaker Activity:**"
520
  ]
521
 
522
  for i in range(status['max_speakers']):
523
  color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
524
- count = status['speaker_counts'][i]
525
- active = "🟢" if count > 0 else "⚫"
526
- status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments")
527
 
528
  return "\n".join(status_lines)
529
 
530
  except Exception as e:
531
  return f"Error getting status: {e}"
532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  def process_audio_chunk(self, audio_data, sample_rate=16000):
534
  """Process audio chunk from FastRTC input"""
535
- if not self.is_running or self.audio_processor is None:
536
  return
537
 
538
  try:
539
- # Ensure audio is float32
540
- if isinstance(audio_data, np.ndarray):
541
- if audio_data.dtype != np.float32:
542
- audio_data = audio_data.astype(np.float32)
 
 
 
 
543
  else:
544
- audio_data = np.array(audio_data, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- # Ensure mono
547
- if len(audio_data.shape) > 1:
548
- audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten()
549
 
550
- # Normalize if needed
551
- if np.max(np.abs(audio_data)) > 1.0:
552
- audio_data = audio_data / np.max(np.abs(audio_data))
553
 
554
- # Add to audio processor buffer for speaker detection
555
- self.audio_processor.add_audio_chunk(audio_data)
 
 
556
 
557
- # Periodically extract embeddings for speaker detection
558
- if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: # Every 0.5 seconds
559
- embedding = self.audio_processor.extract_embedding_from_buffer()
560
- if embedding is not None:
561
- self.speaker_detector.add_embedding(embedding)
562
-
 
 
 
 
 
563
  except Exception as e:
564
- logger.error(f"Error processing audio chunk: {e}")
 
 
565
 
 
566
 
567
- # FastRTC Audio Handler
568
  class DiarizationHandler(AsyncStreamHandler):
569
  def __init__(self, diarization_system):
570
  super().__init__()
571
  self.diarization_system = diarization_system
572
- self.audio_buffer = []
573
- self.buffer_size = BUFFER_SIZE
 
574
 
575
  def copy(self):
576
  """Return a fresh handler for each new stream connection"""
577
  return DiarizationHandler(self.diarization_system)
578
 
579
  async def emit(self):
580
- """Not used - we only receive audio"""
581
  return None
582
 
583
  async def receive(self, frame):
584
- """Receive audio data from FastRTC"""
585
  try:
586
  if not self.diarization_system.is_running:
587
  return
588
 
589
- # Extract audio data
590
- audio_data = getattr(frame, 'data', frame)
 
 
 
 
 
591
 
592
- # Convert to numpy array
593
  if isinstance(audio_data, bytes):
594
- audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
 
 
 
595
  elif isinstance(audio_data, (list, tuple)):
596
  audio_array = np.array(audio_data, dtype=np.float32)
 
 
597
  else:
598
- audio_array = np.array(audio_data, dtype=np.float32)
 
599
 
600
- # Ensure 1D
 
 
 
 
601
  if len(audio_array.shape) > 1:
602
  audio_array = audio_array.flatten()
603
 
604
- # Buffer audio chunks
605
- self.audio_buffer.extend(audio_array)
606
 
607
- # Process in chunks
608
- while len(self.audio_buffer) >= self.buffer_size:
609
- chunk = np.array(self.audio_buffer[:self.buffer_size])
610
- self.audio_buffer = self.audio_buffer[self.buffer_size:]
611
-
612
- # Process asynchronously
613
- await self.process_audio_async(chunk)
614
 
615
  except Exception as e:
616
- logger.error(f"Error in FastRTC receive: {e}")
 
 
617
 
618
- async def process_audio_async(self, audio_data):
619
  """Process audio data asynchronously"""
620
  try:
 
621
  loop = asyncio.get_event_loop()
622
  await loop.run_in_executor(
623
  None,
624
  self.diarization_system.process_audio_chunk,
625
  audio_data,
626
- SAMPLE_RATE
627
  )
628
  except Exception as e:
629
- logger.error(f"Error in async audio processing: {e}")
 
 
 
 
 
 
 
 
 
 
630
 
631
 
632
  # Global instances
633
- diarization_system = RealtimeSpeakerDiarization()
634
  audio_handler = None
635
 
 
636
  def initialize_system():
637
  """Initialize the diarization system"""
638
- global stream
639
  try:
 
 
 
 
640
  success = diarization_system.initialize_models()
641
  if success:
642
- # Update the Stream's handler to use our DiarizationHandler
643
- stream.handler = DiarizationHandler(diarization_system)
644
- return "✅ System initialized successfully!"
645
  else:
646
- return "❌ Failed to initialize system. Check logs for details."
647
  except Exception as e:
648
- logger.error(f"Initialization error: {e}")
649
  return f"❌ Initialization error: {str(e)}"
650
 
 
651
  def start_recording():
652
  """Start recording and transcription"""
653
  try:
 
 
654
  result = diarization_system.start_recording()
655
- return result
656
  except Exception as e:
657
  return f"❌ Failed to start recording: {str(e)}"
658
 
659
- def on_start():
660
- result = start_recording()
661
- return result, gr.update(interactive=False), gr.update(interactive=True)
662
 
663
  def stop_recording():
664
  """Stop recording and transcription"""
665
  try:
 
 
666
  result = diarization_system.stop_recording()
667
  return f"⏹️ {result}"
668
  except Exception as e:
669
  return f"❌ Failed to stop recording: {str(e)}"
670
 
 
671
  def clear_conversation():
672
  """Clear the conversation"""
673
  try:
 
 
674
  result = diarization_system.clear_conversation()
675
  return f"🗑️ {result}"
676
  except Exception as e:
677
  return f"❌ Failed to clear conversation: {str(e)}"
678
 
 
679
  def update_settings(threshold, max_speakers):
680
  """Update system settings"""
681
  try:
 
 
682
  result = diarization_system.update_settings(threshold, max_speakers)
683
  return f"⚙️ {result}"
684
  except Exception as e:
685
  return f"❌ Failed to update settings: {str(e)}"
686
 
 
687
  def get_conversation():
688
  """Get the current conversation"""
689
  try:
 
 
690
  return diarization_system.get_formatted_conversation()
691
  except Exception as e:
692
  return f"<i>Error getting conversation: {str(e)}</i>"
693
 
 
694
  def get_status():
695
  """Get system status"""
696
  try:
 
 
697
  return diarization_system.get_status_info()
698
  except Exception as e:
699
  return f"Error getting status: {str(e)}"
700
 
 
701
  # Create Gradio interface
702
  def create_interface():
703
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
704
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
705
- gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
706
 
707
  with gr.Row():
708
  with gr.Column(scale=2):
709
- # Replace WebRTC with standard Gradio audio component
710
- audio_component = gr.Audio(
711
- label="Audio Input",
712
- sources=["microphone"],
713
- streaming=True
714
- )
715
-
716
- # Conversation display
717
  conversation_output = gr.HTML(
718
- value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
719
- label="Live Conversation"
 
720
  )
721
 
722
  # Control buttons
723
  with gr.Row():
724
  init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
725
- start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
726
- stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
727
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
728
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  # Status display
730
  status_output = gr.Textbox(
731
  label="System Status",
732
- value="Ready to initialize...",
733
- lines=8,
734
- interactive=False
 
735
  )
736
 
737
  with gr.Column(scale=1):
738
- # Settings
739
  gr.Markdown("## ⚙️ Settings")
740
 
741
  threshold_slider = gr.Slider(
742
- minimum=0.3,
743
- maximum=0.9,
744
  step=0.05,
745
- value=DEFAULT_CHANGE_THRESHOLD,
746
  label="Speaker Change Sensitivity",
747
- info="Lower = more sensitive"
748
  )
749
 
750
  max_speakers_slider = gr.Slider(
751
  minimum=2,
752
- maximum=ABSOLUTE_MAX_SPEAKERS,
753
  step=1,
754
- value=DEFAULT_MAX_SPEAKERS,
755
- label="Maximum Speakers"
756
  )
757
 
758
- update_btn = gr.Button("Update Settings", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
759
 
760
  # Instructions
 
761
  gr.Markdown("""
762
- ## 📋 Instructions
763
- 1. **Initialize** the system (loads AI models)
764
- 2. **Start** recording
765
- 3. **Speak** - system will transcribe and identify speakers
766
- 4. **Monitor** real-time results below
767
-
768
- ## 🎨 Speaker Colors
769
- - 🔴 Speaker 1 (Red)
770
- - 🟢 Speaker 2 (Teal)
771
- - 🔵 Speaker 3 (Blue)
772
- - 🟡 Speaker 4 (Green)
773
- - 🟣 Speaker 5 (Yellow)
774
- - 🟤 Speaker 6 (Plum)
775
- - 🟫 Speaker 7 (Mint)
776
- - 🟨 Speaker 8 (Gold)
777
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
 
779
  # Event handlers
780
  def on_initialize():
781
- result = initialize_system()
782
- if "✅" in result:
783
- return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
784
- else:
785
- return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
 
787
  def on_start():
788
- result = start_recording()
789
- return result, gr.update(interactive=False), gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
790
 
791
  def on_stop():
792
- result = stop_recording()
793
- return result, gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
794
 
795
  def on_clear():
796
- result = clear_conversation()
797
- return result
 
 
 
 
 
798
 
799
  def on_update_settings(threshold, max_speakers):
800
- result = update_settings(threshold, int(max_speakers))
801
- return result
802
-
803
- def refresh_conversation():
804
- return get_conversation()
805
-
806
- def refresh_status():
807
- return get_status()
808
 
809
- # Button click handlers
810
  init_btn.click(
811
- fn=on_initialize,
812
- outputs=[status_output, start_btn, stop_btn, clear_btn]
813
  )
814
 
815
  start_btn.click(
816
- fn=on_start,
817
  outputs=[status_output, start_btn, stop_btn]
818
  )
819
 
820
  stop_btn.click(
821
- fn=on_stop,
822
  outputs=[status_output, start_btn, stop_btn]
823
  )
824
 
825
  clear_btn.click(
826
- fn=on_clear,
827
- outputs=[status_output]
828
  )
829
 
830
- update_btn.click(
831
- fn=on_update_settings,
832
  inputs=[threshold_slider, max_speakers_slider],
833
  outputs=[status_output]
834
  )
835
 
836
- # Auto-refresh conversation display every 1 second
837
- conversation_timer = gr.Timer(1)
838
- conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
839
-
840
- # Auto-refresh status every 2 seconds
841
- status_timer = gr.Timer(2)
842
- status_timer.tick(refresh_status, outputs=[status_output])
843
-
844
- # Process audio from Gradio component
845
- def process_audio_input(audio_data):
846
- if audio_data is not None and diarization_system.is_running:
847
- # Extract audio data
848
- if isinstance(audio_data, tuple) and len(audio_data) >= 2:
849
- sample_rate, audio_array = audio_data[0], audio_data[1]
850
- diarization_system.process_audio_chunk(audio_array, sample_rate)
851
- return get_conversation()
852
-
853
- # Connect audio component to processing function
854
- audio_component.stream(
855
- fn=process_audio_input,
856
- outputs=[conversation_output]
857
  )
858
-
859
  return interface
860
 
861
 
862
- # FastAPI setup for FastRTC integration
863
- app = FastAPI()
864
-
865
- # Create a placeholder handler - will be properly initialized later
866
- class DefaultHandler(AsyncStreamHandler):
867
- def __init__(self):
868
- super().__init__()
 
869
 
870
- async def receive(self, frame):
871
- pass
872
 
873
- async def emit(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  return None
 
 
 
 
 
 
875
 
876
- def copy(self):
877
- return DefaultHandler()
 
878
 
879
- async def shutdown(self):
880
- pass
881
 
882
- async def start_up(self):
883
- pass
884
-
885
- # Initialize with placeholder handler
886
- stream = Stream(handler=DefaultHandler(), modality="audio", mode="send-receive")
887
- stream.mount(app)
888
-
889
- @app.get("/")
890
- async def root():
891
- return {"message": "Real-time Speaker Diarization API"}
892
-
893
- @app.get("/health")
894
- async def health_check():
895
- return {"status": "healthy", "system_running": diarization_system.is_running}
896
-
897
- @app.post("/initialize")
898
- async def api_initialize():
899
- result = initialize_system()
900
- return {"result": result, "success": "✅" in result}
901
-
902
- @app.post("/start")
903
- async def api_start():
904
- result = start_recording()
905
- return {"result": result, "success": "🎙️" in result}
906
-
907
- @app.post("/stop")
908
- async def api_stop():
909
- result = stop_recording()
910
- return {"result": result, "success": "⏹️" in result}
911
-
912
- @app.post("/clear")
913
- async def api_clear():
914
- result = clear_conversation()
915
- return {"result": result}
916
-
917
- @app.get("/conversation")
918
- async def api_get_conversation():
919
- return {"conversation": get_conversation()}
920
-
921
- @app.get("/status")
922
- async def api_get_status():
923
- return {"status": get_status()}
924
-
925
- @app.post("/settings")
926
- async def api_update_settings(threshold: float, max_speakers: int):
927
- result = update_settings(threshold, max_speakers)
928
- return {"result": result}
929
-
930
- # Main execution
931
- if __name__ == "__main__":
932
- import argparse
933
 
934
- parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
935
- parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
936
- help="Run mode: gradio interface, API only, or both")
937
- parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
938
- parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
939
- parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
940
 
941
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
942
 
943
- if args.mode == "gradio":
944
- # Run Gradio interface only
945
- interface = create_interface()
 
 
 
 
 
 
 
 
 
 
 
946
  interface.launch(
947
- server_name=args.host,
948
- server_port=args.port,
949
  share=True,
950
- show_error=True
 
951
  )
952
-
953
- elif args.mode == "api":
954
- # Run FastAPI only
955
- uvicorn.run(
956
- app,
957
- host=args.host,
958
- port=args.port,
959
- log_level="info"
960
- )
961
-
962
- elif args.mode == "both":
963
- # Run both Gradio and FastAPI
964
- import multiprocessing
965
- import threading
966
 
967
- def run_gradio():
 
 
 
 
 
 
968
  interface = create_interface()
969
  interface.launch(
970
- server_name=args.host,
971
- server_port=args.port,
972
- share=True,
973
- show_error=True
974
  )
975
-
976
- def run_fastapi():
977
- uvicorn.run(
978
- app,
979
- host=args.host,
980
- port=args.api_port,
981
- log_level="info"
982
- )
983
-
984
- # Start FastAPI in a separate thread
985
- api_thread = threading.Thread(target=run_fastapi, daemon=True)
986
- api_thread.start()
987
-
988
- # Start Gradio in main thread
989
- run_gradio()
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI, APIRouter
13
+ from fastrtc import Stream, AsyncStreamHandler, ReplyOnPause, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
14
  import json
15
+ import io
16
+ import wave
17
  import asyncio
18
  import uvicorn
19
+ import socket
20
  from queue import Queue
21
+ import time
 
 
 
 
 
 
22
  # Simplified configuration parameters
23
  SILENCE_THRESHS = [0, 0.4]
24
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
 
32
  PRE_RECORDING_BUFFER_DURATION = 0.35
33
 
34
  # Speaker change detection parameters
35
+ DEFAULT_CHANGE_THRESHOLD = 0.7
36
  EMBEDDING_HISTORY_SIZE = 5
37
+ MIN_SEGMENT_DURATION = 1.0
38
  DEFAULT_MAX_SPEAKERS = 4
39
+ ABSOLUTE_MAX_SPEAKERS = 10
40
 
41
  # Global variables
42
+ FAST_SENTENCE_END = True
43
  SAMPLE_RATE = 16000
44
+ BUFFER_SIZE = 512
45
  CHANNELS = 1
46
 
47
+ # Speaker colors
48
  SPEAKER_COLORS = [
49
+ "#FFFF00", # Yellow
50
+ "#FF0000", # Red
51
+ "#00FF00", # Green
52
+ "#00FFFF", # Cyan
53
+ "#FF00FF", # Magenta
54
+ "#0000FF", # Blue
55
+ "#FF8000", # Orange
56
+ "#00FF80", # Spring Green
57
+ "#8000FF", # Purple
58
+ "#FFFFFF", # White
59
  ]
60
 
61
  SPEAKER_COLOR_NAMES = [
62
+ "Yellow", "Red", "Green", "Cyan", "Magenta",
63
+ "Blue", "Orange", "Spring Green", "Purple", "White"
64
  ]
65
 
66
 
 
74
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
75
  os.makedirs(self.cache_dir, exist_ok=True)
76
 
77
+ def _download_model(self):
78
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
79
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
80
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
81
+
82
+ if not os.path.exists(model_path):
83
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
84
+ urllib.request.urlretrieve(model_url, model_path)
85
+
86
+ return model_path
87
+
88
  def load_model(self):
89
  """Load the ECAPA-TDNN model"""
90
  try:
91
  from speechbrain.pretrained import EncoderClassifier
92
 
93
+ model_path = self._download_model()
94
+
95
  self.model = EncoderClassifier.from_hparams(
96
  source="speechbrain/spkrec-ecapa-voxceleb",
97
  savedir=self.cache_dir,
 
99
  )
100
 
101
  self.model_loaded = True
 
102
  return True
103
  except Exception as e:
104
+ print(f"Error loading ECAPA-TDNN model: {e}")
105
  return False
106
 
107
  def embed_utterance(self, audio, sr=16000):
 
111
 
112
  try:
113
  if isinstance(audio, np.ndarray):
114
+ waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
 
 
 
 
115
  else:
116
  waveform = audio.unsqueeze(0)
117
 
 
118
  if sr != 16000:
119
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
120
 
 
123
 
124
  return embedding.squeeze().cpu().numpy()
125
  except Exception as e:
126
+ print(f"Error extracting embedding: {e}")
127
  return np.zeros(self.embedding_dim)
128
 
129
 
 
131
  """Processes audio data to extract speaker embeddings"""
132
  def __init__(self, encoder):
133
  self.encoder = encoder
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def extract_embedding(self, audio_int16):
 
 
 
 
136
  try:
137
+ float_audio = audio_int16.astype(np.float32) / 32768.0
 
138
 
139
+ if np.abs(float_audio).max() > 1.0:
140
+ float_audio = float_audio / np.abs(float_audio).max()
141
+
142
+ embedding = self.encoder.embed_utterance(float_audio)
 
143
 
 
144
  return embedding
145
  except Exception as e:
146
+ print(f"Embedding extraction error: {e}")
147
+ return np.zeros(self.encoder.embedding_dim)
148
 
149
 
150
  class SpeakerChangeDetector:
151
+ """Speaker change detector that supports a configurable number of speakers"""
152
  def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
153
  self.embedding_dim = embedding_dim
154
  self.change_threshold = change_threshold
155
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
156
  self.current_speaker = 0
157
+ self.previous_embeddings = []
 
158
  self.last_change_time = time.time()
159
+ self.mean_embeddings = [None] * self.max_speakers
160
+ self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
161
+ self.last_similarity = 0.0
162
  self.active_speakers = set([0])
 
163
 
164
  def set_max_speakers(self, max_speakers):
165
  """Update the maximum number of speakers"""
166
  new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
167
 
168
  if new_max < self.max_speakers:
 
169
  for speaker_id in list(self.active_speakers):
170
  if speaker_id >= new_max:
171
  self.active_speakers.discard(speaker_id)
 
173
  if self.current_speaker >= new_max:
174
  self.current_speaker = 0
175
 
 
176
  if new_max > self.max_speakers:
177
+ self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
178
  self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
 
179
  else:
180
+ self.mean_embeddings = self.mean_embeddings[:new_max]
181
  self.speaker_embeddings = self.speaker_embeddings[:new_max]
 
182
 
183
  self.max_speakers = new_max
184
 
185
  def set_change_threshold(self, threshold):
186
  """Update the threshold for detecting speaker changes"""
187
+ self.change_threshold = max(0.1, min(threshold, 0.99))
188
 
189
  def add_embedding(self, embedding, timestamp=None):
190
+ """Add a new embedding and check if there's a speaker change"""
191
  current_time = timestamp or time.time()
192
+
193
+ if not self.previous_embeddings:
194
+ self.previous_embeddings.append(embedding)
195
+ self.speaker_embeddings[self.current_speaker].append(embedding)
196
+ if self.mean_embeddings[self.current_speaker] is None:
197
+ self.mean_embeddings[self.current_speaker] = embedding.copy()
198
+ return self.current_speaker, 1.0
199
+
200
+ current_mean = self.mean_embeddings[self.current_speaker]
201
+ if current_mean is not None:
202
+ similarity = 1.0 - cosine(embedding, current_mean)
 
 
203
  else:
204
+ similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
205
 
206
  self.last_similarity = similarity
207
 
 
208
  time_since_last_change = current_time - self.last_change_time
209
+ is_speaker_change = False
210
 
211
+ if time_since_last_change >= MIN_SEGMENT_DURATION:
212
+ if similarity < self.change_threshold:
213
+ best_speaker = self.current_speaker
214
+ best_similarity = similarity
215
+
216
+ for speaker_id in range(self.max_speakers):
217
+ if speaker_id == self.current_speaker:
218
+ continue
219
+
220
+ speaker_mean = self.mean_embeddings[speaker_id]
221
 
222
+ if speaker_mean is not None:
223
+ speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
224
+
225
+ if speaker_similarity > best_similarity:
226
+ best_similarity = speaker_similarity
227
+ best_speaker = speaker_id
228
+
229
+ if best_speaker != self.current_speaker:
230
+ is_speaker_change = True
231
+ self.current_speaker = best_speaker
232
+ elif len(self.active_speakers) < self.max_speakers:
233
+ for new_id in range(self.max_speakers):
234
+ if new_id not in self.active_speakers:
235
+ is_speaker_change = True
236
+ self.current_speaker = new_id
237
+ self.active_speakers.add(new_id)
238
+ break
239
+
240
+ if is_speaker_change:
241
+ self.last_change_time = current_time
 
 
242
 
243
+ self.previous_embeddings.append(embedding)
244
+ if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
245
+ self.previous_embeddings.pop(0)
 
246
 
247
+ self.speaker_embeddings[self.current_speaker].append(embedding)
248
+ self.active_speakers.add(self.current_speaker)
249
+
250
+ if len(self.speaker_embeddings[self.current_speaker]) > 30:
251
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
252
+
253
  if self.speaker_embeddings[self.current_speaker]:
254
+ self.mean_embeddings[self.current_speaker] = np.mean(
255
  self.speaker_embeddings[self.current_speaker], axis=0
256
  )
257
 
 
264
  return "#FFFFFF"
265
 
266
  def get_status_info(self):
267
+ """Return status information about the speaker change detector"""
268
  speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
269
 
270
  return {
 
273
  "active_speakers": len(self.active_speakers),
274
  "max_speakers": self.max_speakers,
275
  "last_similarity": self.last_similarity,
276
+ "threshold": self.change_threshold
 
277
  }
278
 
279
 
 
287
  self.full_sentences = []
288
  self.sentence_speakers = []
289
  self.pending_sentences = []
290
+ self.displayed_text = ""
291
+ self.last_realtime_text = ""
292
  self.is_running = False
293
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
294
  self.max_speakers = DEFAULT_MAX_SPEAKERS
295
+ self.current_conversation = ""
296
+ self.audio_buffer = []
297
 
298
  def initialize_models(self):
299
  """Initialize the speaker encoder model"""
300
  try:
301
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
302
+ print(f"Using device: {device_str}")
303
 
304
  self.encoder = SpeechBrainEncoder(device=device_str)
305
  success = self.encoder.load_model()
 
311
  change_threshold=self.change_threshold,
312
  max_speakers=self.max_speakers
313
  )
314
+ print("ECAPA-TDNN model loaded successfully!")
315
  return True
316
  else:
317
+ print("Failed to load ECAPA-TDNN model")
318
  return False
319
  except Exception as e:
320
+ print(f"Model initialization error: {e}")
321
  return False
322
 
323
  def live_text_detected(self, text):
324
  """Callback for real-time transcription updates"""
325
+ text = text.strip()
326
+ if text:
327
+ sentence_delimiters = '.?!。'
328
+ prob_sentence_end = (
329
+ len(self.last_realtime_text) > 0
330
+ and text[-1] in sentence_delimiters
331
+ and self.last_realtime_text[-1] in sentence_delimiters
332
+ )
333
+
334
+ self.last_realtime_text = text
335
+
336
+ if prob_sentence_end and FAST_SENTENCE_END:
337
+ self.recorder.stop()
338
+ elif prob_sentence_end:
339
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
340
+ else:
341
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
342
 
343
  def process_final_text(self, text):
344
  """Process final transcribed text with speaker embedding"""
345
  text = text.strip()
346
  if text:
347
  try:
348
+ bytes_data = self.recorder.last_transcription_bytes
349
+ self.sentence_queue.put((text, bytes_data))
350
+ self.pending_sentences.append(text)
 
 
 
 
 
351
  except Exception as e:
352
+ print(f"Error processing final text: {e}")
353
 
354
  def process_sentence_queue(self):
355
  """Process sentences in the queue for speaker detection"""
356
  while self.is_running:
357
  try:
358
+ text, bytes_data = self.sentence_queue.get(timeout=1)
359
 
360
+ # Convert audio data to int16
361
+ audio_int16 = np.frombuffer(bytes_data, dtype=np.int16)
362
 
363
+ # Extract speaker embedding
364
+ speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
365
+
366
+ # Store sentence and embedding
367
+ self.full_sentences.append((text, speaker_embedding))
 
 
 
 
368
 
369
+ # Fill in missing speaker assignments
370
+ while len(self.sentence_speakers) < len(self.full_sentences) - 1:
371
+ self.sentence_speakers.append(0)
372
+
373
+ # Detect speaker changes
374
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
375
+ self.sentence_speakers.append(speaker_id)
376
+
377
+ # Remove from pending
378
+ if text in self.pending_sentences:
379
+ self.pending_sentences.remove(text)
380
+
381
+ # Update conversation display
382
+ self.current_conversation = self.get_formatted_conversation()
383
 
384
  except queue.Empty:
385
  continue
386
  except Exception as e:
387
+ print(f"Error processing sentence: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  def start_recording(self):
390
  """Start the recording and transcription process"""
 
392
  return "Please initialize models first!"
393
 
394
  try:
395
+ # Setup recorder configuration for manual audio input
396
  recorder_config = {
397
  'spinner': False,
398
+ 'use_microphone': False, # We'll feed audio manually
399
  'model': FINAL_TRANSCRIPTION_MODEL,
400
  'language': TRANSCRIPTION_LANGUAGE,
401
  'silero_sensitivity': SILERO_SENSITIVITY,
 
405
  'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
406
  'min_gap_between_recordings': 0,
407
  'enable_realtime_transcription': True,
408
+ 'realtime_processing_pause': 0,
409
  'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
410
  'on_realtime_transcription_update': self.live_text_detected,
411
  'beam_size': FINAL_BEAM_SIZE,
412
  'beam_size_realtime': REALTIME_BEAM_SIZE,
413
+ 'buffer_size': BUFFER_SIZE,
414
  'sample_rate': SAMPLE_RATE,
415
  }
416
 
417
  self.recorder = AudioToTextRecorder(**recorder_config)
418
 
419
+ # Start sentence processing thread
420
  self.is_running = True
421
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
422
  self.sentence_thread.start()
423
 
424
+ # Start transcription thread
425
  self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
426
  self.transcription_thread.start()
427
 
428
+ return "Recording started successfully! FastRTC audio input ready."
429
 
430
  except Exception as e:
 
431
  return f"Error starting recording: {e}"
432
 
433
  def run_transcription(self):
 
436
  while self.is_running:
437
  self.recorder.text(self.process_final_text)
438
  except Exception as e:
439
+ print(f"Transcription error: {e}")
440
 
441
  def stop_recording(self):
442
  """Stop the recording process"""
 
447
 
448
  def clear_conversation(self):
449
  """Clear all conversation data"""
450
+ self.full_sentences = []
451
+ self.sentence_speakers = []
452
+ self.pending_sentences = []
453
+ self.displayed_text = ""
454
+ self.last_realtime_text = ""
455
+ self.current_conversation = "Conversation cleared!"
456
 
457
  if self.speaker_detector:
458
  self.speaker_detector = SpeakerChangeDetector(
 
475
  return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
476
 
477
  def get_formatted_conversation(self):
478
+ """Get the formatted conversation with speaker colors"""
479
+ try:
480
+ sentences_with_style = []
481
+
482
+ # Process completed sentences
483
+ for i, sentence in enumerate(self.full_sentences):
484
+ sentence_text, _ = sentence
485
+ if i >= len(self.sentence_speakers):
486
+ color = "#FFFFFF"
487
+ speaker_name = "Unknown"
488
+ else:
489
+ speaker_id = self.sentence_speakers[i]
490
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
491
+ speaker_name = f"Speaker {speaker_id + 1}"
492
+
493
+ sentences_with_style.append(
494
+ f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
495
+
496
+ # Add pending sentences
497
+ for pending_sentence in self.pending_sentences:
498
+ sentences_with_style.append(
499
+ f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>')
500
+
501
+ if sentences_with_style:
502
+ return "<br><br>".join(sentences_with_style)
503
+ else:
504
+ return "Waiting for speech input..."
505
+
506
+ except Exception as e:
507
+ return f"Error formatting conversation: {e}"
508
 
509
  def get_status_info(self):
510
  """Get current status information"""
 
520
  f"**Last Similarity:** {status['last_similarity']:.3f}",
521
  f"**Change Threshold:** {status['threshold']:.2f}",
522
  f"**Total Sentences:** {len(self.full_sentences)}",
 
523
  "",
524
+ "**Speaker Segment Counts:**"
525
  ]
526
 
527
  for i in range(status['max_speakers']):
528
  color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
529
+ status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
 
 
530
 
531
  return "\n".join(status_lines)
532
 
533
  except Exception as e:
534
  return f"Error getting status: {e}"
535
 
536
+ def feed_audio_data(self, audio_data):
537
+ """Feed audio data to the recorder"""
538
+ if not self.is_running or not self.recorder:
539
+ return
540
+
541
+ try:
542
+ # Ensure audio is in the correct format (16-bit PCM)
543
+ if isinstance(audio_data, np.ndarray):
544
+ if audio_data.dtype != np.int16:
545
+ # Convert float to int16
546
+ if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
547
+ audio_data = (audio_data * 32767).astype(np.int16)
548
+ else:
549
+ audio_data = audio_data.astype(np.int16)
550
+
551
+ # Convert to bytes
552
+ audio_bytes = audio_data.tobytes()
553
+ else:
554
+ audio_bytes = audio_data
555
+
556
+ # Feed to recorder
557
+ self.recorder.feed_audio(audio_bytes)
558
+
559
+ except Exception as e:
560
+ print(f"Error feeding audio data: {e}")
561
+
562
  def process_audio_chunk(self, audio_data, sample_rate=16000):
563
  """Process audio chunk from FastRTC input"""
564
+ if not self.is_running or self.recorder is None:
565
  return
566
 
567
  try:
568
+ # Convert float audio to int16 for the recorder
569
+ if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
570
+ if np.max(np.abs(audio_data)) <= 1.0:
571
+ # Float audio is normalized to [-1, 1], convert to int16
572
+ audio_int16 = (audio_data * 32767).astype(np.int16)
573
+ else:
574
+ # Audio is already in higher range
575
+ audio_int16 = audio_data.astype(np.int16)
576
  else:
577
+ audio_int16 = audio_data
578
+
579
+ # Ensure correct shape (1, N) for the recorder
580
+ if len(audio_int16.shape) == 1:
581
+ audio_int16 = np.expand_dims(audio_int16, 0)
582
+
583
+ # Resample if needed
584
+ if sample_rate != SAMPLE_RATE:
585
+ audio_int16 = self._resample_audio(audio_int16, sample_rate, SAMPLE_RATE)
586
+
587
+ # Convert to bytes for feeding to recorder
588
+ audio_bytes = audio_int16.tobytes()
589
 
590
+ # Feed to recorder
591
+ self.feed_audio_data(audio_bytes)
 
592
 
593
+ except Exception as e:
594
+ print(f"Error processing audio chunk: {e}")
 
595
 
596
+ def _resample_audio(self, audio, orig_sr, target_sr):
597
+ """Resample audio to target sample rate"""
598
+ try:
599
+ import scipy.signal
600
 
601
+ # Get the resampling ratio
602
+ ratio = target_sr / orig_sr
603
+
604
+ # Calculate the new length
605
+ new_length = int(len(audio[0]) * ratio)
606
+
607
+ # Resample the audio
608
+ resampled = scipy.signal.resample(audio[0], new_length)
609
+
610
+ # Return in the same shape format
611
+ return np.expand_dims(resampled, 0)
612
  except Exception as e:
613
+ print(f"Error resampling audio: {e}")
614
+ return audio
615
+
616
 
617
+ # FastRTC Audio Handler for Real-time Diarization
618
 
 
619
  class DiarizationHandler(AsyncStreamHandler):
620
  def __init__(self, diarization_system):
621
  super().__init__()
622
  self.diarization_system = diarization_system
623
+ self.audio_queue = Queue()
624
+ self.is_processing = False
625
+ self.sample_rate = 16000 # Default sample rate
626
 
627
  def copy(self):
628
  """Return a fresh handler for each new stream connection"""
629
  return DiarizationHandler(self.diarization_system)
630
 
631
  async def emit(self):
632
+ """Not used in this implementation - we only receive audio"""
633
  return None
634
 
635
  async def receive(self, frame):
636
+ """Receive audio data from FastRTC and process it"""
637
  try:
638
  if not self.diarization_system.is_running:
639
  return
640
 
641
+ # Extract audio data from frame
642
+ if hasattr(frame, 'data') and frame.data is not None:
643
+ audio_data = frame.data
644
+ elif hasattr(frame, 'audio') and frame.audio is not None:
645
+ audio_data = frame.audio
646
+ else:
647
+ audio_data = frame
648
 
649
+ # Convert to numpy array if needed
650
  if isinstance(audio_data, bytes):
651
+ # Convert bytes to numpy array (assuming 16-bit PCM)
652
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
653
+ # Normalize to float32 range [-1, 1]
654
+ audio_array = audio_array.astype(np.float32) / 32768.0
655
  elif isinstance(audio_data, (list, tuple)):
656
  audio_array = np.array(audio_data, dtype=np.float32)
657
+ elif isinstance(audio_data, np.ndarray):
658
+ audio_array = audio_data.astype(np.float32)
659
  else:
660
+ print(f"Unknown audio data type: {type(audio_data)}")
661
+ return
662
 
663
+ # Ensure mono audio
664
+ if len(audio_array.shape) > 1 and audio_array.shape[1] > 1:
665
+ audio_array = np.mean(audio_array, axis=1)
666
+
667
+ # Ensure 1D array
668
  if len(audio_array.shape) > 1:
669
  audio_array = audio_array.flatten()
670
 
671
+ # Get sample rate from frame if available
672
+ sample_rate = getattr(frame, 'sample_rate', self.sample_rate)
673
 
674
+ # Process audio asynchronously to avoid blocking
675
+ await self.process_audio_async(audio_array, sample_rate)
 
 
 
 
 
676
 
677
  except Exception as e:
678
+ print(f"Error in FastRTC audio receive: {e}")
679
+ import traceback
680
+ traceback.print_exc()
681
 
682
+ async def process_audio_async(self, audio_data, sample_rate=16000):
683
  """Process audio data asynchronously"""
684
  try:
685
+ # Run the audio processing in a thread pool to avoid blocking
686
  loop = asyncio.get_event_loop()
687
  await loop.run_in_executor(
688
  None,
689
  self.diarization_system.process_audio_chunk,
690
  audio_data,
691
+ sample_rate
692
  )
693
  except Exception as e:
694
+ print(f"Error in async audio processing: {e}")
695
+
696
+ async def start_up(self) -> None:
697
+ """Initialize any resources when the stream starts"""
698
+ print("FastRTC stream started")
699
+ self.is_processing = True
700
+
701
+ async def shutdown(self) -> None:
702
+ """Clean up any resources when the stream ends"""
703
+ print("FastRTC stream shutting down")
704
+ self.is_processing = False
705
 
706
 
707
  # Global instances
708
+ diarization_system = None # Will be initialized when RealtimeSpeakerDiarization is available
709
  audio_handler = None
710
 
711
+
712
  def initialize_system():
713
  """Initialize the diarization system"""
714
+ global audio_handler, diarization_system
715
  try:
716
+ if diarization_system is None:
717
+ print("Error: RealtimeSpeakerDiarization not initialized")
718
+ return "❌ Diarization system not available. Please ensure RealtimeSpeakerDiarization is properly imported."
719
+
720
  success = diarization_system.initialize_models()
721
  if success:
722
+ audio_handler = DiarizationHandler(diarization_system)
723
+ return "✅ System initialized successfully! Models loaded and FastRTC handler ready."
 
724
  else:
725
+ return "❌ Failed to initialize system. Please check the logs."
726
  except Exception as e:
727
+ print(f"Initialization error: {e}")
728
  return f"❌ Initialization error: {str(e)}"
729
 
730
+
731
  def start_recording():
732
  """Start recording and transcription"""
733
  try:
734
+ if diarization_system is None:
735
+ return "❌ System not initialized"
736
  result = diarization_system.start_recording()
737
+ return f"🎙️ {result} - FastRTC audio streaming is active."
738
  except Exception as e:
739
  return f"❌ Failed to start recording: {str(e)}"
740
 
 
 
 
741
 
742
  def stop_recording():
743
  """Stop recording and transcription"""
744
  try:
745
+ if diarization_system is None:
746
+ return "❌ System not initialized"
747
  result = diarization_system.stop_recording()
748
  return f"⏹️ {result}"
749
  except Exception as e:
750
  return f"❌ Failed to stop recording: {str(e)}"
751
 
752
+
753
  def clear_conversation():
754
  """Clear the conversation"""
755
  try:
756
+ if diarization_system is None:
757
+ return "❌ System not initialized"
758
  result = diarization_system.clear_conversation()
759
  return f"🗑️ {result}"
760
  except Exception as e:
761
  return f"❌ Failed to clear conversation: {str(e)}"
762
 
763
+
764
  def update_settings(threshold, max_speakers):
765
  """Update system settings"""
766
  try:
767
+ if diarization_system is None:
768
+ return "❌ System not initialized"
769
  result = diarization_system.update_settings(threshold, max_speakers)
770
  return f"⚙️ {result}"
771
  except Exception as e:
772
  return f"❌ Failed to update settings: {str(e)}"
773
 
774
+
775
  def get_conversation():
776
  """Get the current conversation"""
777
  try:
778
+ if diarization_system is None:
779
+ return "<i>System not initialized</i>"
780
  return diarization_system.get_formatted_conversation()
781
  except Exception as e:
782
  return f"<i>Error getting conversation: {str(e)}</i>"
783
 
784
+
785
  def get_status():
786
  """Get system status"""
787
  try:
788
+ if diarization_system is None:
789
+ return "System not initialized"
790
  return diarization_system.get_status_info()
791
  except Exception as e:
792
  return f"Error getting status: {str(e)}"
793
 
794
+
795
  # Create Gradio interface
796
  def create_interface():
797
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
798
  gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
799
+ gr.Markdown("This app performs real-time speech recognition with automatic speaker identification using FastRTC for low-latency audio streaming.")
800
 
801
  with gr.Row():
802
  with gr.Column(scale=2):
803
+ # Main conversation display
 
 
 
 
 
 
 
804
  conversation_output = gr.HTML(
805
+ value="<div style='padding: 20px; background: #f5f5f5; border-radius: 10px;'><i>Click 'Initialize System' to start...</i></div>",
806
+ label="Live Conversation",
807
+ elem_id="conversation_display"
808
  )
809
 
810
  # Control buttons
811
  with gr.Row():
812
  init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
813
+ start_btn = gr.Button("🎙️ Start Recording", variant="primary", size="lg", interactive=False)
814
+ stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", size="lg", interactive=False)
815
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
816
 
817
+ # FastRTC Stream Interface
818
+ with gr.Row():
819
+ gr.HTML("""
820
+ <div id="fastrtc-container" style="border: 2px solid #ddd; border-radius: 10px; padding: 20px; margin: 10px 0;">
821
+ <h3>🎵 Audio Stream</h3>
822
+ <p>FastRTC audio stream will appear here when recording starts.</p>
823
+ <div id="stream-status" style="padding: 10px; background: #f8f9fa; border-radius: 5px; margin-top: 10px;">
824
+ Status: Waiting for initialization...
825
+ </div>
826
+ </div>
827
+ """)
828
+
829
  # Status display
830
  status_output = gr.Textbox(
831
  label="System Status",
832
+ value="System not initialized. Please click 'Initialize System' to begin.",
833
+ lines=6,
834
+ interactive=False,
835
+ show_copy_button=True
836
  )
837
 
838
  with gr.Column(scale=1):
839
+ # Settings panel
840
  gr.Markdown("## ⚙️ Settings")
841
 
842
  threshold_slider = gr.Slider(
843
+ minimum=0.1,
844
+ maximum=0.95,
845
  step=0.05,
846
+ value=0.5, # DEFAULT_CHANGE_THRESHOLD
847
  label="Speaker Change Sensitivity",
848
+ info="Lower = more sensitive to speaker changes"
849
  )
850
 
851
  max_speakers_slider = gr.Slider(
852
  minimum=2,
853
+ maximum=10, # ABSOLUTE_MAX_SPEAKERS
854
  step=1,
855
+ value=4, # DEFAULT_MAX_SPEAKERS
856
+ label="Maximum Number of Speakers"
857
  )
858
 
859
+ update_settings_btn = gr.Button("Update Settings", variant="secondary")
860
+
861
+ # Audio settings
862
+ gr.Markdown("## 🔊 Audio Configuration")
863
+ with gr.Accordion("Advanced Audio Settings", open=False):
864
+ gr.Markdown("""
865
+ **Current Configuration:**
866
+ - Sample Rate: 16kHz
867
+ - Audio Format: 16-bit PCM → Float32 (via AudioProcessor)
868
+ - Channels: Mono (stereo converted automatically)
869
+ - Buffer Size: 1024 samples for real-time processing
870
+ - Processing: Uses existing AudioProcessor.extract_embedding()
871
+ """)
872
 
873
  # Instructions
874
+ gr.Markdown("## 📝 How to Use")
875
  gr.Markdown("""
876
+ 1. **Initialize**: Click "Initialize System" to load AI models
877
+ 2. **Start**: Click "Start Recording" to begin processing
878
+ 3. **Connect**: The FastRTC stream will activate automatically
879
+ 4. **Allow Access**: Grant microphone permissions when prompted
880
+ 5. **Speak**: Talk naturally into your microphone
881
+ 6. **Monitor**: Watch real-time transcription with speaker colors
 
 
 
 
 
 
 
 
 
882
  """)
883
+
884
+ # Performance tips
885
+ with gr.Accordion("💡 Performance Tips", open=False):
886
+ gr.Markdown("""
887
+ - Use Chrome/Edge for best FastRTC performance
888
+ - Ensure stable internet connection
889
+ - Use headphones to prevent echo
890
+ - Position microphone 6-12 inches away
891
+ - Minimize background noise
892
+ - Allow browser microphone access
893
+ """)
894
+
895
+ # Speaker color legend
896
+ gr.Markdown("## 🎨 Speaker Colors")
897
+ speaker_colors = [
898
+ ("#FF6B6B", "Red"),
899
+ ("#4ECDC4", "Teal"),
900
+ ("#45B7D1", "Blue"),
901
+ ("#96CEB4", "Green"),
902
+ ("#FFEAA7", "Yellow"),
903
+ ("#DDA0DD", "Plum"),
904
+ ("#98D8C8", "Mint"),
905
+ ("#F7DC6F", "Gold")
906
+ ]
907
+
908
+ color_html = ""
909
+ for i, (color, name) in enumerate(speaker_colors[:4]):
910
+ color_html += f'<div style="margin: 3px 0;"><span style="color:{color}; font-size: 16px; font-weight: bold;">●</span> Speaker {i+1} ({name})</div>'
911
+
912
+ gr.HTML(f"<div style='font-size: 14px;'>{color_html}</div>")
913
+
914
+ # Auto-refresh conversation and status
915
+ def refresh_display():
916
+ try:
917
+ conversation = get_conversation()
918
+ status = get_status()
919
+ return conversation, status
920
+ except Exception as e:
921
+ error_msg = f"Error refreshing display: {str(e)}"
922
+ return f"<i>{error_msg}</i>", error_msg
923
 
924
  # Event handlers
925
  def on_initialize():
926
+ try:
927
+ result = initialize_system()
928
+ success = "successfully" in result.lower()
929
+
930
+ conversation, status = refresh_display()
931
+
932
+ return (
933
+ result, # status_output
934
+ gr.update(interactive=success), # start_btn
935
+ gr.update(interactive=success), # clear_btn
936
+ conversation, # conversation_output
937
+ )
938
+ except Exception as e:
939
+ error_msg = f"❌ Initialization failed: {str(e)}"
940
+ return (
941
+ error_msg,
942
+ gr.update(interactive=False),
943
+ gr.update(interactive=False),
944
+ "<i>System not ready</i>",
945
+ )
946
 
947
  def on_start():
948
+ try:
949
+ result = start_recording()
950
+ return (
951
+ result, # status_output
952
+ gr.update(interactive=False), # start_btn
953
+ gr.update(interactive=True), # stop_btn
954
+ )
955
+ except Exception as e:
956
+ error_msg = f"❌ Failed to start: {str(e)}"
957
+ return (
958
+ error_msg,
959
+ gr.update(interactive=True),
960
+ gr.update(interactive=False),
961
+ )
962
 
963
  def on_stop():
964
+ try:
965
+ result = stop_recording()
966
+ return (
967
+ result, # status_output
968
+ gr.update(interactive=True), # start_btn
969
+ gr.update(interactive=False), # stop_btn
970
+ )
971
+ except Exception as e:
972
+ error_msg = f"❌ Failed to stop: {str(e)}"
973
+ return (
974
+ error_msg,
975
+ gr.update(interactive=False),
976
+ gr.update(interactive=True),
977
+ )
978
 
979
  def on_clear():
980
+ try:
981
+ result = clear_conversation()
982
+ conversation, status = refresh_display()
983
+ return result, conversation
984
+ except Exception as e:
985
+ error_msg = f"❌ Failed to clear: {str(e)}"
986
+ return error_msg, "<i>Error clearing conversation</i>"
987
 
988
  def on_update_settings(threshold, max_speakers):
989
+ try:
990
+ result = update_settings(threshold, max_speakers)
991
+ return result
992
+ except Exception as e:
993
+ return f"❌ Failed to update settings: {str(e)}"
 
 
 
994
 
995
+ # Connect event handlers
996
  init_btn.click(
997
+ on_initialize,
998
+ outputs=[status_output, start_btn, clear_btn, conversation_output]
999
  )
1000
 
1001
  start_btn.click(
1002
+ on_start,
1003
  outputs=[status_output, start_btn, stop_btn]
1004
  )
1005
 
1006
  stop_btn.click(
1007
+ on_stop,
1008
  outputs=[status_output, start_btn, stop_btn]
1009
  )
1010
 
1011
  clear_btn.click(
1012
+ on_clear,
1013
+ outputs=[status_output, conversation_output]
1014
  )
1015
 
1016
+ update_settings_btn.click(
1017
+ on_update_settings,
1018
  inputs=[threshold_slider, max_speakers_slider],
1019
  outputs=[status_output]
1020
  )
1021
 
1022
+ # Auto-refresh every 2 seconds when active
1023
+ refresh_timer = gr.Timer(2.0)
1024
+ refresh_timer.tick(
1025
+ refresh_display,
1026
+ outputs=[conversation_output, status_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1027
  )
1028
+
1029
  return interface
1030
 
1031
 
1032
+ # FastAPI setup for API endpoints
1033
+ def create_fastapi_app():
1034
+ """Create FastAPI app with API endpoints"""
1035
+ app = FastAPI(
1036
+ title="Real-time Speaker Diarization",
1037
+ description="Real-time speech recognition with speaker diarization using FastRTC",
1038
+ version="1.0.0"
1039
+ )
1040
 
1041
+ # API Routes
1042
+ router = APIRouter()
1043
 
1044
+ @router.get("/health")
1045
+ async def health_check():
1046
+ """Health check endpoint"""
1047
+ return {
1048
+ "status": "healthy",
1049
+ "timestamp": time.time(),
1050
+ "system_initialized": diarization_system is not None and hasattr(diarization_system, 'encoder') and diarization_system.encoder is not None,
1051
+ "recording_active": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False
1052
+ }
1053
+
1054
+ @router.get("/api/conversation")
1055
+ async def get_conversation_api():
1056
+ """Get current conversation"""
1057
+ try:
1058
+ return {
1059
+ "conversation": get_conversation(),
1060
+ "status": get_status(),
1061
+ "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False,
1062
+ "timestamp": time.time()
1063
+ }
1064
+ except Exception as e:
1065
+ return {"error": str(e), "timestamp": time.time()}
1066
+
1067
+ @router.post("/api/control/{action}")
1068
+ async def control_recording(action: str):
1069
+ """Control recording actions"""
1070
+ try:
1071
+ if action == "start":
1072
+ result = start_recording()
1073
+ elif action == "stop":
1074
+ result = stop_recording()
1075
+ elif action == "clear":
1076
+ result = clear_conversation()
1077
+ elif action == "initialize":
1078
+ result = initialize_system()
1079
+ else:
1080
+ return {"error": "Invalid action. Use: start, stop, clear, or initialize"}
1081
+
1082
+ return {
1083
+ "result": result,
1084
+ "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False,
1085
+ "timestamp": time.time()
1086
+ }
1087
+ except Exception as e:
1088
+ return {"error": str(e), "timestamp": time.time()}
1089
+
1090
+ app.include_router(router)
1091
+ return app
1092
+
1093
+
1094
+ # Function to setup FastRTC stream
1095
+ def setup_fastrtc_stream(app):
1096
+ """Setup FastRTC stream with proper configuration"""
1097
+ try:
1098
+ if audio_handler is None:
1099
+ print("Warning: Audio handler not initialized. Initialize system first.")
1100
+ return None
1101
+
1102
+ # Get HuggingFace token for TURN server (optional)
1103
+ hf_token = os.environ.get("HF_TOKEN")
1104
+
1105
+ # Configure RTC settings
1106
+ rtc_config = {
1107
+ "iceServers": [
1108
+ {"urls": "stun:stun.l.google.com:19302"},
1109
+ {"urls": "stun:stun1.l.google.com:19302"}
1110
+ ]
1111
+ }
1112
+
1113
+ # Create FastRTC stream
1114
+ stream = Stream(
1115
+ handler=audio_handler,
1116
+ rtc_configuration=rtc_config,
1117
+ modality="audio",
1118
+ mode="receive" # We only receive audio, don't send
1119
+ )
1120
+
1121
+ # Mount the stream
1122
+ app.mount("/stream", stream)
1123
+ print("✅ FastRTC stream configured successfully!")
1124
+ return stream
1125
+
1126
+ except Exception as e:
1127
+ print(f"⚠️ Warning: Failed to setup FastRTC stream: {e}")
1128
+ print("Audio streaming may not work properly.")
1129
  return None
1130
+
1131
+
1132
+ # Main application setup
1133
+ def create_app(diarization_sys=None):
1134
+ """Create the complete application"""
1135
+ global diarization_system
1136
 
1137
+ # Set the diarization system
1138
+ if diarization_sys is not None:
1139
+ diarization_system = diarization_sys
1140
 
1141
+ # Create FastAPI app
1142
+ fastapi_app = create_fastapi_app()
1143
 
1144
+ # Create Gradio interface
1145
+ gradio_interface = create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1146
 
1147
+ # Mount Gradio on FastAPI
1148
+ app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/")
 
 
 
 
1149
 
1150
+ # Setup FastRTC stream
1151
+ if diarization_system is not None:
1152
+ # Initialize the system if not already done
1153
+ if not hasattr(diarization_system, 'encoder') or diarization_system.encoder is None:
1154
+ diarization_system.initialize_models()
1155
+
1156
+ # Create audio handler if needed
1157
+ global audio_handler
1158
+ if audio_handler is None:
1159
+ audio_handler = DiarizationHandler(diarization_system)
1160
+
1161
+ # Setup and mount the FastRTC stream
1162
+ setup_fastrtc_stream(app)
1163
 
1164
+ return app, gradio_interface
1165
+
1166
+
1167
+ # Entry point for HuggingFace Spaces
1168
+ if __name__ == "__main__":
1169
+ try:
1170
+ # Import your diarization system here
1171
+ # from your_module import RealtimeSpeakerDiarization
1172
+ diarization_system = RealtimeSpeakerDiarization()
1173
+
1174
+ # Create the application
1175
+ app, interface = create_app()
1176
+
1177
+ # Launch for HuggingFace Spaces
1178
  interface.launch(
1179
+ server_name="0.0.0.0",
1180
+ server_port=7860,
1181
  share=True,
1182
+ show_error=True,
1183
+ quiet=False
1184
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
 
1186
+ except Exception as e:
1187
+ print(f"Failed to launch application: {e}")
1188
+ import traceback
1189
+ traceback.print_exc()
1190
+
1191
+ # Fallback - launch just Gradio interface
1192
+ try:
1193
  interface = create_interface()
1194
  interface.launch(
1195
+ server_name="0.0.0.0",
1196
+ server_port=int(os.environ.get("PORT", 7860)),
1197
+ share=False
 
1198
  )
1199
+ except Exception as fallback_error:
1200
+ print(f"Fallback launch also failed: {fallback_error}")
1201
+
1202
+
1203
+ # Helper function to initialize with your diarization system
1204
+ def initialize_with_diarization_system(diarization_sys):
1205
+ """Initialize the application with your diarization system"""
1206
+ global diarization_system
1207
+ diarization_system = diarization_sys
1208
+ return create_app(diarization_sys)