BiBiER / data_loading /feature_extractor.py
farbverlauf's picture
CPU
92da7ef
# data_loading/feature_extractor.py
import torch
import logging
import numpy as np
import torch.nn.functional as F
from transformers import (
AutoFeatureExtractor,
AutoModel,
AutoTokenizer,
AutoModelForAudioClassification,
Wav2Vec2Processor
)
from data_loading.pretrained_extractors import EmotionModel, get_model_mamba, Mamba
# DEVICE = torch.device('cuda')
DEVICE = torch.device('cpu')
class PretrainedAudioEmbeddingExtractor:
"""
Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'),
с учётом pooling, нормализации и т.д.
"""
def __init__(self, config):
"""
Ожидается, что в config есть поля:
- audio_model_name (str) : название модели (ExHuBERT и т.п.)
- emb_device (str) : "cpu" или "cuda"
- audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг)
- emb_normalize (bool) : делать ли L2-нормализацию выхода
- max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать)
"""
self.config = config
self.device = config.emb_device
self.model_name = config.audio_model_name
self.pooling = config.audio_pooling # может быть None
self.normalize_output = config.emb_normalize
self.max_audio_frames = getattr(config, "max_audio_frames", 0)
self.audio_classifier_checkpoint = config.audio_classifier_checkpoint
# Инициализируем processor и audio_embedder
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
self.audio_embedder = EmotionModel.from_pretrained(self.model_name).to(self.device)
# Загружаем модель
self.classifier_model = self.load_classifier_model_from_checkpoint(self.audio_classifier_checkpoint)
def extract(self, waveform: torch.Tensor, sample_rate=16000):
"""
Извлекает эмбеддинги из аудиоданных.
:param waveform: Тензор формы (T).
:param sample_rate: Частота дискретизации (int).
:return: Тензоры:
вернётся (B, classes), (B, sequence_length, hidden_dim).
"""
embeddings = self.process_audio(waveform, sample_rate)
tensor_emb = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
lengths = [tensor_emb.shape[1]]
with torch.no_grad():
logits, hidden = self.classifier_model(tensor_emb, lengths, with_features=True)
# Если pooling=None => вернём (B, seq_len, hidden_dim)
if hidden.dim() == 3:
if self.pooling is None:
emb = hidden
else:
if self.pooling == "mean":
emb = hidden.mean(dim=1)
elif self.pooling == "cls":
emb = hidden[:, 0, :]
elif self.pooling == "max":
emb, _ = hidden.max(dim=1)
elif self.pooling == "min":
emb, _ = hidden.min(dim=1)
elif self.pooling == "last":
emb = hidden[:, -1, :]
elif self.pooling == "sum":
emb = hidden.sum(dim=1)
else:
emb = hidden.mean(dim=1)
else:
# На всякий случай, если получилось (B, hidden_dim)
emb = hidden
if self.normalize_output and emb.dim() == 2:
emb = F.normalize(emb, p=2, dim=1)
return logits, emb
def process_audio(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray:
inputs = self.processor(signal, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
input_values = inputs["input_values"].to(self.device)
with torch.no_grad():
outputs = self.audio_embedder(input_values)
embeddings = outputs
return embeddings.detach().cpu().numpy()
def load_classifier_model_from_checkpoint(self, checkpoint_path):
if checkpoint_path == "best_audio_model.pt":
checkpoint = torch.load(checkpoint_path, map_location=self.device)
exp_params = checkpoint['exp_params']
classifier_model = get_model_mamba(exp_params).to(self.device)
classifier_model.load_state_dict(checkpoint['model_state_dict'])
elif checkpoint_path == "best_audio_model_2.pt":
model_params = {
"input_size": 1024,
"d_model": 256,
"num_layers": 2,
"num_classes": 7,
"dropout": 0.2
}
classifier_model = get_model_mamba(model_params).to(self.device)
classifier_model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
classifier_model.eval()
return classifier_model
class AudioEmbeddingExtractor:
"""
Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'),
с учётом pooling, нормализации и т.д.
"""
def __init__(self, config):
"""
Ожидается, что в config есть поля:
- audio_model_name (str) : название модели (ExHuBERT и т.п.)
- emb_device (str) : "cpu" или "cuda"
- audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг)
- emb_normalize (bool) : делать ли L2-нормализацию выхода
- max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать)
"""
self.config = config
self.device = config.emb_device
self.model_name = config.audio_model_name
self.pooling = config.audio_pooling # может быть None
self.normalize_output = config.emb_normalize
# self.max_audio_frames = getattr(config, "max_audio_frames", 0)
self.max_audio_frames = config.sample_rate * config.wav_length
# Попробуем загрузить feature_extractor (не у всех моделей доступен)
try:
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name)
logging.info(f"[Audio] Using AutoFeatureExtractor for '{self.model_name}'")
except Exception as e:
self.feature_extractor = None
logging.warning(f"[Audio] No built-in FeatureExtractor found. Model={self.model_name}. Error: {e}")
# Загружаем модель
# Если у модели нет head-классификации, бывает достаточно AutoModel
try:
self.model = AutoModel.from_pretrained(
self.model_name,
output_hidden_states=True # чтобы точно был last_hidden_state
).to(self.device)
logging.info(f"[Audio] Loaded AutoModel with output_hidden_states=True: {self.model_name}")
except Exception as e:
logging.warning(f"[Audio] Fallback to AudioClassification model. Reason: {e}")
self.model = AutoModelForAudioClassification.from_pretrained(
self.model_name,
output_hidden_states=True
).to(self.device)
def extract(self, waveform_batch: torch.Tensor, sample_rate=16000):
"""
Извлекает эмбеддинги из аудиоданных.
:param waveform_batch: Тензор формы (B, T) или (B, 1, T).
:param sample_rate: Частота дискретизации (int).
:return: Тензор:
- если pooling != None, будет (B, hidden_dim)
- если pooling == None и last_hidden_state имел форму (B, seq_len, hidden_dim),
вернётся (B, seq_len, hidden_dim).
"""
# Если пришло (B, 1, T), уберём ось "1"
if waveform_batch.dim() == 3 and waveform_batch.shape[1] == 1:
waveform_batch = waveform_batch.squeeze(1) # -> (B, T)
# Усечение по времени, если нужно
if self.max_audio_frames > 0 and waveform_batch.shape[1] > self.max_audio_frames:
waveform_batch = waveform_batch[:, :self.max_audio_frames]
# Если есть feature_extractor - используем
if self.feature_extractor is not None:
inputs = self.feature_extractor(
waveform_batch,
sampling_rate=sample_rate,
return_tensors="pt",
truncation=True,
max_length=self.max_audio_frames if self.max_audio_frames > 0 else None
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(input_values=inputs["input_values"])
else:
# Иначе подадим напрямую "input_values" на модель
inputs = {"input_values": waveform_batch.to(self.device)}
outputs = self.model(**inputs)
# Теперь outputs может быть BaseModelOutput (с last_hidden_state, hidden_states, etc.)
# Или SequenceClassifierOutput (с logits), если это модель-классификатор
if hasattr(outputs, "last_hidden_state"):
# (B, seq_len, hidden_dim)
hidden = outputs.last_hidden_state
# logging.debug(f"[Audio] last_hidden_state shape: {hidden.shape}")
elif hasattr(outputs, "logits"):
# logits: (B, num_labels)
# Для пуллинга по "seq_len" притворимся, что seq_len=1
hidden = outputs.logits.unsqueeze(1) # (B,1,num_labels)
logging.debug(f"[Audio] Found logits shape: {outputs.logits.shape} => hidden={hidden.shape}")
else:
# Модель может сразу возвращать тензор
hidden = outputs
# Если у нас 2D-тензор (B, hidden_dim), значит всё уже спулено
if hidden.dim() == 2:
emb = hidden
elif hidden.dim() == 3:
# (B, seq_len, hidden_dim)
if self.pooling is None:
# Возвращаем как есть
emb = hidden
else:
# Выполним пуллинг
if self.pooling == "mean":
emb = hidden.mean(dim=1)
elif self.pooling == "cls":
emb = hidden[:, 0, :] # [B, hidden_dim]
elif self.pooling == "max":
emb, _ = hidden.max(dim=1)
elif self.pooling == "min":
emb, _ = hidden.min(dim=1)
elif self.pooling == "last":
emb = hidden[:, -1, :]
else:
emb = hidden.mean(dim=1) # на всякий случай fallback
else:
# На всякий: если ещё какая-то форма
raise ValueError(f"[Audio] Unexpected hidden shape={hidden.shape}, pooling={self.pooling}")
if self.normalize_output and emb.dim() == 2:
emb = F.normalize(emb, p=2, dim=1)
return emb
class TextEmbeddingExtractor:
"""
Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'),
с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения.
"""
def __init__(self, config):
"""
Параметры в config:
- text_model_name (str)
- emb_device (str)
- text_pooling (str | None)
- emb_normalize (bool)
- max_tokens (int)
"""
self.config = config
self.device = config.emb_device
self.model_name = config.text_model_name
self.pooling = config.text_pooling # может быть None
self.normalize_output = config.emb_normalize
self.max_tokens = config.max_tokens
# trust_remote_code=True нужно для моделей вроде jina
logging.info(f"[Text] Loading tokenizer for {self.model_name} with trust_remote_code=True")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
logging.info(f"[Text] Loading model for {self.model_name} with trust_remote_code=True")
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
output_hidden_states=True, # хотим иметь last_hidden_state
force_download=False
).to(self.device)
def extract(self, text_list):
"""
:param text_list: список строк (или одна строка)
:return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None
"""
if isinstance(text_list, str):
text_list = [text_list]
inputs = self.tokenizer(
text_list,
padding="max_length",
truncation=True,
max_length=self.max_tokens,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# Обычно у AutoModel last_hidden_state.shape = (B, seq_len, hidden_dim)
hidden = outputs.last_hidden_state
# logging.debug(f"[Text] last_hidden_state shape: {hidden.shape}")
# Если pooling=None => вернём (B, seq_len, hidden_dim)
if hidden.dim() == 3:
if self.pooling is None:
emb = hidden
else:
if self.pooling == "mean":
emb = hidden.mean(dim=1)
elif self.pooling == "cls":
emb = hidden[:, 0, :]
elif self.pooling == "max":
emb, _ = hidden.max(dim=1)
elif self.pooling == "min":
emb, _ = hidden.min(dim=1)
elif self.pooling == "last":
emb = hidden[:, -1, :]
elif self.pooling == "sum":
emb = hidden.sum(dim=1)
else:
emb = hidden.mean(dim=1)
else:
# На всякий случай, если получилось (B, hidden_dim)
emb = hidden
if self.normalize_output and emb.dim() == 2:
emb = F.normalize(emb, p=2, dim=1)
return emb
class PretrainedTextEmbeddingExtractor:
"""
Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'),
с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения.
"""
def __init__(self, config):
"""
Параметры в config:
- text_model_name (str)
- emb_device (str)
- text_pooling (str | None)
- emb_normalize (bool)
- max_tokens (int)
"""
self.config = config
self.device = config.emb_device
self.model_name = config.text_model_name
self.pooling = config.text_pooling # может быть None
self.normalize_output = config.emb_normalize
self.max_tokens = config.max_tokens
self.text_classifier_checkpoint = config.text_classifier_checkpoint
self.model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name=self.model_name, max_tokens=self.max_tokens, pooling=None).to(self.device)
checkpoint = torch.load(self.text_classifier_checkpoint, map_location=DEVICE)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
def extract(self, text_list):
"""
:param text_list: список строк (или одна строка)
:return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None
"""
if isinstance(text_list, str):
text_list = [text_list]
with torch.no_grad():
logits, hidden = self.model(text_list, with_features=True)
if hidden.dim() == 3:
if self.pooling is None:
emb = hidden
else:
if self.pooling == "mean":
emb = hidden.mean(dim=1)
elif self.pooling == "cls":
emb = hidden[:, 0, :]
elif self.pooling == "max":
emb, _ = hidden.max(dim=1)
elif self.pooling == "min":
emb, _ = hidden.min(dim=1)
elif self.pooling == "last":
emb = hidden[:, -1, :]
elif self.pooling == "sum":
emb = hidden.sum(dim=1)
else:
emb = hidden.mean(dim=1)
else:
# На всякий случай, если получилось (B, hidden_dim)
emb = hidden
if self.normalize_output and emb.dim() == 2:
emb = F.normalize(emb, p=2, dim=1)
return logits, emb