Saiyaswanth007 commited on
Commit
7208f76
Β·
1 Parent(s): b9dea2c

Fixing Real time audio

Browse files
Files changed (1) hide show
  1. app.py +183 -124
app.py CHANGED
@@ -11,6 +11,8 @@ import queue
11
  from collections import deque
12
  import asyncio
13
  from typing import Generator, Tuple, List, Optional
 
 
14
 
15
  # Configuration parameters (keeping original models)
16
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
@@ -30,6 +32,7 @@ MIN_SEGMENT_DURATION = 1.0
30
  DEFAULT_MAX_SPEAKERS = 4
31
  ABSOLUTE_MAX_SPEAKERS = 10
32
  SAMPLE_RATE = 16000
 
33
 
34
  # Speaker labels
35
  SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)]
@@ -220,35 +223,46 @@ class AudioProcessor:
220
 
221
 
222
  class RealTimeSpeakerDiarization:
223
- """Main class for real-time speaker diarization"""
224
  def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
225
  self.encoder = None
226
  self.audio_processor = None
227
  self.speaker_detector = None
 
228
  self.change_threshold = change_threshold
229
  self.max_speakers = max_speakers
230
  self.transcript_history = []
231
  self.is_initialized = False
232
 
233
- # Threading components
234
- self.audio_queue = queue.Queue()
235
- self.processing_thread = None
236
- self.running = False
 
237
 
238
- async def initialize(self):
239
  """Initialize the speaker diarization system"""
240
  if self.is_initialized:
241
  return True
242
 
243
  try:
244
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
245
- print(f"Initializing ECAPA-TDNN model on {device_str}...")
246
 
 
247
  self.encoder = SpeechBrainEncoder(device=device_str)
248
  success = self.encoder.load_model()
249
 
250
  if not success:
251
  return False
 
 
 
 
 
 
 
 
252
 
253
  self.audio_processor = AudioProcessor(self.encoder)
254
  self.speaker_detector = SpeakerChangeDetector(
@@ -274,31 +288,89 @@ class RealTimeSpeakerDiarization:
274
  self.speaker_detector.set_change_threshold(change_threshold)
275
  self.speaker_detector.set_max_speakers(max_speakers)
276
 
277
- def process_audio_segment(self, audio_data: np.ndarray, text: str) -> Tuple[int, str]:
278
- """Process an audio segment and return speaker ID and formatted text"""
279
  if not self.is_initialized:
280
- return 0, text
281
 
282
  try:
283
- # Extract speaker embedding
284
- embedding = self.audio_processor.extract_embedding(audio_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # Detect speaker
287
- speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
 
 
 
288
 
289
- # Format text with speaker label
290
- speaker_label = SPEAKER_LABELS[speaker_id]
291
- formatted_text = f"{speaker_label}: {text}"
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- return speaker_id, formatted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
 
 
 
 
295
  except Exception as e:
296
- print(f"Error processing audio segment: {e}")
297
- return 0, f"Speaker 1: {text}"
298
 
299
- def get_transcript_history(self):
300
- """Get the formatted transcript history"""
301
- return "\n".join(self.transcript_history)
302
 
303
  def add_to_transcript(self, formatted_text: str):
304
  """Add formatted text to transcript history"""
@@ -311,82 +383,74 @@ class RealTimeSpeakerDiarization:
311
  def clear_transcript(self):
312
  """Clear transcript history and reset speaker detector"""
313
  self.transcript_history = []
 
314
  if self.speaker_detector:
315
  self.speaker_detector = SpeakerChangeDetector(
316
  embedding_dim=self.encoder.embedding_dim,
317
  change_threshold=self.change_threshold,
318
  max_speakers=self.max_speakers
319
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
 
322
  # Global instance
323
  diarization_system = RealTimeSpeakerDiarization()
324
 
325
 
326
- async def initialize_system():
327
  """Initialize the diarization system"""
328
- success = await diarization_system.initialize()
329
  if success:
330
  return "βœ… Speaker diarization system initialized successfully!"
331
  else:
332
  return "❌ Failed to initialize speaker diarization system. Please check your setup."
333
 
334
 
335
- def process_audio_with_transcript(audio_data, sample_rate, transcription_text, change_threshold, max_speakers):
336
- """Process audio with transcription for speaker diarization"""
337
  if not diarization_system.is_initialized:
338
- return "Please initialize the system first.", ""
339
 
340
- if audio_data is None or transcription_text.strip() == "":
341
- return diarization_system.get_transcript_history(), ""
342
 
343
- try:
344
- # Update settings
345
- diarization_system.update_settings(change_threshold, max_speakers)
346
-
347
- # Convert audio to the right format
348
- if len(audio_data.shape) > 1:
349
- audio_data = audio_data.mean(axis=1) # Convert to mono
350
-
351
- # Resample if needed
352
- if sample_rate != SAMPLE_RATE:
353
- audio_data = torchaudio.functional.resample(
354
- torch.tensor(audio_data), sample_rate, SAMPLE_RATE
355
- ).numpy()
356
-
357
- # Process the audio segment
358
- speaker_id, formatted_text = diarization_system.process_audio_segment(audio_data, transcription_text)
359
-
360
- # Add to transcript
361
- diarization_system.add_to_transcript(formatted_text)
362
-
363
- # Return updated transcript and current speaker info
364
- transcript = diarization_system.get_transcript_history()
365
- current_speaker_info = f"Current Speaker: {SPEAKER_LABELS[speaker_id]}"
366
-
367
- return transcript, current_speaker_info
368
-
369
- except Exception as e:
370
- error_msg = f"Error processing audio: {str(e)}"
371
- return diarization_system.get_transcript_history(), error_msg
372
 
373
 
374
  def clear_conversation():
375
  """Clear the conversation transcript"""
376
  diarization_system.clear_transcript()
377
- return "", "Conversation cleared."
378
 
379
 
380
  def create_gradio_interface():
381
- """Create and return the Gradio interface"""
382
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
383
- gr.Markdown("# πŸŽ™οΈ Real-time Speaker Diarization with ASR")
384
- gr.Markdown("Upload audio with transcription to perform real-time speaker diarization.")
385
 
386
  # Initialization section
387
  with gr.Row():
388
- init_btn = gr.Button("πŸš€ Initialize System", variant="primary")
389
- init_status = gr.Textbox(label="Initialization Status", interactive=False)
390
 
391
  # Settings section
392
  with gr.Row():
@@ -409,35 +473,35 @@ def create_gradio_interface():
409
  info="Maximum number of speakers to detect"
410
  )
411
 
412
- # Audio input and transcription
413
  with gr.Row():
414
  with gr.Column():
415
- audio_input = gr.Audio(
416
- label="Audio Input",
417
- type="numpy",
418
- format="wav"
 
 
 
419
  )
420
- transcription_input = gr.Textbox(
421
- label="Transcription Text",
422
- placeholder="Enter the transcription of the audio...",
423
- lines=3
424
- )
425
- process_btn = gr.Button("🎯 Process Audio", variant="secondary")
426
 
427
  with gr.Column():
428
- current_speaker = gr.Textbox(
429
- label="Current Speaker",
430
- interactive=False
 
431
  )
432
- clear_btn = gr.Button("πŸ—‘οΈ Clear Conversation", variant="stop")
433
 
434
  # Output section
435
  transcript_output = gr.Textbox(
436
- label="Live Transcript with Speaker Labels",
437
  lines=15,
438
- max_lines=20,
439
  interactive=False,
440
- placeholder="Processed transcript will appear here..."
 
441
  )
442
 
443
  # Event handlers
@@ -446,54 +510,49 @@ def create_gradio_interface():
446
  outputs=[init_status]
447
  )
448
 
449
- process_btn.click(
450
- fn=process_audio_with_transcript,
451
- inputs=[
452
- audio_input,
453
- gr.Number(value=SAMPLE_RATE, visible=False), # Hidden sample rate
454
- transcription_input,
455
- change_threshold,
456
- max_speakers
457
- ],
458
- outputs=[transcript_output, current_speaker]
459
  )
460
 
461
  clear_btn.click(
462
  fn=clear_conversation,
463
- outputs=[transcript_output, current_speaker]
464
- )
465
-
466
- # Auto-process when audio and transcription are provided
467
- audio_input.change(
468
- fn=process_audio_with_transcript,
469
- inputs=[
470
- audio_input,
471
- gr.Number(value=SAMPLE_RATE, visible=False),
472
- transcription_input,
473
- change_threshold,
474
- max_speakers
475
- ],
476
- outputs=[transcript_output, current_speaker]
477
  )
478
 
479
  # Instructions
480
- gr.Markdown("""
481
- ## Instructions:
482
- 1. **Initialize**: Click "Initialize System" to load the speaker diarization models
483
- 2. **Upload Audio**: Upload an audio file (WAV format recommended)
484
- 3. **Add Transcription**: Enter the transcription text for the audio
485
- 4. **Adjust Settings**:
486
- - **Speaker Change Threshold**: Lower values detect speaker changes more easily
487
- - **Max Speakers**: Set the maximum number of speakers you expect
488
- 5. **Process**: Click "Process Audio" or the system will auto-process
489
- 6. **View Results**: See the transcript with speaker labels (Speaker 1, Speaker 2, etc.)
490
-
491
- ## Tips:
492
- - For similar-sounding speakers, increase the threshold (0.6-0.8)
493
- - For different-sounding speakers, lower threshold works better (0.3-0.5)
494
- - The system maintains speaker consistency across the conversation
495
- - Use "Clear Conversation" to reset the speaker memory
496
- """)
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
  return demo
499
 
 
11
  from collections import deque
12
  import asyncio
13
  from typing import Generator, Tuple, List, Optional
14
+ import whisper
15
+ from transformers import pipeline
16
 
17
  # Configuration parameters (keeping original models)
18
  FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
 
32
  DEFAULT_MAX_SPEAKERS = 4
33
  ABSOLUTE_MAX_SPEAKERS = 10
34
  SAMPLE_RATE = 16000
35
+ CHUNK_DURATION = 2.0 # Process audio in 2-second chunks
36
 
37
  # Speaker labels
38
  SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)]
 
223
 
224
 
225
  class RealTimeSpeakerDiarization:
226
+ """Main class for real-time speaker diarization with FastRTC"""
227
  def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
228
  self.encoder = None
229
  self.audio_processor = None
230
  self.speaker_detector = None
231
+ self.transcription_pipeline = None
232
  self.change_threshold = change_threshold
233
  self.max_speakers = max_speakers
234
  self.transcript_history = []
235
  self.is_initialized = False
236
 
237
+ # Audio processing
238
+ self.audio_buffer = deque(maxlen=int(SAMPLE_RATE * 10)) # 10 second buffer
239
+ self.processing_queue = queue.Queue()
240
+ self.last_processed_time = 0
241
+ self.current_transcript = ""
242
 
243
+ def initialize(self):
244
  """Initialize the speaker diarization system"""
245
  if self.is_initialized:
246
  return True
247
 
248
  try:
249
  device_str = "cuda" if torch.cuda.is_available() else "cpu"
250
+ print(f"Initializing models on {device_str}...")
251
 
252
+ # Initialize speaker encoder
253
  self.encoder = SpeechBrainEncoder(device=device_str)
254
  success = self.encoder.load_model()
255
 
256
  if not success:
257
  return False
258
+
259
+ # Initialize transcription pipeline
260
+ self.transcription_pipeline = pipeline(
261
+ "automatic-speech-recognition",
262
+ model=f"openai/whisper-{REALTIME_TRANSCRIPTION_MODEL}",
263
+ device=0 if torch.cuda.is_available() else -1,
264
+ return_timestamps=True
265
+ )
266
 
267
  self.audio_processor = AudioProcessor(self.encoder)
268
  self.speaker_detector = SpeakerChangeDetector(
 
288
  self.speaker_detector.set_change_threshold(change_threshold)
289
  self.speaker_detector.set_max_speakers(max_speakers)
290
 
291
+ def process_audio_stream(self, audio_chunk, sample_rate):
292
+ """Process real-time audio stream from FastRTC"""
293
  if not self.is_initialized:
294
+ return self.get_current_transcript(), "System not initialized"
295
 
296
  try:
297
+ # Convert to numpy array if needed
298
+ if hasattr(audio_chunk, 'numpy'):
299
+ audio_data = audio_chunk.numpy()
300
+ else:
301
+ audio_data = np.array(audio_chunk)
302
+
303
+ # Handle different audio formats
304
+ if len(audio_data.shape) > 1:
305
+ audio_data = audio_data.mean(axis=1) # Convert to mono
306
+
307
+ # Resample if needed
308
+ if sample_rate != SAMPLE_RATE:
309
+ audio_data = torchaudio.functional.resample(
310
+ torch.tensor(audio_data), sample_rate, SAMPLE_RATE
311
+ ).numpy()
312
+
313
+ # Add to buffer
314
+ self.audio_buffer.extend(audio_data)
315
 
316
+ # Process if we have enough audio
317
+ current_time = time.time()
318
+ if (current_time - self.last_processed_time) >= CHUNK_DURATION:
319
+ self.process_buffered_audio()
320
+ self.last_processed_time = current_time
321
 
322
+ return self.get_current_transcript(), f"Processing... Buffer: {len(self.audio_buffer)} samples"
323
+
324
+ except Exception as e:
325
+ error_msg = f"Error processing audio stream: {str(e)}"
326
+ print(error_msg)
327
+ return self.get_current_transcript(), error_msg
328
+
329
+ def process_buffered_audio(self):
330
+ """Process buffered audio for transcription and speaker diarization"""
331
+ if len(self.audio_buffer) < int(SAMPLE_RATE * MIN_LENGTH_OF_RECORDING):
332
+ return
333
+
334
+ try:
335
+ # Get audio data from buffer
336
+ audio_data = np.array(list(self.audio_buffer))
337
 
338
+ # Transcribe audio
339
+ if len(audio_data) > 0:
340
+ result = self.transcription_pipeline(
341
+ audio_data,
342
+ return_timestamps=True,
343
+ generate_kwargs={"language": TRANSCRIPTION_LANGUAGE}
344
+ )
345
+
346
+ transcription = result["text"].strip()
347
+
348
+ if transcription and len(transcription) > 0:
349
+ # Extract speaker embedding
350
+ embedding = self.audio_processor.extract_embedding(audio_data)
351
+
352
+ # Detect speaker
353
+ speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
354
+
355
+ # Format text with speaker label
356
+ speaker_label = SPEAKER_LABELS[speaker_id]
357
+ formatted_text = f"{speaker_label}: {transcription}"
358
+
359
+ # Add to transcript
360
+ self.add_to_transcript(formatted_text)
361
+
362
+ print(f"Transcribed: {formatted_text} (Similarity: {similarity:.3f})")
363
 
364
+ # Clear part of the buffer to prevent memory issues
365
+ if len(self.audio_buffer) > SAMPLE_RATE * 5: # Keep last 5 seconds
366
+ self.audio_buffer = deque(list(self.audio_buffer)[-SAMPLE_RATE * 3:], maxlen=int(SAMPLE_RATE * 10))
367
+
368
  except Exception as e:
369
+ print(f"Error in process_buffered_audio: {e}")
 
370
 
371
+ def get_current_transcript(self):
372
+ """Get the current transcript"""
373
+ return "\n".join(self.transcript_history) if self.transcript_history else "Listening..."
374
 
375
  def add_to_transcript(self, formatted_text: str):
376
  """Add formatted text to transcript history"""
 
383
  def clear_transcript(self):
384
  """Clear transcript history and reset speaker detector"""
385
  self.transcript_history = []
386
+ self.audio_buffer.clear()
387
  if self.speaker_detector:
388
  self.speaker_detector = SpeakerChangeDetector(
389
  embedding_dim=self.encoder.embedding_dim,
390
  change_threshold=self.change_threshold,
391
  max_speakers=self.max_speakers
392
  )
393
+
394
+ def get_status(self):
395
+ """Get current system status"""
396
+ if not self.is_initialized:
397
+ return "System not initialized"
398
+
399
+ if self.speaker_detector:
400
+ active_speakers = len(self.speaker_detector.active_speakers)
401
+ current_speaker = self.speaker_detector.current_speaker + 1
402
+ similarity = self.speaker_detector.last_similarity
403
+ return f"Active: {active_speakers} speakers | Current: Speaker {current_speaker} | Similarity: {similarity:.3f}"
404
+
405
+ return "Ready"
406
 
407
 
408
  # Global instance
409
  diarization_system = RealTimeSpeakerDiarization()
410
 
411
 
412
+ def initialize_system():
413
  """Initialize the diarization system"""
414
+ success = diarization_system.initialize()
415
  if success:
416
  return "βœ… Speaker diarization system initialized successfully!"
417
  else:
418
  return "❌ Failed to initialize speaker diarization system. Please check your setup."
419
 
420
 
421
+ def process_realtime_audio(audio_stream, change_threshold, max_speakers):
422
+ """Process real-time audio stream from FastRTC"""
423
  if not diarization_system.is_initialized:
424
+ return "Please initialize the system first.", "System not ready"
425
 
426
+ # Update settings
427
+ diarization_system.update_settings(change_threshold, max_speakers)
428
 
429
+ if audio_stream is None:
430
+ return diarization_system.get_current_transcript(), diarization_system.get_status()
431
+
432
+ # Process the audio stream
433
+ transcript, status = diarization_system.process_audio_stream(audio_stream, SAMPLE_RATE)
434
+
435
+ return transcript, diarization_system.get_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
 
438
  def clear_conversation():
439
  """Clear the conversation transcript"""
440
  diarization_system.clear_transcript()
441
+ return "Conversation cleared. Listening...", "Ready"
442
 
443
 
444
  def create_gradio_interface():
445
+ """Create and return the Gradio interface with FastRTC"""
446
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo:
447
+ gr.Markdown("# πŸŽ™οΈ Real-time Speaker Diarization with FastRTC")
448
+ gr.Markdown("Speak into your microphone for real-time speaker diarization and transcription.")
449
 
450
  # Initialization section
451
  with gr.Row():
452
+ init_btn = gr.Button("πŸš€ Initialize System", variant="primary", scale=1)
453
+ init_status = gr.Textbox(label="System Status", interactive=False, scale=2)
454
 
455
  # Settings section
456
  with gr.Row():
 
473
  info="Maximum number of speakers to detect"
474
  )
475
 
476
+ # FastRTC Audio Input
477
  with gr.Row():
478
  with gr.Column():
479
+ # FastRTC component for real-time audio
480
+ audio_input = gr.FastRTC(
481
+ audio=True,
482
+ video=False,
483
+ label="🎀 Real-time Audio Input",
484
+ audio_sample_rate=SAMPLE_RATE,
485
+ audio_channels=1
486
  )
487
+
488
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Conversation", variant="stop")
 
 
 
 
489
 
490
  with gr.Column():
491
+ current_status = gr.Textbox(
492
+ label="Current Status",
493
+ interactive=False,
494
+ value="Click Initialize to start"
495
  )
 
496
 
497
  # Output section
498
  transcript_output = gr.Textbox(
499
+ label="πŸ”΄ Live Transcript with Speaker Labels",
500
  lines=15,
501
+ max_lines=25,
502
  interactive=False,
503
+ value="Click Initialize, then start speaking...",
504
+ autoscroll=True
505
  )
506
 
507
  # Event handlers
 
510
  outputs=[init_status]
511
  )
512
 
513
+ # FastRTC stream processing
514
+ audio_input.stream(
515
+ fn=process_realtime_audio,
516
+ inputs=[audio_input, change_threshold, max_speakers],
517
+ outputs=[transcript_output, current_status],
518
+ time_limit=30 # Process in 30-second chunks
 
 
 
 
519
  )
520
 
521
  clear_btn.click(
522
  fn=clear_conversation,
523
+ outputs=[transcript_output, current_status]
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  )
525
 
526
  # Instructions
527
+ with gr.Accordion("πŸ“‹ Instructions", open=False):
528
+ gr.Markdown("""
529
+ ## How to Use:
530
+
531
+ 1. **Initialize**: Click "πŸš€ Initialize System" to load the AI models (this may take a moment)
532
+ 2. **Allow Microphone**: Your browser will ask for microphone permission - please allow it
533
+ 3. **Adjust Settings**:
534
+ - **Speaker Change Threshold**:
535
+ - Lower (0.3-0.5) for speakers with different voices
536
+ - Higher (0.6-0.8) for speakers with similar voices
537
+ - **Max Speakers**: Set expected number of speakers (2-10)
538
+ 4. **Start Speaking**: The system will automatically transcribe and identify speakers
539
+ 5. **View Results**: See real-time transcript with speaker labels (Speaker 1, Speaker 2, etc.)
540
+ 6. **Clear**: Use "Clear Conversation" to reset and start fresh
541
+
542
+ ## Features:
543
+ - βœ… Real-time audio processing via FastRTC
544
+ - βœ… Automatic speech recognition with Whisper
545
+ - βœ… Speaker diarization with ECAPA-TDNN
546
+ - βœ… Live transcript with speaker labels
547
+ - βœ… Configurable sensitivity settings
548
+ - βœ… Support for up to 10 speakers
549
+
550
+ ## Tips:
551
+ - Speak clearly and allow brief pauses between speakers
552
+ - The system learns speaker characteristics over time
553
+ - Better results with distinct speaker voices
554
+ - Ensure good microphone quality for best performance
555
+ """)
556
 
557
  return demo
558