Spaces:
Running
Running
# data_loading/feature_extractor.py | |
import torch | |
import logging | |
import numpy as np | |
import torch.nn.functional as F | |
from transformers import ( | |
AutoFeatureExtractor, | |
AutoModel, | |
AutoTokenizer, | |
AutoModelForAudioClassification, | |
Wav2Vec2Processor | |
) | |
from data_loading.pretrained_extractors import EmotionModel, get_model_mamba, Mamba | |
# DEVICE = torch.device('cuda') | |
DEVICE = torch.device('cpu') | |
class PretrainedAudioEmbeddingExtractor: | |
""" | |
Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'), | |
с учётом pooling, нормализации и т.д. | |
""" | |
def __init__(self, config): | |
""" | |
Ожидается, что в config есть поля: | |
- audio_model_name (str) : название модели (ExHuBERT и т.п.) | |
- emb_device (str) : "cpu" или "cuda" | |
- audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг) | |
- emb_normalize (bool) : делать ли L2-нормализацию выхода | |
- max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать) | |
""" | |
self.config = config | |
self.device = config.emb_device | |
self.model_name = config.audio_model_name | |
self.pooling = config.audio_pooling # может быть None | |
self.normalize_output = config.emb_normalize | |
self.max_audio_frames = getattr(config, "max_audio_frames", 0) | |
self.audio_classifier_checkpoint = config.audio_classifier_checkpoint | |
# Инициализируем processor и audio_embedder | |
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) | |
self.audio_embedder = EmotionModel.from_pretrained(self.model_name).to(self.device) | |
# Загружаем модель | |
self.classifier_model = self.load_classifier_model_from_checkpoint(self.audio_classifier_checkpoint) | |
def extract(self, waveform: torch.Tensor, sample_rate=16000): | |
""" | |
Извлекает эмбеддинги из аудиоданных. | |
:param waveform: Тензор формы (T). | |
:param sample_rate: Частота дискретизации (int). | |
:return: Тензоры: | |
вернётся (B, classes), (B, sequence_length, hidden_dim). | |
""" | |
embeddings = self.process_audio(waveform, sample_rate) | |
tensor_emb = torch.tensor(embeddings, dtype=torch.float32).to(self.device) | |
lengths = [tensor_emb.shape[1]] | |
with torch.no_grad(): | |
logits, hidden = self.classifier_model(tensor_emb, lengths, with_features=True) | |
# Если pooling=None => вернём (B, seq_len, hidden_dim) | |
if hidden.dim() == 3: | |
if self.pooling is None: | |
emb = hidden | |
else: | |
if self.pooling == "mean": | |
emb = hidden.mean(dim=1) | |
elif self.pooling == "cls": | |
emb = hidden[:, 0, :] | |
elif self.pooling == "max": | |
emb, _ = hidden.max(dim=1) | |
elif self.pooling == "min": | |
emb, _ = hidden.min(dim=1) | |
elif self.pooling == "last": | |
emb = hidden[:, -1, :] | |
elif self.pooling == "sum": | |
emb = hidden.sum(dim=1) | |
else: | |
emb = hidden.mean(dim=1) | |
else: | |
# На всякий случай, если получилось (B, hidden_dim) | |
emb = hidden | |
if self.normalize_output and emb.dim() == 2: | |
emb = F.normalize(emb, p=2, dim=1) | |
return logits, emb | |
def process_audio(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray: | |
inputs = self.processor(signal, sampling_rate=sampling_rate, return_tensors="pt", padding=True) | |
input_values = inputs["input_values"].to(self.device) | |
with torch.no_grad(): | |
outputs = self.audio_embedder(input_values) | |
embeddings = outputs | |
return embeddings.detach().cpu().numpy() | |
def load_classifier_model_from_checkpoint(self, checkpoint_path): | |
if checkpoint_path == "best_audio_model.pt": | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
exp_params = checkpoint['exp_params'] | |
classifier_model = get_model_mamba(exp_params).to(self.device) | |
classifier_model.load_state_dict(checkpoint['model_state_dict']) | |
elif checkpoint_path == "best_audio_model_2.pt": | |
model_params = { | |
"input_size": 1024, | |
"d_model": 256, | |
"num_layers": 2, | |
"num_classes": 7, | |
"dropout": 0.2 | |
} | |
classifier_model = get_model_mamba(model_params).to(self.device) | |
classifier_model.load_state_dict(torch.load(checkpoint_path, map_location=self.device)) | |
classifier_model.eval() | |
return classifier_model | |
class AudioEmbeddingExtractor: | |
""" | |
Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'), | |
с учётом pooling, нормализации и т.д. | |
""" | |
def __init__(self, config): | |
""" | |
Ожидается, что в config есть поля: | |
- audio_model_name (str) : название модели (ExHuBERT и т.п.) | |
- emb_device (str) : "cpu" или "cuda" | |
- audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг) | |
- emb_normalize (bool) : делать ли L2-нормализацию выхода | |
- max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать) | |
""" | |
self.config = config | |
self.device = config.emb_device | |
self.model_name = config.audio_model_name | |
self.pooling = config.audio_pooling # может быть None | |
self.normalize_output = config.emb_normalize | |
# self.max_audio_frames = getattr(config, "max_audio_frames", 0) | |
self.max_audio_frames = config.sample_rate * config.wav_length | |
# Попробуем загрузить feature_extractor (не у всех моделей доступен) | |
try: | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name) | |
logging.info(f"[Audio] Using AutoFeatureExtractor for '{self.model_name}'") | |
except Exception as e: | |
self.feature_extractor = None | |
logging.warning(f"[Audio] No built-in FeatureExtractor found. Model={self.model_name}. Error: {e}") | |
# Загружаем модель | |
# Если у модели нет head-классификации, бывает достаточно AutoModel | |
try: | |
self.model = AutoModel.from_pretrained( | |
self.model_name, | |
output_hidden_states=True # чтобы точно был last_hidden_state | |
).to(self.device) | |
logging.info(f"[Audio] Loaded AutoModel with output_hidden_states=True: {self.model_name}") | |
except Exception as e: | |
logging.warning(f"[Audio] Fallback to AudioClassification model. Reason: {e}") | |
self.model = AutoModelForAudioClassification.from_pretrained( | |
self.model_name, | |
output_hidden_states=True | |
).to(self.device) | |
def extract(self, waveform_batch: torch.Tensor, sample_rate=16000): | |
""" | |
Извлекает эмбеддинги из аудиоданных. | |
:param waveform_batch: Тензор формы (B, T) или (B, 1, T). | |
:param sample_rate: Частота дискретизации (int). | |
:return: Тензор: | |
- если pooling != None, будет (B, hidden_dim) | |
- если pooling == None и last_hidden_state имел форму (B, seq_len, hidden_dim), | |
вернётся (B, seq_len, hidden_dim). | |
""" | |
# Если пришло (B, 1, T), уберём ось "1" | |
if waveform_batch.dim() == 3 and waveform_batch.shape[1] == 1: | |
waveform_batch = waveform_batch.squeeze(1) # -> (B, T) | |
# Усечение по времени, если нужно | |
if self.max_audio_frames > 0 and waveform_batch.shape[1] > self.max_audio_frames: | |
waveform_batch = waveform_batch[:, :self.max_audio_frames] | |
# Если есть feature_extractor - используем | |
if self.feature_extractor is not None: | |
inputs = self.feature_extractor( | |
waveform_batch, | |
sampling_rate=sample_rate, | |
return_tensors="pt", | |
truncation=True, | |
max_length=self.max_audio_frames if self.max_audio_frames > 0 else None | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
outputs = self.model(input_values=inputs["input_values"]) | |
else: | |
# Иначе подадим напрямую "input_values" на модель | |
inputs = {"input_values": waveform_batch.to(self.device)} | |
outputs = self.model(**inputs) | |
# Теперь outputs может быть BaseModelOutput (с last_hidden_state, hidden_states, etc.) | |
# Или SequenceClassifierOutput (с logits), если это модель-классификатор | |
if hasattr(outputs, "last_hidden_state"): | |
# (B, seq_len, hidden_dim) | |
hidden = outputs.last_hidden_state | |
# logging.debug(f"[Audio] last_hidden_state shape: {hidden.shape}") | |
elif hasattr(outputs, "logits"): | |
# logits: (B, num_labels) | |
# Для пуллинга по "seq_len" притворимся, что seq_len=1 | |
hidden = outputs.logits.unsqueeze(1) # (B,1,num_labels) | |
logging.debug(f"[Audio] Found logits shape: {outputs.logits.shape} => hidden={hidden.shape}") | |
else: | |
# Модель может сразу возвращать тензор | |
hidden = outputs | |
# Если у нас 2D-тензор (B, hidden_dim), значит всё уже спулено | |
if hidden.dim() == 2: | |
emb = hidden | |
elif hidden.dim() == 3: | |
# (B, seq_len, hidden_dim) | |
if self.pooling is None: | |
# Возвращаем как есть | |
emb = hidden | |
else: | |
# Выполним пуллинг | |
if self.pooling == "mean": | |
emb = hidden.mean(dim=1) | |
elif self.pooling == "cls": | |
emb = hidden[:, 0, :] # [B, hidden_dim] | |
elif self.pooling == "max": | |
emb, _ = hidden.max(dim=1) | |
elif self.pooling == "min": | |
emb, _ = hidden.min(dim=1) | |
elif self.pooling == "last": | |
emb = hidden[:, -1, :] | |
else: | |
emb = hidden.mean(dim=1) # на всякий случай fallback | |
else: | |
# На всякий: если ещё какая-то форма | |
raise ValueError(f"[Audio] Unexpected hidden shape={hidden.shape}, pooling={self.pooling}") | |
if self.normalize_output and emb.dim() == 2: | |
emb = F.normalize(emb, p=2, dim=1) | |
return emb | |
class TextEmbeddingExtractor: | |
""" | |
Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'), | |
с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения. | |
""" | |
def __init__(self, config): | |
""" | |
Параметры в config: | |
- text_model_name (str) | |
- emb_device (str) | |
- text_pooling (str | None) | |
- emb_normalize (bool) | |
- max_tokens (int) | |
""" | |
self.config = config | |
self.device = config.emb_device | |
self.model_name = config.text_model_name | |
self.pooling = config.text_pooling # может быть None | |
self.normalize_output = config.emb_normalize | |
self.max_tokens = config.max_tokens | |
# trust_remote_code=True нужно для моделей вроде jina | |
logging.info(f"[Text] Loading tokenizer for {self.model_name} with trust_remote_code=True") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
trust_remote_code=True | |
) | |
logging.info(f"[Text] Loading model for {self.model_name} with trust_remote_code=True") | |
self.model = AutoModel.from_pretrained( | |
self.model_name, | |
trust_remote_code=True, | |
output_hidden_states=True, # хотим иметь last_hidden_state | |
force_download=False | |
).to(self.device) | |
def extract(self, text_list): | |
""" | |
:param text_list: список строк (или одна строка) | |
:return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None | |
""" | |
if isinstance(text_list, str): | |
text_list = [text_list] | |
inputs = self.tokenizer( | |
text_list, | |
padding="max_length", | |
truncation=True, | |
max_length=self.max_tokens, | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
# Обычно у AutoModel last_hidden_state.shape = (B, seq_len, hidden_dim) | |
hidden = outputs.last_hidden_state | |
# logging.debug(f"[Text] last_hidden_state shape: {hidden.shape}") | |
# Если pooling=None => вернём (B, seq_len, hidden_dim) | |
if hidden.dim() == 3: | |
if self.pooling is None: | |
emb = hidden | |
else: | |
if self.pooling == "mean": | |
emb = hidden.mean(dim=1) | |
elif self.pooling == "cls": | |
emb = hidden[:, 0, :] | |
elif self.pooling == "max": | |
emb, _ = hidden.max(dim=1) | |
elif self.pooling == "min": | |
emb, _ = hidden.min(dim=1) | |
elif self.pooling == "last": | |
emb = hidden[:, -1, :] | |
elif self.pooling == "sum": | |
emb = hidden.sum(dim=1) | |
else: | |
emb = hidden.mean(dim=1) | |
else: | |
# На всякий случай, если получилось (B, hidden_dim) | |
emb = hidden | |
if self.normalize_output and emb.dim() == 2: | |
emb = F.normalize(emb, p=2, dim=1) | |
return emb | |
class PretrainedTextEmbeddingExtractor: | |
""" | |
Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'), | |
с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения. | |
""" | |
def __init__(self, config): | |
""" | |
Параметры в config: | |
- text_model_name (str) | |
- emb_device (str) | |
- text_pooling (str | None) | |
- emb_normalize (bool) | |
- max_tokens (int) | |
""" | |
self.config = config | |
self.device = config.emb_device | |
self.model_name = config.text_model_name | |
self.pooling = config.text_pooling # может быть None | |
self.normalize_output = config.emb_normalize | |
self.max_tokens = config.max_tokens | |
self.text_classifier_checkpoint = config.text_classifier_checkpoint | |
self.model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name=self.model_name, max_tokens=self.max_tokens, pooling=None).to(self.device) | |
checkpoint = torch.load(self.text_classifier_checkpoint, map_location=DEVICE) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.model.eval() | |
def extract(self, text_list): | |
""" | |
:param text_list: список строк (или одна строка) | |
:return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None | |
""" | |
if isinstance(text_list, str): | |
text_list = [text_list] | |
with torch.no_grad(): | |
logits, hidden = self.model(text_list, with_features=True) | |
if hidden.dim() == 3: | |
if self.pooling is None: | |
emb = hidden | |
else: | |
if self.pooling == "mean": | |
emb = hidden.mean(dim=1) | |
elif self.pooling == "cls": | |
emb = hidden[:, 0, :] | |
elif self.pooling == "max": | |
emb, _ = hidden.max(dim=1) | |
elif self.pooling == "min": | |
emb, _ = hidden.min(dim=1) | |
elif self.pooling == "last": | |
emb = hidden[:, -1, :] | |
elif self.pooling == "sum": | |
emb = hidden.sum(dim=1) | |
else: | |
emb = hidden.mean(dim=1) | |
else: | |
# На всякий случай, если получилось (B, hidden_dim) | |
emb = hidden | |
if self.normalize_output and emb.dim() == 2: | |
emb = F.normalize(emb, p=2, dim=1) | |
return logits, emb | |