Saiyaswanth007 commited on
Commit
640dd0e
·
1 Parent(s): af81629

Code changes

Browse files
Files changed (1) hide show
  1. app.py +509 -290
app.py CHANGED
@@ -1,96 +1,149 @@
1
  import gradio as gr
2
  import numpy as np
3
- import torch
4
- import torchaudio
5
- import threading
6
  import queue
 
7
  import time
 
8
  import os
9
  import urllib.request
 
10
  from scipy.spatial.distance import cosine
11
- from collections import deque
12
- import tempfile
13
- import librosa
14
 
15
- # Configuration parameters
16
- FINAL_TRANSCRIPTION_MODEL = "openai/whisper-small"
 
 
 
 
17
  TRANSCRIPTION_LANGUAGE = "en"
 
 
 
 
 
 
18
  DEFAULT_CHANGE_THRESHOLD = 0.7
19
  EMBEDDING_HISTORY_SIZE = 5
20
  MIN_SEGMENT_DURATION = 1.0
21
  DEFAULT_MAX_SPEAKERS = 4
22
- ABSOLUTE_MAX_SPEAKERS = 6
 
 
 
 
23
  SAMPLE_RATE = 16000
 
 
24
 
25
- # Speaker colors for up to 6 speakers
26
  SPEAKER_COLORS = [
27
- "#FFD700", # Gold
28
- "#FF6B6B", # Red
29
- "#4ECDC4", # Teal
30
- "#45B7D1", # Blue
31
- "#96CEB4", # Green
32
- "#FFEAA7", # Yellow
 
 
 
 
33
  ]
34
 
35
  SPEAKER_COLOR_NAMES = [
36
- "Gold", "Red", "Teal", "Blue", "Green", "Yellow"
 
37
  ]
38
 
39
 
40
  class SpeechBrainEncoder:
41
- """Simplified encoder for speaker embeddings using torch audio features"""
42
  def __init__(self, device="cpu"):
43
  self.device = device
44
- self.embedding_dim = 128
45
- self.model_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
47
  def load_model(self):
48
- """Model loading simulation"""
49
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def embed_utterance(self, audio, sr=16000):
52
- """Extract simple spectral features as speaker embedding"""
 
 
 
53
  try:
54
  if isinstance(audio, np.ndarray):
55
- waveform = torch.tensor(audio, dtype=torch.float32)
56
  else:
57
- waveform = audio
58
-
59
- if len(waveform.shape) == 1:
60
- waveform = waveform.unsqueeze(0)
61
 
62
- # Resample if needed
63
  if sr != 16000:
64
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
65
 
66
- # Extract MFCC features as a simple embedding
67
- mfcc_transform = torchaudio.transforms.MFCC(
68
- sample_rate=16000,
69
- n_mfcc=13,
70
- melkwargs={'n_mels': 40}
71
- )
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- mfcc = mfcc_transform(waveform)
74
- # Take mean across time dimension and flatten
75
- embedding = mfcc.mean(dim=2).flatten()
76
 
77
- # Pad or truncate to fixed size
78
- if len(embedding) > self.embedding_dim:
79
- embedding = embedding[:self.embedding_dim]
80
- elif len(embedding) < self.embedding_dim:
81
- padding = torch.zeros(self.embedding_dim - len(embedding))
82
- embedding = torch.cat([embedding, padding])
83
-
84
- return embedding.numpy()
85
 
 
86
  except Exception as e:
87
- print(f"Error extracting embedding: {e}")
88
- return np.random.randn(self.embedding_dim)
89
 
90
 
91
  class SpeakerChangeDetector:
92
- """Speaker change detector for real-time diarization"""
93
- def __init__(self, embedding_dim=128, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
94
  self.embedding_dim = embedding_dim
95
  self.change_threshold = change_threshold
96
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
@@ -110,6 +163,7 @@ class SpeakerChangeDetector:
110
  for speaker_id in list(self.active_speakers):
111
  if speaker_id >= new_max:
112
  self.active_speakers.discard(speaker_id)
 
113
  if self.current_speaker >= new_max:
114
  self.current_speaker = 0
115
 
@@ -161,6 +215,7 @@ class SpeakerChangeDetector:
161
 
162
  if speaker_mean is not None:
163
  speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
 
164
  if speaker_similarity > best_similarity:
165
  best_similarity = speaker_similarity
166
  best_speaker = speaker_id
@@ -201,317 +256,481 @@ class SpeakerChangeDetector:
201
  if 0 <= speaker_id < len(SPEAKER_COLORS):
202
  return SPEAKER_COLORS[speaker_id]
203
  return "#FFFFFF"
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
- class RealTimeASRDiarization:
207
- """Main class for real-time ASR with speaker diarization"""
208
  def __init__(self):
209
- self.encoder = SpeechBrainEncoder()
210
- self.encoder.load_model()
211
- self.speaker_detector = SpeakerChangeDetector()
212
- self.transcription_queue = queue.Queue()
213
- self.conversation_history = []
214
- self.is_processing = False
215
-
216
- # Load Whisper model
 
 
 
 
 
 
 
 
 
217
  try:
218
- import whisper
219
- self.whisper_model = whisper.load_model("base")
220
- except ImportError:
221
- print("Whisper not available, using mock transcription")
222
- self.whisper_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- def transcribe_audio(self, audio_data, sr=16000):
225
- """Transcribe audio using Whisper"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  try:
227
- if self.whisper_model is None:
228
- return "Mock transcription: Hello, this is a test."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- # Ensure audio is the right format
231
- if isinstance(audio_data, tuple):
232
- sr, audio_data = audio_data
 
233
 
234
- if len(audio_data.shape) > 1:
235
- audio_data = audio_data.mean(axis=1)
 
236
 
237
- # Normalize audio
238
- audio_data = audio_data.astype(np.float32)
239
- if np.abs(audio_data).max() > 1.0:
240
- audio_data = audio_data / np.abs(audio_data).max()
241
 
242
- # Resample to 16kHz if needed
243
- if sr != 16000:
244
- audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
245
 
246
- # Transcribe
247
- result = self.whisper_model.transcribe(audio_data, language="en")
248
- return result["text"].strip()
 
 
 
 
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  except Exception as e:
251
  print(f"Transcription error: {e}")
252
- return ""
253
 
254
- def extract_speaker_embedding(self, audio_data, sr=16000):
255
- """Extract speaker embedding from audio"""
256
- return self.encoder.embed_utterance(audio_data, sr)
 
 
 
257
 
258
- def process_audio_segment(self, audio_data, sr=16000):
259
- """Process an audio segment for transcription and speaker identification"""
260
- if len(audio_data) < sr * 0.5: # Skip very short segments
261
- return None, None, None
262
-
263
- # Transcribe the audio
264
- transcription = self.transcribe_audio(audio_data, sr)
265
-
266
- if not transcription:
267
- return None, None, None
268
-
269
- # Extract speaker embedding
270
- embedding = self.extract_speaker_embedding(audio_data, sr)
271
-
272
- # Detect speaker
273
- speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
274
 
275
- return transcription, speaker_id, similarity
276
 
277
- def update_conversation(self, transcription, speaker_id):
278
- """Update conversation history with new transcription"""
279
- speaker_name = f"Speaker {speaker_id + 1}"
280
- color = self.speaker_detector.get_color_for_speaker(speaker_id)
281
-
282
- entry = {
283
- "speaker": speaker_name,
284
- "text": transcription,
285
- "color": color,
286
- "timestamp": time.time()
287
- }
288
 
289
- self.conversation_history.append(entry)
290
- return entry
291
-
292
- def format_conversation_html(self):
293
- """Format conversation history as HTML"""
294
- if not self.conversation_history:
295
- return "<p><i>No conversation yet. Start speaking to see real-time transcription with speaker diarization.</i></p>"
296
-
297
- html_parts = []
298
- for entry in self.conversation_history:
299
- html_parts.append(
300
- f'<p><span style="color: {entry["color"]}; font-weight: bold;">'
301
- f'{entry["speaker"]}:</span> {entry["text"]}</p>'
302
- )
303
 
304
- return "".join(html_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  def get_status_info(self):
307
  """Get current status information"""
308
- status = {
309
- "active_speakers": len(self.speaker_detector.active_speakers),
310
- "max_speakers": self.speaker_detector.max_speakers,
311
- "current_speaker": self.speaker_detector.current_speaker + 1,
312
- "total_segments": len(self.conversation_history),
313
- "threshold": self.speaker_detector.change_threshold
314
- }
315
- return status
316
-
317
- def clear_conversation(self):
318
- """Clear conversation history and reset speaker detector"""
319
- self.conversation_history = []
320
- self.speaker_detector = SpeakerChangeDetector(
321
- change_threshold=self.speaker_detector.change_threshold,
322
- max_speakers=self.speaker_detector.max_speakers
323
- )
324
-
325
- def set_parameters(self, threshold, max_speakers):
326
- """Update parameters"""
327
- self.speaker_detector.set_change_threshold(threshold)
328
- self.speaker_detector.set_max_speakers(max_speakers)
 
 
 
329
 
330
 
331
  # Global instance
332
- asr_system = RealTimeASRDiarization()
333
 
334
 
335
- def process_audio_realtime(audio_data, threshold, max_speakers):
336
- """Process audio in real-time"""
337
- global asr_system
338
-
339
- if audio_data is None:
340
- return asr_system.format_conversation_html(), get_status_display()
341
-
342
- # Update parameters
343
- asr_system.set_parameters(threshold, max_speakers)
344
-
345
- try:
346
- # Process the audio segment
347
- sr, audio_array = audio_data
348
-
349
- # Convert to float32 and normalize
350
- if audio_array.dtype != np.float32:
351
- audio_array = audio_array.astype(np.float32)
352
- if audio_array.dtype == np.int16:
353
- audio_array = audio_array / 32768.0
354
- elif audio_array.dtype == np.int32:
355
- audio_array = audio_array / 2147483648.0
356
-
357
- # Process the audio segment
358
- transcription, speaker_id, similarity = asr_system.process_audio_segment(audio_array, sr)
359
-
360
- if transcription and speaker_id is not None:
361
- # Update conversation
362
- asr_system.update_conversation(transcription, speaker_id)
363
-
364
- except Exception as e:
365
- print(f"Error processing audio: {e}")
366
-
367
- return asr_system.format_conversation_html(), get_status_display()
368
 
369
 
370
- def get_status_display():
371
- """Get formatted status display"""
372
- status = asr_system.get_status_info()
373
-
374
- status_html = f"""
375
- <div style="font-family: monospace; font-size: 12px;">
376
- <strong>Status:</strong><br>
377
- Current Speaker: {status['current_speaker']}<br>
378
- Active Speakers: {status['active_speakers']} / {status['max_speakers']}<br>
379
- Total Segments: {status['total_segments']}<br>
380
- Threshold: {status['threshold']:.2f}<br>
381
- </div>
382
- """
383
-
384
- return status_html
385
 
386
 
387
  def clear_conversation():
388
  """Clear the conversation"""
389
- global asr_system
390
- asr_system.clear_conversation()
391
- return asr_system.format_conversation_html(), get_status_display()
 
 
 
 
392
 
 
 
 
393
 
 
 
 
 
 
 
 
394
  def create_interface():
395
- """Create Gradio interface"""
396
-
397
- with gr.Blocks(
398
- title="Real-time ASR with Speaker Diarization",
399
- theme=gr.themes.Soft(),
400
- css="""
401
- .conversation-box {
402
- height: 400px;
403
- overflow-y: auto;
404
- border: 1px solid #ddd;
405
- padding: 10px;
406
- background-color: #f9f9f9;
407
- }
408
- .status-box {
409
- border: 1px solid #ccc;
410
- padding: 10px;
411
- background-color: #f0f0f0;
412
- }
413
- """
414
- ) as demo:
415
-
416
- gr.Markdown(
417
- """
418
- # 🎤 Real-time ASR with Live Speaker Diarization
419
-
420
- This application provides real-time speech recognition with speaker diarization.
421
- It can distinguish between different speakers and display their conversations in different colors.
422
-
423
- **Instructions:**
424
- 1. Adjust the speaker change threshold and maximum speakers
425
- 2. Click the microphone button to start recording
426
- 3. Speak naturally - the system will detect speaker changes and transcribe speech
427
- 4. Each speaker will be assigned a different color
428
- """
429
- )
430
 
431
  with gr.Row():
432
- with gr.Column(scale=3):
433
  # Main conversation display
434
- conversation_display = gr.HTML(
435
- value="<p><i>Click the microphone to start recording...</i></p>",
436
- elem_classes=["conversation-box"]
437
  )
438
 
439
- # Audio input
440
- audio_input = gr.Audio(
441
- source="microphone",
442
- type="numpy",
443
- streaming=True,
444
- label="🎤 Microphone Input"
445
- )
446
 
 
 
 
 
 
 
 
 
447
  with gr.Column(scale=1):
448
- # Controls
449
- gr.Markdown("### Controls")
450
 
451
  threshold_slider = gr.Slider(
452
  minimum=0.1,
453
- maximum=0.9,
454
- value=DEFAULT_CHANGE_THRESHOLD,
455
  step=0.05,
456
- label="Speaker Change Threshold",
457
- info="Higher values = less sensitive to speaker changes"
 
458
  )
459
 
460
  max_speakers_slider = gr.Slider(
461
  minimum=2,
462
  maximum=ABSOLUTE_MAX_SPEAKERS,
463
- value=DEFAULT_MAX_SPEAKERS,
464
  step=1,
465
- label="Maximum Speakers",
466
- info="Maximum number of different speakers to detect"
467
  )
468
 
469
- clear_btn = gr.Button("🗑️ Clear Conversation", variant="secondary")
470
-
471
- # Status display
472
- gr.Markdown("### Status")
473
- status_display = gr.HTML(
474
- value=get_status_display(),
475
- elem_classes=["status-box"]
476
- )
477
 
478
  # Speaker color legend
479
- gr.Markdown("### Speaker Colors")
480
- legend_html = ""
481
- for i in range(ABSOLUTE_MAX_SPEAKERS):
482
- color = SPEAKER_COLORS[i]
483
- name = SPEAKER_COLOR_NAMES[i]
484
- legend_html += f'<p><span style="color: {color}; font-weight: bold;">● Speaker {i+1} ({name})</span></p>'
485
 
486
- gr.HTML(legend_html)
 
 
 
 
487
 
488
  # Event handlers
489
- audio_input.change(
490
- fn=process_audio_realtime,
491
- inputs=[audio_input, threshold_slider, max_speakers_slider],
492
- outputs=[conversation_display, status_display],
493
- show_progress=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  )
495
 
496
  clear_btn.click(
497
- fn=clear_conversation,
498
- outputs=[conversation_display, status_display]
 
 
 
 
 
 
499
  )
500
 
501
- # Update status periodically
502
- demo.load(
503
- fn=lambda: (asr_system.format_conversation_html(), get_status_display()),
504
- outputs=[conversation_display, status_display],
505
- every=2
506
  )
507
 
508
- return demo
509
 
510
 
511
  if __name__ == "__main__":
512
- # Create and launch the interface
513
- demo = create_interface()
514
- demo.launch(
515
  server_name="0.0.0.0",
516
  server_port=7860,
517
  share=True
 
1
  import gradio as gr
2
  import numpy as np
3
+ import soundcard as sc
 
 
4
  import queue
5
+ import torch
6
  import time
7
+ import threading
8
  import os
9
  import urllib.request
10
+ import torchaudio
11
  from scipy.spatial.distance import cosine
12
+ from RealtimeSTT import AudioToTextRecorder
13
+ import json
 
14
 
15
+ # Simplified configuration parameters
16
+ SILENCE_THRESHS = [0, 0.4]
17
+ FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
18
+ FINAL_BEAM_SIZE = 5
19
+ REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
20
+ REALTIME_BEAM_SIZE = 5
21
  TRANSCRIPTION_LANGUAGE = "en"
22
+ SILERO_SENSITIVITY = 0.4
23
+ WEBRTC_SENSITIVITY = 3
24
+ MIN_LENGTH_OF_RECORDING = 0.7
25
+ PRE_RECORDING_BUFFER_DURATION = 0.35
26
+
27
+ # Speaker change detection parameters
28
  DEFAULT_CHANGE_THRESHOLD = 0.7
29
  EMBEDDING_HISTORY_SIZE = 5
30
  MIN_SEGMENT_DURATION = 1.0
31
  DEFAULT_MAX_SPEAKERS = 4
32
+ ABSOLUTE_MAX_SPEAKERS = 10
33
+
34
+ # Global variables
35
+ FAST_SENTENCE_END = True
36
+ USE_MICROPHONE = False
37
  SAMPLE_RATE = 16000
38
+ BUFFER_SIZE = 512
39
+ CHANNELS = 1
40
 
41
+ # Speaker colors
42
  SPEAKER_COLORS = [
43
+ "#FFFF00", # Yellow
44
+ "#FF0000", # Red
45
+ "#00FF00", # Green
46
+ "#00FFFF", # Cyan
47
+ "#FF00FF", # Magenta
48
+ "#0000FF", # Blue
49
+ "#FF8000", # Orange
50
+ "#00FF80", # Spring Green
51
+ "#8000FF", # Purple
52
+ "#FFFFFF", # White
53
  ]
54
 
55
  SPEAKER_COLOR_NAMES = [
56
+ "Yellow", "Red", "Green", "Cyan", "Magenta",
57
+ "Blue", "Orange", "Spring Green", "Purple", "White"
58
  ]
59
 
60
 
61
  class SpeechBrainEncoder:
62
+ """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
63
  def __init__(self, device="cpu"):
64
  self.device = device
65
+ self.model = None
66
+ self.embedding_dim = 192
67
+ self.model_loaded = False
68
+ self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
69
+ os.makedirs(self.cache_dir, exist_ok=True)
70
+
71
+ def _download_model(self):
72
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
73
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
74
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
75
+
76
+ if not os.path.exists(model_path):
77
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
78
+ urllib.request.urlretrieve(model_url, model_path)
79
 
80
+ return model_path
81
+
82
  def load_model(self):
83
+ """Load the ECAPA-TDNN model"""
84
+ try:
85
+ from speechbrain.pretrained import EncoderClassifier
86
+
87
+ model_path = self._download_model()
88
+
89
+ self.model = EncoderClassifier.from_hparams(
90
+ source="speechbrain/spkrec-ecapa-voxceleb",
91
+ savedir=self.cache_dir,
92
+ run_opts={"device": self.device}
93
+ )
94
+
95
+ self.model_loaded = True
96
+ return True
97
+ except Exception as e:
98
+ print(f"Error loading ECAPA-TDNN model: {e}")
99
+ return False
100
 
101
  def embed_utterance(self, audio, sr=16000):
102
+ """Extract speaker embedding from audio"""
103
+ if not self.model_loaded:
104
+ raise ValueError("Model not loaded. Call load_model() first.")
105
+
106
  try:
107
  if isinstance(audio, np.ndarray):
108
+ waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
109
  else:
110
+ waveform = audio.unsqueeze(0)
 
 
 
111
 
 
112
  if sr != 16000:
113
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
114
 
115
+ with torch.no_grad():
116
+ embedding = self.model.encode_batch(waveform)
117
+
118
+ return embedding.squeeze().cpu().numpy()
119
+ except Exception as e:
120
+ print(f"Error extracting embedding: {e}")
121
+ return np.zeros(self.embedding_dim)
122
+
123
+
124
+ class AudioProcessor:
125
+ """Processes audio data to extract speaker embeddings"""
126
+ def __init__(self, encoder):
127
+ self.encoder = encoder
128
+
129
+ def extract_embedding(self, audio_int16):
130
+ try:
131
+ float_audio = audio_int16.astype(np.float32) / 32768.0
132
 
133
+ if np.abs(float_audio).max() > 1.0:
134
+ float_audio = float_audio / np.abs(float_audio).max()
 
135
 
136
+ embedding = self.encoder.embed_utterance(float_audio)
 
 
 
 
 
 
 
137
 
138
+ return embedding
139
  except Exception as e:
140
+ print(f"Embedding extraction error: {e}")
141
+ return np.zeros(self.encoder.embedding_dim)
142
 
143
 
144
  class SpeakerChangeDetector:
145
+ """Speaker change detector that supports a configurable number of speakers"""
146
+ def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
147
  self.embedding_dim = embedding_dim
148
  self.change_threshold = change_threshold
149
  self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
 
163
  for speaker_id in list(self.active_speakers):
164
  if speaker_id >= new_max:
165
  self.active_speakers.discard(speaker_id)
166
+
167
  if self.current_speaker >= new_max:
168
  self.current_speaker = 0
169
 
 
215
 
216
  if speaker_mean is not None:
217
  speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
218
+
219
  if speaker_similarity > best_similarity:
220
  best_similarity = speaker_similarity
221
  best_speaker = speaker_id
 
256
  if 0 <= speaker_id < len(SPEAKER_COLORS):
257
  return SPEAKER_COLORS[speaker_id]
258
  return "#FFFFFF"
259
+
260
+ def get_status_info(self):
261
+ """Return status information about the speaker change detector"""
262
+ speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
263
+
264
+ return {
265
+ "current_speaker": self.current_speaker,
266
+ "speaker_counts": speaker_counts,
267
+ "active_speakers": len(self.active_speakers),
268
+ "max_speakers": self.max_speakers,
269
+ "last_similarity": self.last_similarity,
270
+ "threshold": self.change_threshold
271
+ }
272
 
273
 
274
+ class RealtimeSpeakerDiarization:
 
275
  def __init__(self):
276
+ self.encoder = None
277
+ self.audio_processor = None
278
+ self.speaker_detector = None
279
+ self.recorder = None
280
+ self.recording_thread = None
281
+ self.sentence_queue = queue.Queue()
282
+ self.full_sentences = []
283
+ self.sentence_speakers = []
284
+ self.pending_sentences = []
285
+ self.displayed_text = ""
286
+ self.last_realtime_text = ""
287
+ self.is_running = False
288
+ self.change_threshold = DEFAULT_CHANGE_THRESHOLD
289
+ self.max_speakers = DEFAULT_MAX_SPEAKERS
290
+
291
+ def initialize_models(self):
292
+ """Initialize the speaker encoder model"""
293
  try:
294
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
295
+ print(f"Using device: {device_str}")
296
+
297
+ self.encoder = SpeechBrainEncoder(device=device_str)
298
+ success = self.encoder.load_model()
299
+
300
+ if success:
301
+ self.audio_processor = AudioProcessor(self.encoder)
302
+ self.speaker_detector = SpeakerChangeDetector(
303
+ embedding_dim=self.encoder.embedding_dim,
304
+ change_threshold=self.change_threshold,
305
+ max_speakers=self.max_speakers
306
+ )
307
+ print("ECAPA-TDNN model loaded successfully!")
308
+ return True
309
+ else:
310
+ print("Failed to load ECAPA-TDNN model")
311
+ return False
312
+ except Exception as e:
313
+ print(f"Model initialization error: {e}")
314
+ return False
315
+
316
+ def live_text_detected(self, text):
317
+ """Callback for real-time transcription updates"""
318
+ text = text.strip()
319
+ if text:
320
+ sentence_delimiters = '.?!。'
321
+ prob_sentence_end = (
322
+ len(self.last_realtime_text) > 0
323
+ and text[-1] in sentence_delimiters
324
+ and self.last_realtime_text[-1] in sentence_delimiters
325
+ )
326
+
327
+ self.last_realtime_text = text
328
+
329
+ if prob_sentence_end and FAST_SENTENCE_END:
330
+ self.recorder.stop()
331
+ elif prob_sentence_end:
332
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
333
+ else:
334
+ self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
335
+
336
+ def process_final_text(self, text):
337
+ """Process final transcribed text with speaker embedding"""
338
+ text = text.strip()
339
+ if text:
340
+ try:
341
+ bytes_data = self.recorder.last_transcription_bytes
342
+ self.sentence_queue.put((text, bytes_data))
343
+ self.pending_sentences.append(text)
344
+ except Exception as e:
345
+ print(f"Error processing final text: {e}")
346
 
347
+ def process_sentence_queue(self):
348
+ """Process sentences in the queue for speaker detection"""
349
+ while self.is_running:
350
+ try:
351
+ text, bytes_data = self.sentence_queue.get(timeout=1)
352
+
353
+ # Convert audio data to int16
354
+ audio_int16 = np.int16(bytes_data * 32767)
355
+
356
+ # Extract speaker embedding
357
+ speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
358
+
359
+ # Store sentence and embedding
360
+ self.full_sentences.append((text, speaker_embedding))
361
+
362
+ # Fill in missing speaker assignments
363
+ while len(self.sentence_speakers) < len(self.full_sentences) - 1:
364
+ self.sentence_speakers.append(0)
365
+
366
+ # Detect speaker changes
367
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
368
+ self.sentence_speakers.append(speaker_id)
369
+
370
+ # Remove from pending
371
+ if text in self.pending_sentences:
372
+ self.pending_sentences.remove(text)
373
+
374
+ except queue.Empty:
375
+ continue
376
+ except Exception as e:
377
+ print(f"Error processing sentence: {e}")
378
+
379
+ def start_recording(self):
380
+ """Start the recording and transcription process"""
381
+ if self.encoder is None:
382
+ return "Please initialize models first!"
383
+
384
  try:
385
+ # Setup recorder configuration
386
+ recorder_config = {
387
+ 'spinner': False,
388
+ 'use_microphone': USE_MICROPHONE,
389
+ 'model': FINAL_TRANSCRIPTION_MODEL,
390
+ 'language': TRANSCRIPTION_LANGUAGE,
391
+ 'silero_sensitivity': SILERO_SENSITIVITY,
392
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
393
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
394
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
395
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
396
+ 'min_gap_between_recordings': 0,
397
+ 'enable_realtime_transcription': True,
398
+ 'realtime_processing_pause': 0,
399
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
400
+ 'on_realtime_transcription_update': self.live_text_detected,
401
+ 'beam_size': FINAL_BEAM_SIZE,
402
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
403
+ 'buffer_size': BUFFER_SIZE,
404
+ 'sample_rate': SAMPLE_RATE,
405
+ }
406
+
407
+ self.recorder = AudioToTextRecorder(**recorder_config)
408
 
409
+ # Start sentence processing thread
410
+ self.is_running = True
411
+ self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
412
+ self.sentence_thread.start()
413
 
414
+ # Start audio capture thread
415
+ self.audio_thread = threading.Thread(target=self.capture_audio, daemon=True)
416
+ self.audio_thread.start()
417
 
418
+ # Start transcription thread
419
+ self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
420
+ self.transcription_thread.start()
 
421
 
422
+ return "Recording started successfully!"
 
 
423
 
424
+ except Exception as e:
425
+ return f"Error starting recording: {e}"
426
+
427
+ def capture_audio(self):
428
+ """Capture audio from default speaker/microphone"""
429
+ try:
430
+ device_id = str(sc.default_speaker().name if not USE_MICROPHONE else sc.default_microphone().name)
431
+ include_loopback = not USE_MICROPHONE
432
 
433
+ with sc.get_microphone(id=device_id, include_loopback=include_loopback).recorder(
434
+ samplerate=SAMPLE_RATE, blocksize=BUFFER_SIZE
435
+ ) as mic:
436
+ while self.is_running:
437
+ audio_data = mic.record(numframes=BUFFER_SIZE)
438
+
439
+ if audio_data.shape[1] > 1 and CHANNELS == 1:
440
+ audio_data = audio_data[:, 0]
441
+
442
+ audio_int16 = (audio_data.flatten() * 32767).astype(np.int16)
443
+ audio_bytes = audio_int16.tobytes()
444
+ self.recorder.feed_audio(audio_bytes)
445
+
446
+ except Exception as e:
447
+ print(f"Audio capture error: {e}")
448
+
449
+ def run_transcription(self):
450
+ """Run the transcription loop"""
451
+ try:
452
+ while self.is_running:
453
+ self.recorder.text(self.process_final_text)
454
  except Exception as e:
455
  print(f"Transcription error: {e}")
 
456
 
457
+ def stop_recording(self):
458
+ """Stop the recording process"""
459
+ self.is_running = False
460
+ if self.recorder:
461
+ self.recorder.stop()
462
+ return "Recording stopped!"
463
 
464
+ def clear_conversation(self):
465
+ """Clear all conversation data"""
466
+ self.full_sentences = []
467
+ self.sentence_speakers = []
468
+ self.pending_sentences = []
469
+ self.displayed_text = ""
470
+ self.last_realtime_text = ""
471
+
472
+ if self.speaker_detector:
473
+ self.speaker_detector = SpeakerChangeDetector(
474
+ embedding_dim=self.encoder.embedding_dim,
475
+ change_threshold=self.change_threshold,
476
+ max_speakers=self.max_speakers
477
+ )
 
 
478
 
479
+ return "Conversation cleared!"
480
 
481
+ def update_settings(self, threshold, max_speakers):
482
+ """Update speaker detection settings"""
483
+ self.change_threshold = threshold
484
+ self.max_speakers = max_speakers
 
 
 
 
 
 
 
485
 
486
+ if self.speaker_detector:
487
+ self.speaker_detector.set_change_threshold(threshold)
488
+ self.speaker_detector.set_max_speakers(max_speakers)
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
491
+
492
+ def get_formatted_conversation(self):
493
+ """Get the formatted conversation with speaker colors"""
494
+ try:
495
+ sentences_with_style = []
496
+
497
+ # Process completed sentences
498
+ for i, sentence in enumerate(self.full_sentences):
499
+ sentence_text, _ = sentence
500
+ if i >= len(self.sentence_speakers):
501
+ color = "#FFFFFF"
502
+ else:
503
+ speaker_id = self.sentence_speakers[i]
504
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
505
+ speaker_name = f"Speaker {speaker_id + 1}"
506
+
507
+ sentences_with_style.append(
508
+ f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
509
+
510
+ # Add pending sentences
511
+ for pending_sentence in self.pending_sentences:
512
+ sentences_with_style.append(
513
+ f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>')
514
+
515
+ if sentences_with_style:
516
+ return "<br><br>".join(sentences_with_style)
517
+ else:
518
+ return "Waiting for speech input..."
519
+
520
+ except Exception as e:
521
+ return f"Error formatting conversation: {e}"
522
 
523
  def get_status_info(self):
524
  """Get current status information"""
525
+ if not self.speaker_detector:
526
+ return "Speaker detector not initialized"
527
+
528
+ try:
529
+ status = self.speaker_detector.get_status_info()
530
+
531
+ status_lines = [
532
+ f"**Current Speaker:** {status['current_speaker'] + 1}",
533
+ f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
534
+ f"**Last Similarity:** {status['last_similarity']:.3f}",
535
+ f"**Change Threshold:** {status['threshold']:.2f}",
536
+ f"**Total Sentences:** {len(self.full_sentences)}",
537
+ "",
538
+ "**Speaker Segment Counts:**"
539
+ ]
540
+
541
+ for i in range(status['max_speakers']):
542
+ color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
543
+ status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
544
+
545
+ return "\n".join(status_lines)
546
+
547
+ except Exception as e:
548
+ return f"Error getting status: {e}"
549
 
550
 
551
  # Global instance
552
+ diarization_system = RealtimeSpeakerDiarization()
553
 
554
 
555
+ def initialize_system():
556
+ """Initialize the diarization system"""
557
+ success = diarization_system.initialize_models()
558
+ if success:
559
+ return "✅ System initialized successfully! Models loaded."
560
+ else:
561
+ return "❌ Failed to initialize system. Please check the logs."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
 
564
+ def start_recording():
565
+ """Start recording and transcription"""
566
+ return diarization_system.start_recording()
567
+
568
+
569
+ def stop_recording():
570
+ """Stop recording and transcription"""
571
+ return diarization_system.stop_recording()
 
 
 
 
 
 
 
572
 
573
 
574
  def clear_conversation():
575
  """Clear the conversation"""
576
+ return diarization_system.clear_conversation()
577
+
578
+
579
+ def update_settings(threshold, max_speakers):
580
+ """Update system settings"""
581
+ return diarization_system.update_settings(threshold, max_speakers)
582
+
583
 
584
+ def get_conversation():
585
+ """Get the current conversation"""
586
+ return diarization_system.get_formatted_conversation()
587
 
588
+
589
+ def get_status():
590
+ """Get system status"""
591
+ return diarization_system.get_status_info()
592
+
593
+
594
+ # Create Gradio interface
595
  def create_interface():
596
+ with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Dark()) as app:
597
+ gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
598
+ gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  with gr.Row():
601
+ with gr.Column(scale=2):
602
  # Main conversation display
603
+ conversation_output = gr.HTML(
604
+ value="<i>Click 'Initialize System' to start...</i>",
605
+ label="Live Conversation"
606
  )
607
 
608
+ # Control buttons
609
+ with gr.Row():
610
+ init_btn = gr.Button("🔧 Initialize System", variant="secondary")
611
+ start_btn = gr.Button("🎙️ Start Recording", variant="primary", interactive=False)
612
+ stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", interactive=False)
613
+ clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)
 
614
 
615
+ # Status display
616
+ status_output = gr.Textbox(
617
+ label="System Status",
618
+ value="System not initialized",
619
+ lines=8,
620
+ interactive=False
621
+ )
622
+
623
  with gr.Column(scale=1):
624
+ # Settings panel
625
+ gr.Markdown("## ⚙️ Settings")
626
 
627
  threshold_slider = gr.Slider(
628
  minimum=0.1,
629
+ maximum=0.95,
 
630
  step=0.05,
631
+ value=DEFAULT_CHANGE_THRESHOLD,
632
+ label="Speaker Change Sensitivity",
633
+ info="Lower values = more sensitive to speaker changes"
634
  )
635
 
636
  max_speakers_slider = gr.Slider(
637
  minimum=2,
638
  maximum=ABSOLUTE_MAX_SPEAKERS,
 
639
  step=1,
640
+ value=DEFAULT_MAX_SPEAKERS,
641
+ label="Maximum Number of Speakers"
642
  )
643
 
644
+ update_settings_btn = gr.Button("Update Settings")
 
 
 
 
 
 
 
645
 
646
  # Speaker color legend
647
+ gr.Markdown("## 🎨 Speaker Colors")
648
+ color_info = []
649
+ for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)):
650
+ color_info.append(f'<span style="color:{color};">■</span> Speaker {i+1} ({name})')
 
 
651
 
652
+ gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS]))
653
+
654
+ # Auto-refresh conversation and status
655
+ def refresh_display():
656
+ return get_conversation(), get_status()
657
 
658
  # Event handlers
659
+ def on_initialize():
660
+ result = initialize_system()
661
+ if "successfully" in result:
662
+ return (
663
+ result,
664
+ gr.update(interactive=True), # start_btn
665
+ gr.update(interactive=True), # clear_btn
666
+ get_conversation(),
667
+ get_status()
668
+ )
669
+ else:
670
+ return (
671
+ result,
672
+ gr.update(interactive=False), # start_btn
673
+ gr.update(interactive=False), # clear_btn
674
+ get_conversation(),
675
+ get_status()
676
+ )
677
+
678
+ def on_start():
679
+ result = start_recording()
680
+ return (
681
+ result,
682
+ gr.update(interactive=False), # start_btn
683
+ gr.update(interactive=True), # stop_btn
684
+ )
685
+
686
+ def on_stop():
687
+ result = stop_recording()
688
+ return (
689
+ result,
690
+ gr.update(interactive=True), # start_btn
691
+ gr.update(interactive=False), # stop_btn
692
+ )
693
+
694
+ # Connect event handlers
695
+ init_btn.click(
696
+ on_initialize,
697
+ outputs=[status_output, start_btn, clear_btn, conversation_output, status_output]
698
+ )
699
+
700
+ start_btn.click(
701
+ on_start,
702
+ outputs=[status_output, start_btn, stop_btn]
703
+ )
704
+
705
+ stop_btn.click(
706
+ on_stop,
707
+ outputs=[status_output, start_btn, stop_btn]
708
  )
709
 
710
  clear_btn.click(
711
+ clear_conversation,
712
+ outputs=[status_output]
713
+ )
714
+
715
+ update_settings_btn.click(
716
+ update_settings,
717
+ inputs=[threshold_slider, max_speakers_slider],
718
+ outputs=[status_output]
719
  )
720
 
721
+ # Auto-refresh every 2 seconds when recording
722
+ refresh_timer = gr.Timer(2.0)
723
+ refresh_timer.tick(
724
+ refresh_display,
725
+ outputs=[conversation_output, status_output]
726
  )
727
 
728
+ return app
729
 
730
 
731
  if __name__ == "__main__":
732
+ app = create_interface()
733
+ app.launch(
 
734
  server_name="0.0.0.0",
735
  server_port=7860,
736
  share=True