Saiyaswanth007 commited on
Commit
7609dee
Β·
1 Parent(s): 1a722f5

Removed soundcard as it doesn't support hugging space

Browse files
Files changed (1) hide show
  1. realtime_diarize.py +507 -449
realtime_diarize.py CHANGED
@@ -1,523 +1,581 @@
1
- import os
2
- import sys
3
- import time
4
- import queue
5
- import threading
6
- import signal
7
- import atexit
8
- from contextlib import contextmanager
9
- import warnings
10
- warnings.filterwarnings("ignore", category=UserWarning)
11
-
12
  import numpy as np
 
13
  import torch
 
 
 
 
14
  import torchaudio
15
  from scipy.spatial.distance import cosine
 
 
 
16
 
17
- try:
18
- import soundcard as sc
19
- except ImportError:
20
- print("soundcard not found. Install with: pip install soundcard")
21
- sys.exit(1)
 
 
 
 
 
 
22
 
23
- try:
24
- from RealtimeSTT import AudioToTextRecorder
25
- except ImportError:
26
- print("RealtimeSTT not found. Install with: pip install RealtimeSTT")
27
- sys.exit(1)
 
28
 
29
- # Configuration
30
- class Config:
31
- # Audio settings
32
- SAMPLE_RATE = 16000
33
- BUFFER_SIZE = 1024
34
- CHANNELS = 1
35
-
36
- # Transcription settings
37
- FINAL_MODEL = "distil-large-v3"
38
- REALTIME_MODEL = "distil-small.en"
39
- LANGUAGE = "en"
40
- BEAM_SIZE = 5
41
- REALTIME_BEAM_SIZE = 3
42
-
43
- # Voice activity detection
44
- SILENCE_THRESHOLD = 0.4
45
- MIN_RECORDING_LENGTH = 0.5
46
- PRE_RECORDING_BUFFER = 0.2
47
- SILERO_SENSITIVITY = 0.4
48
- WEBRTC_SENSITIVITY = 3
49
-
50
- # Speaker detection
51
- CHANGE_THRESHOLD = 0.65
52
- MAX_SPEAKERS = 4
53
- MIN_SEGMENT_DURATION = 1.0
54
- EMBEDDING_HISTORY_SIZE = 3
55
- SPEAKER_MEMORY_SIZE = 20
56
 
57
- # Console colors for speakers
58
- COLORS = [
59
- '\033[93m', # Yellow
60
- '\033[91m', # Red
61
- '\033[92m', # Green
62
- '\033[96m', # Cyan
63
- '\033[95m', # Magenta
64
- '\033[94m', # Blue
65
- '\033[97m', # White
66
- '\033[33m', # Orange
 
 
67
  ]
68
- RESET = '\033[0m'
69
- LIVE_COLOR = '\033[90m'
70
 
71
- class SpeakerEncoder:
72
- """Simplified speaker encoder using torchaudio transforms"""
73
-
 
 
 
 
 
 
 
 
74
  def __init__(self, device="cpu"):
75
  self.device = device
76
- self.embedding_dim = 128
 
77
  self.model_loaded = False
78
- self._setup_model()
 
79
 
80
- def _setup_model(self):
81
- """Setup a simple MFCC-based feature extractor"""
82
  try:
83
- self.mfcc_transform = torchaudio.transforms.MFCC(
84
- sample_rate=Config.SAMPLE_RATE,
85
- n_mfcc=13,
86
- melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23}
87
- ).to(self.device)
 
 
 
88
  self.model_loaded = True
89
- print("Simple MFCC-based encoder initialized")
 
90
  except Exception as e:
91
- print(f"Error setting up encoder: {e}")
92
- self.model_loaded = False
93
 
94
- def extract_embedding(self, audio):
95
  """Extract speaker embedding from audio"""
96
  if not self.model_loaded:
97
- return np.zeros(self.embedding_dim)
98
 
99
  try:
100
- # Ensure audio is float32 and normalized
101
  if isinstance(audio, np.ndarray):
102
- audio = torch.from_numpy(audio).float()
103
-
104
- # Normalize audio
105
- if audio.abs().max() > 0:
106
- audio = audio / audio.abs().max()
107
 
108
- # Add batch dimension if needed
109
- if audio.dim() == 1:
110
- audio = audio.unsqueeze(0)
111
 
112
- # Extract MFCC features
113
  with torch.no_grad():
114
- mfcc = self.mfcc_transform(audio)
115
- # Simple statistics-based embedding
116
- embedding = torch.cat([
117
- mfcc.mean(dim=2).flatten(),
118
- mfcc.std(dim=2).flatten(),
119
- mfcc.max(dim=2)[0].flatten(),
120
- mfcc.min(dim=2)[0].flatten()
121
- ])
122
 
123
- # Pad or truncate to fixed size
124
- if embedding.size(0) > self.embedding_dim:
125
- embedding = embedding[:self.embedding_dim]
126
- elif embedding.size(0) < self.embedding_dim:
127
- padding = torch.zeros(self.embedding_dim - embedding.size(0))
128
- embedding = torch.cat([embedding, padding])
129
-
130
- return embedding.cpu().numpy()
131
-
132
  except Exception as e:
133
  print(f"Error extracting embedding: {e}")
134
  return np.zeros(self.embedding_dim)
135
 
136
- class SpeakerDetector:
137
- """Speaker change detection using embeddings"""
 
 
 
138
 
139
- def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS):
140
- self.threshold = threshold
141
- self.max_speakers = max_speakers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  self.current_speaker = 0
143
- self.speaker_embeddings = [[] for _ in range(max_speakers)]
144
- self.speaker_centroids = [None] * max_speakers
145
  self.last_change_time = time.time()
146
- self.active_speakers = {0}
147
-
148
- def detect_speaker(self, embedding):
149
- """Detect current speaker from embedding"""
150
- current_time = time.time()
151
-
152
- # Initialize first speaker
153
- if not self.speaker_embeddings[0]:
154
- self.speaker_embeddings[0].append(embedding)
155
- self.speaker_centroids[0] = embedding.copy()
156
- return 0, 1.0
157
-
158
- # Calculate similarity with current speaker
159
- current_centroid = self.speaker_centroids[self.current_speaker]
160
- if current_centroid is not None:
161
- similarity = 1.0 - cosine(embedding, current_centroid)
162
- else:
163
- similarity = 0.0
164
-
165
- # Check if enough time has passed for a speaker change
166
- if current_time - self.last_change_time < Config.MIN_SEGMENT_DURATION:
167
- self._update_speaker_model(self.current_speaker, embedding)
168
- return self.current_speaker, similarity
169
-
170
- # Check for speaker change
171
- if similarity < self.threshold:
172
- # Find best matching existing speaker
173
- best_speaker = self.current_speaker
174
- best_similarity = similarity
175
 
176
- for speaker_id in self.active_speakers:
177
- if speaker_id == self.current_speaker:
178
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- centroid = self.speaker_centroids[speaker_id]
181
- if centroid is not None:
182
- sim = 1.0 - cosine(embedding, centroid)
183
- if sim > best_similarity and sim > self.threshold:
184
- best_similarity = sim
185
- best_speaker = speaker_id
186
-
187
- # Create new speaker if no good match and slots available
188
- if (best_speaker == self.current_speaker and
189
- len(self.active_speakers) < self.max_speakers):
190
- for new_id in range(self.max_speakers):
191
- if new_id not in self.active_speakers:
192
- best_speaker = new_id
193
- best_similarity = 0.0
194
- self.active_speakers.add(new_id)
195
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Update current speaker if changed
198
- if best_speaker != self.current_speaker:
199
- self.current_speaker = best_speaker
200
- self.last_change_time = current_time
201
- similarity = best_similarity
202
-
203
- # Update speaker model
204
- self._update_speaker_model(self.current_speaker, embedding)
205
- return self.current_speaker, similarity
206
-
207
- def _update_speaker_model(self, speaker_id, embedding):
208
- """Update speaker model with new embedding"""
209
- self.speaker_embeddings[speaker_id].append(embedding)
210
-
211
- # Keep only recent embeddings
212
- if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE:
213
- self.speaker_embeddings[speaker_id] = \
214
- self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:]
215
-
216
- # Update centroid
217
- if self.speaker_embeddings[speaker_id]:
218
- self.speaker_centroids[speaker_id] = np.mean(
219
- self.speaker_embeddings[speaker_id], axis=0
220
  )
221
-
222
- class AudioRecorder:
223
- """Handles audio recording from system audio"""
224
-
225
- def __init__(self, audio_queue):
226
- self.audio_queue = audio_queue
227
- self.running = False
228
- self.thread = None
229
-
230
- def start(self):
231
- """Start recording"""
232
- self.running = True
233
- self.thread = threading.Thread(target=self._record_loop, daemon=True)
234
- self.thread.start()
235
- print("Audio recording started")
236
 
237
- def stop(self):
238
- """Stop recording"""
239
- self.running = False
240
- if self.thread and self.thread.is_alive():
241
- self.thread.join(timeout=2)
242
 
243
- def _record_loop(self):
244
- """Main recording loop"""
245
- try:
246
- # Try to use system audio (loopback)
247
- try:
248
- device = sc.default_speaker()
249
- with device.recorder(
250
- samplerate=Config.SAMPLE_RATE,
251
- blocksize=Config.BUFFER_SIZE,
252
- channels=Config.CHANNELS
253
- ) as recorder:
254
- print(f"Recording from: {device.name}")
255
- while self.running:
256
- data = recorder.record(numframes=Config.BUFFER_SIZE)
257
- if data is not None and len(data) > 0:
258
- # Convert to mono if needed
259
- if data.ndim > 1:
260
- data = data[:, 0]
261
- self.audio_queue.put(data.flatten())
262
-
263
- except Exception as e:
264
- print(f"Loopback recording failed: {e}")
265
- print("Falling back to microphone...")
266
-
267
- # Fallback to microphone
268
- mic = sc.default_microphone()
269
- with mic.recorder(
270
- samplerate=Config.SAMPLE_RATE,
271
- blocksize=Config.BUFFER_SIZE,
272
- channels=Config.CHANNELS
273
- ) as recorder:
274
- print(f"Recording from microphone: {mic.name}")
275
- while self.running:
276
- data = recorder.record(numframes=Config.BUFFER_SIZE)
277
- if data is not None and len(data) > 0:
278
- if data.ndim > 1:
279
- data = data[:, 0]
280
- self.audio_queue.put(data.flatten())
281
-
282
- except Exception as e:
283
- print(f"Recording error: {e}")
284
- self.running = False
285
 
286
- class TranscriptionProcessor:
287
- """Handles transcription and speaker detection"""
288
-
289
  def __init__(self):
290
- self.encoder = SpeakerEncoder()
291
- self.detector = SpeakerDetector()
292
- self.recorder = None
293
- self.audio_queue = queue.Queue(maxsize=100)
294
- self.audio_recorder = AudioRecorder(self.audio_queue)
295
- self.processing_thread = None
296
- self.running = False
297
-
298
- def setup(self):
299
- """Setup transcription recorder"""
 
300
  try:
301
- self.recorder = AudioToTextRecorder(
302
- spinner=False,
303
- use_microphone=False,
304
- model=Config.FINAL_MODEL,
305
- language=Config.LANGUAGE,
306
- silero_sensitivity=Config.SILERO_SENSITIVITY,
307
- webrtc_sensitivity=Config.WEBRTC_SENSITIVITY,
308
- post_speech_silence_duration=Config.SILENCE_THRESHOLD,
309
- min_length_of_recording=Config.MIN_RECORDING_LENGTH,
310
- pre_recording_buffer_duration=Config.PRE_RECORDING_BUFFER,
311
- enable_realtime_transcription=True,
312
- realtime_model_type=Config.REALTIME_MODEL,
313
- beam_size=Config.BEAM_SIZE,
314
- beam_size_realtime=Config.REALTIME_BEAM_SIZE,
315
- on_realtime_transcription_update=self._on_live_text,
316
- )
317
- print("Transcription recorder setup complete")
318
- return True
 
319
  except Exception as e:
320
- print(f"Transcription setup failed: {e}")
321
  return False
322
 
323
- def start(self):
324
- """Start processing"""
325
- if not self.setup():
326
- return False
327
-
328
- self.running = True
329
 
330
- # Start audio recording
331
- self.audio_recorder.start()
332
 
333
- # Start audio processing thread
334
- self.processing_thread = threading.Thread(target=self._process_audio, daemon=True)
335
- self.processing_thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- # Start transcription
338
- self._start_transcription()
 
 
 
 
339
 
340
- return True
341
 
342
- def stop(self):
343
- """Stop processing"""
344
- print("\nStopping transcription...")
345
- self.running = False
 
 
 
 
346
 
347
- if self.audio_recorder:
348
- self.audio_recorder.stop()
 
 
 
 
 
 
349
 
350
- if self.processing_thread and self.processing_thread.is_alive():
351
- self.processing_thread.join(timeout=2)
352
 
353
- if self.recorder:
354
- try:
355
- self.recorder.shutdown()
356
- except:
357
- pass
358
-
359
- def _process_audio(self):
360
- """Process audio chunks for speaker detection"""
361
- audio_buffer = []
362
-
363
- while self.running:
364
- try:
365
- # Get audio chunk
366
- chunk = self.audio_queue.get(timeout=0.1)
367
- audio_buffer.extend(chunk)
368
-
369
- # Process when we have enough audio (about 1 second)
370
- if len(audio_buffer) >= Config.SAMPLE_RATE:
371
- audio_array = np.array(audio_buffer[:Config.SAMPLE_RATE])
372
- audio_buffer = audio_buffer[Config.SAMPLE_RATE//2:] # 50% overlap
373
-
374
- # Convert to int16 for recorder
375
- audio_int16 = (audio_array * 32767).astype(np.int16)
376
 
377
- # Feed to transcription recorder
378
- if self.recorder:
379
- self.recorder.feed_audio(audio_int16.tobytes())
380
-
381
- except queue.Empty:
382
- continue
383
- except Exception as e:
384
- if self.running:
385
- print(f"Audio processing error: {e}")
386
-
387
- def _start_transcription(self):
388
- """Start transcription loop"""
389
- def transcription_loop():
390
- while self.running:
391
- try:
392
- text = self.recorder.text()
393
- if text and text.strip():
394
- self._process_final_text(text)
395
- except Exception as e:
396
- if self.running:
397
- print(f"Transcription error: {e}")
398
- break
399
-
400
- transcription_thread = threading.Thread(target=transcription_loop, daemon=True)
401
- transcription_thread.start()
402
-
403
- def _on_live_text(self, text):
404
- """Handle live transcription updates"""
405
- if text and text.strip():
406
- print(f"\r{LIVE_COLOR}[Live] {text}{RESET}", end="", flush=True)
407
 
408
- def _process_final_text(self, text):
409
- """Process final transcription with speaker detection"""
410
- # Clear live text line
411
- print("\r" + " " * 80 + "\r", end="")
412
 
413
  try:
414
- # Get recent audio for speaker detection
415
- recent_audio = []
416
- temp_queue = []
417
 
418
- # Collect recent audio chunks
419
- for _ in range(min(10, self.audio_queue.qsize())):
420
- try:
421
- chunk = self.audio_queue.get_nowait()
422
- recent_audio.extend(chunk)
423
- temp_queue.append(chunk)
424
- except queue.Empty:
425
- break
 
426
 
427
- # Put chunks back
428
- for chunk in reversed(temp_queue):
429
- try:
430
- self.audio_queue.put_nowait(chunk)
431
- except queue.Full:
432
- break
433
 
434
- # Extract speaker embedding if we have audio
435
- if recent_audio:
436
- audio_tensor = torch.FloatTensor(recent_audio[-Config.SAMPLE_RATE:])
437
- embedding = self.encoder.extract_embedding(audio_tensor)
438
- speaker_id, similarity = self.detector.detect_speaker(embedding)
439
- else:
440
- speaker_id, similarity = 0, 1.0
441
-
442
- # Display with speaker color
443
- color = COLORS[speaker_id % len(COLORS)]
444
- print(f"{color}Speaker {speaker_id + 1}: {text}{RESET}")
445
 
446
  except Exception as e:
447
- print(f"Error processing text: {e}")
448
- print(f"Text: {text}")
449
 
450
- class RealTimeSpeakerDetection:
451
- """Main application class"""
452
-
453
- def __init__(self):
454
- self.processor = None
455
- self.running = False
456
-
457
- # Setup signal handlers for clean shutdown
458
- signal.signal(signal.SIGINT, self._signal_handler)
459
- signal.signal(signal.SIGTERM, self._signal_handler)
460
- atexit.register(self.cleanup)
461
-
462
- def _signal_handler(self, signum, frame):
463
- """Handle shutdown signals"""
464
- print(f"\nReceived signal {signum}, shutting down...")
465
- self.stop()
466
-
467
- def start(self):
468
- """Start the application"""
469
- print("=== Real-time Speaker Detection and Transcription ===")
470
- print("Initializing...")
471
-
472
- self.processor = TranscriptionProcessor()
473
-
474
- if not self.processor.start():
475
- print("Failed to start. Check your audio setup and dependencies.")
476
- return False
 
 
 
 
 
 
 
477
 
478
- self.running = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
- print("=" * 60)
481
- print("System ready! Listening for audio...")
482
- print("Different speakers will be shown in different colors.")
483
- print("Press Ctrl+C to stop.")
484
- print("=" * 60)
485
 
486
- # Keep main thread alive
487
- try:
488
- while self.running:
489
- time.sleep(1)
490
- except KeyboardInterrupt:
491
- pass
492
 
493
- return True
494
-
495
- def stop(self):
496
- """Stop the application"""
497
- if not self.running:
498
- return
499
-
500
- self.running = False
501
 
502
- if self.processor:
503
- self.processor.stop()
 
 
504
 
505
- print("System stopped.")
 
 
 
 
506
 
507
- def cleanup(self):
508
- """Cleanup resources"""
509
- self.stop()
510
 
511
- def main():
512
- """Main entry point"""
513
- app = RealTimeSpeakerDetection()
514
-
515
- try:
516
- app.start()
517
- except Exception as e:
518
- print(f"Application error: {e}")
519
- finally:
520
- app.cleanup()
521
 
522
  if __name__ == "__main__":
523
- main()
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
+ import queue
4
  import torch
5
+ import time
6
+ import threading
7
+ import os
8
+ import urllib.request
9
  import torchaudio
10
  from scipy.spatial.distance import cosine
11
+ import json
12
+ import io
13
+ import wave
14
 
15
+ # Simplified configuration parameters
16
+ SILENCE_THRESHS = [0, 0.4]
17
+ FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
18
+ FINAL_BEAM_SIZE = 5
19
+ REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
20
+ REALTIME_BEAM_SIZE = 5
21
+ TRANSCRIPTION_LANGUAGE = "en"
22
+ SILERO_SENSITIVITY = 0.4
23
+ WEBRTC_SENSITIVITY = 3
24
+ MIN_LENGTH_OF_RECORDING = 0.7
25
+ PRE_RECORDING_BUFFER_DURATION = 0.35
26
 
27
+ # Speaker change detection parameters
28
+ DEFAULT_CHANGE_THRESHOLD = 0.7
29
+ EMBEDDING_HISTORY_SIZE = 5
30
+ MIN_SEGMENT_DURATION = 1.0
31
+ DEFAULT_MAX_SPEAKERS = 4
32
+ ABSOLUTE_MAX_SPEAKERS = 10
33
 
34
+ # Global variables
35
+ FAST_SENTENCE_END = True
36
+ SAMPLE_RATE = 16000
37
+ BUFFER_SIZE = 512
38
+ CHANNELS = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Speaker colors
41
+ SPEAKER_COLORS = [
42
+ "#FFFF00", # Yellow
43
+ "#FF0000", # Red
44
+ "#00FF00", # Green
45
+ "#00FFFF", # Cyan
46
+ "#FF00FF", # Magenta
47
+ "#0000FF", # Blue
48
+ "#FF8000", # Orange
49
+ "#00FF80", # Spring Green
50
+ "#8000FF", # Purple
51
+ "#FFFFFF", # White
52
  ]
 
 
53
 
54
+ SPEAKER_COLOR_NAMES = [
55
+ "Yellow", "Red", "Green", "Cyan", "Magenta",
56
+ "Blue", "Orange", "Spring Green", "Purple", "White"
57
+ ]
58
+
59
+
60
+
61
+
62
+
63
+ class SpeechBrainEncoder:
64
+ """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
65
  def __init__(self, device="cpu"):
66
  self.device = device
67
+ self.model = None
68
+ self.embedding_dim = 192
69
  self.model_loaded = False
70
+ self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
71
+ os.makedirs(self.cache_dir, exist_ok=True)
72
 
73
+ def load_model(self):
74
+ """Load the ECAPA-TDNN model"""
75
  try:
76
+ from speechbrain.pretrained import EncoderClassifier
77
+
78
+ self.model = EncoderClassifier.from_hparams(
79
+ source="speechbrain/spkrec-ecapa-voxceleb",
80
+ savedir=self.cache_dir,
81
+ run_opts={"device": self.device}
82
+ )
83
+
84
  self.model_loaded = True
85
+ print("ECAPA-TDNN model loaded successfully!")
86
+ return True
87
  except Exception as e:
88
+ print(f"SpeechBrain not available: {e}")
89
+ return False
90
 
91
+ def embed_utterance(self, audio, sr=16000):
92
  """Extract speaker embedding from audio"""
93
  if not self.model_loaded:
94
+ raise ValueError("Model not loaded. Call load_model() first.")
95
 
96
  try:
 
97
  if isinstance(audio, np.ndarray):
98
+ waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
99
+ else:
100
+ waveform = audio.unsqueeze(0)
 
 
101
 
102
+ if sr != 16000:
103
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
 
104
 
 
105
  with torch.no_grad():
106
+ embedding = self.model.encode_batch(waveform)
 
 
 
 
 
 
 
107
 
108
+ return embedding.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
109
  except Exception as e:
110
  print(f"Error extracting embedding: {e}")
111
  return np.zeros(self.embedding_dim)
112
 
113
+
114
+ class AudioProcessor:
115
+ """Processes audio data to extract speaker embeddings"""
116
+ def __init__(self, encoder):
117
+ self.encoder = encoder
118
 
119
+ def extract_embedding(self, audio_data, sample_rate=16000):
120
+ try:
121
+ # Ensure audio is float32 and normalized
122
+ if audio_data.dtype == np.int16:
123
+ float_audio = audio_data.astype(np.float32) / 32768.0
124
+ else:
125
+ float_audio = audio_data.astype(np.float32)
126
+
127
+ # Normalize if needed
128
+ if np.abs(float_audio).max() > 1.0:
129
+ float_audio = float_audio / np.abs(float_audio).max()
130
+
131
+ embedding = self.encoder.embed_utterance(float_audio, sample_rate)
132
+ return embedding
133
+
134
+ except Exception as e:
135
+ print(f"Embedding extraction error: {e}")
136
+ return np.zeros(self.encoder.embedding_dim)
137
+
138
+
139
+ class SpeakerChangeDetector:
140
+ """Speaker change detector that supports a configurable number of speakers"""
141
+ def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
142
+ self.embedding_dim = embedding_dim
143
+ self.change_threshold = change_threshold
144
+ self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
145
  self.current_speaker = 0
146
+ self.previous_embeddings = []
 
147
  self.last_change_time = time.time()
148
+ self.mean_embeddings = [None] * self.max_speakers
149
+ self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
150
+ self.last_similarity = 0.0
151
+ self.active_speakers = set([0])
152
+
153
+ def set_max_speakers(self, max_speakers):
154
+ """Update the maximum number of speakers"""
155
+ new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
156
+
157
+ if new_max < self.max_speakers:
158
+ for speaker_id in list(self.active_speakers):
159
+ if speaker_id >= new_max:
160
+ self.active_speakers.discard(speaker_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ if self.current_speaker >= new_max:
163
+ self.current_speaker = 0
164
+
165
+ if new_max > self.max_speakers:
166
+ self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
167
+ self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
168
+ else:
169
+ self.mean_embeddings = self.mean_embeddings[:new_max]
170
+ self.speaker_embeddings = self.speaker_embeddings[:new_max]
171
+
172
+ self.max_speakers = new_max
173
+
174
+ def set_change_threshold(self, threshold):
175
+ """Update the threshold for detecting speaker changes"""
176
+ self.change_threshold = max(0.1, min(threshold, 0.99))
177
+
178
+ def add_embedding(self, embedding, timestamp=None):
179
+ """Add a new embedding and check if there's a speaker change"""
180
+ current_time = timestamp or time.time()
181
+
182
+ if not self.previous_embeddings:
183
+ self.previous_embeddings.append(embedding)
184
+ self.speaker_embeddings[self.current_speaker].append(embedding)
185
+ if self.mean_embeddings[self.current_speaker] is None:
186
+ self.mean_embeddings[self.current_speaker] = embedding.copy()
187
+ return self.current_speaker, 1.0
188
+
189
+ current_mean = self.mean_embeddings[self.current_speaker]
190
+ if current_mean is not None:
191
+ similarity = 1.0 - cosine(embedding, current_mean)
192
+ else:
193
+ similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
194
+
195
+ self.last_similarity = similarity
196
+
197
+ time_since_last_change = current_time - self.last_change_time
198
+ is_speaker_change = False
199
+
200
+ if time_since_last_change >= MIN_SEGMENT_DURATION:
201
+ if similarity < self.change_threshold:
202
+ best_speaker = self.current_speaker
203
+ best_similarity = similarity
204
+
205
+ for speaker_id in range(self.max_speakers):
206
+ if speaker_id == self.current_speaker:
207
+ continue
208
+
209
+ speaker_mean = self.mean_embeddings[speaker_id]
210
 
211
+ if speaker_mean is not None:
212
+ speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
213
+
214
+ if speaker_similarity > best_similarity:
215
+ best_similarity = speaker_similarity
216
+ best_speaker = speaker_id
217
+
218
+ if best_speaker != self.current_speaker:
219
+ is_speaker_change = True
220
+ self.current_speaker = best_speaker
221
+ elif len(self.active_speakers) < self.max_speakers:
222
+ for new_id in range(self.max_speakers):
223
+ if new_id not in self.active_speakers:
224
+ is_speaker_change = True
225
+ self.current_speaker = new_id
226
+ self.active_speakers.add(new_id)
227
+ break
228
+
229
+ if is_speaker_change:
230
+ self.last_change_time = current_time
231
+
232
+ self.previous_embeddings.append(embedding)
233
+ if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
234
+ self.previous_embeddings.pop(0)
235
+
236
+ self.speaker_embeddings[self.current_speaker].append(embedding)
237
+ self.active_speakers.add(self.current_speaker)
238
+
239
+ if len(self.speaker_embeddings[self.current_speaker]) > 30:
240
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
241
 
242
+ if self.speaker_embeddings[self.current_speaker]:
243
+ self.mean_embeddings[self.current_speaker] = np.mean(
244
+ self.speaker_embeddings[self.current_speaker], axis=0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  )
246
+
247
+ return self.current_speaker, similarity
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ def get_color_for_speaker(self, speaker_id):
250
+ """Return color for speaker ID"""
251
+ if 0 <= speaker_id < len(SPEAKER_COLORS):
252
+ return SPEAKER_COLORS[speaker_id]
253
+ return "#FFFFFF"
254
 
255
+ def get_status_info(self):
256
+ """Return status information about the speaker change detector"""
257
+ speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
258
+
259
+ return {
260
+ "current_speaker": self.current_speaker,
261
+ "speaker_counts": speaker_counts,
262
+ "active_speakers": len(self.active_speakers),
263
+ "max_speakers": self.max_speakers,
264
+ "last_similarity": self.last_similarity,
265
+ "threshold": self.change_threshold
266
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+
269
+ class GradioSpeakerDiarization:
 
270
  def __init__(self):
271
+ self.encoder = None
272
+ self.audio_processor = None
273
+ self.speaker_detector = None
274
+ self.full_sentences = []
275
+ self.sentence_speakers = []
276
+ self.is_initialized = False
277
+ self.change_threshold = DEFAULT_CHANGE_THRESHOLD
278
+ self.max_speakers = DEFAULT_MAX_SPEAKERS
279
+
280
+ def initialize_models(self):
281
+ """Initialize the speaker encoder model"""
282
  try:
283
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
284
+ print(f"Using device: {device_str}")
285
+
286
+ # Load SpeechBrain encoder
287
+ self.encoder = SpeechBrainEncoder(device=device_str)
288
+ success = self.encoder.load_model()
289
+
290
+ if success:
291
+ self.audio_processor = AudioProcessor(self.encoder)
292
+ self.speaker_detector = SpeakerChangeDetector(
293
+ embedding_dim=self.encoder.embedding_dim,
294
+ change_threshold=self.change_threshold,
295
+ max_speakers=self.max_speakers
296
+ )
297
+ self.is_initialized = True
298
+ return True
299
+ else:
300
+ return False
301
+
302
  except Exception as e:
303
+ print(f"Model initialization error: {e}")
304
  return False
305
 
306
+ def transcribe_audio(self, audio_input):
307
+ """Process audio input and perform transcription with speaker diarization"""
308
+ if not self.is_initialized:
309
+ return "❌ Please initialize the system first!", self.get_formatted_conversation(), self.get_status_info()
 
 
310
 
311
+ if audio_input is None:
312
+ return "No audio received", self.get_formatted_conversation(), self.get_status_info()
313
 
314
+ try:
315
+ # Handle different audio input formats
316
+ if isinstance(audio_input, tuple):
317
+ sample_rate, audio_data = audio_input
318
+ else:
319
+ # Assume it's a file path
320
+ import librosa
321
+ audio_data, sample_rate = librosa.load(audio_input, sr=16000)
322
+
323
+ # Ensure audio is in the right format
324
+ if len(audio_data.shape) > 1:
325
+ audio_data = audio_data.mean(axis=1) # Convert to mono
326
+
327
+ # Perform simple transcription (placeholder - you'd want to integrate with Whisper or similar)
328
+ # For now, we'll just do speaker diarization
329
+ transcription = f"Audio segment {len(self.full_sentences) + 1} (duration: {len(audio_data)/sample_rate:.1f}s)"
330
+
331
+ # Extract speaker embedding
332
+ speaker_embedding = self.audio_processor.extract_embedding(audio_data, sample_rate)
333
+
334
+ # Store sentence and embedding
335
+ self.full_sentences.append((transcription, speaker_embedding))
336
+
337
+ # Detect speaker changes
338
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
339
+ self.sentence_speakers.append(speaker_id)
340
+
341
+ status_msg = f"βœ… Processed audio segment. Detected as Speaker {speaker_id + 1} (similarity: {similarity:.3f})"
342
+
343
+ return status_msg, self.get_formatted_conversation(), self.get_status_info()
344
+
345
+ except Exception as e:
346
+ error_msg = f"❌ Error processing audio: {str(e)}"
347
+ return error_msg, self.get_formatted_conversation(), self.get_status_info()
348
+
349
+ def clear_conversation(self):
350
+ """Clear all conversation data"""
351
+ self.full_sentences = []
352
+ self.sentence_speakers = []
353
 
354
+ if self.speaker_detector:
355
+ self.speaker_detector = SpeakerChangeDetector(
356
+ embedding_dim=self.encoder.embedding_dim,
357
+ change_threshold=self.change_threshold,
358
+ max_speakers=self.max_speakers
359
+ )
360
 
361
+ return "Conversation cleared!", self.get_formatted_conversation(), self.get_status_info()
362
 
363
+ def update_settings(self, threshold, max_speakers):
364
+ """Update speaker detection settings"""
365
+ self.change_threshold = threshold
366
+ self.max_speakers = max_speakers
367
+
368
+ if self.speaker_detector:
369
+ self.speaker_detector.set_change_threshold(threshold)
370
+ self.speaker_detector.set_max_speakers(max_speakers)
371
 
372
+ status_msg = f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
373
+ return status_msg, self.get_formatted_conversation(), self.get_status_info()
374
+
375
+ def get_formatted_conversation(self):
376
+ """Get the formatted conversation with speaker colors"""
377
+ try:
378
+ if not self.full_sentences:
379
+ return "No audio processed yet. Upload an audio file or record using the microphone."
380
 
381
+ sentences_with_style = []
 
382
 
383
+ for i, sentence in enumerate(self.full_sentences):
384
+ sentence_text, _ = sentence
385
+ if i >= len(self.sentence_speakers):
386
+ color = "#FFFFFF"
387
+ speaker_name = "Unknown"
388
+ else:
389
+ speaker_id = self.sentence_speakers[i]
390
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
391
+ speaker_name = f"Speaker {speaker_id + 1}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ sentences_with_style.append(
394
+ f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>')
395
+
396
+ return "<br><br>".join(sentences_with_style)
397
+
398
+ except Exception as e:
399
+ return f"Error formatting conversation: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
+ def get_status_info(self):
402
+ """Get current status information"""
403
+ if not self.speaker_detector:
404
+ return "Speaker detector not initialized"
405
 
406
  try:
407
+ status = self.speaker_detector.get_status_info()
 
 
408
 
409
+ status_lines = [
410
+ f"**Current Speaker:** {status['current_speaker'] + 1}",
411
+ f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
412
+ f"**Last Similarity:** {status['last_similarity']:.3f}",
413
+ f"**Change Threshold:** {status['threshold']:.2f}",
414
+ f"**Total Segments:** {len(self.full_sentences)}",
415
+ "",
416
+ "**Speaker Segment Counts:**"
417
+ ]
418
 
419
+ for i in range(status['max_speakers']):
420
+ color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
421
+ status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}")
 
 
 
422
 
423
+ return "\n".join(status_lines)
 
 
 
 
 
 
 
 
 
 
424
 
425
  except Exception as e:
426
+ return f"Error getting status: {e}"
 
427
 
428
+
429
+ # Global instance
430
+ diarization_system = GradioSpeakerDiarization()
431
+
432
+
433
+ def initialize_system():
434
+ """Initialize the diarization system"""
435
+ success = diarization_system.initialize_models()
436
+ if success:
437
+ return "βœ… System initialized successfully! Models loaded.", "", ""
438
+ else:
439
+ return "❌ Failed to initialize system. Please check the logs.", "", ""
440
+
441
+
442
+ def process_audio(audio):
443
+ """Process uploaded or recorded audio"""
444
+ return diarization_system.transcribe_audio(audio)
445
+
446
+
447
+ def clear_conversation():
448
+ """Clear the conversation"""
449
+ return diarization_system.clear_conversation()
450
+
451
+
452
+ def update_settings(threshold, max_speakers):
453
+ """Update system settings"""
454
+ return diarization_system.update_settings(threshold, max_speakers)
455
+
456
+
457
+ # Create Gradio interface
458
+ def create_interface():
459
+ with gr.Blocks(title="Speaker Diarization", theme=gr.themes.Soft()) as app:
460
+ gr.Markdown("# 🎀 Audio Speaker Diarization")
461
+ gr.Markdown("Upload audio files or record directly to identify different speakers using voice characteristics.")
462
 
463
+ with gr.Row():
464
+ with gr.Column(scale=2):
465
+ # Initialize button
466
+ with gr.Row():
467
+ init_btn = gr.Button("πŸ”§ Initialize System", variant="primary", size="lg")
468
+
469
+ # Audio input options
470
+ gr.Markdown("### πŸ“ Audio Input")
471
+ with gr.Tab("Upload Audio File"):
472
+ audio_file = gr.Audio(
473
+ label="Upload Audio File",
474
+ type="filepath",
475
+ sources=["upload"]
476
+ )
477
+ process_file_btn = gr.Button("Process Audio File", variant="secondary")
478
+
479
+ with gr.Tab("Record Audio"):
480
+ audio_mic = gr.Audio(
481
+ label="Record Audio",
482
+ type="numpy",
483
+ sources=["microphone"]
484
+ )
485
+ process_mic_btn = gr.Button("Process Recording", variant="secondary")
486
+
487
+ # Results display
488
+ status_output = gr.Textbox(
489
+ label="Status",
490
+ value="Click 'Initialize System' to start...",
491
+ lines=2,
492
+ interactive=False
493
+ )
494
+
495
+ conversation_output = gr.HTML(
496
+ value="<i>System not initialized...</i>",
497
+ label="Speaker Analysis Results"
498
+ )
499
+
500
+ # Control buttons
501
+ with gr.Row():
502
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Results", variant="stop")
503
+
504
+ with gr.Column(scale=1):
505
+ # Settings panel
506
+ gr.Markdown("## βš™οΈ Settings")
507
+
508
+ threshold_slider = gr.Slider(
509
+ minimum=0.1,
510
+ maximum=0.95,
511
+ step=0.05,
512
+ value=DEFAULT_CHANGE_THRESHOLD,
513
+ label="Speaker Change Sensitivity",
514
+ info="Lower = more sensitive to speaker changes"
515
+ )
516
+
517
+ max_speakers_slider = gr.Slider(
518
+ minimum=2,
519
+ maximum=ABSOLUTE_MAX_SPEAKERS,
520
+ step=1,
521
+ value=DEFAULT_MAX_SPEAKERS,
522
+ label="Maximum Number of Speakers"
523
+ )
524
+
525
+ update_settings_btn = gr.Button("Update Settings", variant="secondary")
526
+
527
+ # System status
528
+ system_status = gr.Textbox(
529
+ label="System Status",
530
+ value="System not initialized",
531
+ lines=12,
532
+ interactive=False
533
+ )
534
+
535
+ # Speaker color legend
536
+ gr.Markdown("## 🎨 Speaker Colors")
537
+ color_info = []
538
+ for i, (color, name) in enumerate(zip(SPEAKER_COLORS[:DEFAULT_MAX_SPEAKERS], SPEAKER_COLOR_NAMES[:DEFAULT_MAX_SPEAKERS])):
539
+ color_info.append(f'<span style="color:{color};">●</span> Speaker {i+1} ({name})')
540
+
541
+ gr.HTML("<br>".join(color_info))
542
 
543
+ # Event handlers
544
+ init_btn.click(
545
+ initialize_system,
546
+ outputs=[status_output, conversation_output, system_status]
547
+ )
548
 
549
+ process_file_btn.click(
550
+ process_audio,
551
+ inputs=[audio_file],
552
+ outputs=[status_output, conversation_output, system_status]
553
+ )
 
554
 
555
+ process_mic_btn.click(
556
+ process_audio,
557
+ inputs=[audio_mic],
558
+ outputs=[status_output, conversation_output, system_status]
559
+ )
 
 
 
560
 
561
+ clear_btn.click(
562
+ clear_conversation,
563
+ outputs=[status_output, conversation_output, system_status]
564
+ )
565
 
566
+ update_settings_btn.click(
567
+ update_settings,
568
+ inputs=[threshold_slider, max_speakers_slider],
569
+ outputs=[status_output, conversation_output, system_status]
570
+ )
571
 
572
+ return app
 
 
573
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  if __name__ == "__main__":
576
+ app = create_interface()
577
+ app.launch(
578
+ server_name="0.0.0.0",
579
+ server_port=7860,
580
+ share=True
581
+ )