Saiyaswanth007 commited on
Commit
68a4a19
·
1 Parent(s): c473c8e

Intial Config

Browse files
Files changed (5) hide show
  1. Dockerfile +16 -0
  2. app.py +975 -0
  3. packages.txt +5 -0
  4. realtime_diarize.py +970 -0
  5. requirements.txt +188 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.10.0
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import queue
4
+ import torch
5
+ import time
6
+ import threading
7
+ import os
8
+ import urllib.request
9
+ import torchaudio
10
+ from scipy.spatial.distance import cosine
11
+ from scipy.signal import resample
12
+ from RealtimeSTT import AudioToTextRecorder
13
+ from fastapi import FastAPI, APIRouter
14
+ from fastrtc import Stream, AsyncStreamHandler
15
+ import json
16
+ import asyncio
17
+ import uvicorn
18
+ from queue import Queue
19
+ import logging
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Simplified configuration parameters
26
+ SILENCE_THRESHS = [0, 0.4]
27
+ FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
28
+ FINAL_BEAM_SIZE = 5
29
+ REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
30
+ REALTIME_BEAM_SIZE = 5
31
+ TRANSCRIPTION_LANGUAGE = "en"
32
+ SILERO_SENSITIVITY = 0.4
33
+ WEBRTC_SENSITIVITY = 3
34
+ MIN_LENGTH_OF_RECORDING = 0.7
35
+ PRE_RECORDING_BUFFER_DURATION = 0.35
36
+
37
+ # Speaker change detection parameters
38
+ DEFAULT_CHANGE_THRESHOLD = 0.65
39
+ EMBEDDING_HISTORY_SIZE = 5
40
+ MIN_SEGMENT_DURATION = 1.5
41
+ DEFAULT_MAX_SPEAKERS = 4
42
+ ABSOLUTE_MAX_SPEAKERS = 8
43
+
44
+ # Global variables
45
+ SAMPLE_RATE = 16000
46
+ BUFFER_SIZE = 1024
47
+ CHANNELS = 1
48
+
49
+ # Speaker colors - more distinguishable colors
50
+ SPEAKER_COLORS = [
51
+ "#FF6B6B", # Red
52
+ "#4ECDC4", # Teal
53
+ "#45B7D1", # Blue
54
+ "#96CEB4", # Green
55
+ "#FFEAA7", # Yellow
56
+ "#DDA0DD", # Plum
57
+ "#98D8C8", # Mint
58
+ "#F7DC6F", # Gold
59
+ ]
60
+
61
+ SPEAKER_COLOR_NAMES = [
62
+ "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold"
63
+ ]
64
+
65
+
66
+ class SpeechBrainEncoder:
67
+ """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
68
+ def __init__(self, device="cpu"):
69
+ self.device = device
70
+ self.model = None
71
+ self.embedding_dim = 192
72
+ self.model_loaded = False
73
+ self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
74
+ os.makedirs(self.cache_dir, exist_ok=True)
75
+
76
+ def load_model(self):
77
+ """Load the ECAPA-TDNN model"""
78
+ try:
79
+ # Import SpeechBrain
80
+ from speechbrain.pretrained import EncoderClassifier
81
+
82
+ # Get model path
83
+ model_path = self._download_model()
84
+
85
+ # Load the pre-trained model
86
+ self.model = EncoderClassifier.from_hparams(
87
+ source="speechbrain/spkrec-ecapa-voxceleb",
88
+ savedir=self.cache_dir,
89
+ run_opts={"device": self.device}
90
+ )
91
+
92
+ self.model_loaded = True
93
+ return True
94
+ except Exception as e:
95
+ print(f"Error loading ECAPA-TDNN model: {e}")
96
+ return False
97
+
98
+ def embed_utterance(self, audio, sr=16000):
99
+ """Extract speaker embedding from audio"""
100
+ if not self.model_loaded:
101
+ raise ValueError("Model not loaded. Call load_model() first.")
102
+
103
+ try:
104
+ if isinstance(audio, np.ndarray):
105
+ # Ensure audio is float32 and properly normalized
106
+ audio = audio.astype(np.float32)
107
+ if np.max(np.abs(audio)) > 1.0:
108
+ audio = audio / np.max(np.abs(audio))
109
+ waveform = torch.tensor(audio).unsqueeze(0)
110
+ else:
111
+ waveform = audio.unsqueeze(0)
112
+
113
+ # Resample if necessary
114
+ if sr != 16000:
115
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
116
+
117
+ with torch.no_grad():
118
+ embedding = self.model.encode_batch(waveform)
119
+
120
+ return embedding.squeeze().cpu().numpy()
121
+ except Exception as e:
122
+ logger.error(f"Error extracting embedding: {e}")
123
+ return np.zeros(self.embedding_dim)
124
+
125
+
126
+ class AudioProcessor:
127
+ """Processes audio data to extract speaker embeddings"""
128
+ def __init__(self, encoder):
129
+ self.encoder = encoder
130
+ self.audio_buffer = []
131
+ self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio
132
+
133
+ def add_audio_chunk(self, audio_chunk):
134
+ """Add audio chunk to buffer"""
135
+ self.audio_buffer.extend(audio_chunk)
136
+
137
+ # Keep buffer from getting too large
138
+ max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max
139
+ if len(self.audio_buffer) > max_buffer_size:
140
+ self.audio_buffer = self.audio_buffer[-max_buffer_size:]
141
+
142
+ def extract_embedding_from_buffer(self):
143
+ """Extract embedding from current audio buffer"""
144
+ if len(self.audio_buffer) < self.min_audio_length:
145
+ return None
146
+
147
+ try:
148
+ # Use the last portion of the buffer for embedding
149
+ audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32)
150
+
151
+ # Normalize audio
152
+ if np.max(np.abs(audio_segment)) > 0:
153
+ audio_segment = audio_segment / np.max(np.abs(audio_segment))
154
+ else:
155
+ return None
156
+
157
+ embedding = self.encoder.embed_utterance(audio_segment)
158
+ return embedding
159
+ except Exception as e:
160
+ logger.error(f"Embedding extraction error: {e}")
161
+ return None
162
+
163
+
164
+ class SpeakerChangeDetector:
165
+ """Improved speaker change detector"""
166
+ def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
167
+ self.embedding_dim = embedding_dim
168
+ self.change_threshold = change_threshold
169
+ self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
170
+ self.current_speaker = 0
171
+ self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
172
+ self.speaker_centroids = [None] * self.max_speakers
173
+ self.last_change_time = time.time()
174
+ self.last_similarity = 1.0
175
+ self.active_speakers = set([0])
176
+ self.segment_counter = 0
177
+
178
+ def set_max_speakers(self, max_speakers):
179
+ """Update the maximum number of speakers"""
180
+ new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
181
+
182
+ if new_max < self.max_speakers:
183
+ # Remove speakers beyond the new limit
184
+ for speaker_id in list(self.active_speakers):
185
+ if speaker_id >= new_max:
186
+ self.active_speakers.discard(speaker_id)
187
+
188
+ if self.current_speaker >= new_max:
189
+ self.current_speaker = 0
190
+
191
+ # Resize arrays
192
+ if new_max > self.max_speakers:
193
+ self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
194
+ self.speaker_centroids.extend([None] * (new_max - self.max_speakers))
195
+ else:
196
+ self.speaker_embeddings = self.speaker_embeddings[:new_max]
197
+ self.speaker_centroids = self.speaker_centroids[:new_max]
198
+
199
+ self.max_speakers = new_max
200
+
201
+ def set_change_threshold(self, threshold):
202
+ """Update the threshold for detecting speaker changes"""
203
+ self.change_threshold = max(0.1, min(threshold, 0.95))
204
+
205
+ def add_embedding(self, embedding, timestamp=None):
206
+ """Add a new embedding and detect speaker changes"""
207
+ current_time = timestamp or time.time()
208
+ self.segment_counter += 1
209
+
210
+ # Initialize first speaker
211
+ if not self.speaker_embeddings[0]:
212
+ self.speaker_embeddings[0].append(embedding)
213
+ self.speaker_centroids[0] = embedding.copy()
214
+ self.active_speakers.add(0)
215
+ return 0, 1.0
216
+
217
+ # Calculate similarity with current speaker
218
+ current_centroid = self.speaker_centroids[self.current_speaker]
219
+ if current_centroid is not None:
220
+ similarity = 1.0 - cosine(embedding, current_centroid)
221
+ else:
222
+ similarity = 0.5
223
+
224
+ self.last_similarity = similarity
225
+
226
+ # Check for speaker change
227
+ time_since_last_change = current_time - self.last_change_time
228
+ speaker_changed = False
229
+
230
+ if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold:
231
+ # Find best matching speaker
232
+ best_speaker = self.current_speaker
233
+ best_similarity = similarity
234
+
235
+ for speaker_id in self.active_speakers:
236
+ if speaker_id == self.current_speaker:
237
+ continue
238
+
239
+ centroid = self.speaker_centroids[speaker_id]
240
+ if centroid is not None:
241
+ speaker_similarity = 1.0 - cosine(embedding, centroid)
242
+ if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold:
243
+ best_similarity = speaker_similarity
244
+ best_speaker = speaker_id
245
+
246
+ # If no good match found and we can add a new speaker
247
+ if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers:
248
+ for new_id in range(self.max_speakers):
249
+ if new_id not in self.active_speakers:
250
+ best_speaker = new_id
251
+ self.active_speakers.add(new_id)
252
+ break
253
+
254
+ if best_speaker != self.current_speaker:
255
+ self.current_speaker = best_speaker
256
+ self.last_change_time = current_time
257
+ speaker_changed = True
258
+
259
+ # Update speaker embeddings and centroids
260
+ self.speaker_embeddings[self.current_speaker].append(embedding)
261
+
262
+ # Keep only recent embeddings (sliding window)
263
+ max_embeddings = 20
264
+ if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings:
265
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:]
266
+
267
+ # Update centroid
268
+ if self.speaker_embeddings[self.current_speaker]:
269
+ self.speaker_centroids[self.current_speaker] = np.mean(
270
+ self.speaker_embeddings[self.current_speaker], axis=0
271
+ )
272
+
273
+ return self.current_speaker, similarity
274
+
275
+ def get_color_for_speaker(self, speaker_id):
276
+ """Return color for speaker ID"""
277
+ if 0 <= speaker_id < len(SPEAKER_COLORS):
278
+ return SPEAKER_COLORS[speaker_id]
279
+ return "#FFFFFF"
280
+
281
+ def get_status_info(self):
282
+ """Return status information"""
283
+ speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
284
+
285
+ return {
286
+ "current_speaker": self.current_speaker,
287
+ "speaker_counts": speaker_counts,
288
+ "active_speakers": len(self.active_speakers),
289
+ "max_speakers": self.max_speakers,
290
+ "last_similarity": self.last_similarity,
291
+ "threshold": self.change_threshold,
292
+ "segment_counter": self.segment_counter
293
+ }
294
+
295
+
296
+ class RealtimeSpeakerDiarization:
297
+ def __init__(self):
298
+ self.encoder = None
299
+ self.audio_processor = None
300
+ self.speaker_detector = None
301
+ self.recorder = None
302
+ self.sentence_queue = queue.Queue()
303
+ self.full_sentences = []
304
+ self.sentence_speakers = []
305
+ self.pending_sentences = []
306
+ self.current_conversation = ""
307
+ self.is_running = False
308
+ self.change_threshold = DEFAULT_CHANGE_THRESHOLD
309
+ self.max_speakers = DEFAULT_MAX_SPEAKERS
310
+ self.last_transcription = ""
311
+ self.transcription_lock = threading.Lock()
312
+
313
+ def initialize_models(self):
314
+ """Initialize the speaker encoder model"""
315
+ try:
316
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
317
+ logger.info(f"Using device: {device_str}")
318
+
319
+ self.encoder = SpeechBrainEncoder(device=device_str)
320
+ success = self.encoder.load_model()
321
+
322
+ if success:
323
+ self.audio_processor = AudioProcessor(self.encoder)
324
+ self.speaker_detector = SpeakerChangeDetector(
325
+ embedding_dim=self.encoder.embedding_dim,
326
+ change_threshold=self.change_threshold,
327
+ max_speakers=self.max_speakers
328
+ )
329
+ logger.info("Models initialized successfully!")
330
+ return True
331
+ else:
332
+ logger.error("Failed to load models")
333
+ return False
334
+ except Exception as e:
335
+ logger.error(f"Model initialization error: {e}")
336
+ return False
337
+
338
+ def live_text_detected(self, text):
339
+ """Callback for real-time transcription updates"""
340
+ with self.transcription_lock:
341
+ self.last_transcription = text.strip()
342
+
343
+ def process_final_text(self, text):
344
+ """Process final transcribed text with speaker embedding"""
345
+ text = text.strip()
346
+ if text:
347
+ try:
348
+ # Get audio data for this transcription
349
+ audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None)
350
+ if audio_bytes:
351
+ self.sentence_queue.put((text, audio_bytes))
352
+ else:
353
+ # If no audio bytes, use current speaker
354
+ self.sentence_queue.put((text, None))
355
+
356
+ except Exception as e:
357
+ logger.error(f"Error processing final text: {e}")
358
+
359
+ def process_sentence_queue(self):
360
+ """Process sentences in the queue for speaker detection"""
361
+ while self.is_running:
362
+ try:
363
+ text, audio_bytes = self.sentence_queue.get(timeout=1)
364
+
365
+ current_speaker = self.speaker_detector.current_speaker
366
+
367
+ if audio_bytes:
368
+ # Convert audio data and extract embedding
369
+ audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
370
+ audio_float = audio_int16.astype(np.float32) / 32768.0
371
+
372
+ # Extract embedding
373
+ embedding = self.audio_processor.encoder.embed_utterance(audio_float)
374
+ if embedding is not None:
375
+ current_speaker, similarity = self.speaker_detector.add_embedding(embedding)
376
+
377
+ # Store sentence with speaker
378
+ with self.transcription_lock:
379
+ self.full_sentences.append((text, current_speaker))
380
+ self.update_conversation_display()
381
+
382
+ except queue.Empty:
383
+ continue
384
+ except Exception as e:
385
+ logger.error(f"Error processing sentence: {e}")
386
+
387
+ def update_conversation_display(self):
388
+ """Update the conversation display"""
389
+ try:
390
+ sentences_with_style = []
391
+
392
+ for sentence_text, speaker_id in self.full_sentences:
393
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
394
+ speaker_name = f"Speaker {speaker_id + 1}"
395
+ sentences_with_style.append(
396
+ f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> '
397
+ f'<span style="color:#333333;">{sentence_text}</span>'
398
+ )
399
+
400
+ # Add current transcription if available
401
+ if self.last_transcription:
402
+ current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker)
403
+ current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}"
404
+ sentences_with_style.append(
405
+ f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> '
406
+ f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>'
407
+ )
408
+
409
+ if sentences_with_style:
410
+ self.current_conversation = "<br><br>".join(sentences_with_style)
411
+ else:
412
+ self.current_conversation = "<i>Waiting for speech input...</i>"
413
+
414
+ except Exception as e:
415
+ logger.error(f"Error updating conversation display: {e}")
416
+ self.current_conversation = f"<i>Error: {str(e)}</i>"
417
+
418
+ def start_recording(self):
419
+ """Start the recording and transcription process"""
420
+ if self.encoder is None:
421
+ return "Please initialize models first!"
422
+
423
+ try:
424
+ # Setup recorder configuration
425
+ recorder_config = {
426
+ 'spinner': False,
427
+ 'use_microphone': False, # Using FastRTC for audio input
428
+ 'model': FINAL_TRANSCRIPTION_MODEL,
429
+ 'language': TRANSCRIPTION_LANGUAGE,
430
+ 'silero_sensitivity': SILERO_SENSITIVITY,
431
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
432
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
433
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
434
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
435
+ 'min_gap_between_recordings': 0,
436
+ 'enable_realtime_transcription': True,
437
+ 'realtime_processing_pause': 0.1,
438
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
439
+ 'on_realtime_transcription_update': self.live_text_detected,
440
+ 'beam_size': FINAL_BEAM_SIZE,
441
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
442
+ 'sample_rate': SAMPLE_RATE,
443
+ }
444
+
445
+ self.recorder = AudioToTextRecorder(**recorder_config)
446
+
447
+ # Start processing threads
448
+ self.is_running = True
449
+ self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
450
+ self.sentence_thread.start()
451
+
452
+ self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True)
453
+ self.transcription_thread.start()
454
+
455
+ return "Recording started successfully!"
456
+
457
+ except Exception as e:
458
+ logger.error(f"Error starting recording: {e}")
459
+ return f"Error starting recording: {e}"
460
+
461
+ def run_transcription(self):
462
+ """Run the transcription loop"""
463
+ try:
464
+ logger.info("Starting transcription thread")
465
+ while self.is_running:
466
+ # Just check for final text from recorder, audio is fed externally via FastRTC
467
+ text = self.recorder.text(self.process_final_text)
468
+ time.sleep(0.01) # Small sleep to prevent CPU hogging
469
+ except Exception as e:
470
+ logger.error(f"Transcription error: {e}")
471
+
472
+ def stop_recording(self):
473
+ """Stop the recording process"""
474
+ self.is_running = False
475
+ if self.recorder:
476
+ self.recorder.stop()
477
+ return "Recording stopped!"
478
+
479
+ def clear_conversation(self):
480
+ """Clear all conversation data"""
481
+ with self.transcription_lock:
482
+ self.full_sentences = []
483
+ self.last_transcription = ""
484
+ self.current_conversation = "Conversation cleared!"
485
+
486
+ if self.speaker_detector:
487
+ self.speaker_detector = SpeakerChangeDetector(
488
+ embedding_dim=self.encoder.embedding_dim,
489
+ change_threshold=self.change_threshold,
490
+ max_speakers=self.max_speakers
491
+ )
492
+
493
+ return "Conversation cleared!"
494
+
495
+ def update_settings(self, threshold, max_speakers):
496
+ """Update speaker detection settings"""
497
+ self.change_threshold = threshold
498
+ self.max_speakers = max_speakers
499
+
500
+ if self.speaker_detector:
501
+ self.speaker_detector.set_change_threshold(threshold)
502
+ self.speaker_detector.set_max_speakers(max_speakers)
503
+
504
+ return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}"
505
+
506
+ def get_formatted_conversation(self):
507
+ """Get the formatted conversation"""
508
+ return self.current_conversation
509
+
510
+ def get_status_info(self):
511
+ """Get current status information"""
512
+ if not self.speaker_detector:
513
+ return "Speaker detector not initialized"
514
+
515
+ try:
516
+ status = self.speaker_detector.get_status_info()
517
+
518
+ status_lines = [
519
+ f"**Current Speaker:** {status['current_speaker'] + 1}",
520
+ f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}",
521
+ f"**Last Similarity:** {status['last_similarity']:.3f}",
522
+ f"**Change Threshold:** {status['threshold']:.2f}",
523
+ f"**Total Sentences:** {len(self.full_sentences)}",
524
+ f"**Segments Processed:** {status['segment_counter']}",
525
+ "",
526
+ "**Speaker Activity:**"
527
+ ]
528
+
529
+ for i in range(status['max_speakers']):
530
+ color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}"
531
+ count = status['speaker_counts'][i]
532
+ active = "🟢" if count > 0 else "⚫"
533
+ status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments")
534
+
535
+ return "\n".join(status_lines)
536
+
537
+ except Exception as e:
538
+ return f"Error getting status: {e}"
539
+
540
+ def process_audio_chunk(self, audio_data, sample_rate=16000):
541
+ """Process audio chunk from FastRTC input"""
542
+ if not self.is_running or self.audio_processor is None:
543
+ return
544
+
545
+ try:
546
+ # Ensure audio is float32
547
+ if isinstance(audio_data, np.ndarray):
548
+ if audio_data.dtype != np.float32:
549
+ audio_data = audio_data.astype(np.float32)
550
+ else:
551
+ audio_data = np.array(audio_data, dtype=np.float32)
552
+
553
+ # Ensure mono
554
+ if len(audio_data.shape) > 1:
555
+ audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten()
556
+
557
+ # Normalize if needed
558
+ if np.max(np.abs(audio_data)) > 1.0:
559
+ audio_data = audio_data / np.max(np.abs(audio_data))
560
+
561
+ # Add to audio processor buffer for speaker detection
562
+ self.audio_processor.add_audio_chunk(audio_data)
563
+
564
+ # Periodically extract embeddings for speaker detection
565
+ if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: # Every 0.5 seconds
566
+ embedding = self.audio_processor.extract_embedding_from_buffer()
567
+ if embedding is not None:
568
+ self.speaker_detector.add_embedding(embedding)
569
+
570
+ # Feed audio to RealtimeSTT recorder
571
+ if self.recorder and self.is_running:
572
+ # Convert float32 [-1.0, 1.0] to int16 for RealtimeSTT
573
+ int16_data = (audio_data * 32768.0).astype(np.int16).tobytes()
574
+ if sample_rate != 16000:
575
+ int16_data = self.resample_audio(int16_data, sample_rate, 16000)
576
+ self.recorder.feed_audio(int16_data)
577
+
578
+ except Exception as e:
579
+ logger.error(f"Error processing audio chunk: {e}")
580
+
581
+ def resample_audio(self, audio_bytes, from_rate, to_rate):
582
+ """Resample audio to target sample rate"""
583
+ try:
584
+ audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
585
+ num_samples = len(audio_np)
586
+ num_target_samples = int(num_samples * to_rate / from_rate)
587
+
588
+ resampled = resample(audio_np, num_target_samples)
589
+
590
+ return resampled.astype(np.int16).tobytes()
591
+ except Exception as e:
592
+ logger.error(f"Error resampling audio: {e}")
593
+ return audio_bytes
594
+
595
+
596
+ # FastRTC Audio Handler
597
+ class DiarizationHandler(AsyncStreamHandler):
598
+ def __init__(self, diarization_system):
599
+ super().__init__()
600
+ self.diarization_system = diarization_system
601
+ self.audio_buffer = []
602
+ self.buffer_size = BUFFER_SIZE
603
+
604
+ def copy(self):
605
+ """Return a fresh handler for each new stream connection"""
606
+ return DiarizationHandler(self.diarization_system)
607
+
608
+ async def emit(self):
609
+ """Not used - we only receive audio"""
610
+ return None
611
+
612
+ async def receive(self, frame):
613
+ """Receive audio data from FastRTC"""
614
+ try:
615
+ if not self.diarization_system.is_running:
616
+ return
617
+
618
+ # Extract audio data
619
+ audio_data = getattr(frame, 'data', frame)
620
+
621
+ # Convert to numpy array
622
+ if isinstance(audio_data, bytes):
623
+ audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
624
+ elif isinstance(audio_data, (list, tuple)):
625
+ sample_rate, audio_array = audio_data
626
+ if isinstance(audio_array, (list, tuple)):
627
+ audio_array = np.array(audio_array, dtype=np.float32)
628
+ else:
629
+ audio_array = np.array(audio_data, dtype=np.float32)
630
+
631
+ # Ensure 1D
632
+ if len(audio_array.shape) > 1:
633
+ audio_array = audio_array.flatten()
634
+
635
+ # Buffer audio chunks
636
+ self.audio_buffer.extend(audio_array)
637
+
638
+ # Process in chunks
639
+ while len(self.audio_buffer) >= self.buffer_size:
640
+ chunk = np.array(self.audio_buffer[:self.buffer_size])
641
+ self.audio_buffer = self.audio_buffer[self.buffer_size:]
642
+
643
+ # Process asynchronously
644
+ await self.process_audio_async(chunk)
645
+
646
+ except Exception as e:
647
+ logger.error(f"Error in FastRTC receive: {e}")
648
+
649
+ async def process_audio_async(self, audio_data):
650
+ """Process audio data asynchronously"""
651
+ try:
652
+ loop = asyncio.get_event_loop()
653
+ await loop.run_in_executor(
654
+ None,
655
+ self.diarization_system.process_audio_chunk,
656
+ audio_data,
657
+ SAMPLE_RATE
658
+ )
659
+ except Exception as e:
660
+ logger.error(f"Error in async audio processing: {e}")
661
+
662
+
663
+ # Global instances
664
+ diarization_system = RealtimeSpeakerDiarization()
665
+ audio_handler = None
666
+
667
+ def initialize_system():
668
+ """Initialize the diarization system"""
669
+ global audio_handler
670
+ try:
671
+ success = diarization_system.initialize_models()
672
+ if success:
673
+ audio_handler = DiarizationHandler(diarization_system)
674
+ return "✅ System initialized successfully!"
675
+ else:
676
+ return "❌ Failed to initialize system. Check logs for details."
677
+ except Exception as e:
678
+ logger.error(f"Initialization error: {e}")
679
+ return f"❌ Initialization error: {str(e)}"
680
+
681
+ def start_recording():
682
+ """Start recording and transcription"""
683
+ try:
684
+ result = diarization_system.start_recording()
685
+ return f"🎙️ {result}"
686
+ except Exception as e:
687
+ return f"❌ Failed to start recording: {str(e)}"
688
+
689
+ def stop_recording():
690
+ """Stop recording and transcription"""
691
+ try:
692
+ result = diarization_system.stop_recording()
693
+ return f"⏹️ {result}"
694
+ except Exception as e:
695
+ return f"❌ Failed to stop recording: {str(e)}"
696
+
697
+ def clear_conversation():
698
+ """Clear the conversation"""
699
+ try:
700
+ result = diarization_system.clear_conversation()
701
+ return f"🗑️ {result}"
702
+ except Exception as e:
703
+ return f"❌ Failed to clear conversation: {str(e)}"
704
+
705
+ def update_settings(threshold, max_speakers):
706
+ """Update system settings"""
707
+ try:
708
+ result = diarization_system.update_settings(threshold, max_speakers)
709
+ return f"⚙️ {result}"
710
+ except Exception as e:
711
+ return f"❌ Failed to update settings: {str(e)}"
712
+
713
+ def get_conversation():
714
+ """Get the current conversation"""
715
+ try:
716
+ return diarization_system.get_formatted_conversation()
717
+ except Exception as e:
718
+ return f"<i>Error getting conversation: {str(e)}</i>"
719
+
720
+ def get_status():
721
+ """Get system status"""
722
+ try:
723
+ return diarization_system.get_status_info()
724
+ except Exception as e:
725
+ return f"Error getting status: {str(e)}"
726
+
727
+ # Create Gradio interface
728
+ def create_interface():
729
+ with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface:
730
+ gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization")
731
+ gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.")
732
+
733
+ with gr.Row():
734
+ with gr.Column(scale=2):
735
+ # Conversation display
736
+ conversation_output = gr.HTML(
737
+ value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>",
738
+ label="Live Conversation"
739
+ )
740
+
741
+ # Control buttons
742
+ with gr.Row():
743
+ init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg")
744
+ start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False)
745
+ stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False)
746
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False)
747
+
748
+ # Status display
749
+ status_output = gr.Textbox(
750
+ label="System Status",
751
+ value="Ready to initialize...",
752
+ lines=8,
753
+ interactive=False
754
+ )
755
+
756
+ with gr.Column(scale=1):
757
+ # Settings
758
+ gr.Markdown("## ⚙️ Settings")
759
+
760
+ threshold_slider = gr.Slider(
761
+ minimum=0.3,
762
+ maximum=0.9,
763
+ step=0.05,
764
+ value=DEFAULT_CHANGE_THRESHOLD,
765
+ label="Speaker Change Sensitivity",
766
+ info="Lower = more sensitive"
767
+ )
768
+
769
+ max_speakers_slider = gr.Slider(
770
+ minimum=2,
771
+ maximum=ABSOLUTE_MAX_SPEAKERS,
772
+ step=1,
773
+ value=DEFAULT_MAX_SPEAKERS,
774
+ label="Maximum Speakers"
775
+ )
776
+
777
+ update_btn = gr.Button("Update Settings", variant="secondary")
778
+
779
+ # Instructions
780
+ gr.Markdown("""
781
+ ## 📋 Instructions
782
+ 1. **Initialize** the system (loads AI models)
783
+ 2. **Start** recording
784
+ 3. **Speak** - system will transcribe and identify speakers
785
+ 4. **Monitor** real-time results below
786
+
787
+ ## 🎨 Speaker Colors
788
+ - 🔴 Speaker 1 (Red)
789
+ - 🟢 Speaker 2 (Teal)
790
+ - 🔵 Speaker 3 (Blue)
791
+ - 🟡 Speaker 4 (Green)
792
+ - 🟣 Speaker 5 (Yellow)
793
+ - 🟤 Speaker 6 (Plum)
794
+ - 🟫 Speaker 7 (Mint)
795
+ - 🟨 Speaker 8 (Gold)
796
+ """)
797
+
798
+ # Event handlers
799
+ def on_initialize():
800
+ result = initialize_system()
801
+ if "✅" in result:
802
+ return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
803
+ else:
804
+ return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
805
+
806
+ def on_start():
807
+ result = start_recording()
808
+ return result, gr.update(interactive=False), gr.update(interactive=True)
809
+
810
+ def on_stop():
811
+ result = stop_recording()
812
+ return result, gr.update(interactive=True), gr.update(interactive=False)
813
+
814
+ def on_clear():
815
+ result = clear_conversation()
816
+ return result
817
+
818
+ def on_update_settings(threshold, max_speakers):
819
+ result = update_settings(threshold, int(max_speakers))
820
+ return result
821
+
822
+ def refresh_conversation():
823
+ return get_conversation()
824
+
825
+ def refresh_status():
826
+ return get_status()
827
+
828
+ # Button click handlers
829
+ init_btn.click(
830
+ fn=on_initialize,
831
+ outputs=[status_output, start_btn, stop_btn, clear_btn]
832
+ )
833
+
834
+ start_btn.click(
835
+ fn=on_start,
836
+ outputs=[status_output, start_btn, stop_btn]
837
+ )
838
+
839
+ stop_btn.click(
840
+ fn=on_stop,
841
+ outputs=[status_output, start_btn, stop_btn]
842
+ )
843
+
844
+ clear_btn.click(
845
+ fn=on_clear,
846
+ outputs=[status_output]
847
+ )
848
+
849
+ update_btn.click(
850
+ fn=on_update_settings,
851
+ inputs=[threshold_slider, max_speakers_slider],
852
+ outputs=[status_output]
853
+ )
854
+
855
+ # Auto-refresh conversation display every 1 second
856
+ conversation_timer = gr.Timer(1)
857
+ conversation_timer.tick(refresh_conversation, outputs=[conversation_output])
858
+
859
+ # Auto-refresh status every 2 seconds
860
+ status_timer = gr.Timer(2)
861
+ status_timer.tick(refresh_status, outputs=[status_output])
862
+
863
+ return interface
864
+
865
+
866
+ # FastAPI setup for FastRTC integration
867
+ app = FastAPI()
868
+
869
+ @app.get("/")
870
+ async def root():
871
+ return {"message": "Real-time Speaker Diarization API"}
872
+
873
+ @app.get("/health")
874
+ async def health_check():
875
+ return {"status": "healthy", "system_running": diarization_system.is_running}
876
+
877
+ @app.post("/initialize")
878
+ async def api_initialize():
879
+ result = initialize_system()
880
+ return {"result": result, "success": "✅" in result}
881
+
882
+ @app.post("/start")
883
+ async def api_start():
884
+ result = start_recording()
885
+ return {"result": result, "success": "🎙️" in result}
886
+
887
+ @app.post("/stop")
888
+ async def api_stop():
889
+ result = stop_recording()
890
+ return {"result": result, "success": "⏹️" in result}
891
+
892
+ @app.post("/clear")
893
+ async def api_clear():
894
+ result = clear_conversation()
895
+ return {"result": result}
896
+
897
+ @app.get("/conversation")
898
+ async def api_get_conversation():
899
+ return {"conversation": get_conversation()}
900
+
901
+ @app.get("/status")
902
+ async def api_get_status():
903
+ return {"status": get_status()}
904
+
905
+ @app.post("/settings")
906
+ async def api_update_settings(threshold: float, max_speakers: int):
907
+ result = update_settings(threshold, max_speakers)
908
+ return {"result": result}
909
+
910
+ # FastRTC Stream setup
911
+ if audio_handler:
912
+ stream = Stream(handler=audio_handler)
913
+ app.include_router(stream.router, prefix="/stream")
914
+
915
+
916
+ # Main execution
917
+ if __name__ == "__main__":
918
+ import argparse
919
+
920
+ parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System")
921
+ parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
922
+ help="Run mode: gradio interface, API only, or both")
923
+ parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
924
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
925
+ parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)")
926
+
927
+ args = parser.parse_args()
928
+
929
+ if args.mode == "gradio":
930
+ # Run Gradio interface only
931
+ interface = create_interface()
932
+ interface.launch(
933
+ server_name=args.host,
934
+ server_port=args.port,
935
+ share=True,
936
+ show_error=True
937
+ )
938
+
939
+ elif args.mode == "api":
940
+ # Run FastAPI only
941
+ uvicorn.run(
942
+ app,
943
+ host=args.host,
944
+ port=args.port,
945
+ log_level="info"
946
+ )
947
+
948
+ elif args.mode == "both":
949
+ # Run both Gradio and FastAPI
950
+ import multiprocessing
951
+ import threading
952
+
953
+ def run_gradio():
954
+ interface = create_interface()
955
+ interface.launch(
956
+ server_name=args.host,
957
+ server_port=args.port,
958
+ share=True,
959
+ show_error=True
960
+ )
961
+
962
+ def run_fastapi():
963
+ uvicorn.run(
964
+ app,
965
+ host=args.host,
966
+ port=args.api_port,
967
+ log_level="info"
968
+ )
969
+
970
+ # Start FastAPI in a separate thread
971
+ api_thread = threading.Thread(target=run_fastapi, daemon=True)
972
+ api_thread.start()
973
+
974
+ # Start Gradio in main thread
975
+ run_gradio()
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ portaudio19-dev
2
+ libasound2-dev
3
+ libjack-jackd2-dev
4
+ pulseaudio
5
+ pulseaudio-utils
realtime_diarize.py ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt6.QtWidgets import (QApplication, QTextEdit, QMainWindow, QLabel, QVBoxLayout, QWidget,
2
+ QHBoxLayout, QPushButton, QSizePolicy, QGroupBox, QSlider, QSpinBox)
3
+ from PyQt6.QtCore import Qt, pyqtSignal, QThread, QEvent, QTimer
4
+ from scipy.spatial.distance import cosine
5
+ from RealtimeSTT import AudioToTextRecorder
6
+ import numpy as np
7
+ import soundcard as sc
8
+ import queue
9
+ import torch
10
+ import time
11
+ import sys
12
+ import os
13
+ import urllib.request
14
+ import torchaudio
15
+
16
+ # Simplified configuration parameters
17
+ SILENCE_THRESHS = [0, 0.4]
18
+ FINAL_TRANSCRIPTION_MODEL = "distil-large-v3"
19
+ FINAL_BEAM_SIZE = 5
20
+ REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
21
+ REALTIME_BEAM_SIZE = 5
22
+ TRANSCRIPTION_LANGUAGE = "en" # Accuracy in languages ​​other than English is very low.
23
+ SILERO_SENSITIVITY = 0.4
24
+ WEBRTC_SENSITIVITY = 3
25
+ MIN_LENGTH_OF_RECORDING = 0.7
26
+ PRE_RECORDING_BUFFER_DURATION = 0.35
27
+
28
+ # Speaker change detection parameters
29
+ DEFAULT_CHANGE_THRESHOLD = 0.7 # Threshold for detecting speaker change
30
+ EMBEDDING_HISTORY_SIZE = 5 # Number of embeddings to keep for comparison
31
+ MIN_SEGMENT_DURATION = 1.0 # Minimum duration before considering a speaker change
32
+ DEFAULT_MAX_SPEAKERS = 4 # Default maximum number of speakers
33
+ ABSOLUTE_MAX_SPEAKERS = 10 # Absolute maximum number of speakers allowed
34
+
35
+ # Global variables
36
+ FAST_SENTENCE_END = True
37
+ USE_MICROPHONE = False
38
+ SAMPLE_RATE = 16000
39
+ BUFFER_SIZE = 512
40
+ CHANNELS = 1
41
+
42
+ # Speaker colors - now we have colors for up to 10 speakers
43
+ SPEAKER_COLORS = [
44
+ "#FFFF00", # Yellow
45
+ "#FF0000", # Red
46
+ "#00FF00", # Green
47
+ "#00FFFF", # Cyan
48
+ "#FF00FF", # Magenta
49
+ "#0000FF", # Blue
50
+ "#FF8000", # Orange
51
+ "#00FF80", # Spring Green
52
+ "#8000FF", # Purple
53
+ "#FFFFFF", # White
54
+ ]
55
+
56
+ # Color names for display
57
+ SPEAKER_COLOR_NAMES = [
58
+ "Yellow",
59
+ "Red",
60
+ "Green",
61
+ "Cyan",
62
+ "Magenta",
63
+ "Blue",
64
+ "Orange",
65
+ "Spring Green",
66
+ "Purple",
67
+ "White"
68
+ ]
69
+
70
+
71
+ class SpeechBrainEncoder:
72
+ """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings"""
73
+ def __init__(self, device="cpu"):
74
+ self.device = device
75
+ self.model = None
76
+ self.embedding_dim = 192 # ECAPA-TDNN default dimension
77
+ self.model_loaded = False
78
+ self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
79
+ os.makedirs(self.cache_dir, exist_ok=True)
80
+
81
+ def _download_model(self):
82
+ """Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
83
+ model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
84
+ model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
85
+
86
+ if not os.path.exists(model_path):
87
+ print(f"Downloading ECAPA-TDNN model to {model_path}...")
88
+ urllib.request.urlretrieve(model_url, model_path)
89
+
90
+ return model_path
91
+
92
+ def load_model(self):
93
+ """Load the ECAPA-TDNN model"""
94
+ try:
95
+ # Import SpeechBrain
96
+ from speechbrain.pretrained import EncoderClassifier
97
+
98
+ # Get model path
99
+ model_path = self._download_model()
100
+
101
+ # Load the pre-trained model
102
+ self.model = EncoderClassifier.from_hparams(
103
+ source="speechbrain/spkrec-ecapa-voxceleb",
104
+ savedir=self.cache_dir,
105
+ run_opts={"device": self.device}
106
+ )
107
+
108
+ self.model_loaded = True
109
+ return True
110
+ except Exception as e:
111
+ print(f"Error loading ECAPA-TDNN model: {e}")
112
+ return False
113
+
114
+ def embed_utterance(self, audio, sr=16000):
115
+ """Extract speaker embedding from audio"""
116
+ if not self.model_loaded:
117
+ raise ValueError("Model not loaded. Call load_model() first.")
118
+
119
+ try:
120
+ # Convert numpy array to torch tensor
121
+ if isinstance(audio, np.ndarray):
122
+ waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)
123
+ else:
124
+ waveform = audio.unsqueeze(0)
125
+
126
+ # Ensure sample rate matches model expected rate
127
+ if sr != 16000:
128
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
129
+
130
+ # Get embedding
131
+ with torch.no_grad():
132
+ embedding = self.model.encode_batch(waveform)
133
+
134
+ return embedding.squeeze().cpu().numpy()
135
+ except Exception as e:
136
+ print(f"Error extracting embedding: {e}")
137
+ return np.zeros(self.embedding_dim)
138
+
139
+
140
+ class AudioProcessor:
141
+ """Processes audio data to extract speaker embeddings"""
142
+ def __init__(self, encoder):
143
+ self.encoder = encoder
144
+
145
+ def extract_embedding(self, audio_int16):
146
+ try:
147
+ # Convert int16 audio data to float32
148
+ float_audio = audio_int16.astype(np.float32) / 32768.0
149
+
150
+ # Normalize if needed
151
+ if np.abs(float_audio).max() > 1.0:
152
+ float_audio = float_audio / np.abs(float_audio).max()
153
+
154
+ # Extract embedding using the loaded encoder
155
+ embedding = self.encoder.embed_utterance(float_audio)
156
+
157
+ return embedding
158
+ except Exception as e:
159
+ print(f"Embedding extraction error: {e}")
160
+ return np.zeros(self.encoder.embedding_dim)
161
+
162
+
163
+ class EncoderLoaderThread(QThread):
164
+ """Thread for loading the speaker encoder model"""
165
+ model_loaded = pyqtSignal(object)
166
+ progress_update = pyqtSignal(str)
167
+
168
+ def run(self):
169
+ try:
170
+ self.progress_update.emit("Initializing speaker encoder model...")
171
+
172
+ # Check device
173
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
174
+ self.progress_update.emit(f"Using device: {device_str}")
175
+
176
+ # Create SpeechBrain encoder
177
+ self.progress_update.emit("Loading ECAPA-TDNN model...")
178
+ encoder = SpeechBrainEncoder(device=device_str)
179
+
180
+ # Load the model
181
+ success = encoder.load_model()
182
+
183
+ if success:
184
+ self.progress_update.emit("ECAPA-TDNN model loading complete!")
185
+ self.model_loaded.emit(encoder)
186
+ else:
187
+ self.progress_update.emit("Failed to load ECAPA-TDNN model. Using fallback...")
188
+ self.model_loaded.emit(None)
189
+ except Exception as e:
190
+ self.progress_update.emit(f"Model loading error: {e}")
191
+ self.model_loaded.emit(None)
192
+
193
+
194
+ class SpeakerChangeDetector:
195
+ """Modified speaker change detector that supports a configurable number of speakers"""
196
+ def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
197
+ self.embedding_dim = embedding_dim
198
+ self.change_threshold = change_threshold
199
+ self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) # Ensure we don't exceed absolute max
200
+ self.current_speaker = 0 # Initial speaker (0 to max_speakers-1)
201
+ self.previous_embeddings = []
202
+ self.last_change_time = time.time()
203
+ self.mean_embeddings = [None] * self.max_speakers # Mean embeddings for each speaker
204
+ self.speaker_embeddings = [[] for _ in range(self.max_speakers)] # All embeddings for each speaker
205
+ self.last_similarity = 0.0
206
+ self.active_speakers = set([0]) # Track which speakers have been detected
207
+
208
+ def set_max_speakers(self, max_speakers):
209
+ """Update the maximum number of speakers"""
210
+ new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
211
+
212
+ # If reducing the number of speakers
213
+ if new_max < self.max_speakers:
214
+ # Remove any speakers beyond the new max
215
+ for speaker_id in list(self.active_speakers):
216
+ if speaker_id >= new_max:
217
+ self.active_speakers.discard(speaker_id)
218
+
219
+ # Ensure current speaker is valid
220
+ if self.current_speaker >= new_max:
221
+ self.current_speaker = 0
222
+
223
+ # Expand arrays if increasing max speakers
224
+ if new_max > self.max_speakers:
225
+ # Extend mean_embeddings array
226
+ self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
227
+
228
+ # Extend speaker_embeddings array
229
+ self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
230
+
231
+ # Truncate arrays if decreasing max speakers
232
+ else:
233
+ self.mean_embeddings = self.mean_embeddings[:new_max]
234
+ self.speaker_embeddings = self.speaker_embeddings[:new_max]
235
+
236
+ self.max_speakers = new_max
237
+
238
+ def set_change_threshold(self, threshold):
239
+ """Update the threshold for detecting speaker changes"""
240
+ self.change_threshold = max(0.1, min(threshold, 0.99))
241
+
242
+ def add_embedding(self, embedding, timestamp=None):
243
+ """Add a new embedding and check if there's a speaker change"""
244
+ current_time = timestamp or time.time()
245
+
246
+ # Initialize first speaker if no embeddings yet
247
+ if not self.previous_embeddings:
248
+ self.previous_embeddings.append(embedding)
249
+ self.speaker_embeddings[self.current_speaker].append(embedding)
250
+ if self.mean_embeddings[self.current_speaker] is None:
251
+ self.mean_embeddings[self.current_speaker] = embedding.copy()
252
+ return self.current_speaker, 1.0
253
+
254
+ # Calculate similarity with current speaker's mean embedding
255
+ current_mean = self.mean_embeddings[self.current_speaker]
256
+ if current_mean is not None:
257
+ similarity = 1.0 - cosine(embedding, current_mean)
258
+ else:
259
+ # If no mean yet, compare with most recent embedding
260
+ similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
261
+
262
+ self.last_similarity = similarity
263
+
264
+ # Decide if this is a speaker change
265
+ time_since_last_change = current_time - self.last_change_time
266
+ is_speaker_change = False
267
+
268
+ # Only consider change if minimum time has passed since last change
269
+ if time_since_last_change >= MIN_SEGMENT_DURATION:
270
+ # Check similarity against threshold
271
+ if similarity < self.change_threshold:
272
+ # Compare with all other speakers' means if available
273
+ best_speaker = self.current_speaker
274
+ best_similarity = similarity
275
+
276
+ # Check each active speaker
277
+ for speaker_id in range(self.max_speakers):
278
+ if speaker_id == self.current_speaker:
279
+ continue
280
+
281
+ speaker_mean = self.mean_embeddings[speaker_id]
282
+
283
+ if speaker_mean is not None:
284
+ # Calculate similarity with this speaker
285
+ speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
286
+
287
+ # If more similar to this speaker, update best match
288
+ if speaker_similarity > best_similarity:
289
+ best_similarity = speaker_similarity
290
+ best_speaker = speaker_id
291
+
292
+ # If best match is different from current speaker, change speaker
293
+ if best_speaker != self.current_speaker:
294
+ is_speaker_change = True
295
+ self.current_speaker = best_speaker
296
+ # If no good match with existing speakers and we haven't used all speakers yet
297
+ elif len(self.active_speakers) < self.max_speakers:
298
+ # Find the next unused speaker ID
299
+ for new_id in range(self.max_speakers):
300
+ if new_id not in self.active_speakers:
301
+ is_speaker_change = True
302
+ self.current_speaker = new_id
303
+ self.active_speakers.add(new_id)
304
+ break
305
+
306
+ # Handle speaker change
307
+ if is_speaker_change:
308
+ self.last_change_time = current_time
309
+
310
+ # Update embeddings
311
+ self.previous_embeddings.append(embedding)
312
+ if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
313
+ self.previous_embeddings.pop(0)
314
+
315
+ # Update current speaker's embeddings and mean
316
+ self.speaker_embeddings[self.current_speaker].append(embedding)
317
+ self.active_speakers.add(self.current_speaker)
318
+
319
+ if len(self.speaker_embeddings[self.current_speaker]) > 30: # Limit history size
320
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
321
+
322
+ # Update mean embedding for current speaker
323
+ if self.speaker_embeddings[self.current_speaker]:
324
+ self.mean_embeddings[self.current_speaker] = np.mean(
325
+ self.speaker_embeddings[self.current_speaker], axis=0
326
+ )
327
+
328
+ return self.current_speaker, similarity
329
+
330
+ def get_color_for_speaker(self, speaker_id):
331
+ """Return color for speaker ID (0 to max_speakers-1)"""
332
+ if 0 <= speaker_id < len(SPEAKER_COLORS):
333
+ return SPEAKER_COLORS[speaker_id]
334
+ return "#FFFFFF" # Default to white if out of range
335
+
336
+ def get_status_info(self):
337
+ """Return status information about the speaker change detector"""
338
+ speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)]
339
+
340
+ return {
341
+ "current_speaker": self.current_speaker,
342
+ "speaker_counts": speaker_counts,
343
+ "active_speakers": len(self.active_speakers),
344
+ "max_speakers": self.max_speakers,
345
+ "last_similarity": self.last_similarity,
346
+ "threshold": self.change_threshold
347
+ }
348
+
349
+
350
+ class TextUpdateThread(QThread):
351
+ text_update_signal = pyqtSignal(str)
352
+
353
+ def __init__(self, text):
354
+ super().__init__()
355
+ self.text = text
356
+
357
+ def run(self):
358
+ self.text_update_signal.emit(self.text)
359
+
360
+
361
+ class SentenceWorker(QThread):
362
+ sentence_update_signal = pyqtSignal(list, list)
363
+ status_signal = pyqtSignal(str)
364
+
365
+ def __init__(self, queue, encoder, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
366
+ super().__init__()
367
+ self.queue = queue
368
+ self.encoder = encoder
369
+ self._is_running = True
370
+ self.full_sentences = []
371
+ self.sentence_speakers = []
372
+ self.change_threshold = change_threshold
373
+ self.max_speakers = max_speakers
374
+
375
+ # Initialize audio processor for embedding extraction
376
+ self.audio_processor = AudioProcessor(self.encoder)
377
+
378
+ # Initialize speaker change detector
379
+ self.speaker_detector = SpeakerChangeDetector(
380
+ embedding_dim=self.encoder.embedding_dim,
381
+ change_threshold=self.change_threshold,
382
+ max_speakers=self.max_speakers
383
+ )
384
+
385
+ # Setup monitoring timer
386
+ self.monitoring_timer = QTimer()
387
+ self.monitoring_timer.timeout.connect(self.report_status)
388
+ self.monitoring_timer.start(2000) # Report every 2 seconds
389
+
390
+ def set_change_threshold(self, threshold):
391
+ """Update change detection threshold"""
392
+ self.change_threshold = threshold
393
+ self.speaker_detector.set_change_threshold(threshold)
394
+
395
+ def set_max_speakers(self, max_speakers):
396
+ """Update maximum number of speakers"""
397
+ self.max_speakers = max_speakers
398
+ self.speaker_detector.set_max_speakers(max_speakers)
399
+
400
+ def run(self):
401
+ """Main worker thread loop"""
402
+ while self._is_running:
403
+ try:
404
+ text, bytes = self.queue.get(timeout=1)
405
+ self.process_item(text, bytes)
406
+ except queue.Empty:
407
+ continue
408
+
409
+ def report_status(self):
410
+ """Report status information"""
411
+ # Get status information from speaker detector
412
+ status = self.speaker_detector.get_status_info()
413
+
414
+ # Prepare status message with information for all speakers
415
+ status_text = f"Current speaker: {status['current_speaker'] + 1}\n"
416
+ status_text += f"Active speakers: {status['active_speakers']} of {status['max_speakers']}\n"
417
+
418
+ # Show segment counts for each speaker
419
+ for i in range(status['max_speakers']):
420
+ if i < len(SPEAKER_COLOR_NAMES):
421
+ color_name = SPEAKER_COLOR_NAMES[i]
422
+ else:
423
+ color_name = f"Speaker {i+1}"
424
+ status_text += f"Speaker {i+1} ({color_name}) segments: {status['speaker_counts'][i]}\n"
425
+
426
+ status_text += f"Last similarity score: {status['last_similarity']:.3f}\n"
427
+ status_text += f"Change threshold: {status['threshold']:.2f}\n"
428
+ status_text += f"Total sentences: {len(self.full_sentences)}"
429
+
430
+ # Send to UI
431
+ self.status_signal.emit(status_text)
432
+
433
+ def process_item(self, text, bytes):
434
+ """Process a new text-audio pair"""
435
+ # Convert audio data to int16
436
+ audio_int16 = np.int16(bytes * 32767)
437
+
438
+ # Extract speaker embedding
439
+ speaker_embedding = self.audio_processor.extract_embedding(audio_int16)
440
+
441
+ # Store sentence and embedding
442
+ self.full_sentences.append((text, speaker_embedding))
443
+
444
+ # Fill in any missing speaker assignments
445
+ if len(self.sentence_speakers) < len(self.full_sentences) - 1:
446
+ while len(self.sentence_speakers) < len(self.full_sentences) - 1:
447
+ self.sentence_speakers.append(0) # Default to first speaker
448
+
449
+ # Detect speaker changes
450
+ speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding)
451
+ self.sentence_speakers.append(speaker_id)
452
+
453
+ # Send updated data to UI
454
+ self.sentence_update_signal.emit(self.full_sentences, self.sentence_speakers)
455
+
456
+ def stop(self):
457
+ """Stop the worker thread"""
458
+ self._is_running = False
459
+ if self.monitoring_timer.isActive():
460
+ self.monitoring_timer.stop()
461
+
462
+
463
+ class RecordingThread(QThread):
464
+ def __init__(self, recorder):
465
+ super().__init__()
466
+ self.recorder = recorder
467
+ self._is_running = True
468
+
469
+ # Determine input source
470
+ if USE_MICROPHONE:
471
+ self.device_id = str(sc.default_microphone().name)
472
+ self.include_loopback = False
473
+ else:
474
+ self.device_id = str(sc.default_speaker().name)
475
+ self.include_loopback = True
476
+
477
+ def updateDevice(self, device_id, include_loopback):
478
+ self.device_id = device_id
479
+ self.include_loopback = include_loopback
480
+
481
+ def run(self):
482
+ while self._is_running:
483
+ try:
484
+ with sc.get_microphone(id=self.device_id, include_loopback=self.include_loopback).recorder(
485
+ samplerate=SAMPLE_RATE, blocksize=BUFFER_SIZE
486
+ ) as mic:
487
+ # Process audio chunks while device hasn't changed
488
+ current_device = self.device_id
489
+ current_loopback = self.include_loopback
490
+
491
+ while self._is_running and current_device == self.device_id and current_loopback == self.include_loopback:
492
+ # Record audio chunk
493
+ audio_data = mic.record(numframes=BUFFER_SIZE)
494
+
495
+ # Convert stereo to mono if needed
496
+ if audio_data.shape[1] > 1 and CHANNELS == 1:
497
+ audio_data = audio_data[:, 0]
498
+
499
+ # Convert to int16
500
+ audio_int16 = (audio_data.flatten() * 32767).astype(np.int16)
501
+
502
+ # Feed to recorder
503
+ audio_bytes = audio_int16.tobytes()
504
+ self.recorder.feed_audio(audio_bytes)
505
+
506
+ except Exception as e:
507
+ print(f"Recording error: {e}")
508
+ # Wait before retry on error
509
+ time.sleep(1)
510
+
511
+ def stop(self):
512
+ self._is_running = False
513
+
514
+
515
+ class TextRetrievalThread(QThread):
516
+ textRetrievedFinal = pyqtSignal(str, np.ndarray)
517
+ textRetrievedLive = pyqtSignal(str)
518
+ recorderStarted = pyqtSignal()
519
+
520
+ def __init__(self):
521
+ super().__init__()
522
+
523
+ def live_text_detected(self, text):
524
+ self.textRetrievedLive.emit(text)
525
+
526
+ def run(self):
527
+ recorder_config = {
528
+ 'spinner': False,
529
+ 'use_microphone': False,
530
+ 'model': FINAL_TRANSCRIPTION_MODEL,
531
+ 'language': TRANSCRIPTION_LANGUAGE,
532
+ 'silero_sensitivity': SILERO_SENSITIVITY,
533
+ 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
534
+ 'post_speech_silence_duration': SILENCE_THRESHS[1],
535
+ 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
536
+ 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
537
+ 'min_gap_between_recordings': 0,
538
+ 'enable_realtime_transcription': True,
539
+ 'realtime_processing_pause': 0,
540
+ 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
541
+ 'on_realtime_transcription_update': self.live_text_detected,
542
+ 'beam_size': FINAL_BEAM_SIZE,
543
+ 'beam_size_realtime': REALTIME_BEAM_SIZE,
544
+ 'buffer_size': BUFFER_SIZE,
545
+ 'sample_rate': SAMPLE_RATE,
546
+ }
547
+
548
+ self.recorder = AudioToTextRecorder(**recorder_config)
549
+ self.recorderStarted.emit()
550
+
551
+ def process_text(text):
552
+ bytes = self.recorder.last_transcription_bytes
553
+ self.textRetrievedFinal.emit(text, bytes)
554
+
555
+ while True:
556
+ self.recorder.text(process_text)
557
+
558
+
559
+ class MainWindow(QMainWindow):
560
+ def __init__(self):
561
+ super().__init__()
562
+
563
+ self.setWindowTitle("Real-time Speaker Change Detection")
564
+
565
+ self.encoder = None
566
+ self.initialized = False
567
+ self.displayed_text = ""
568
+ self.last_realtime_text = ""
569
+ self.full_sentences = []
570
+ self.sentence_speakers = []
571
+ self.pending_sentences = []
572
+ self.queue = queue.Queue()
573
+ self.recording_thread = None
574
+ self.change_threshold = DEFAULT_CHANGE_THRESHOLD
575
+ self.max_speakers = DEFAULT_MAX_SPEAKERS
576
+
577
+ # Create main horizontal layout
578
+ self.mainLayout = QHBoxLayout()
579
+
580
+ # Add text edit area to main layout
581
+ self.text_edit = QTextEdit(self)
582
+ self.mainLayout.addWidget(self.text_edit, 1)
583
+
584
+ # Create right layout for controls
585
+ self.rightLayout = QVBoxLayout()
586
+ self.rightLayout.setAlignment(Qt.AlignmentFlag.AlignTop)
587
+
588
+ # Create all controls
589
+ self.create_controls()
590
+
591
+ # Create container for right layout
592
+ self.rightContainer = QWidget()
593
+ self.rightContainer.setLayout(self.rightLayout)
594
+ self.mainLayout.addWidget(self.rightContainer, 0)
595
+
596
+ # Set main layout as central widget
597
+ self.centralWidget = QWidget()
598
+ self.centralWidget.setLayout(self.mainLayout)
599
+ self.setCentralWidget(self.centralWidget)
600
+
601
+ self.setStyleSheet("""
602
+ QGroupBox {
603
+ border: 1px solid #555;
604
+ border-radius: 3px;
605
+ margin-top: 10px;
606
+ padding-top: 10px;
607
+ color: #ddd;
608
+ }
609
+ QGroupBox::title {
610
+ subcontrol-origin: margin;
611
+ subcontrol-position: top center;
612
+ padding: 0 5px;
613
+ }
614
+ QLabel {
615
+ color: #ddd;
616
+ }
617
+ QPushButton {
618
+ background: #444;
619
+ color: #ddd;
620
+ border: 1px solid #555;
621
+ padding: 5px;
622
+ margin-bottom: 10px;
623
+ }
624
+ QPushButton:hover {
625
+ background: #555;
626
+ }
627
+ QTextEdit {
628
+ background-color: #1e1e1e;
629
+ color: #ffffff;
630
+ font-family: 'Arial';
631
+ font-size: 16pt;
632
+ }
633
+ QSlider {
634
+ height: 30px;
635
+ }
636
+ QSlider::groove:horizontal {
637
+ height: 8px;
638
+ background: #333;
639
+ margin: 2px 0;
640
+ }
641
+ QSlider::handle:horizontal {
642
+ background: #666;
643
+ border: 1px solid #777;
644
+ width: 18px;
645
+ margin: -8px 0;
646
+ border-radius: 9px;
647
+ }
648
+ """)
649
+
650
+ def create_controls(self):
651
+ # Speaker change threshold control
652
+ self.threshold_group = QGroupBox("Speaker Change Sensitivity")
653
+ threshold_layout = QVBoxLayout()
654
+
655
+ self.threshold_label = QLabel(f"Change threshold: {self.change_threshold:.2f}")
656
+ threshold_layout.addWidget(self.threshold_label)
657
+
658
+ self.threshold_slider = QSlider(Qt.Orientation.Horizontal)
659
+ self.threshold_slider.setMinimum(10)
660
+ self.threshold_slider.setMaximum(95)
661
+ self.threshold_slider.setValue(int(self.change_threshold * 100))
662
+ self.threshold_slider.valueChanged.connect(self.update_threshold)
663
+ threshold_layout.addWidget(self.threshold_slider)
664
+
665
+ self.threshold_explanation = QLabel(
666
+ "If the speakers have similar voices, it would be better to set it above 0.5, and if they have different voices, it would be lower."
667
+ )
668
+ self.threshold_explanation.setWordWrap(True)
669
+ threshold_layout.addWidget(self.threshold_explanation)
670
+
671
+ self.threshold_group.setLayout(threshold_layout)
672
+ self.rightLayout.addWidget(self.threshold_group)
673
+
674
+ # Max speakers control
675
+ self.max_speakers_group = QGroupBox("Maximum Number of Speakers")
676
+ max_speakers_layout = QVBoxLayout()
677
+
678
+ self.max_speakers_label = QLabel(f"Max speakers: {self.max_speakers}")
679
+ max_speakers_layout.addWidget(self.max_speakers_label)
680
+
681
+ self.max_speakers_spinbox = QSpinBox()
682
+ self.max_speakers_spinbox.setMinimum(2)
683
+ self.max_speakers_spinbox.setMaximum(ABSOLUTE_MAX_SPEAKERS)
684
+ self.max_speakers_spinbox.setValue(self.max_speakers)
685
+ self.max_speakers_spinbox.valueChanged.connect(self.update_max_speakers)
686
+ max_speakers_layout.addWidget(self.max_speakers_spinbox)
687
+
688
+ self.max_speakers_explanation = QLabel(
689
+ f"You can set between 2 and {ABSOLUTE_MAX_SPEAKERS} speakers.\n"
690
+ "Changes will apply immediately."
691
+ )
692
+ self.max_speakers_explanation.setWordWrap(True)
693
+ max_speakers_layout.addWidget(self.max_speakers_explanation)
694
+
695
+ self.max_speakers_group.setLayout(max_speakers_layout)
696
+ self.rightLayout.addWidget(self.max_speakers_group)
697
+
698
+ # Speaker color legend - dynamic based on max speakers
699
+ self.legend_group = QGroupBox("Speaker Colors")
700
+ self.legend_layout = QVBoxLayout()
701
+
702
+ # Create speaker labels dynamically
703
+ self.speaker_labels = []
704
+ for i in range(ABSOLUTE_MAX_SPEAKERS):
705
+ color = SPEAKER_COLORS[i]
706
+ color_name = SPEAKER_COLOR_NAMES[i]
707
+ label = QLabel(f"Speaker {i+1} ({color_name}): <span style='color:{color};'>■■■■■</span>")
708
+ self.speaker_labels.append(label)
709
+ if i < self.max_speakers:
710
+ self.legend_layout.addWidget(label)
711
+
712
+ self.legend_group.setLayout(self.legend_layout)
713
+ self.rightLayout.addWidget(self.legend_group)
714
+
715
+ # Status display area
716
+ self.status_group = QGroupBox("Status")
717
+ status_layout = QVBoxLayout()
718
+
719
+ self.status_label = QLabel("Status information will be displayed here.")
720
+ self.status_label.setWordWrap(True)
721
+ status_layout.addWidget(self.status_label)
722
+
723
+ self.status_group.setLayout(status_layout)
724
+ self.rightLayout.addWidget(self.status_group)
725
+
726
+ # Clear button
727
+ self.clear_button = QPushButton("Clear Conversation")
728
+ self.clear_button.clicked.connect(self.clear_state)
729
+ self.clear_button.setEnabled(False)
730
+ self.rightLayout.addWidget(self.clear_button)
731
+
732
+ def update_threshold(self, value):
733
+ """Update speaker change detection threshold"""
734
+ threshold = value / 100.0
735
+ self.change_threshold = threshold
736
+ self.threshold_label.setText(f"Change threshold: {threshold:.2f}")
737
+
738
+ # Update in worker if it exists
739
+ if hasattr(self, 'worker_thread'):
740
+ self.worker_thread.set_change_threshold(threshold)
741
+
742
+ def update_max_speakers(self, value):
743
+ """Update maximum number of speakers"""
744
+ self.max_speakers = value
745
+ self.max_speakers_label.setText(f"Max speakers: {value}")
746
+
747
+ # Update visible speaker labels
748
+ self.update_speaker_labels()
749
+
750
+ # Update in worker if it exists
751
+ if hasattr(self, 'worker_thread'):
752
+ self.worker_thread.set_max_speakers(value)
753
+
754
+ def update_speaker_labels(self):
755
+ """Update which speaker labels are visible based on max_speakers"""
756
+ # Clear all labels first
757
+ for i in range(len(self.speaker_labels)):
758
+ label = self.speaker_labels[i]
759
+ if label.parent():
760
+ self.legend_layout.removeWidget(label)
761
+ label.setParent(None)
762
+
763
+ # Add only the labels for the current max_speakers
764
+ for i in range(min(self.max_speakers, len(self.speaker_labels))):
765
+ self.legend_layout.addWidget(self.speaker_labels[i])
766
+
767
+ def clear_state(self):
768
+ # Clear text edit area
769
+ self.text_edit.clear()
770
+
771
+ # Reset state variables
772
+ self.displayed_text = ""
773
+ self.last_realtime_text = ""
774
+ self.full_sentences = []
775
+ self.sentence_speakers = []
776
+ self.pending_sentences = []
777
+
778
+ if hasattr(self, 'worker_thread'):
779
+ self.worker_thread.full_sentences = []
780
+ self.worker_thread.sentence_speakers = []
781
+ # Reset speaker detector with current threshold and max_speakers
782
+ self.worker_thread.speaker_detector = SpeakerChangeDetector(
783
+ embedding_dim=self.encoder.embedding_dim,
784
+ change_threshold=self.change_threshold,
785
+ max_speakers=self.max_speakers
786
+ )
787
+
788
+ # Display message
789
+ self.text_edit.setHtml("<i>All content cleared. Waiting for new input...</i>")
790
+
791
+ def update_status(self, status_text):
792
+ self.status_label.setText(status_text)
793
+
794
+ def showEvent(self, event):
795
+ super().showEvent(event)
796
+ if event.type() == QEvent.Type.Show:
797
+ if not self.initialized:
798
+ self.initialized = True
799
+ self.resize(1200, 800)
800
+ self.update_text("<i>Initializing application...</i>")
801
+
802
+ QTimer.singleShot(500, self.init)
803
+
804
+ def process_live_text(self, text):
805
+ text = text.strip()
806
+
807
+ if text:
808
+ sentence_delimiters = '.?!。'
809
+ prob_sentence_end = (
810
+ len(self.last_realtime_text) > 0
811
+ and text[-1] in sentence_delimiters
812
+ and self.last_realtime_text[-1] in sentence_delimiters
813
+ )
814
+
815
+ self.last_realtime_text = text
816
+
817
+ if prob_sentence_end:
818
+ if FAST_SENTENCE_END:
819
+ self.text_retrieval_thread.recorder.stop()
820
+ else:
821
+ self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
822
+ else:
823
+ self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
824
+
825
+ self.text_detected(text)
826
+
827
+ def text_detected(self, text):
828
+ try:
829
+ sentences_with_style = []
830
+ for i, sentence in enumerate(self.full_sentences):
831
+ sentence_text, _ = sentence
832
+ if i >= len(self.sentence_speakers):
833
+ color = "#FFFFFF" # Default white
834
+ else:
835
+ speaker_id = self.sentence_speakers[i]
836
+ color = self.worker_thread.speaker_detector.get_color_for_speaker(speaker_id)
837
+
838
+ sentences_with_style.append(
839
+ f'<span style="color:{color};">{sentence_text}</span>')
840
+
841
+ for pending_sentence in self.pending_sentences:
842
+ sentences_with_style.append(
843
+ f'<span style="color:#60FFFF;">{pending_sentence}</span>')
844
+
845
+ new_text = " ".join(sentences_with_style).strip() + " " + text if len(sentences_with_style) > 0 else text
846
+
847
+ if new_text != self.displayed_text:
848
+ self.displayed_text = new_text
849
+ self.update_text(new_text)
850
+ except Exception as e:
851
+ print(f"Error: {e}")
852
+
853
+ def process_final(self, text, bytes):
854
+ text = text.strip()
855
+ if text:
856
+ try:
857
+ self.pending_sentences.append(text)
858
+ self.queue.put((text, bytes))
859
+ except Exception as e:
860
+ print(f"Error: {e}")
861
+
862
+ def capture_output_and_feed_to_recorder(self):
863
+ # Use default device settings
864
+ device_id = str(sc.default_speaker().name)
865
+ include_loopback = True
866
+
867
+ self.recording_thread = RecordingThread(self.text_retrieval_thread.recorder)
868
+ # Update with current device settings
869
+ self.recording_thread.updateDevice(device_id, include_loopback)
870
+ self.recording_thread.start()
871
+
872
+ def recorder_ready(self):
873
+ self.update_text("<i>Recording ready</i>")
874
+ self.capture_output_and_feed_to_recorder()
875
+
876
+ def init(self):
877
+ self.update_text("<i>Loading ECAPA-TDNN model... Please wait.</i>")
878
+
879
+ # Start model loading in background thread
880
+ self.start_encoder()
881
+
882
+ def update_loading_status(self, message):
883
+ self.update_text(f"<i>{message}</i>")
884
+
885
+ def start_encoder(self):
886
+ # Create and start encoder loader thread
887
+ self.encoder_loader_thread = EncoderLoaderThread()
888
+ self.encoder_loader_thread.model_loaded.connect(self.on_model_loaded)
889
+ self.encoder_loader_thread.progress_update.connect(self.update_loading_status)
890
+ self.encoder_loader_thread.start()
891
+
892
+ def on_model_loaded(self, encoder):
893
+ # Store loaded encoder model
894
+ self.encoder = encoder
895
+
896
+ if self.encoder is None:
897
+ self.update_text("<i>Failed to load ECAPA-TDNN model. Please check your configuration.</i>")
898
+ return
899
+
900
+ # Enable all controls after model is loaded
901
+ self.clear_button.setEnabled(True)
902
+ self.threshold_slider.setEnabled(True)
903
+
904
+ # Continue initialization
905
+ self.update_text("<i>ECAPA-TDNN model loaded. Starting recorder...</i>")
906
+
907
+ self.text_retrieval_thread = TextRetrievalThread()
908
+ self.text_retrieval_thread.recorderStarted.connect(
909
+ self.recorder_ready)
910
+ self.text_retrieval_thread.textRetrievedLive.connect(
911
+ self.process_live_text)
912
+ self.text_retrieval_thread.textRetrievedFinal.connect(
913
+ self.process_final)
914
+ self.text_retrieval_thread.start()
915
+
916
+ self.worker_thread = SentenceWorker(
917
+ self.queue,
918
+ self.encoder,
919
+ change_threshold=self.change_threshold,
920
+ max_speakers=self.max_speakers
921
+ )
922
+ self.worker_thread.sentence_update_signal.connect(
923
+ self.sentence_updated)
924
+ self.worker_thread.status_signal.connect(
925
+ self.update_status)
926
+ self.worker_thread.start()
927
+
928
+ def sentence_updated(self, full_sentences, sentence_speakers):
929
+ self.pending_text = ""
930
+ self.full_sentences = full_sentences
931
+ self.sentence_speakers = sentence_speakers
932
+ for sentence in self.full_sentences:
933
+ sentence_text, _ = sentence
934
+ if sentence_text in self.pending_sentences:
935
+ self.pending_sentences.remove(sentence_text)
936
+ self.text_detected("")
937
+
938
+ def set_text(self, text):
939
+ self.update_thread = TextUpdateThread(text)
940
+ self.update_thread.text_update_signal.connect(self.update_text)
941
+ self.update_thread.start()
942
+
943
+ def update_text(self, text):
944
+ self.text_edit.setHtml(text)
945
+ self.text_edit.verticalScrollBar().setValue(
946
+ self.text_edit.verticalScrollBar().maximum())
947
+
948
+
949
+ def main():
950
+ app = QApplication(sys.argv)
951
+
952
+ dark_stylesheet = """
953
+ QMainWindow {
954
+ background-color: #323232;
955
+ }
956
+ QTextEdit {
957
+ background-color: #1e1e1e;
958
+ color: #ffffff;
959
+ }
960
+ """
961
+ app.setStyleSheet(dark_stylesheet)
962
+
963
+ main_window = MainWindow()
964
+ main_window.show()
965
+
966
+ sys.exit(app.exec())
967
+
968
+
969
+ if __name__ == "__main__":
970
+ main()
requirements.txt ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ absl-py
4
+ aiohttp
5
+ aiosignal
6
+ annotated-types
7
+ anyascii
8
+ anyio
9
+ asttokens
10
+ attrs
11
+ audioread
12
+ av
13
+ azure-cognitiveservices-speech
14
+ Babel
15
+ bangla
16
+ blinker
17
+ blis
18
+ bnnumerizer
19
+ bnunicodenormalizer
20
+ catalogue
21
+ certifi
22
+ cffi
23
+ charset-normalizer
24
+ click
25
+ cloudpathlib
26
+ colorama
27
+ coloredlogs
28
+ comtypes
29
+ confection
30
+ contourpy
31
+ coqpit
32
+ ctranslate2
33
+ cycler
34
+ cymem
35
+ Cython
36
+ dateparser
37
+ decorator
38
+ distro
39
+ docopt
40
+ einops
41
+ elevenlabs
42
+ emoji
43
+ encodec
44
+ enum34
45
+ executing
46
+ faster-whisper
47
+ fastrtc==0.0.25
48
+ ffmpeg-python
49
+ filelock
50
+ Flask
51
+ flatbuffers
52
+ fonttools
53
+ frozenlist
54
+ fsspec
55
+ future
56
+ g2pkk
57
+ grpcio
58
+ gruut
59
+ gruut-ipa
60
+ gruut_lang_de
61
+ gruut_lang_en
62
+ gruut_lang_es
63
+ gruut_lang_fr
64
+ h11
65
+ halo
66
+ hangul-romanize
67
+ httpcore
68
+ httpx
69
+ huggingface-hub
70
+ humanfriendly
71
+ idna
72
+ inflect
73
+ ipython
74
+ itsdangerous
75
+ jamo
76
+ jedi
77
+ jieba
78
+ Jinja2
79
+ joblib
80
+ jsonlines
81
+ kiwisolver
82
+ langcodes
83
+ lazy_loader
84
+ librosa
85
+ llvmlite
86
+ log-symbols
87
+ Markdown
88
+ MarkupSafe
89
+ matplotlib
90
+ matplotlib-inline
91
+ more-itertools
92
+ mpmath
93
+ msgpack
94
+ multidict
95
+ murmurhash
96
+ networkx
97
+ nltk
98
+ num2words
99
+ numba
100
+ numpy
101
+ onnxruntime
102
+ openai
103
+ openai-whisper
104
+ packaging
105
+ pandas
106
+ parso
107
+ pillow
108
+ platformdirs
109
+ pooch
110
+ preshed
111
+ prompt-toolkit
112
+ protobuf
113
+ psutil
114
+ pure-eval
115
+ pvporcupine
116
+ pyannote-audio
117
+ pycparser
118
+ pydantic
119
+ pydantic_core
120
+ pydub
121
+ Pygments
122
+ pynndescent
123
+ pyparsing
124
+ pypinyin
125
+ pyreadline3
126
+ pysbd
127
+ python-crfsuite
128
+ python-dateutil
129
+ pyttsx3
130
+ pytz
131
+ PyYAML
132
+ RealTimeSTT
133
+ regex
134
+ requests
135
+ safetensors
136
+ scikit-learn
137
+ scipy
138
+ six
139
+ smart-open
140
+ sniffio
141
+ SoundCard
142
+ soundfile
143
+ soxr
144
+ spacy
145
+ spacy-legacy
146
+ spacy-loggers
147
+ spinners
148
+ srsly
149
+ stable-ts
150
+ stack-data
151
+ stanza
152
+ stream2sentence
153
+ SudachiDict-core
154
+ SudachiPy
155
+ sympy
156
+ tensorboard
157
+ tensorboard-data-server
158
+ termcolor
159
+ thinc
160
+ threadpoolctl
161
+ tiktoken
162
+ tokenizers
163
+ gradio-webrtc==0.0.8
164
+
165
+ --extra-index-url https://download.pytorch.org/whl/cpu
166
+ gradio
167
+ # … any other non-PyTorch dependencies …
168
+ torch
169
+ torchaudio
170
+ tqdm
171
+ trainer
172
+ traitlets
173
+ transformers
174
+ typer
175
+ typing_extensions
176
+ tzdata
177
+ tzlocal
178
+ umap-learn
179
+ Unidecode
180
+ urllib3
181
+ wasabi
182
+ wcwidth
183
+ weasel
184
+ webrtcvad
185
+ websockets
186
+ Werkzeug
187
+ yarl
188
+ yt-dlp