Saiyaswanth007 commited on
Commit
78869ff
·
1 Parent(s): b3cc831

Code fixing

Browse files
Files changed (1) hide show
  1. app.py +204 -176
app.py CHANGED
@@ -10,16 +10,12 @@ import torchaudio
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI
 
13
  import json
14
  import io
15
  import wave
16
  import asyncio
17
  import uvicorn
18
- import logging
19
-
20
- # Configure logging to reduce noise
21
- logging.getLogger("uvicorn").setLevel(logging.WARNING)
22
- logging.getLogger("gradio").setLevel(logging.WARNING)
23
 
24
  # Simplified configuration parameters
25
  SILENCE_THRESHS = [0, 0.4]
@@ -76,15 +72,23 @@ class SpeechBrainEncoder:
76
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
77
  os.makedirs(self.cache_dir, exist_ok=True)
78
 
 
 
 
 
 
 
 
 
 
 
 
79
  def load_model(self):
80
- """Load the ECAPA-TDNN model with error handling"""
81
  try:
82
- # Try to import speechbrain
83
- try:
84
- from speechbrain.pretrained import EncoderClassifier
85
- except ImportError:
86
- print("SpeechBrain not available. Using fallback embedding model.")
87
- return self._load_fallback_model()
88
 
89
  self.model = EncoderClassifier.from_hparams(
90
  source="speechbrain/spkrec-ecapa-voxceleb",
@@ -93,17 +97,10 @@ class SpeechBrainEncoder:
93
  )
94
 
95
  self.model_loaded = True
96
- print("ECAPA-TDNN model loaded successfully!")
97
  return True
98
  except Exception as e:
99
  print(f"Error loading ECAPA-TDNN model: {e}")
100
- return self._load_fallback_model()
101
-
102
- def _load_fallback_model(self):
103
- """Fallback to a simple embedding model if SpeechBrain is not available"""
104
- print("Using fallback embedding model (simple spectral features)")
105
- self.model_loaded = True
106
- return True
107
 
108
  def embed_utterance(self, audio, sr=16000):
109
  """Extract speaker embedding from audio"""
@@ -111,48 +108,21 @@ class SpeechBrainEncoder:
111
  raise ValueError("Model not loaded. Call load_model() first.")
112
 
113
  try:
114
- if self.model is not None:
115
- # Use SpeechBrain model
116
- if isinstance(audio, np.ndarray):
117
- waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
118
- else:
119
- waveform = audio.unsqueeze(0)
120
-
121
- if sr != 16000:
122
- waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
123
-
124
- with torch.no_grad():
125
- embedding = self.model.encode_batch(waveform)
126
-
127
- return embedding.squeeze().cpu().numpy()
128
  else:
129
- # Use fallback method - simple spectral features
130
- return self._extract_simple_features(audio)
131
- except Exception as e:
132
- print(f"Error extracting embedding: {e}")
133
- return self._extract_simple_features(audio)
134
-
135
- def _extract_simple_features(self, audio):
136
- """Simple fallback feature extraction"""
137
- try:
138
- # Ensure audio is numpy array
139
- if isinstance(audio, torch.Tensor):
140
- audio = audio.numpy()
141
 
142
- # Basic spectral features as a fallback
143
- fft = np.fft.fft(audio)
144
- magnitude = np.abs(fft)
145
 
146
- # Take first 192 features to match expected embedding dimension
147
- features = magnitude[:self.embedding_dim] if len(magnitude) >= self.embedding_dim else np.pad(magnitude, (0, self.embedding_dim - len(magnitude)))
148
-
149
- # Normalize
150
- features = features / (np.linalg.norm(features) + 1e-8)
151
-
152
- return features.astype(np.float32)
153
  except Exception as e:
154
- print(f"Error in fallback feature extraction: {e}")
155
- return np.random.randn(self.embedding_dim).astype(np.float32)
156
 
157
 
158
  class AudioProcessor:
@@ -321,7 +291,6 @@ class RealtimeSpeakerDiarization:
321
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
322
  self.max_speakers = DEFAULT_MAX_SPEAKERS
323
  self.current_conversation = ""
324
- self.audio_buffer = []
325
 
326
  def initialize_models(self):
327
  """Initialize the speaker encoder model"""
@@ -339,10 +308,10 @@ class RealtimeSpeakerDiarization:
339
  change_threshold=self.change_threshold,
340
  max_speakers=self.max_speakers
341
  )
342
- print("Speaker diarization model loaded successfully!")
343
  return True
344
  else:
345
- print("Failed to load speaker diarization model")
346
  return False
347
  except Exception as e:
348
  print(f"Model initialization error: {e}")
@@ -362,31 +331,19 @@ class RealtimeSpeakerDiarization:
362
  self.last_realtime_text = text
363
 
364
  if prob_sentence_end and FAST_SENTENCE_END:
365
- if self.recorder:
366
- self.recorder.stop()
367
  elif prob_sentence_end:
368
- if self.recorder:
369
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
370
  else:
371
- if self.recorder:
372
- self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
373
 
374
  def process_final_text(self, text):
375
  """Process final transcribed text with speaker embedding"""
376
  text = text.strip()
377
  if text:
378
  try:
379
- if self.recorder and hasattr(self.recorder, 'last_transcription_bytes'):
380
- bytes_data = self.recorder.last_transcription_bytes
381
- self.sentence_queue.put((text, bytes_data))
382
- else:
383
- # Use audio buffer as fallback
384
- if self.audio_buffer:
385
- audio_data = np.concatenate(self.audio_buffer)
386
- bytes_data = audio_data.tobytes()
387
- self.sentence_queue.put((text, bytes_data))
388
- self.audio_buffer = [] # Clear buffer after use
389
-
390
  self.pending_sentences.append(text)
391
  except Exception as e:
392
  print(f"Error processing final text: {e}")
@@ -432,51 +389,40 @@ class RealtimeSpeakerDiarization:
432
  return "Please initialize models first!"
433
 
434
  try:
435
- # Check if RealtimeSTT is available
436
- try:
437
- from RealtimeSTT import AudioToTextRecorder
438
- recorder_available = True
439
- except ImportError:
440
- print("RealtimeSTT not available. Using simulated audio processing.")
441
- recorder_available = False
442
-
443
- if recorder_available:
444
- # Setup recorder configuration
445
- recorder_config = {
446
- 'spinner': False,
447
- 'use_microphone': True,
448
- 'model': FINAL_TRANSCRIPTION_MODEL,
449
- 'language': TRANSCRIPTION_LANGUAGE,
450
- 'silero_sensitivity': SILERO_SENSITIVITY,
451
- 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
452
- 'post_speech_silence_duration': SILENCE_THRESHS[1],
453
- 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
454
- 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
455
- 'min_gap_between_recordings': 0,
456
- 'enable_realtime_transcription': True,
457
- 'realtime_processing_pause': 0,
458
- 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
459
- 'on_realtime_transcription_update': self.live_text_detected,
460
- 'beam_size': FINAL_BEAM_SIZE,
461
- 'beam_size_realtime': REALTIME_BEAM_SIZE,
462
- 'buffer_size': BUFFER_SIZE,
463
- 'sample_rate': SAMPLE_RATE,
464
- }
465
-
466
- self.recorder = AudioToTextRecorder(**recorder_config)
467
 
468
  # Start sentence processing thread
469
  self.is_running = True
470
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
471
  self.sentence_thread.start()
472
 
473
- if recorder_available:
474
- # Start transcription thread
475
- self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
476
- self.transcription_thread.start()
477
- return "Recording started successfully! Please speak into your microphone."
478
- else:
479
- return "Simulation mode active. Speaker diarization ready for audio input."
480
 
481
  except Exception as e:
482
  return f"Error starting recording: {e}"
@@ -484,7 +430,7 @@ class RealtimeSpeakerDiarization:
484
  def run_transcription(self):
485
  """Run the transcription loop"""
486
  try:
487
- while self.is_running and self.recorder:
488
  self.recorder.text(self.process_final_text)
489
  except Exception as e:
490
  print(f"Transcription error: {e}")
@@ -493,10 +439,7 @@ class RealtimeSpeakerDiarization:
493
  """Stop the recording process"""
494
  self.is_running = False
495
  if self.recorder:
496
- try:
497
- self.recorder.stop()
498
- except:
499
- pass
500
  return "Recording stopped!"
501
 
502
  def clear_conversation(self):
@@ -507,7 +450,6 @@ class RealtimeSpeakerDiarization:
507
  self.displayed_text = ""
508
  self.last_realtime_text = ""
509
  self.current_conversation = "Conversation cleared!"
510
- self.audio_buffer = []
511
 
512
  if self.speaker_detector:
513
  self.speaker_detector = SpeakerChangeDetector(
@@ -589,42 +531,43 @@ class RealtimeSpeakerDiarization:
589
  return f"Error getting status: {e}"
590
 
591
  def process_audio(self, audio_data):
592
- """Process audio data from external sources"""
593
- if not self.is_running:
594
  return
595
 
596
  try:
597
- # Handle different audio data formats
598
- if isinstance(audio_data, tuple) and len(audio_data) == 2:
599
- sample_rate, audio_array = audio_data
600
- else:
601
- audio_array = audio_data
602
- sample_rate = SAMPLE_RATE
603
 
604
  # Convert to int16 format
605
  if audio_array.dtype != np.int16:
606
- if audio_array.dtype == np.float32 or audio_array.dtype == np.float64:
607
- audio_array = (audio_array * 32767).astype(np.int16)
608
- else:
609
- audio_array = audio_array.astype(np.int16)
610
-
611
- # Store in buffer for later processing
612
- self.audio_buffer.append(audio_array)
613
 
614
- # Process if we have enough audio data
615
- if len(self.audio_buffer) > 10: # Process every ~0.5 seconds of audio
616
- combined_audio = np.concatenate(self.audio_buffer)
617
-
618
- # Simulate transcription for demonstration
619
- if len(combined_audio) > SAMPLE_RATE: # At least 1 second of audio
620
- # In a real implementation, this would be transcribed text
621
- demo_text = f"Sample speech segment {len(self.full_sentences) + 1}"
622
- self.process_final_text(demo_text)
623
-
624
- self.audio_buffer = [] # Clear buffer
625
-
626
  except Exception as e:
627
- print(f"Error processing audio: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
 
629
 
630
  # Global instance
@@ -670,6 +613,61 @@ def get_status():
670
  return diarization_system.get_status_info()
671
 
672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  # Create Gradio interface
674
  def create_interface():
675
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome()) as interface:
@@ -678,13 +676,31 @@ def create_interface():
678
 
679
  with gr.Row():
680
  with gr.Column(scale=2):
681
- # Audio input component
682
- audio_input = gr.Audio(
683
- label="🎙️ Audio Input",
684
- sources=["microphone"],
685
- type="numpy",
686
- streaming=True
687
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
  # Main conversation display
690
  conversation_output = gr.HTML(
@@ -735,9 +751,11 @@ def create_interface():
735
  gr.Markdown("""
736
  1. Click **Initialize System** to load models
737
  2. Click **Start Recording** to begin processing
738
- 3. Use the microphone input above to record audio
739
- 4. Watch real-time transcription with speaker labels
740
- 5. Adjust settings as needed
 
 
741
  """)
742
 
743
  # Speaker color legend
@@ -747,13 +765,20 @@ def create_interface():
747
  color_info.append(f'<span style="color:{color};">■</span> Speaker {i+1} ({name})')
748
 
749
  gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
750
-
751
- # Audio processing function
752
- def process_audio_stream(audio_data):
753
- if audio_data is not None and diarization_system.is_running:
754
- diarization_system.process_audio(audio_data)
755
- return diarization_system.get_formatted_conversation()
756
- return None
 
 
 
 
 
 
 
757
 
758
  # Auto-refresh conversation and status
759
  def refresh_display():
@@ -822,13 +847,6 @@ def create_interface():
822
  outputs=[status_output]
823
  )
824
 
825
- # Connect audio input to processing function
826
- audio_input.stream(
827
- process_audio_stream,
828
- inputs=[audio_input],
829
- outputs=[conversation_output]
830
- )
831
-
832
  # Auto-refresh every 2 seconds when recording
833
  refresh_timer = gr.Timer(2.0)
834
  refresh_timer.tick(
@@ -848,6 +866,16 @@ gradio_interface = create_interface()
848
  # 3) Mount Gradio onto FastAPI at root
849
  app = gr.mount_gradio_app(app, gradio_interface, path="/")
850
 
851
- # 4) Local dev via uvicorn; HF Spaces will auto-detect 'app' and ignore this
 
 
 
 
 
 
 
 
 
 
852
  if __name__ == "__main__":
853
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
10
  from scipy.spatial.distance import cosine
11
  from RealtimeSTT import AudioToTextRecorder
12
  from fastapi import FastAPI
13
+ from fastrtc import Stream, AsyncStreamHandler, ReplyOnPause, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials
14
  import json
15
  import io
16
  import wave
17
  import asyncio
18
  import uvicorn
 
 
 
 
 
19
 
20
  # Simplified configuration parameters
21
  SILENCE_THRESHS = [0, 0.4]
 
72
  self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
73
  os.makedirs(self.cache_dir, exist_ok=True)
74
 
75
+ def _download_model(self):
76
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
77
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
78
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
79
+
80
+ if not os.path.exists(model_path):
81
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
82
+ urllib.request.urlretrieve(model_url, model_path)
83
+
84
+ return model_path
85
+
86
  def load_model(self):
87
+ """Load the ECAPA-TDNN model"""
88
  try:
89
+ from speechbrain.pretrained import EncoderClassifier
90
+
91
+ model_path = self._download_model()
 
 
 
92
 
93
  self.model = EncoderClassifier.from_hparams(
94
  source="speechbrain/spkrec-ecapa-voxceleb",
 
97
  )
98
 
99
  self.model_loaded = True
 
100
  return True
101
  except Exception as e:
102
  print(f"Error loading ECAPA-TDNN model: {e}")
103
+ return False
 
 
 
 
 
 
104
 
105
  def embed_utterance(self, audio, sr=16000):
106
  """Extract speaker embedding from audio"""
 
108
  raise ValueError("Model not loaded. Call load_model() first.")
109
 
110
  try:
111
+ if isinstance(audio, np.ndarray):
112
+ waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
113
  else:
114
+ waveform = audio.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ if sr != 16000:
117
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
 
118
 
119
+ with torch.no_grad():
120
+ embedding = self.model.encode_batch(waveform)
121
+
122
+ return embedding.squeeze().cpu().numpy()
 
 
 
123
  except Exception as e:
124
+ print(f"Error extracting embedding: {e}")
125
+ return np.zeros(self.embedding_dim)
126
 
127
 
128
  class AudioProcessor:
 
291
  self.change_threshold = DEFAULT_CHANGE_THRESHOLD
292
  self.max_speakers = DEFAULT_MAX_SPEAKERS
293
  self.current_conversation = ""
 
294
 
295
  def initialize_models(self):
296
  """Initialize the speaker encoder model"""
 
308
  change_threshold=self.change_threshold,
309
  max_speakers=self.max_speakers
310
  )
311
+ print("ECAPA-TDNN model loaded successfully!")
312
  return True
313
  else:
314
+ print("Failed to load ECAPA-TDNN model")
315
  return False
316
  except Exception as e:
317
  print(f"Model initialization error: {e}")
 
331
  self.last_realtime_text = text
332
 
333
  if prob_sentence_end and FAST_SENTENCE_END:
334
+ self.recorder.stop()
 
335
  elif prob_sentence_end:
336
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
 
337
  else:
338
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
 
339
 
340
  def process_final_text(self, text):
341
  """Process final transcribed text with speaker embedding"""
342
  text = text.strip()
343
  if text:
344
  try:
345
+ bytes_data = self.recorder.last_transcription_bytes
346
+ self.sentence_queue.put((text, bytes_data))
 
 
 
 
 
 
 
 
 
347
  self.pending_sentences.append(text)
348
  except Exception as e:
349
  print(f"Error processing final text: {e}")
 
389
  return "Please initialize models first!"
390
 
391
  try:
392
+ # Setup recorder configuration for WebRTC input
393
+ recorder_config = {
394
+ 'spinner': False,
395
+ 'use_microphone': False, # We'll feed audio manually
396
+ 'model': FINAL_TRANSCRIPTION_MODEL,
397
+ 'language': TRANSCRIPTION_LANGUAGE,
398
+ 'silero_sensitivity': SILERO_SENSITIVITY,
399
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
400
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
401
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
402
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
403
+ 'min_gap_between_recordings': 0,
404
+ 'enable_realtime_transcription': True,
405
+ 'realtime_processing_pause': 0,
406
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
407
+ 'on_realtime_transcription_update': self.live_text_detected,
408
+ 'beam_size': FINAL_BEAM_SIZE,
409
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
410
+ 'buffer_size': BUFFER_SIZE,
411
+ 'sample_rate': SAMPLE_RATE,
412
+ }
413
+
414
+ self.recorder = AudioToTextRecorder(**recorder_config)
 
 
 
 
 
 
 
 
 
415
 
416
  # Start sentence processing thread
417
  self.is_running = True
418
  self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
419
  self.sentence_thread.start()
420
 
421
+ # Start transcription thread
422
+ self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
423
+ self.transcription_thread.start()
424
+
425
+ return "Recording started successfully! FastRTC audio input ready."
 
 
426
 
427
  except Exception as e:
428
  return f"Error starting recording: {e}"
 
430
  def run_transcription(self):
431
  """Run the transcription loop"""
432
  try:
433
+ while self.is_running:
434
  self.recorder.text(self.process_final_text)
435
  except Exception as e:
436
  print(f"Transcription error: {e}")
 
439
  """Stop the recording process"""
440
  self.is_running = False
441
  if self.recorder:
442
+ self.recorder.stop()
 
 
 
443
  return "Recording stopped!"
444
 
445
  def clear_conversation(self):
 
450
  self.displayed_text = ""
451
  self.last_realtime_text = ""
452
  self.current_conversation = "Conversation cleared!"
 
453
 
454
  if self.speaker_detector:
455
  self.speaker_detector = SpeakerChangeDetector(
 
531
  return f"Error getting status: {e}"
532
 
533
  def process_audio(self, audio_data):
534
+ """Process audio data from FastRTC"""
535
+ if not self.is_running or not self.recorder:
536
  return
537
 
538
  try:
539
+ # Extract audio data from FastRTC format (sample_rate, numpy_array)
540
+ sample_rate, audio_array = audio_data
 
 
 
 
541
 
542
  # Convert to int16 format
543
  if audio_array.dtype != np.int16:
544
+ audio_array = (audio_array * 32767).astype(np.int16)
 
 
 
 
 
 
545
 
546
+ # Convert to bytes and feed to recorder
547
+ audio_bytes = audio_array.tobytes()
548
+ self.recorder.feed_audio(audio_bytes)
 
 
 
 
 
 
 
 
 
549
  except Exception as e:
550
+ print(f"Error processing FastRTC audio: {e}")
551
+
552
+
553
+ # FastRTC Audio Handler
554
+ class DiarizationHandler(AsyncStreamHandler):
555
+ def __init__(self, diarization_system):
556
+ super().__init__()
557
+ self.diarization_system = diarization_system
558
+
559
+ def copy(self):
560
+ # Return a fresh handler for each new stream connection
561
+ return DiarizationHandler(self.diarization_system)
562
+
563
+ async def emit(self):
564
+ """Not used in this implementation"""
565
+ return None
566
+
567
+ async def receive(self, data):
568
+ """Receive audio data from FastRTC and process it"""
569
+ if self.diarization_system.is_running:
570
+ self.diarization_system.process_audio(data)
571
 
572
 
573
  # Global instance
 
613
  return diarization_system.get_status_info()
614
 
615
 
616
+ # Get Cloudflare TURN credentials for FastRTC
617
+ async def get_cloudflare_credentials():
618
+ # Check if HF_TOKEN is set in environment
619
+ hf_token = os.environ.get("HF_TOKEN")
620
+
621
+ # If not set, use a default Hugging Face token if available
622
+ if not hf_token:
623
+ # Log a warning that user should set their own token
624
+ print("Warning: HF_TOKEN environment variable not set. Please set your own Hugging Face token.")
625
+ # Try to use the Hugging Face token from the environment
626
+ from huggingface_hub import HfApi
627
+ try:
628
+ api = HfApi()
629
+ hf_token = api.token
630
+ if not hf_token:
631
+ print("Error: No Hugging Face token available. TURN relay may not work properly.")
632
+ except:
633
+ print("Error: Failed to get Hugging Face token. TURN relay may not work properly.")
634
+
635
+ # Get Cloudflare TURN credentials using the Hugging Face token
636
+ if hf_token:
637
+ try:
638
+ return await get_cloudflare_turn_credentials_async(hf_token=hf_token)
639
+ except Exception as e:
640
+ print(f"Error getting Cloudflare TURN credentials: {e}")
641
+
642
+ # Fallback to a default configuration that may not work
643
+ return {
644
+ "iceServers": [
645
+ {
646
+ "urls": "stun:stun.l.google.com:19302"
647
+ }
648
+ ]
649
+ }
650
+
651
+
652
+ # Setup FastRTC stream handler with TURN server configuration
653
+ def setup_fastrtc_handler():
654
+ """Set up FastRTC audio stream handler with TURN server configuration"""
655
+ handler = DiarizationHandler(diarization_system)
656
+
657
+ # Get server-side credentials (longer TTL)
658
+ server_credentials = get_cloudflare_turn_credentials(ttl=360000)
659
+
660
+ stream = Stream(
661
+ handler=handler,
662
+ modality="audio",
663
+ mode="receive",
664
+ rtc_configuration=get_cloudflare_credentials, # Async function for client-side credentials
665
+ server_rtc_configuration=server_credentials # Server-side credentials with longer TTL
666
+ )
667
+
668
+ return stream
669
+
670
+
671
  # Create Gradio interface
672
  def create_interface():
673
  with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome()) as interface:
 
676
 
677
  with gr.Row():
678
  with gr.Column(scale=2):
679
+ # FastRTC Audio Component
680
+ fastrtc_html = gr.HTML("""
681
+ <div class="fastrtc-container" style="margin-bottom: 20px;">
682
+ <h3>🎙️ FastRTC Audio Input</h3>
683
+ <p>Click the button below to start the audio stream:</p>
684
+ <button id="start-fastrtc" style="background: #3498db; color: white; padding: 10px 20px; border: none; border-radius: 5px; cursor: pointer;">
685
+ Start FastRTC Audio
686
+ </button>
687
+ <div id="fastrtc-status" style="margin-top: 10px; font-style: italic;">Not connected</div>
688
+ <script>
689
+ document.getElementById('start-fastrtc').addEventListener('click', function() {
690
+ document.getElementById('fastrtc-status').textContent = 'Connecting...';
691
+ // FastRTC will initialize the connection
692
+ fetch('/start-rtc', { method: 'POST' })
693
+ .then(response => response.text())
694
+ .then(data => {
695
+ document.getElementById('fastrtc-status').textContent = 'Connected! Speak now...';
696
+ })
697
+ .catch(error => {
698
+ document.getElementById('fastrtc-status').textContent = 'Connection error: ' + error;
699
+ });
700
+ });
701
+ </script>
702
+ </div>
703
+ """)
704
 
705
  # Main conversation display
706
  conversation_output = gr.HTML(
 
751
  gr.Markdown("""
752
  1. Click **Initialize System** to load models
753
  2. Click **Start Recording** to begin processing
754
+ 3. Click **Start FastRTC Audio** to connect your microphone
755
+ 4. Allow microphone access when prompted
756
+ 5. Speak into your microphone
757
+ 6. Watch real-time transcription with speaker labels
758
+ 7. Adjust settings as needed
759
  """)
760
 
761
  # Speaker color legend
 
765
  color_info.append(f'<span style="color:{color};">■</span> Speaker {i+1} ({name})')
766
 
767
  gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
768
+
769
+ # FastRTC Integration Notice
770
+ gr.Markdown("""
771
+ ## ℹ️ About FastRTC
772
+ This app uses FastRTC for low-latency audio streaming.
773
+ For optimal performance, use a modern browser and allow microphone access when prompted.
774
+ """)
775
+
776
+ # Hugging Face Token Information
777
+ gr.Markdown("""
778
+ ## 🔑 Hugging Face Token
779
+ This app uses Cloudflare TURN server via Hugging Face integration.
780
+ If audio connection fails, set your HF_TOKEN environment variable in the Space settings.
781
+ """)
782
 
783
  # Auto-refresh conversation and status
784
  def refresh_display():
 
847
  outputs=[status_output]
848
  )
849
 
 
 
 
 
 
 
 
850
  # Auto-refresh every 2 seconds when recording
851
  refresh_timer = gr.Timer(2.0)
852
  refresh_timer.tick(
 
866
  # 3) Mount Gradio onto FastAPI at root
867
  app = gr.mount_gradio_app(app, gradio_interface, path="/")
868
 
869
+ # 4) Initialize and mount FastRTC stream on the same app
870
+ rtc_stream = setup_fastrtc_handler()
871
+ rtc_stream.mount(app)
872
+
873
+ # 5) Expose an endpoint to trigger the client-side RTC handshake
874
+ @app.post("/start-rtc")
875
+ async def start_rtc():
876
+ await rtc_stream.start_client()
877
+ return {"status": "success"}
878
+
879
+ # 6) Local dev via uvicorn; HF Spaces will auto-detect 'app' and ignore this
880
  if __name__ == "__main__":
881
  uvicorn.run(app, host="0.0.0.0", port=7860)