|
from PyQt6.QtWidgets import (QApplication, QTextEdit, QMainWindow, QLabel, QVBoxLayout, QWidget, |
|
QHBoxLayout, QPushButton, QSizePolicy, QGroupBox, QSlider, QSpinBox) |
|
from PyQt6.QtCore import Qt, pyqtSignal, QThread, QEvent, QTimer |
|
from scipy.spatial.distance import cosine |
|
from RealtimeSTT import AudioToTextRecorder |
|
import numpy as np |
|
import soundcard as sc |
|
import queue |
|
import torch |
|
import time |
|
import sys |
|
import os |
|
import urllib.request |
|
import torchaudio |
|
|
|
|
|
SILENCE_THRESHS = [0, 0.4] |
|
FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" |
|
FINAL_BEAM_SIZE = 5 |
|
REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" |
|
REALTIME_BEAM_SIZE = 5 |
|
TRANSCRIPTION_LANGUAGE = "en" |
|
SILERO_SENSITIVITY = 0.4 |
|
WEBRTC_SENSITIVITY = 3 |
|
MIN_LENGTH_OF_RECORDING = 0.7 |
|
PRE_RECORDING_BUFFER_DURATION = 0.35 |
|
|
|
|
|
DEFAULT_CHANGE_THRESHOLD = 0.7 |
|
EMBEDDING_HISTORY_SIZE = 5 |
|
MIN_SEGMENT_DURATION = 1.0 |
|
DEFAULT_MAX_SPEAKERS = 4 |
|
ABSOLUTE_MAX_SPEAKERS = 10 |
|
|
|
|
|
FAST_SENTENCE_END = True |
|
USE_MICROPHONE = False |
|
SAMPLE_RATE = 16000 |
|
BUFFER_SIZE = 512 |
|
CHANNELS = 1 |
|
|
|
|
|
SPEAKER_COLORS = [ |
|
"#FFFF00", |
|
"#FF0000", |
|
"#00FF00", |
|
"#00FFFF", |
|
"#FF00FF", |
|
"#0000FF", |
|
"#FF8000", |
|
"#00FF80", |
|
"#8000FF", |
|
"#FFFFFF", |
|
] |
|
|
|
|
|
SPEAKER_COLOR_NAMES = [ |
|
"Yellow", |
|
"Red", |
|
"Green", |
|
"Cyan", |
|
"Magenta", |
|
"Blue", |
|
"Orange", |
|
"Spring Green", |
|
"Purple", |
|
"White" |
|
] |
|
|
|
|
|
class SpeechBrainEncoder: |
|
"""ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" |
|
def __init__(self, device="cpu"): |
|
self.device = device |
|
self.model = None |
|
self.embedding_dim = 192 |
|
self.model_loaded = False |
|
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") |
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
def _download_model(self): |
|
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" |
|
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" |
|
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") |
|
|
|
if not os.path.exists(model_path): |
|
print(f"Downloading ECAPA-TDNN model to {model_path}...") |
|
urllib.request.urlretrieve(model_url, model_path) |
|
|
|
return model_path |
|
|
|
def load_model(self): |
|
"""Load the ECAPA-TDNN model""" |
|
try: |
|
|
|
from speechbrain.pretrained import EncoderClassifier |
|
|
|
|
|
model_path = self._download_model() |
|
|
|
|
|
self.model = EncoderClassifier.from_hparams( |
|
source="speechbrain/spkrec-ecapa-voxceleb", |
|
savedir=self.cache_dir, |
|
run_opts={"device": self.device} |
|
) |
|
|
|
self.model_loaded = True |
|
return True |
|
except Exception as e: |
|
print(f"Error loading ECAPA-TDNN model: {e}") |
|
return False |
|
|
|
def embed_utterance(self, audio, sr=16000): |
|
"""Extract speaker embedding from audio""" |
|
if not self.model_loaded: |
|
raise ValueError("Model not loaded. Call load_model() first.") |
|
|
|
try: |
|
|
|
if isinstance(audio, np.ndarray): |
|
waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) |
|
else: |
|
waveform = audio.unsqueeze(0) |
|
|
|
|
|
if sr != 16000: |
|
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) |
|
|
|
|
|
with torch.no_grad(): |
|
embedding = self.model.encode_batch(waveform) |
|
|
|
return embedding.squeeze().cpu().numpy() |
|
except Exception as e: |
|
print(f"Error extracting embedding: {e}") |
|
return np.zeros(self.embedding_dim) |
|
|
|
|
|
class AudioProcessor: |
|
"""Processes audio data to extract speaker embeddings""" |
|
def __init__(self, encoder): |
|
self.encoder = encoder |
|
|
|
def extract_embedding(self, audio_int16): |
|
try: |
|
|
|
float_audio = audio_int16.astype(np.float32) / 32768.0 |
|
|
|
|
|
if np.abs(float_audio).max() > 1.0: |
|
float_audio = float_audio / np.abs(float_audio).max() |
|
|
|
|
|
embedding = self.encoder.embed_utterance(float_audio) |
|
|
|
return embedding |
|
except Exception as e: |
|
print(f"Embedding extraction error: {e}") |
|
return np.zeros(self.encoder.embedding_dim) |
|
|
|
|
|
class EncoderLoaderThread(QThread): |
|
"""Thread for loading the speaker encoder model""" |
|
model_loaded = pyqtSignal(object) |
|
progress_update = pyqtSignal(str) |
|
|
|
def run(self): |
|
try: |
|
self.progress_update.emit("Initializing speaker encoder model...") |
|
|
|
|
|
device_str = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.progress_update.emit(f"Using device: {device_str}") |
|
|
|
|
|
self.progress_update.emit("Loading ECAPA-TDNN model...") |
|
encoder = SpeechBrainEncoder(device=device_str) |
|
|
|
|
|
success = encoder.load_model() |
|
|
|
if success: |
|
self.progress_update.emit("ECAPA-TDNN model loading complete!") |
|
self.model_loaded.emit(encoder) |
|
else: |
|
self.progress_update.emit("Failed to load ECAPA-TDNN model. Using fallback...") |
|
self.model_loaded.emit(None) |
|
except Exception as e: |
|
self.progress_update.emit(f"Model loading error: {e}") |
|
self.model_loaded.emit(None) |
|
|
|
|
|
class SpeakerChangeDetector: |
|
"""Modified speaker change detector that supports a configurable number of speakers""" |
|
def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): |
|
self.embedding_dim = embedding_dim |
|
self.change_threshold = change_threshold |
|
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
|
self.current_speaker = 0 |
|
self.previous_embeddings = [] |
|
self.last_change_time = time.time() |
|
self.mean_embeddings = [None] * self.max_speakers |
|
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] |
|
self.last_similarity = 0.0 |
|
self.active_speakers = set([0]) |
|
|
|
def set_max_speakers(self, max_speakers): |
|
"""Update the maximum number of speakers""" |
|
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
|
|
|
|
|
if new_max < self.max_speakers: |
|
|
|
for speaker_id in list(self.active_speakers): |
|
if speaker_id >= new_max: |
|
self.active_speakers.discard(speaker_id) |
|
|
|
|
|
if self.current_speaker >= new_max: |
|
self.current_speaker = 0 |
|
|
|
|
|
if new_max > self.max_speakers: |
|
|
|
self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) |
|
|
|
|
|
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) |
|
|
|
|
|
else: |
|
self.mean_embeddings = self.mean_embeddings[:new_max] |
|
self.speaker_embeddings = self.speaker_embeddings[:new_max] |
|
|
|
self.max_speakers = new_max |
|
|
|
def set_change_threshold(self, threshold): |
|
"""Update the threshold for detecting speaker changes""" |
|
self.change_threshold = max(0.1, min(threshold, 0.99)) |
|
|
|
def add_embedding(self, embedding, timestamp=None): |
|
"""Add a new embedding and check if there's a speaker change""" |
|
current_time = timestamp or time.time() |
|
|
|
|
|
if not self.previous_embeddings: |
|
self.previous_embeddings.append(embedding) |
|
self.speaker_embeddings[self.current_speaker].append(embedding) |
|
if self.mean_embeddings[self.current_speaker] is None: |
|
self.mean_embeddings[self.current_speaker] = embedding.copy() |
|
return self.current_speaker, 1.0 |
|
|
|
|
|
current_mean = self.mean_embeddings[self.current_speaker] |
|
if current_mean is not None: |
|
similarity = 1.0 - cosine(embedding, current_mean) |
|
else: |
|
|
|
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) |
|
|
|
self.last_similarity = similarity |
|
|
|
|
|
time_since_last_change = current_time - self.last_change_time |
|
is_speaker_change = False |
|
|
|
|
|
if time_since_last_change >= MIN_SEGMENT_DURATION: |
|
|
|
if similarity < self.change_threshold: |
|
|
|
best_speaker = self.current_speaker |
|
best_similarity = similarity |
|
|
|
|
|
for speaker_id in range(self.max_speakers): |
|
if speaker_id == self.current_speaker: |
|
continue |
|
|
|
speaker_mean = self.mean_embeddings[speaker_id] |
|
|
|
if speaker_mean is not None: |
|
|
|
speaker_similarity = 1.0 - cosine(embedding, speaker_mean) |
|
|
|
|
|
if speaker_similarity > best_similarity: |
|
best_similarity = speaker_similarity |
|
best_speaker = speaker_id |
|
|
|
|
|
if best_speaker != self.current_speaker: |
|
is_speaker_change = True |
|
self.current_speaker = best_speaker |
|
|
|
elif len(self.active_speakers) < self.max_speakers: |
|
|
|
for new_id in range(self.max_speakers): |
|
if new_id not in self.active_speakers: |
|
is_speaker_change = True |
|
self.current_speaker = new_id |
|
self.active_speakers.add(new_id) |
|
break |
|
|
|
|
|
if is_speaker_change: |
|
self.last_change_time = current_time |
|
|
|
|
|
self.previous_embeddings.append(embedding) |
|
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: |
|
self.previous_embeddings.pop(0) |
|
|
|
|
|
self.speaker_embeddings[self.current_speaker].append(embedding) |
|
self.active_speakers.add(self.current_speaker) |
|
|
|
if len(self.speaker_embeddings[self.current_speaker]) > 30: |
|
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] |
|
|
|
|
|
if self.speaker_embeddings[self.current_speaker]: |
|
self.mean_embeddings[self.current_speaker] = np.mean( |
|
self.speaker_embeddings[self.current_speaker], axis=0 |
|
) |
|
|
|
return self.current_speaker, similarity |
|
|
|
def get_color_for_speaker(self, speaker_id): |
|
"""Return color for speaker ID (0 to max_speakers-1)""" |
|
if 0 <= speaker_id < len(SPEAKER_COLORS): |
|
return SPEAKER_COLORS[speaker_id] |
|
return "#FFFFFF" |
|
|
|
def get_status_info(self): |
|
"""Return status information about the speaker change detector""" |
|
speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] |
|
|
|
return { |
|
"current_speaker": self.current_speaker, |
|
"speaker_counts": speaker_counts, |
|
"active_speakers": len(self.active_speakers), |
|
"max_speakers": self.max_speakers, |
|
"last_similarity": self.last_similarity, |
|
"threshold": self.change_threshold |
|
} |
|
|
|
|
|
class TextUpdateThread(QThread): |
|
text_update_signal = pyqtSignal(str) |
|
|
|
def __init__(self, text): |
|
super().__init__() |
|
self.text = text |
|
|
|
def run(self): |
|
self.text_update_signal.emit(self.text) |
|
|
|
|
|
class SentenceWorker(QThread): |
|
sentence_update_signal = pyqtSignal(list, list) |
|
status_signal = pyqtSignal(str) |
|
|
|
def __init__(self, queue, encoder, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): |
|
super().__init__() |
|
self.queue = queue |
|
self.encoder = encoder |
|
self._is_running = True |
|
self.full_sentences = [] |
|
self.sentence_speakers = [] |
|
self.change_threshold = change_threshold |
|
self.max_speakers = max_speakers |
|
|
|
|
|
self.audio_processor = AudioProcessor(self.encoder) |
|
|
|
|
|
self.speaker_detector = SpeakerChangeDetector( |
|
embedding_dim=self.encoder.embedding_dim, |
|
change_threshold=self.change_threshold, |
|
max_speakers=self.max_speakers |
|
) |
|
|
|
|
|
self.monitoring_timer = QTimer() |
|
self.monitoring_timer.timeout.connect(self.report_status) |
|
self.monitoring_timer.start(2000) |
|
|
|
def set_change_threshold(self, threshold): |
|
"""Update change detection threshold""" |
|
self.change_threshold = threshold |
|
self.speaker_detector.set_change_threshold(threshold) |
|
|
|
def set_max_speakers(self, max_speakers): |
|
"""Update maximum number of speakers""" |
|
self.max_speakers = max_speakers |
|
self.speaker_detector.set_max_speakers(max_speakers) |
|
|
|
def run(self): |
|
"""Main worker thread loop""" |
|
while self._is_running: |
|
try: |
|
text, bytes = self.queue.get(timeout=1) |
|
self.process_item(text, bytes) |
|
except queue.Empty: |
|
continue |
|
|
|
def report_status(self): |
|
"""Report status information""" |
|
|
|
status = self.speaker_detector.get_status_info() |
|
|
|
|
|
status_text = f"Current speaker: {status['current_speaker'] + 1}\n" |
|
status_text += f"Active speakers: {status['active_speakers']} of {status['max_speakers']}\n" |
|
|
|
|
|
for i in range(status['max_speakers']): |
|
if i < len(SPEAKER_COLOR_NAMES): |
|
color_name = SPEAKER_COLOR_NAMES[i] |
|
else: |
|
color_name = f"Speaker {i+1}" |
|
status_text += f"Speaker {i+1} ({color_name}) segments: {status['speaker_counts'][i]}\n" |
|
|
|
status_text += f"Last similarity score: {status['last_similarity']:.3f}\n" |
|
status_text += f"Change threshold: {status['threshold']:.2f}\n" |
|
status_text += f"Total sentences: {len(self.full_sentences)}" |
|
|
|
|
|
self.status_signal.emit(status_text) |
|
|
|
def process_item(self, text, bytes): |
|
"""Process a new text-audio pair""" |
|
|
|
audio_int16 = np.int16(bytes * 32767) |
|
|
|
|
|
speaker_embedding = self.audio_processor.extract_embedding(audio_int16) |
|
|
|
|
|
self.full_sentences.append((text, speaker_embedding)) |
|
|
|
|
|
if len(self.sentence_speakers) < len(self.full_sentences) - 1: |
|
while len(self.sentence_speakers) < len(self.full_sentences) - 1: |
|
self.sentence_speakers.append(0) |
|
|
|
|
|
speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding) |
|
self.sentence_speakers.append(speaker_id) |
|
|
|
|
|
self.sentence_update_signal.emit(self.full_sentences, self.sentence_speakers) |
|
|
|
def stop(self): |
|
"""Stop the worker thread""" |
|
self._is_running = False |
|
if self.monitoring_timer.isActive(): |
|
self.monitoring_timer.stop() |
|
|
|
|
|
class RecordingThread(QThread): |
|
def __init__(self, recorder): |
|
super().__init__() |
|
self.recorder = recorder |
|
self._is_running = True |
|
|
|
|
|
if USE_MICROPHONE: |
|
self.device_id = str(sc.default_microphone().name) |
|
self.include_loopback = False |
|
else: |
|
self.device_id = str(sc.default_speaker().name) |
|
self.include_loopback = True |
|
|
|
def updateDevice(self, device_id, include_loopback): |
|
self.device_id = device_id |
|
self.include_loopback = include_loopback |
|
|
|
def run(self): |
|
while self._is_running: |
|
try: |
|
with sc.get_microphone(id=self.device_id, include_loopback=self.include_loopback).recorder( |
|
samplerate=SAMPLE_RATE, blocksize=BUFFER_SIZE |
|
) as mic: |
|
|
|
current_device = self.device_id |
|
current_loopback = self.include_loopback |
|
|
|
while self._is_running and current_device == self.device_id and current_loopback == self.include_loopback: |
|
|
|
audio_data = mic.record(numframes=BUFFER_SIZE) |
|
|
|
|
|
if audio_data.shape[1] > 1 and CHANNELS == 1: |
|
audio_data = audio_data[:, 0] |
|
|
|
|
|
audio_int16 = (audio_data.flatten() * 32767).astype(np.int16) |
|
|
|
|
|
audio_bytes = audio_int16.tobytes() |
|
self.recorder.feed_audio(audio_bytes) |
|
|
|
except Exception as e: |
|
print(f"Recording error: {e}") |
|
|
|
time.sleep(1) |
|
|
|
def stop(self): |
|
self._is_running = False |
|
|
|
|
|
class TextRetrievalThread(QThread): |
|
textRetrievedFinal = pyqtSignal(str, np.ndarray) |
|
textRetrievedLive = pyqtSignal(str) |
|
recorderStarted = pyqtSignal() |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def live_text_detected(self, text): |
|
self.textRetrievedLive.emit(text) |
|
|
|
def run(self): |
|
recorder_config = { |
|
'spinner': False, |
|
'use_microphone': False, |
|
'model': FINAL_TRANSCRIPTION_MODEL, |
|
'language': TRANSCRIPTION_LANGUAGE, |
|
'silero_sensitivity': SILERO_SENSITIVITY, |
|
'webrtc_sensitivity': WEBRTC_SENSITIVITY, |
|
'post_speech_silence_duration': SILENCE_THRESHS[1], |
|
'min_length_of_recording': MIN_LENGTH_OF_RECORDING, |
|
'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, |
|
'min_gap_between_recordings': 0, |
|
'enable_realtime_transcription': True, |
|
'realtime_processing_pause': 0, |
|
'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, |
|
'on_realtime_transcription_update': self.live_text_detected, |
|
'beam_size': FINAL_BEAM_SIZE, |
|
'beam_size_realtime': REALTIME_BEAM_SIZE, |
|
'buffer_size': BUFFER_SIZE, |
|
'sample_rate': SAMPLE_RATE, |
|
} |
|
|
|
self.recorder = AudioToTextRecorder(**recorder_config) |
|
self.recorderStarted.emit() |
|
|
|
def process_text(text): |
|
bytes = self.recorder.last_transcription_bytes |
|
self.textRetrievedFinal.emit(text, bytes) |
|
|
|
while True: |
|
self.recorder.text(process_text) |
|
|
|
|
|
class MainWindow(QMainWindow): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.setWindowTitle("Real-time Speaker Change Detection") |
|
|
|
self.encoder = None |
|
self.initialized = False |
|
self.displayed_text = "" |
|
self.last_realtime_text = "" |
|
self.full_sentences = [] |
|
self.sentence_speakers = [] |
|
self.pending_sentences = [] |
|
self.queue = queue.Queue() |
|
self.recording_thread = None |
|
self.change_threshold = DEFAULT_CHANGE_THRESHOLD |
|
self.max_speakers = DEFAULT_MAX_SPEAKERS |
|
|
|
|
|
self.mainLayout = QHBoxLayout() |
|
|
|
|
|
self.text_edit = QTextEdit(self) |
|
self.mainLayout.addWidget(self.text_edit, 1) |
|
|
|
|
|
self.rightLayout = QVBoxLayout() |
|
self.rightLayout.setAlignment(Qt.AlignmentFlag.AlignTop) |
|
|
|
|
|
self.create_controls() |
|
|
|
|
|
self.rightContainer = QWidget() |
|
self.rightContainer.setLayout(self.rightLayout) |
|
self.mainLayout.addWidget(self.rightContainer, 0) |
|
|
|
|
|
self.centralWidget = QWidget() |
|
self.centralWidget.setLayout(self.mainLayout) |
|
self.setCentralWidget(self.centralWidget) |
|
|
|
self.setStyleSheet(""" |
|
QGroupBox { |
|
border: 1px solid #555; |
|
border-radius: 3px; |
|
margin-top: 10px; |
|
padding-top: 10px; |
|
color: #ddd; |
|
} |
|
QGroupBox::title { |
|
subcontrol-origin: margin; |
|
subcontrol-position: top center; |
|
padding: 0 5px; |
|
} |
|
QLabel { |
|
color: #ddd; |
|
} |
|
QPushButton { |
|
background: #444; |
|
color: #ddd; |
|
border: 1px solid #555; |
|
padding: 5px; |
|
margin-bottom: 10px; |
|
} |
|
QPushButton:hover { |
|
background: #555; |
|
} |
|
QTextEdit { |
|
background-color: #1e1e1e; |
|
color: #ffffff; |
|
font-family: 'Arial'; |
|
font-size: 16pt; |
|
} |
|
QSlider { |
|
height: 30px; |
|
} |
|
QSlider::groove:horizontal { |
|
height: 8px; |
|
background: #333; |
|
margin: 2px 0; |
|
} |
|
QSlider::handle:horizontal { |
|
background: #666; |
|
border: 1px solid #777; |
|
width: 18px; |
|
margin: -8px 0; |
|
border-radius: 9px; |
|
} |
|
""") |
|
|
|
def create_controls(self): |
|
|
|
self.threshold_group = QGroupBox("Speaker Change Sensitivity") |
|
threshold_layout = QVBoxLayout() |
|
|
|
self.threshold_label = QLabel(f"Change threshold: {self.change_threshold:.2f}") |
|
threshold_layout.addWidget(self.threshold_label) |
|
|
|
self.threshold_slider = QSlider(Qt.Orientation.Horizontal) |
|
self.threshold_slider.setMinimum(10) |
|
self.threshold_slider.setMaximum(95) |
|
self.threshold_slider.setValue(int(self.change_threshold * 100)) |
|
self.threshold_slider.valueChanged.connect(self.update_threshold) |
|
threshold_layout.addWidget(self.threshold_slider) |
|
|
|
self.threshold_explanation = QLabel( |
|
"If the speakers have similar voices, it would be better to set it above 0.5, and if they have different voices, it would be lower." |
|
) |
|
self.threshold_explanation.setWordWrap(True) |
|
threshold_layout.addWidget(self.threshold_explanation) |
|
|
|
self.threshold_group.setLayout(threshold_layout) |
|
self.rightLayout.addWidget(self.threshold_group) |
|
|
|
|
|
self.max_speakers_group = QGroupBox("Maximum Number of Speakers") |
|
max_speakers_layout = QVBoxLayout() |
|
|
|
self.max_speakers_label = QLabel(f"Max speakers: {self.max_speakers}") |
|
max_speakers_layout.addWidget(self.max_speakers_label) |
|
|
|
self.max_speakers_spinbox = QSpinBox() |
|
self.max_speakers_spinbox.setMinimum(2) |
|
self.max_speakers_spinbox.setMaximum(ABSOLUTE_MAX_SPEAKERS) |
|
self.max_speakers_spinbox.setValue(self.max_speakers) |
|
self.max_speakers_spinbox.valueChanged.connect(self.update_max_speakers) |
|
max_speakers_layout.addWidget(self.max_speakers_spinbox) |
|
|
|
self.max_speakers_explanation = QLabel( |
|
f"You can set between 2 and {ABSOLUTE_MAX_SPEAKERS} speakers.\n" |
|
"Changes will apply immediately." |
|
) |
|
self.max_speakers_explanation.setWordWrap(True) |
|
max_speakers_layout.addWidget(self.max_speakers_explanation) |
|
|
|
self.max_speakers_group.setLayout(max_speakers_layout) |
|
self.rightLayout.addWidget(self.max_speakers_group) |
|
|
|
|
|
self.legend_group = QGroupBox("Speaker Colors") |
|
self.legend_layout = QVBoxLayout() |
|
|
|
|
|
self.speaker_labels = [] |
|
for i in range(ABSOLUTE_MAX_SPEAKERS): |
|
color = SPEAKER_COLORS[i] |
|
color_name = SPEAKER_COLOR_NAMES[i] |
|
label = QLabel(f"Speaker {i+1} ({color_name}): <span style='color:{color};'>■■■■■</span>") |
|
self.speaker_labels.append(label) |
|
if i < self.max_speakers: |
|
self.legend_layout.addWidget(label) |
|
|
|
self.legend_group.setLayout(self.legend_layout) |
|
self.rightLayout.addWidget(self.legend_group) |
|
|
|
|
|
self.status_group = QGroupBox("Status") |
|
status_layout = QVBoxLayout() |
|
|
|
self.status_label = QLabel("Status information will be displayed here.") |
|
self.status_label.setWordWrap(True) |
|
status_layout.addWidget(self.status_label) |
|
|
|
self.status_group.setLayout(status_layout) |
|
self.rightLayout.addWidget(self.status_group) |
|
|
|
|
|
self.clear_button = QPushButton("Clear Conversation") |
|
self.clear_button.clicked.connect(self.clear_state) |
|
self.clear_button.setEnabled(False) |
|
self.rightLayout.addWidget(self.clear_button) |
|
|
|
def update_threshold(self, value): |
|
"""Update speaker change detection threshold""" |
|
threshold = value / 100.0 |
|
self.change_threshold = threshold |
|
self.threshold_label.setText(f"Change threshold: {threshold:.2f}") |
|
|
|
|
|
if hasattr(self, 'worker_thread'): |
|
self.worker_thread.set_change_threshold(threshold) |
|
|
|
def update_max_speakers(self, value): |
|
"""Update maximum number of speakers""" |
|
self.max_speakers = value |
|
self.max_speakers_label.setText(f"Max speakers: {value}") |
|
|
|
|
|
self.update_speaker_labels() |
|
|
|
|
|
if hasattr(self, 'worker_thread'): |
|
self.worker_thread.set_max_speakers(value) |
|
|
|
def update_speaker_labels(self): |
|
"""Update which speaker labels are visible based on max_speakers""" |
|
|
|
for i in range(len(self.speaker_labels)): |
|
label = self.speaker_labels[i] |
|
if label.parent(): |
|
self.legend_layout.removeWidget(label) |
|
label.setParent(None) |
|
|
|
|
|
for i in range(min(self.max_speakers, len(self.speaker_labels))): |
|
self.legend_layout.addWidget(self.speaker_labels[i]) |
|
|
|
def clear_state(self): |
|
|
|
self.text_edit.clear() |
|
|
|
|
|
self.displayed_text = "" |
|
self.last_realtime_text = "" |
|
self.full_sentences = [] |
|
self.sentence_speakers = [] |
|
self.pending_sentences = [] |
|
|
|
if hasattr(self, 'worker_thread'): |
|
self.worker_thread.full_sentences = [] |
|
self.worker_thread.sentence_speakers = [] |
|
|
|
self.worker_thread.speaker_detector = SpeakerChangeDetector( |
|
embedding_dim=self.encoder.embedding_dim, |
|
change_threshold=self.change_threshold, |
|
max_speakers=self.max_speakers |
|
) |
|
|
|
|
|
self.text_edit.setHtml("<i>All content cleared. Waiting for new input...</i>") |
|
|
|
def update_status(self, status_text): |
|
self.status_label.setText(status_text) |
|
|
|
def showEvent(self, event): |
|
super().showEvent(event) |
|
if event.type() == QEvent.Type.Show: |
|
if not self.initialized: |
|
self.initialized = True |
|
self.resize(1200, 800) |
|
self.update_text("<i>Initializing application...</i>") |
|
|
|
QTimer.singleShot(500, self.init) |
|
|
|
def process_live_text(self, text): |
|
text = text.strip() |
|
|
|
if text: |
|
sentence_delimiters = '.?!。' |
|
prob_sentence_end = ( |
|
len(self.last_realtime_text) > 0 |
|
and text[-1] in sentence_delimiters |
|
and self.last_realtime_text[-1] in sentence_delimiters |
|
) |
|
|
|
self.last_realtime_text = text |
|
|
|
if prob_sentence_end: |
|
if FAST_SENTENCE_END: |
|
self.text_retrieval_thread.recorder.stop() |
|
else: |
|
self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[0] |
|
else: |
|
self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[1] |
|
|
|
self.text_detected(text) |
|
|
|
def text_detected(self, text): |
|
try: |
|
sentences_with_style = [] |
|
for i, sentence in enumerate(self.full_sentences): |
|
sentence_text, _ = sentence |
|
if i >= len(self.sentence_speakers): |
|
color = "#FFFFFF" |
|
else: |
|
speaker_id = self.sentence_speakers[i] |
|
color = self.worker_thread.speaker_detector.get_color_for_speaker(speaker_id) |
|
|
|
sentences_with_style.append( |
|
f'<span style="color:{color};">{sentence_text}</span>') |
|
|
|
for pending_sentence in self.pending_sentences: |
|
sentences_with_style.append( |
|
f'<span style="color:#60FFFF;">{pending_sentence}</span>') |
|
|
|
new_text = " ".join(sentences_with_style).strip() + " " + text if len(sentences_with_style) > 0 else text |
|
|
|
if new_text != self.displayed_text: |
|
self.displayed_text = new_text |
|
self.update_text(new_text) |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
def process_final(self, text, bytes): |
|
text = text.strip() |
|
if text: |
|
try: |
|
self.pending_sentences.append(text) |
|
self.queue.put((text, bytes)) |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
def capture_output_and_feed_to_recorder(self): |
|
|
|
device_id = str(sc.default_speaker().name) |
|
include_loopback = True |
|
|
|
self.recording_thread = RecordingThread(self.text_retrieval_thread.recorder) |
|
|
|
self.recording_thread.updateDevice(device_id, include_loopback) |
|
self.recording_thread.start() |
|
|
|
def recorder_ready(self): |
|
self.update_text("<i>Recording ready</i>") |
|
self.capture_output_and_feed_to_recorder() |
|
|
|
def init(self): |
|
self.update_text("<i>Loading ECAPA-TDNN model... Please wait.</i>") |
|
|
|
|
|
self.start_encoder() |
|
|
|
def update_loading_status(self, message): |
|
self.update_text(f"<i>{message}</i>") |
|
|
|
def start_encoder(self): |
|
|
|
self.encoder_loader_thread = EncoderLoaderThread() |
|
self.encoder_loader_thread.model_loaded.connect(self.on_model_loaded) |
|
self.encoder_loader_thread.progress_update.connect(self.update_loading_status) |
|
self.encoder_loader_thread.start() |
|
|
|
def on_model_loaded(self, encoder): |
|
|
|
self.encoder = encoder |
|
|
|
if self.encoder is None: |
|
self.update_text("<i>Failed to load ECAPA-TDNN model. Please check your configuration.</i>") |
|
return |
|
|
|
|
|
self.clear_button.setEnabled(True) |
|
self.threshold_slider.setEnabled(True) |
|
|
|
|
|
self.update_text("<i>ECAPA-TDNN model loaded. Starting recorder...</i>") |
|
|
|
self.text_retrieval_thread = TextRetrievalThread() |
|
self.text_retrieval_thread.recorderStarted.connect( |
|
self.recorder_ready) |
|
self.text_retrieval_thread.textRetrievedLive.connect( |
|
self.process_live_text) |
|
self.text_retrieval_thread.textRetrievedFinal.connect( |
|
self.process_final) |
|
self.text_retrieval_thread.start() |
|
|
|
self.worker_thread = SentenceWorker( |
|
self.queue, |
|
self.encoder, |
|
change_threshold=self.change_threshold, |
|
max_speakers=self.max_speakers |
|
) |
|
self.worker_thread.sentence_update_signal.connect( |
|
self.sentence_updated) |
|
self.worker_thread.status_signal.connect( |
|
self.update_status) |
|
self.worker_thread.start() |
|
|
|
def sentence_updated(self, full_sentences, sentence_speakers): |
|
self.pending_text = "" |
|
self.full_sentences = full_sentences |
|
self.sentence_speakers = sentence_speakers |
|
for sentence in self.full_sentences: |
|
sentence_text, _ = sentence |
|
if sentence_text in self.pending_sentences: |
|
self.pending_sentences.remove(sentence_text) |
|
self.text_detected("") |
|
|
|
def set_text(self, text): |
|
self.update_thread = TextUpdateThread(text) |
|
self.update_thread.text_update_signal.connect(self.update_text) |
|
self.update_thread.start() |
|
|
|
def update_text(self, text): |
|
self.text_edit.setHtml(text) |
|
self.text_edit.verticalScrollBar().setValue( |
|
self.text_edit.verticalScrollBar().maximum()) |
|
|
|
|
|
def main(): |
|
app = QApplication(sys.argv) |
|
|
|
dark_stylesheet = """ |
|
QMainWindow { |
|
background-color: #323232; |
|
} |
|
QTextEdit { |
|
background-color: #1e1e1e; |
|
color: #ffffff; |
|
} |
|
""" |
|
app.setStyleSheet(dark_stylesheet) |
|
|
|
main_window = MainWindow() |
|
main_window.show() |
|
|
|
sys.exit(app.exec()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|