Saiyaswanth007 commited on
Commit
66992f6
·
1 Parent(s): 4611564

Code Update

Browse files
Files changed (1) hide show
  1. app.py +453 -0
app.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ from scipy.spatial.distance import cosine
6
+ import tempfile
7
+ import os
8
+ import warnings
9
+ warnings.filterwarnings("ignore", category=UserWarning)
10
+
11
+ try:
12
+ from transformers import pipeline
13
+ except ImportError:
14
+ print("transformers not found. Install with: pip install transformers")
15
+
16
+ # Configuration
17
+ class Config:
18
+ # Audio settings
19
+ SAMPLE_RATE = 16000
20
+
21
+ # Speaker detection
22
+ CHANGE_THRESHOLD = 0.65
23
+ MAX_SPEAKERS = 4
24
+ MIN_SEGMENT_DURATION = 1.0
25
+ EMBEDDING_HISTORY_SIZE = 3
26
+ SPEAKER_MEMORY_SIZE = 20
27
+
28
+ # Console colors for speakers (HTML version)
29
+ SPEAKER_COLORS = [
30
+ "#FFD700", # Gold
31
+ "#FF6B6B", # Red
32
+ "#4ECDC4", # Teal
33
+ "#45B7D1", # Blue
34
+ "#96CEB4", # Mint
35
+ "#FFEAA7", # Light Yellow
36
+ "#DDA0DD", # Plum
37
+ "#98D8C8", # Mint Green
38
+ ]
39
+
40
+ class SpeakerEncoder:
41
+ """Simplified speaker encoder using torchaudio transforms"""
42
+
43
+ def __init__(self, device="cpu"):
44
+ self.device = device
45
+ self.embedding_dim = 128
46
+ self.model_loaded = False
47
+ self._setup_model()
48
+
49
+ def _setup_model(self):
50
+ """Setup a simple MFCC-based feature extractor"""
51
+ try:
52
+ self.mfcc_transform = torchaudio.transforms.MFCC(
53
+ sample_rate=Config.SAMPLE_RATE,
54
+ n_mfcc=13,
55
+ melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23}
56
+ ).to(self.device)
57
+ self.model_loaded = True
58
+ print("Simple MFCC-based encoder initialized")
59
+ except Exception as e:
60
+ print(f"Error setting up encoder: {e}")
61
+ self.model_loaded = False
62
+
63
+ def extract_embedding(self, audio):
64
+ """Extract speaker embedding from audio"""
65
+ if not self.model_loaded:
66
+ return np.zeros(self.embedding_dim)
67
+
68
+ try:
69
+ # Ensure audio is float32 and normalized
70
+ if isinstance(audio, np.ndarray):
71
+ audio = torch.from_numpy(audio).float()
72
+
73
+ # Normalize audio
74
+ if audio.abs().max() > 0:
75
+ audio = audio / audio.abs().max()
76
+
77
+ # Add batch dimension if needed
78
+ if audio.dim() == 1:
79
+ audio = audio.unsqueeze(0)
80
+
81
+ # Extract MFCC features
82
+ with torch.no_grad():
83
+ mfcc = self.mfcc_transform(audio)
84
+ # Simple statistics-based embedding
85
+ embedding = torch.cat([
86
+ mfcc.mean(dim=2).flatten(),
87
+ mfcc.std(dim=2).flatten(),
88
+ mfcc.max(dim=2)[0].flatten(),
89
+ mfcc.min(dim=2)[0].flatten()
90
+ ])
91
+
92
+ # Pad or truncate to fixed size
93
+ if embedding.size(0) > self.embedding_dim:
94
+ embedding = embedding[:self.embedding_dim]
95
+ elif embedding.size(0) < self.embedding_dim:
96
+ padding = torch.zeros(self.embedding_dim - embedding.size(0))
97
+ embedding = torch.cat([embedding, padding])
98
+
99
+ return embedding.cpu().numpy()
100
+
101
+ except Exception as e:
102
+ print(f"Error extracting embedding: {e}")
103
+ return np.zeros(self.embedding_dim)
104
+
105
+ class SpeakerDetector:
106
+ """Speaker change detection using embeddings"""
107
+
108
+ def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS):
109
+ self.threshold = threshold
110
+ self.max_speakers = max_speakers
111
+ self.current_speaker = 0
112
+ self.speaker_embeddings = [[] for _ in range(max_speakers)]
113
+ self.speaker_centroids = [None] * max_speakers
114
+ self.active_speakers = {0}
115
+
116
+ def reset(self):
117
+ """Reset speaker detection state"""
118
+ self.current_speaker = 0
119
+ self.speaker_embeddings = [[] for _ in range(self.max_speakers)]
120
+ self.speaker_centroids = [None] * self.max_speakers
121
+ self.active_speakers = {0}
122
+
123
+ def detect_speaker(self, embedding):
124
+ """Detect current speaker from embedding"""
125
+ # Initialize first speaker
126
+ if not self.speaker_embeddings[0]:
127
+ self.speaker_embeddings[0].append(embedding)
128
+ self.speaker_centroids[0] = embedding.copy()
129
+ return 0, 1.0
130
+
131
+ # Calculate similarity with current speaker
132
+ current_centroid = self.speaker_centroids[self.current_speaker]
133
+ if current_centroid is not None:
134
+ similarity = 1.0 - cosine(embedding, current_centroid)
135
+ else:
136
+ similarity = 0.0
137
+
138
+ # Check for speaker change
139
+ if similarity < self.threshold:
140
+ # Find best matching existing speaker
141
+ best_speaker = self.current_speaker
142
+ best_similarity = similarity
143
+
144
+ for speaker_id in self.active_speakers:
145
+ if speaker_id == self.current_speaker:
146
+ continue
147
+
148
+ centroid = self.speaker_centroids[speaker_id]
149
+ if centroid is not None:
150
+ sim = 1.0 - cosine(embedding, centroid)
151
+ if sim > best_similarity and sim > self.threshold:
152
+ best_similarity = sim
153
+ best_speaker = speaker_id
154
+
155
+ # Create new speaker if no good match and slots available
156
+ if (best_speaker == self.current_speaker and
157
+ len(self.active_speakers) < self.max_speakers):
158
+ for new_id in range(self.max_speakers):
159
+ if new_id not in self.active_speakers:
160
+ best_speaker = new_id
161
+ best_similarity = 0.0
162
+ self.active_speakers.add(new_id)
163
+ break
164
+
165
+ # Update current speaker if changed
166
+ if best_speaker != self.current_speaker:
167
+ self.current_speaker = best_speaker
168
+ similarity = best_similarity
169
+
170
+ # Update speaker model
171
+ self._update_speaker_model(self.current_speaker, embedding)
172
+ return self.current_speaker, similarity
173
+
174
+ def _update_speaker_model(self, speaker_id, embedding):
175
+ """Update speaker model with new embedding"""
176
+ self.speaker_embeddings[speaker_id].append(embedding)
177
+
178
+ # Keep only recent embeddings
179
+ if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE:
180
+ self.speaker_embeddings[speaker_id] = \
181
+ self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:]
182
+
183
+ # Update centroid
184
+ if self.speaker_embeddings[speaker_id]:
185
+ self.speaker_centroids[speaker_id] = np.mean(
186
+ self.speaker_embeddings[speaker_id], axis=0
187
+ )
188
+
189
+ class AudioProcessor:
190
+ """Handles audio processing and transcription"""
191
+
192
+ def __init__(self):
193
+ self.encoder = SpeakerEncoder()
194
+ self.detector = SpeakerDetector()
195
+
196
+ # Initialize Whisper model for transcription
197
+ try:
198
+ self.transcriber = pipeline(
199
+ "automatic-speech-recognition",
200
+ model="openai/whisper-base",
201
+ chunk_length_s=30,
202
+ device=0 if torch.cuda.is_available() else -1
203
+ )
204
+ print("Whisper model loaded successfully")
205
+ except Exception as e:
206
+ print(f"Error loading Whisper model: {e}")
207
+ self.transcriber = None
208
+
209
+ def process_audio_file(self, audio_file):
210
+ """Process uploaded audio file"""
211
+ if audio_file is None:
212
+ return "Please upload an audio file.", ""
213
+
214
+ try:
215
+ # Reset speaker detection for new file
216
+ self.detector.reset()
217
+
218
+ # Load audio file
219
+ waveform, sample_rate = torchaudio.load(audio_file)
220
+
221
+ # Convert to mono if stereo
222
+ if waveform.shape[0] > 1:
223
+ waveform = waveform.mean(dim=0, keepdim=True)
224
+
225
+ # Resample to 16kHz if needed
226
+ if sample_rate != Config.SAMPLE_RATE:
227
+ resampler = torchaudio.transforms.Resample(sample_rate, Config.SAMPLE_RATE)
228
+ waveform = resampler(waveform)
229
+
230
+ # Convert to numpy
231
+ audio_data = waveform.squeeze().numpy()
232
+
233
+ # Transcribe entire audio
234
+ if self.transcriber:
235
+ transcription_result = self.transcriber(audio_file)
236
+ full_transcription = transcription_result['text']
237
+ else:
238
+ full_transcription = "Transcription service unavailable"
239
+
240
+ # Process audio in chunks for speaker detection
241
+ chunk_duration = 3.0 # 3 second chunks
242
+ chunk_samples = int(chunk_duration * Config.SAMPLE_RATE)
243
+ results = []
244
+
245
+ for i in range(0, len(audio_data), chunk_samples // 2): # 50% overlap
246
+ chunk = audio_data[i:i + chunk_samples]
247
+
248
+ if len(chunk) < Config.SAMPLE_RATE: # Skip chunks less than 1 second
249
+ continue
250
+
251
+ # Extract speaker embedding
252
+ embedding = self.encoder.extract_embedding(chunk)
253
+ speaker_id, similarity = self.detector.detect_speaker(embedding)
254
+
255
+ # Get timestamp
256
+ start_time = i / Config.SAMPLE_RATE
257
+ end_time = (i + len(chunk)) / Config.SAMPLE_RATE
258
+
259
+ # Transcribe chunk
260
+ if self.transcriber and len(chunk) > Config.SAMPLE_RATE:
261
+ # Save chunk temporarily for transcription
262
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
263
+ torchaudio.save(tmp_file.name, torch.tensor(chunk).unsqueeze(0), Config.SAMPLE_RATE)
264
+ chunk_result = self.transcriber(tmp_file.name)
265
+ chunk_text = chunk_result['text'].strip()
266
+ os.unlink(tmp_file.name) # Clean up temp file
267
+ else:
268
+ chunk_text = ""
269
+
270
+ if chunk_text: # Only add if there's actual text
271
+ results.append({
272
+ 'speaker_id': speaker_id,
273
+ 'start_time': start_time,
274
+ 'end_time': end_time,
275
+ 'text': chunk_text,
276
+ 'similarity': similarity
277
+ })
278
+
279
+ # Format results
280
+ formatted_output = self._format_results(results)
281
+ return formatted_output, full_transcription
282
+
283
+ except Exception as e:
284
+ return f"Error processing audio: {str(e)}", ""
285
+
286
+ def _format_results(self, results):
287
+ """Format results with speaker colors"""
288
+ if not results:
289
+ return "No speech detected in the audio file."
290
+
291
+ formatted_lines = []
292
+ formatted_lines.append("🎤 **Speaker Diarization Results**\n")
293
+
294
+ for result in results:
295
+ speaker_id = result['speaker_id']
296
+ start_time = result['start_time']
297
+ end_time = result['end_time']
298
+ text = result['text']
299
+ similarity = result['similarity']
300
+
301
+ color = SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)]
302
+
303
+ # Format timestamp
304
+ start_min, start_sec = divmod(int(start_time), 60)
305
+ end_min, end_sec = divmod(int(end_time), 60)
306
+ timestamp = f"[{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}]"
307
+
308
+ # Create colored HTML output
309
+ formatted_lines.append(
310
+ f'<div style="margin-bottom: 10px; padding: 8px; border-left: 4px solid {color}; background-color: {color}20;">'
311
+ f'<strong style="color: {color};">Speaker {speaker_id + 1}</strong> '
312
+ f'<span style="color: #666; font-size: 0.9em;">{timestamp}</span><br>'
313
+ f'<span style="color: #333;">{text}</span>'
314
+ f'</div>'
315
+ )
316
+
317
+ return "".join(formatted_lines)
318
+
319
+ # Global processor instance
320
+ processor = AudioProcessor()
321
+
322
+ def process_audio(audio_file, sensitivity):
323
+ """Process audio file with speaker detection"""
324
+ if audio_file is None:
325
+ return "Please upload an audio file.", ""
326
+
327
+ # Update sensitivity
328
+ processor.detector.threshold = sensitivity
329
+
330
+ # Process the audio
331
+ diarized_output, full_transcription = processor.process_audio_file(audio_file)
332
+
333
+ return diarized_output, full_transcription
334
+
335
+ # Create Gradio interface
336
+ def create_interface():
337
+ """Create Gradio interface"""
338
+
339
+ with gr.Blocks(
340
+ theme=gr.themes.Soft(),
341
+ title="Speaker Diarization & Transcription",
342
+ css="""
343
+ .gradio-container {
344
+ max-width: 1200px !important;
345
+ }
346
+ .speaker-output {
347
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
348
+ }
349
+ """
350
+ ) as demo:
351
+
352
+ gr.Markdown(
353
+ """
354
+ # 🎙️ Speaker Diarization & Transcription
355
+
356
+ Upload an audio file to automatically detect different speakers and transcribe their speech.
357
+ The system will identify speaker changes and display each speaker's text in different colors.
358
+ """
359
+ )
360
+
361
+ with gr.Row():
362
+ with gr.Column(scale=1):
363
+ audio_input = gr.Audio(
364
+ label="Upload Audio File",
365
+ type="filepath",
366
+ sources=["upload", "microphone"]
367
+ )
368
+
369
+ sensitivity_slider = gr.Slider(
370
+ minimum=0.1,
371
+ maximum=1.0,
372
+ value=0.65,
373
+ step=0.05,
374
+ label="Speaker Change Sensitivity",
375
+ info="Lower values = more sensitive to speaker changes"
376
+ )
377
+
378
+ process_btn = gr.Button("🎯 Process Audio", variant="primary", size="lg")
379
+
380
+ gr.Markdown(
381
+ """
382
+ ### Instructions:
383
+ 1. Upload an audio file (WAV, MP3, etc.)
384
+ 2. Adjust sensitivity if needed
385
+ 3. Click "Process Audio"
386
+ 4. View results with speaker colors
387
+
388
+ ### Tips:
389
+ - Works best with clear speech
390
+ - Supports multiple file formats
391
+ - Different speakers shown in different colors
392
+ - Processing may take a moment for longer files
393
+ """
394
+ )
395
+
396
+ with gr.Column(scale=2):
397
+ with gr.Tabs():
398
+ with gr.TabItem("🎨 Speaker Diarization"):
399
+ diarized_output = gr.HTML(
400
+ label="Speaker Diarization Results",
401
+ elem_classes=["speaker-output"]
402
+ )
403
+
404
+ with gr.TabItem("📝 Full Transcription"):
405
+ full_transcription = gr.Textbox(
406
+ label="Complete Transcription",
407
+ lines=15,
408
+ max_lines=20,
409
+ show_copy_button=True
410
+ )
411
+
412
+ # Event handlers
413
+ process_btn.click(
414
+ fn=process_audio,
415
+ inputs=[audio_input, sensitivity_slider],
416
+ outputs=[diarized_output, full_transcription],
417
+ show_progress=True
418
+ )
419
+
420
+ # Auto-process when audio is uploaded
421
+ audio_input.change(
422
+ fn=process_audio,
423
+ inputs=[audio_input, sensitivity_slider],
424
+ outputs=[diarized_output, full_transcription],
425
+ show_progress=True
426
+ )
427
+
428
+ gr.Markdown(
429
+ """
430
+ ---
431
+ ### About
432
+ This application uses:
433
+ - **MFCC features** for speaker embedding extraction
434
+ - **Cosine similarity** for speaker change detection
435
+ - **OpenAI Whisper** for speech-to-text transcription
436
+ - **Gradio** for the web interface
437
+
438
+ **Note**: This is a simplified speaker diarization system. For production use,
439
+ consider more advanced speaker embedding models like speechbrain or pyannote.audio.
440
+ """
441
+ )
442
+
443
+ return demo
444
+
445
+ # Create and launch the interface
446
+ if __name__ == "__main__":
447
+ demo = create_interface()
448
+ demo.launch(
449
+ server_name="0.0.0.0",
450
+ server_port=7860,
451
+ share=False,
452
+ show_error=True
453
+ )