Saiyaswanth007 commited on
Commit
af81629
·
1 Parent(s): 10c2754

requirements

Browse files
Files changed (2) hide show
  1. app.py +415 -350
  2. requirements.txt +2 -2
app.py CHANGED
@@ -2,452 +2,517 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  import torchaudio
 
 
 
 
 
5
  from scipy.spatial.distance import cosine
 
6
  import tempfile
7
- import os
8
- import warnings
9
- warnings.filterwarnings("ignore", category=UserWarning)
10
 
11
- try:
12
- from transformers import pipeline
13
- except ImportError:
14
- print("transformers not found. Install with: pip install transformers")
15
-
16
- # Configuration
17
- class Config:
18
- # Audio settings
19
- SAMPLE_RATE = 16000
20
-
21
- # Speaker detection
22
- CHANGE_THRESHOLD = 0.65
23
- MAX_SPEAKERS = 4
24
- MIN_SEGMENT_DURATION = 1.0
25
- EMBEDDING_HISTORY_SIZE = 3
26
- SPEAKER_MEMORY_SIZE = 20
27
 
28
- # Console colors for speakers (HTML version)
29
  SPEAKER_COLORS = [
30
  "#FFD700", # Gold
31
  "#FF6B6B", # Red
32
  "#4ECDC4", # Teal
33
  "#45B7D1", # Blue
34
- "#96CEB4", # Mint
35
- "#FFEAA7", # Light Yellow
36
- "#DDA0DD", # Plum
37
- "#98D8C8", # Mint Green
38
  ]
39
 
40
- class SpeakerEncoder:
41
- """Simplified speaker encoder using torchaudio transforms"""
42
-
 
 
 
 
43
  def __init__(self, device="cpu"):
44
  self.device = device
45
  self.embedding_dim = 128
46
- self.model_loaded = False
47
- self._setup_model()
48
-
49
- def _setup_model(self):
50
- """Setup a simple MFCC-based feature extractor"""
51
- try:
52
- self.mfcc_transform = torchaudio.transforms.MFCC(
53
- sample_rate=Config.SAMPLE_RATE,
54
- n_mfcc=13,
55
- melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23}
56
- ).to(self.device)
57
- self.model_loaded = True
58
- print("Simple MFCC-based encoder initialized")
59
- except Exception as e:
60
- print(f"Error setting up encoder: {e}")
61
- self.model_loaded = False
62
-
63
- def extract_embedding(self, audio):
64
- """Extract speaker embedding from audio"""
65
- if not self.model_loaded:
66
- return np.zeros(self.embedding_dim)
67
 
 
 
 
 
 
 
68
  try:
69
- # Ensure audio is float32 and normalized
70
  if isinstance(audio, np.ndarray):
71
- audio = torch.from_numpy(audio).float()
 
 
 
 
 
72
 
73
- # Normalize audio
74
- if audio.abs().max() > 0:
75
- audio = audio / audio.abs().max()
 
 
 
 
 
 
 
76
 
77
- # Add batch dimension if needed
78
- if audio.dim() == 1:
79
- audio = audio.unsqueeze(0)
80
 
81
- # Extract MFCC features
82
- with torch.no_grad():
83
- mfcc = self.mfcc_transform(audio)
84
- # Simple statistics-based embedding
85
- embedding = torch.cat([
86
- mfcc.mean(dim=2).flatten(),
87
- mfcc.std(dim=2).flatten(),
88
- mfcc.max(dim=2)[0].flatten(),
89
- mfcc.min(dim=2)[0].flatten()
90
- ])
91
 
92
- # Pad or truncate to fixed size
93
- if embedding.size(0) > self.embedding_dim:
94
- embedding = embedding[:self.embedding_dim]
95
- elif embedding.size(0) < self.embedding_dim:
96
- padding = torch.zeros(self.embedding_dim - embedding.size(0))
97
- embedding = torch.cat([embedding, padding])
98
-
99
- return embedding.cpu().numpy()
100
 
101
  except Exception as e:
102
  print(f"Error extracting embedding: {e}")
103
- return np.zeros(self.embedding_dim)
104
 
105
- class SpeakerDetector:
106
- """Speaker change detection using embeddings"""
107
-
108
- def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS):
109
- self.threshold = threshold
110
- self.max_speakers = max_speakers
111
- self.current_speaker = 0
112
- self.speaker_embeddings = [[] for _ in range(max_speakers)]
113
- self.speaker_centroids = [None] * max_speakers
114
- self.active_speakers = {0}
115
-
116
- def reset(self):
117
- """Reset speaker detection state"""
118
  self.current_speaker = 0
 
 
 
119
  self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
120
- self.speaker_centroids = [None] * self.max_speakers
121
- self.active_speakers = {0}
122
-
123
- def detect_speaker(self, embedding):
124
- """Detect current speaker from embedding"""
125
- # Initialize first speaker
126
- if not self.speaker_embeddings[0]:
127
- self.speaker_embeddings[0].append(embedding)
128
- self.speaker_centroids[0] = embedding.copy()
129
- return 0, 1.0
130
-
131
- # Calculate similarity with current speaker
132
- current_centroid = self.speaker_centroids[self.current_speaker]
133
- if current_centroid is not None:
134
- similarity = 1.0 - cosine(embedding, current_centroid)
 
 
135
  else:
136
- similarity = 0.0
 
137
 
138
- # Check for speaker change
139
- if similarity < self.threshold:
140
- # Find best matching existing speaker
141
- best_speaker = self.current_speaker
142
- best_similarity = similarity
143
-
144
- for speaker_id in self.active_speakers:
145
- if speaker_id == self.current_speaker:
146
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- centroid = self.speaker_centroids[speaker_id]
149
- if centroid is not None:
150
- sim = 1.0 - cosine(embedding, centroid)
151
- if sim > best_similarity and sim > self.threshold:
152
- best_similarity = sim
153
- best_speaker = speaker_id
154
-
155
- # Create new speaker if no good match and slots available
156
- if (best_speaker == self.current_speaker and
157
- len(self.active_speakers) < self.max_speakers):
158
- for new_id in range(self.max_speakers):
159
- if new_id not in self.active_speakers:
160
- best_speaker = new_id
161
- best_similarity = 0.0
162
- self.active_speakers.add(new_id)
163
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- # Update current speaker if changed
166
- if best_speaker != self.current_speaker:
167
- self.current_speaker = best_speaker
168
- similarity = best_similarity
169
 
170
- # Update speaker model
171
- self._update_speaker_model(self.current_speaker, embedding)
172
  return self.current_speaker, similarity
173
 
174
- def _update_speaker_model(self, speaker_id, embedding):
175
- """Update speaker model with new embedding"""
176
- self.speaker_embeddings[speaker_id].append(embedding)
177
-
178
- # Keep only recent embeddings
179
- if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE:
180
- self.speaker_embeddings[speaker_id] = \
181
- self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:]
182
-
183
- # Update centroid
184
- if self.speaker_embeddings[speaker_id]:
185
- self.speaker_centroids[speaker_id] = np.mean(
186
- self.speaker_embeddings[speaker_id], axis=0
187
- )
188
 
189
- class AudioProcessor:
190
- """Handles audio processing and transcription"""
191
-
192
  def __init__(self):
193
- self.encoder = SpeakerEncoder()
194
- self.detector = SpeakerDetector()
 
 
 
 
195
 
196
- # Initialize Whisper model for transcription
197
  try:
198
- self.transcriber = pipeline(
199
- "automatic-speech-recognition",
200
- model="openai/whisper-base",
201
- chunk_length_s=30,
202
- device=0 if torch.cuda.is_available() else -1
203
- )
204
- print("Whisper model loaded successfully")
205
- except Exception as e:
206
- print(f"Error loading Whisper model: {e}")
207
- self.transcriber = None
208
 
209
- def process_audio_file(self, audio_file):
210
- """Process uploaded audio file"""
211
- if audio_file is None:
212
- return "Please upload an audio file.", ""
213
-
214
  try:
215
- # Reset speaker detection for new file
216
- self.detector.reset()
217
 
218
- # Load audio file
219
- waveform, sample_rate = torchaudio.load(audio_file)
 
220
 
221
- # Convert to mono if stereo
222
- if waveform.shape[0] > 1:
223
- waveform = waveform.mean(dim=0, keepdim=True)
224
 
225
- # Resample to 16kHz if needed
226
- if sample_rate != Config.SAMPLE_RATE:
227
- resampler = torchaudio.transforms.Resample(sample_rate, Config.SAMPLE_RATE)
228
- waveform = resampler(waveform)
229
-
230
- # Convert to numpy
231
- audio_data = waveform.squeeze().numpy()
232
-
233
- # Transcribe entire audio
234
- if self.transcriber:
235
- transcription_result = self.transcriber(audio_file)
236
- full_transcription = transcription_result['text']
237
- else:
238
- full_transcription = "Transcription service unavailable"
239
-
240
- # Process audio in chunks for speaker detection
241
- chunk_duration = 3.0 # 3 second chunks
242
- chunk_samples = int(chunk_duration * Config.SAMPLE_RATE)
243
- results = []
244
 
245
- for i in range(0, len(audio_data), chunk_samples // 2): # 50% overlap
246
- chunk = audio_data[i:i + chunk_samples]
247
-
248
- if len(chunk) < Config.SAMPLE_RATE: # Skip chunks less than 1 second
249
- continue
250
-
251
- # Extract speaker embedding
252
- embedding = self.encoder.extract_embedding(chunk)
253
- speaker_id, similarity = self.detector.detect_speaker(embedding)
254
-
255
- # Get timestamp
256
- start_time = i / Config.SAMPLE_RATE
257
- end_time = (i + len(chunk)) / Config.SAMPLE_RATE
258
-
259
- # Transcribe chunk
260
- if self.transcriber and len(chunk) > Config.SAMPLE_RATE:
261
- # Save chunk temporarily for transcription
262
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
263
- torchaudio.save(tmp_file.name, torch.tensor(chunk).unsqueeze(0), Config.SAMPLE_RATE)
264
- chunk_result = self.transcriber(tmp_file.name)
265
- chunk_text = chunk_result['text'].strip()
266
- os.unlink(tmp_file.name) # Clean up temp file
267
- else:
268
- chunk_text = ""
269
-
270
- if chunk_text: # Only add if there's actual text
271
- results.append({
272
- 'speaker_id': speaker_id,
273
- 'start_time': start_time,
274
- 'end_time': end_time,
275
- 'text': chunk_text,
276
- 'similarity': similarity
277
- })
278
 
279
- # Format results
280
- formatted_output = self._format_results(results)
281
- return formatted_output, full_transcription
282
 
283
  except Exception as e:
284
- return f"Error processing audio: {str(e)}", ""
 
285
 
286
- def _format_results(self, results):
287
- """Format results with speaker colors"""
288
- if not results:
289
- return "No speech detected in the audio file."
290
-
291
- formatted_lines = []
292
- formatted_lines.append("🎤 **Speaker Diarization Results**\n")
293
-
294
- for result in results:
295
- speaker_id = result['speaker_id']
296
- start_time = result['start_time']
297
- end_time = result['end_time']
298
- text = result['text']
299
- similarity = result['similarity']
300
-
301
- color = SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)]
302
-
303
- # Format timestamp
304
- start_min, start_sec = divmod(int(start_time), 60)
305
- end_min, end_sec = divmod(int(end_time), 60)
306
- timestamp = f"[{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}]"
307
-
308
- # Create colored HTML output
309
- formatted_lines.append(
310
- f'<div style="margin-bottom: 10px; padding: 8px; border-left: 4px solid {color}; background-color: {color}20;">'
311
- f'<strong style="color: {color};">Speaker {speaker_id + 1}</strong> '
312
- f'<span style="color: #666; font-size: 0.9em;">{timestamp}</span><br>'
313
- f'<span style="color: #333;">{text}</span>'
314
- f'</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  )
316
 
317
- return "".join(formatted_lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- # Global processor instance
320
- processor = AudioProcessor()
321
 
322
- def process_audio(audio_file, sensitivity):
323
- """Process audio file with speaker detection"""
324
- if audio_file is None:
325
- return "Please upload an audio file.", ""
326
 
327
- # Update sensitivity
328
- processor.detector.threshold = sensitivity
329
 
330
- # Process the audio
331
- diarized_output, full_transcription = processor.process_audio_file(audio_file)
332
 
333
- return diarized_output, full_transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- # Create Gradio interface
336
  def create_interface():
337
  """Create Gradio interface"""
338
 
339
  with gr.Blocks(
 
340
  theme=gr.themes.Soft(),
341
- title="Speaker Diarization & Transcription",
342
  css="""
343
- .gradio-container {
344
- max-width: 1200px !important;
 
 
 
 
345
  }
346
- .speaker-output {
347
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
 
 
348
  }
349
  """
350
  ) as demo:
351
 
352
  gr.Markdown(
353
  """
354
- # 🎙️ Speaker Diarization & Transcription
 
 
 
355
 
356
- Upload an audio file to automatically detect different speakers and transcribe their speech.
357
- The system will identify speaker changes and display each speaker's text in different colors.
 
 
 
358
  """
359
  )
360
 
361
  with gr.Row():
362
- with gr.Column(scale=1):
 
 
 
 
 
 
 
363
  audio_input = gr.Audio(
364
- label="Upload Audio File",
365
- type="filepath",
366
- sources=["upload", "microphone"]
 
367
  )
368
 
369
- sensitivity_slider = gr.Slider(
 
 
 
 
370
  minimum=0.1,
371
- maximum=1.0,
372
- value=0.65,
373
  step=0.05,
374
- label="Speaker Change Sensitivity",
375
- info="Lower values = more sensitive to speaker changes"
376
  )
377
 
378
- process_btn = gr.Button("🎯 Process Audio", variant="primary", size="lg")
 
 
 
 
 
 
 
379
 
380
- gr.Markdown(
381
- """
382
- ### Instructions:
383
- 1. Upload an audio file (WAV, MP3, etc.)
384
- 2. Adjust sensitivity if needed
385
- 3. Click "Process Audio"
386
- 4. View results with speaker colors
387
-
388
- ### Tips:
389
- - Works best with clear speech
390
- - Supports multiple file formats
391
- - Different speakers shown in different colors
392
- - Processing may take a moment for longer files
393
- """
394
  )
395
-
396
- with gr.Column(scale=2):
397
- with gr.Tabs():
398
- with gr.TabItem("🎨 Speaker Diarization"):
399
- diarized_output = gr.HTML(
400
- label="Speaker Diarization Results",
401
- elem_classes=["speaker-output"]
402
- )
403
-
404
- with gr.TabItem("📝 Full Transcription"):
405
- full_transcription = gr.Textbox(
406
- label="Complete Transcription",
407
- lines=15,
408
- max_lines=20,
409
- show_copy_button=True
410
- )
411
 
412
  # Event handlers
413
- process_btn.click(
414
- fn=process_audio,
415
- inputs=[audio_input, sensitivity_slider],
416
- outputs=[diarized_output, full_transcription],
417
- show_progress=True
418
  )
419
 
420
- # Auto-process when audio is uploaded
421
- audio_input.change(
422
- fn=process_audio,
423
- inputs=[audio_input, sensitivity_slider],
424
- outputs=[diarized_output, full_transcription],
425
- show_progress=True
426
  )
427
 
428
- gr.Markdown(
429
- """
430
- ---
431
- ### About
432
- This application uses:
433
- - **MFCC features** for speaker embedding extraction
434
- - **Cosine similarity** for speaker change detection
435
- - **OpenAI Whisper** for speech-to-text transcription
436
- - **Gradio** for the web interface
437
-
438
- **Note**: This is a simplified speaker diarization system. For production use,
439
- consider more advanced speaker embedding models like speechbrain or pyannote.audio.
440
- """
441
  )
442
 
443
  return demo
444
 
445
- # Create and launch the interface
446
  if __name__ == "__main__":
 
447
  demo = create_interface()
448
  demo.launch(
449
  server_name="0.0.0.0",
450
  server_port=7860,
451
- share=False,
452
- show_error=True
453
  )
 
2
  import numpy as np
3
  import torch
4
  import torchaudio
5
+ import threading
6
+ import queue
7
+ import time
8
+ import os
9
+ import urllib.request
10
  from scipy.spatial.distance import cosine
11
+ from collections import deque
12
  import tempfile
13
+ import librosa
 
 
14
 
15
+ # Configuration parameters
16
+ FINAL_TRANSCRIPTION_MODEL = "openai/whisper-small"
17
+ TRANSCRIPTION_LANGUAGE = "en"
18
+ DEFAULT_CHANGE_THRESHOLD = 0.7
19
+ EMBEDDING_HISTORY_SIZE = 5
20
+ MIN_SEGMENT_DURATION = 1.0
21
+ DEFAULT_MAX_SPEAKERS = 4
22
+ ABSOLUTE_MAX_SPEAKERS = 6
23
+ SAMPLE_RATE = 16000
 
 
 
 
 
 
 
24
 
25
+ # Speaker colors for up to 6 speakers
26
  SPEAKER_COLORS = [
27
  "#FFD700", # Gold
28
  "#FF6B6B", # Red
29
  "#4ECDC4", # Teal
30
  "#45B7D1", # Blue
31
+ "#96CEB4", # Green
32
+ "#FFEAA7", # Yellow
 
 
33
  ]
34
 
35
+ SPEAKER_COLOR_NAMES = [
36
+ "Gold", "Red", "Teal", "Blue", "Green", "Yellow"
37
+ ]
38
+
39
+
40
+ class SpeechBrainEncoder:
41
+ """Simplified encoder for speaker embeddings using torch audio features"""
42
  def __init__(self, device="cpu"):
43
  self.device = device
44
  self.embedding_dim = 128
45
+ self.model_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ def load_model(self):
48
+ """Model loading simulation"""
49
+ return True
50
+
51
+ def embed_utterance(self, audio, sr=16000):
52
+ """Extract simple spectral features as speaker embedding"""
53
  try:
 
54
  if isinstance(audio, np.ndarray):
55
+ waveform = torch.tensor(audio, dtype=torch.float32)
56
+ else:
57
+ waveform = audio
58
+
59
+ if len(waveform.shape) == 1:
60
+ waveform = waveform.unsqueeze(0)
61
 
62
+ # Resample if needed
63
+ if sr != 16000:
64
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
65
+
66
+ # Extract MFCC features as a simple embedding
67
+ mfcc_transform = torchaudio.transforms.MFCC(
68
+ sample_rate=16000,
69
+ n_mfcc=13,
70
+ melkwargs={'n_mels': 40}
71
+ )
72
 
73
+ mfcc = mfcc_transform(waveform)
74
+ # Take mean across time dimension and flatten
75
+ embedding = mfcc.mean(dim=2).flatten()
76
 
77
+ # Pad or truncate to fixed size
78
+ if len(embedding) > self.embedding_dim:
79
+ embedding = embedding[:self.embedding_dim]
80
+ elif len(embedding) < self.embedding_dim:
81
+ padding = torch.zeros(self.embedding_dim - len(embedding))
82
+ embedding = torch.cat([embedding, padding])
 
 
 
 
83
 
84
+ return embedding.numpy()
 
 
 
 
 
 
 
85
 
86
  except Exception as e:
87
  print(f"Error extracting embedding: {e}")
88
+ return np.random.randn(self.embedding_dim)
89
 
90
+
91
+ class SpeakerChangeDetector:
92
+ """Speaker change detector for real-time diarization"""
93
+ def __init__(self, embedding_dim=128, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS):
94
+ self.embedding_dim = embedding_dim
95
+ self.change_threshold = change_threshold
96
+ self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
 
 
 
 
 
 
97
  self.current_speaker = 0
98
+ self.previous_embeddings = []
99
+ self.last_change_time = time.time()
100
+ self.mean_embeddings = [None] * self.max_speakers
101
  self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
102
+ self.last_similarity = 0.0
103
+ self.active_speakers = set([0])
104
+
105
+ def set_max_speakers(self, max_speakers):
106
+ """Update the maximum number of speakers"""
107
+ new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS)
108
+
109
+ if new_max < self.max_speakers:
110
+ for speaker_id in list(self.active_speakers):
111
+ if speaker_id >= new_max:
112
+ self.active_speakers.discard(speaker_id)
113
+ if self.current_speaker >= new_max:
114
+ self.current_speaker = 0
115
+
116
+ if new_max > self.max_speakers:
117
+ self.mean_embeddings.extend([None] * (new_max - self.max_speakers))
118
+ self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)])
119
  else:
120
+ self.mean_embeddings = self.mean_embeddings[:new_max]
121
+ self.speaker_embeddings = self.speaker_embeddings[:new_max]
122
 
123
+ self.max_speakers = new_max
124
+
125
+ def set_change_threshold(self, threshold):
126
+ """Update the threshold for detecting speaker changes"""
127
+ self.change_threshold = max(0.1, min(threshold, 0.99))
128
+
129
+ def add_embedding(self, embedding, timestamp=None):
130
+ """Add a new embedding and check if there's a speaker change"""
131
+ current_time = timestamp or time.time()
132
+
133
+ if not self.previous_embeddings:
134
+ self.previous_embeddings.append(embedding)
135
+ self.speaker_embeddings[self.current_speaker].append(embedding)
136
+ if self.mean_embeddings[self.current_speaker] is None:
137
+ self.mean_embeddings[self.current_speaker] = embedding.copy()
138
+ return self.current_speaker, 1.0
139
+
140
+ current_mean = self.mean_embeddings[self.current_speaker]
141
+ if current_mean is not None:
142
+ similarity = 1.0 - cosine(embedding, current_mean)
143
+ else:
144
+ similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1])
145
+
146
+ self.last_similarity = similarity
147
+
148
+ time_since_last_change = current_time - self.last_change_time
149
+ is_speaker_change = False
150
+
151
+ if time_since_last_change >= MIN_SEGMENT_DURATION:
152
+ if similarity < self.change_threshold:
153
+ best_speaker = self.current_speaker
154
+ best_similarity = similarity
155
+
156
+ for speaker_id in range(self.max_speakers):
157
+ if speaker_id == self.current_speaker:
158
+ continue
159
+
160
+ speaker_mean = self.mean_embeddings[speaker_id]
161
 
162
+ if speaker_mean is not None:
163
+ speaker_similarity = 1.0 - cosine(embedding, speaker_mean)
164
+ if speaker_similarity > best_similarity:
165
+ best_similarity = speaker_similarity
166
+ best_speaker = speaker_id
167
+
168
+ if best_speaker != self.current_speaker:
169
+ is_speaker_change = True
170
+ self.current_speaker = best_speaker
171
+ elif len(self.active_speakers) < self.max_speakers:
172
+ for new_id in range(self.max_speakers):
173
+ if new_id not in self.active_speakers:
174
+ is_speaker_change = True
175
+ self.current_speaker = new_id
176
+ self.active_speakers.add(new_id)
177
+ break
178
+
179
+ if is_speaker_change:
180
+ self.last_change_time = current_time
181
+
182
+ self.previous_embeddings.append(embedding)
183
+ if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE:
184
+ self.previous_embeddings.pop(0)
185
+
186
+ self.speaker_embeddings[self.current_speaker].append(embedding)
187
+ self.active_speakers.add(self.current_speaker)
188
+
189
+ if len(self.speaker_embeddings[self.current_speaker]) > 30:
190
+ self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:]
191
 
192
+ if self.speaker_embeddings[self.current_speaker]:
193
+ self.mean_embeddings[self.current_speaker] = np.mean(
194
+ self.speaker_embeddings[self.current_speaker], axis=0
195
+ )
196
 
 
 
197
  return self.current_speaker, similarity
198
 
199
+ def get_color_for_speaker(self, speaker_id):
200
+ """Return color for speaker ID"""
201
+ if 0 <= speaker_id < len(SPEAKER_COLORS):
202
+ return SPEAKER_COLORS[speaker_id]
203
+ return "#FFFFFF"
 
 
 
 
 
 
 
 
 
204
 
205
+
206
+ class RealTimeASRDiarization:
207
+ """Main class for real-time ASR with speaker diarization"""
208
  def __init__(self):
209
+ self.encoder = SpeechBrainEncoder()
210
+ self.encoder.load_model()
211
+ self.speaker_detector = SpeakerChangeDetector()
212
+ self.transcription_queue = queue.Queue()
213
+ self.conversation_history = []
214
+ self.is_processing = False
215
 
216
+ # Load Whisper model
217
  try:
218
+ import whisper
219
+ self.whisper_model = whisper.load_model("base")
220
+ except ImportError:
221
+ print("Whisper not available, using mock transcription")
222
+ self.whisper_model = None
 
 
 
 
 
223
 
224
+ def transcribe_audio(self, audio_data, sr=16000):
225
+ """Transcribe audio using Whisper"""
 
 
 
226
  try:
227
+ if self.whisper_model is None:
228
+ return "Mock transcription: Hello, this is a test."
229
 
230
+ # Ensure audio is the right format
231
+ if isinstance(audio_data, tuple):
232
+ sr, audio_data = audio_data
233
 
234
+ if len(audio_data.shape) > 1:
235
+ audio_data = audio_data.mean(axis=1)
 
236
 
237
+ # Normalize audio
238
+ audio_data = audio_data.astype(np.float32)
239
+ if np.abs(audio_data).max() > 1.0:
240
+ audio_data = audio_data / np.abs(audio_data).max()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ # Resample to 16kHz if needed
243
+ if sr != 16000:
244
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ # Transcribe
247
+ result = self.whisper_model.transcribe(audio_data, language="en")
248
+ return result["text"].strip()
249
 
250
  except Exception as e:
251
+ print(f"Transcription error: {e}")
252
+ return ""
253
 
254
+ def extract_speaker_embedding(self, audio_data, sr=16000):
255
+ """Extract speaker embedding from audio"""
256
+ return self.encoder.embed_utterance(audio_data, sr)
257
+
258
+ def process_audio_segment(self, audio_data, sr=16000):
259
+ """Process an audio segment for transcription and speaker identification"""
260
+ if len(audio_data) < sr * 0.5: # Skip very short segments
261
+ return None, None, None
262
+
263
+ # Transcribe the audio
264
+ transcription = self.transcribe_audio(audio_data, sr)
265
+
266
+ if not transcription:
267
+ return None, None, None
268
+
269
+ # Extract speaker embedding
270
+ embedding = self.extract_speaker_embedding(audio_data, sr)
271
+
272
+ # Detect speaker
273
+ speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
274
+
275
+ return transcription, speaker_id, similarity
276
+
277
+ def update_conversation(self, transcription, speaker_id):
278
+ """Update conversation history with new transcription"""
279
+ speaker_name = f"Speaker {speaker_id + 1}"
280
+ color = self.speaker_detector.get_color_for_speaker(speaker_id)
281
+
282
+ entry = {
283
+ "speaker": speaker_name,
284
+ "text": transcription,
285
+ "color": color,
286
+ "timestamp": time.time()
287
+ }
288
+
289
+ self.conversation_history.append(entry)
290
+ return entry
291
+
292
+ def format_conversation_html(self):
293
+ """Format conversation history as HTML"""
294
+ if not self.conversation_history:
295
+ return "<p><i>No conversation yet. Start speaking to see real-time transcription with speaker diarization.</i></p>"
296
+
297
+ html_parts = []
298
+ for entry in self.conversation_history:
299
+ html_parts.append(
300
+ f'<p><span style="color: {entry["color"]}; font-weight: bold;">'
301
+ f'{entry["speaker"]}:</span> {entry["text"]}</p>'
302
  )
303
 
304
+ return "".join(html_parts)
305
+
306
+ def get_status_info(self):
307
+ """Get current status information"""
308
+ status = {
309
+ "active_speakers": len(self.speaker_detector.active_speakers),
310
+ "max_speakers": self.speaker_detector.max_speakers,
311
+ "current_speaker": self.speaker_detector.current_speaker + 1,
312
+ "total_segments": len(self.conversation_history),
313
+ "threshold": self.speaker_detector.change_threshold
314
+ }
315
+ return status
316
+
317
+ def clear_conversation(self):
318
+ """Clear conversation history and reset speaker detector"""
319
+ self.conversation_history = []
320
+ self.speaker_detector = SpeakerChangeDetector(
321
+ change_threshold=self.speaker_detector.change_threshold,
322
+ max_speakers=self.speaker_detector.max_speakers
323
+ )
324
+
325
+ def set_parameters(self, threshold, max_speakers):
326
+ """Update parameters"""
327
+ self.speaker_detector.set_change_threshold(threshold)
328
+ self.speaker_detector.set_max_speakers(max_speakers)
329
+
330
+
331
+ # Global instance
332
+ asr_system = RealTimeASRDiarization()
333
 
 
 
334
 
335
+ def process_audio_realtime(audio_data, threshold, max_speakers):
336
+ """Process audio in real-time"""
337
+ global asr_system
 
338
 
339
+ if audio_data is None:
340
+ return asr_system.format_conversation_html(), get_status_display()
341
 
342
+ # Update parameters
343
+ asr_system.set_parameters(threshold, max_speakers)
344
 
345
+ try:
346
+ # Process the audio segment
347
+ sr, audio_array = audio_data
348
+
349
+ # Convert to float32 and normalize
350
+ if audio_array.dtype != np.float32:
351
+ audio_array = audio_array.astype(np.float32)
352
+ if audio_array.dtype == np.int16:
353
+ audio_array = audio_array / 32768.0
354
+ elif audio_array.dtype == np.int32:
355
+ audio_array = audio_array / 2147483648.0
356
+
357
+ # Process the audio segment
358
+ transcription, speaker_id, similarity = asr_system.process_audio_segment(audio_array, sr)
359
+
360
+ if transcription and speaker_id is not None:
361
+ # Update conversation
362
+ asr_system.update_conversation(transcription, speaker_id)
363
+
364
+ except Exception as e:
365
+ print(f"Error processing audio: {e}")
366
+
367
+ return asr_system.format_conversation_html(), get_status_display()
368
+
369
+
370
+ def get_status_display():
371
+ """Get formatted status display"""
372
+ status = asr_system.get_status_info()
373
+
374
+ status_html = f"""
375
+ <div style="font-family: monospace; font-size: 12px;">
376
+ <strong>Status:</strong><br>
377
+ Current Speaker: {status['current_speaker']}<br>
378
+ Active Speakers: {status['active_speakers']} / {status['max_speakers']}<br>
379
+ Total Segments: {status['total_segments']}<br>
380
+ Threshold: {status['threshold']:.2f}<br>
381
+ </div>
382
+ """
383
+
384
+ return status_html
385
+
386
+
387
+ def clear_conversation():
388
+ """Clear the conversation"""
389
+ global asr_system
390
+ asr_system.clear_conversation()
391
+ return asr_system.format_conversation_html(), get_status_display()
392
+
393
 
 
394
  def create_interface():
395
  """Create Gradio interface"""
396
 
397
  with gr.Blocks(
398
+ title="Real-time ASR with Speaker Diarization",
399
  theme=gr.themes.Soft(),
 
400
  css="""
401
+ .conversation-box {
402
+ height: 400px;
403
+ overflow-y: auto;
404
+ border: 1px solid #ddd;
405
+ padding: 10px;
406
+ background-color: #f9f9f9;
407
  }
408
+ .status-box {
409
+ border: 1px solid #ccc;
410
+ padding: 10px;
411
+ background-color: #f0f0f0;
412
  }
413
  """
414
  ) as demo:
415
 
416
  gr.Markdown(
417
  """
418
+ # 🎤 Real-time ASR with Live Speaker Diarization
419
+
420
+ This application provides real-time speech recognition with speaker diarization.
421
+ It can distinguish between different speakers and display their conversations in different colors.
422
 
423
+ **Instructions:**
424
+ 1. Adjust the speaker change threshold and maximum speakers
425
+ 2. Click the microphone button to start recording
426
+ 3. Speak naturally - the system will detect speaker changes and transcribe speech
427
+ 4. Each speaker will be assigned a different color
428
  """
429
  )
430
 
431
  with gr.Row():
432
+ with gr.Column(scale=3):
433
+ # Main conversation display
434
+ conversation_display = gr.HTML(
435
+ value="<p><i>Click the microphone to start recording...</i></p>",
436
+ elem_classes=["conversation-box"]
437
+ )
438
+
439
+ # Audio input
440
  audio_input = gr.Audio(
441
+ source="microphone",
442
+ type="numpy",
443
+ streaming=True,
444
+ label="🎤 Microphone Input"
445
  )
446
 
447
+ with gr.Column(scale=1):
448
+ # Controls
449
+ gr.Markdown("### Controls")
450
+
451
+ threshold_slider = gr.Slider(
452
  minimum=0.1,
453
+ maximum=0.9,
454
+ value=DEFAULT_CHANGE_THRESHOLD,
455
  step=0.05,
456
+ label="Speaker Change Threshold",
457
+ info="Higher values = less sensitive to speaker changes"
458
  )
459
 
460
+ max_speakers_slider = gr.Slider(
461
+ minimum=2,
462
+ maximum=ABSOLUTE_MAX_SPEAKERS,
463
+ value=DEFAULT_MAX_SPEAKERS,
464
+ step=1,
465
+ label="Maximum Speakers",
466
+ info="Maximum number of different speakers to detect"
467
+ )
468
 
469
+ clear_btn = gr.Button("🗑️ Clear Conversation", variant="secondary")
470
+
471
+ # Status display
472
+ gr.Markdown("### Status")
473
+ status_display = gr.HTML(
474
+ value=get_status_display(),
475
+ elem_classes=["status-box"]
 
 
 
 
 
 
 
476
  )
477
+
478
+ # Speaker color legend
479
+ gr.Markdown("### Speaker Colors")
480
+ legend_html = ""
481
+ for i in range(ABSOLUTE_MAX_SPEAKERS):
482
+ color = SPEAKER_COLORS[i]
483
+ name = SPEAKER_COLOR_NAMES[i]
484
+ legend_html += f'<p><span style="color: {color}; font-weight: bold;">● Speaker {i+1} ({name})</span></p>'
485
+
486
+ gr.HTML(legend_html)
 
 
 
 
 
 
487
 
488
  # Event handlers
489
+ audio_input.change(
490
+ fn=process_audio_realtime,
491
+ inputs=[audio_input, threshold_slider, max_speakers_slider],
492
+ outputs=[conversation_display, status_display],
493
+ show_progress=False
494
  )
495
 
496
+ clear_btn.click(
497
+ fn=clear_conversation,
498
+ outputs=[conversation_display, status_display]
 
 
 
499
  )
500
 
501
+ # Update status periodically
502
+ demo.load(
503
+ fn=lambda: (asr_system.format_conversation_html(), get_status_display()),
504
+ outputs=[conversation_display, status_display],
505
+ every=2
 
 
 
 
 
 
 
 
506
  )
507
 
508
  return demo
509
 
510
+
511
  if __name__ == "__main__":
512
+ # Create and launch the interface
513
  demo = create_interface()
514
  demo.launch(
515
  server_name="0.0.0.0",
516
  server_port=7860,
517
+ share=True
 
518
  )
requirements.txt CHANGED
@@ -128,7 +128,7 @@ pytz==2024.1
128
  PyYAML==6.0.1
129
  RealTimeSTT==0.1.13
130
  regex==2023.12.25
131
- requests==2.31.0
132
  safetensors==0.4.2
133
  scikit-learn==1.4.1.post1
134
  scipy==1.15.2
@@ -163,7 +163,7 @@ absl-py==2.1.0
163
  # … any other non-PyTorch dependencies …
164
  torch==2.2.2+cpu
165
  torchaudio==2.2.2+cpu
166
- tqdm==4.66.2
167
  trainer==0.0.36
168
  traitlets==5.14.2
169
  transformers==4.39.2
 
128
  PyYAML==6.0.1
129
  RealTimeSTT==0.1.13
130
  regex==2023.12.25
131
+ requests==2.32.3
132
  safetensors==0.4.2
133
  scikit-learn==1.4.1.post1
134
  scipy==1.15.2
 
163
  # … any other non-PyTorch dependencies …
164
  torch==2.2.2+cpu
165
  torchaudio==2.2.2+cpu
166
+ tqdm==4.67.1
167
  trainer==0.0.36
168
  traitlets==5.14.2
169
  transformers==4.39.2