Jordi Catafal
trying hibrid approach
5861022
# utils/helpers.py
"""Helper functions for model loading and embedding generation"""
import torch
import torch.nn.functional as F
from transformers import (
AutoTokenizer, AutoModel,
RobertaTokenizer, RobertaModel,
BertTokenizer, BertModel
)
from typing import List, Dict, Optional
import gc
import os
def load_models(model_names: List[str] = None) -> Dict:
"""
Load specific embedding models with memory optimization
Args:
model_names: List of model names to load. If None, loads all models.
Returns:
Dict containing loaded models and tokenizers
"""
models_cache = {}
# Default to all models if none specified
if model_names is None:
model_names = ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"]
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Load Jina v2 Spanish model
if "jina" in model_names:
print("Loading Jina embeddings v2 Spanish model...")
jina_tokenizer = AutoTokenizer.from_pretrained(
'jinaai/jina-embeddings-v2-base-es',
trust_remote_code=True
)
jina_model = AutoModel.from_pretrained(
'jinaai/jina-embeddings-v2-base-es',
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
jina_model.eval()
models_cache['jina'] = {
'tokenizer': jina_tokenizer,
'model': jina_model,
'device': device,
'pooling': 'mean'
}
# Load RoBERTalex model
if "robertalex" in model_names:
print("Loading RoBERTalex model...")
robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex')
robertalex_model = RobertaModel.from_pretrained(
'PlanTL-GOB-ES/RoBERTalex',
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
robertalex_model.eval()
models_cache['robertalex'] = {
'tokenizer': robertalex_tokenizer,
'model': robertalex_model,
'device': device,
'pooling': 'cls'
}
# Load Jina v3 model
if "jina-v3" in model_names:
print("Loading Jina embeddings v3 model...")
jina_v3_tokenizer = AutoTokenizer.from_pretrained(
'jinaai/jina-embeddings-v3',
trust_remote_code=True
)
jina_v3_model = AutoModel.from_pretrained(
'jinaai/jina-embeddings-v3',
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
jina_v3_model.eval()
models_cache['jina-v3'] = {
'tokenizer': jina_v3_tokenizer,
'model': jina_v3_model,
'device': device,
'pooling': 'mean'
}
# Load Legal BERT model
if "legal-bert" in model_names:
print("Loading Legal BERT model...")
legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
legal_bert_model = BertModel.from_pretrained(
'nlpaueb/legal-bert-base-uncased',
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
legal_bert_model.eval()
models_cache['legal-bert'] = {
'tokenizer': legal_bert_tokenizer,
'model': legal_bert_model,
'device': device,
'pooling': 'cls'
}
# Load Catalan RoBERTa model
if "roberta-ca" in model_names:
print("Loading Catalan RoBERTa-large model...")
roberta_ca_tokenizer = AutoTokenizer.from_pretrained('projecte-aina/roberta-large-ca-v2')
roberta_ca_model = AutoModel.from_pretrained(
'projecte-aina/roberta-large-ca-v2',
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
roberta_ca_model.eval()
models_cache['roberta-ca'] = {
'tokenizer': roberta_ca_tokenizer,
'model': roberta_ca_model,
'device': device,
'pooling': 'cls'
}
# Force garbage collection after loading
gc.collect()
return models_cache
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
def mean_pooling(model_output, attention_mask):
"""
Apply mean pooling to get sentence embeddings
Args:
model_output: Model output containing token embeddings
attention_mask: Attention mask for valid tokens
Returns:
Pooled embeddings
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(
texts: List[str],
model_name: str,
models_cache: Dict,
normalize: bool = True,
max_length: Optional[int] = None
) -> List[List[float]]:
"""
Generate embeddings for texts using specified model
Args:
texts: List of texts to embed
model_name: Name of model to use
models_cache: Dictionary containing loaded models
normalize: Whether to normalize embeddings
max_length: Maximum sequence length
Returns:
List of embedding vectors
"""
if model_name not in models_cache:
raise ValueError(f"Model {model_name} not available. Choose from: {list(models_cache.keys())}")
tokenizer = models_cache[model_name]['tokenizer']
model = models_cache[model_name]['model']
device = models_cache[model_name]['device']
pooling_strategy = models_cache[model_name]['pooling']
# Set max length based on model capabilities
if max_length is None:
if model_name in ['jina', 'jina-v3']:
max_length = 8192
else: # robertalex, legal-bert, roberta-ca
max_length = 512
# Process in batches for memory efficiency
# Reduce batch size for large models
if model_name in ['jina-v3', 'roberta-ca']:
batch_size = 4 if len(texts) > 4 else len(texts)
else:
batch_size = 8 if len(texts) > 8 else len(texts)
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenize inputs
encoded_input = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
).to(device)
# Generate embeddings
with torch.no_grad():
model_output = model(**encoded_input)
if pooling_strategy == 'mean':
# Mean pooling for Jina models
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
else:
# CLS token for BERT-based models
embeddings = model_output.last_hidden_state[:, 0, :]
# Normalize if requested
if normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)
# Convert to CPU and list
batch_embeddings = embeddings.cpu().numpy().tolist()
all_embeddings.extend(batch_embeddings)
return all_embeddings
def cleanup_memory():
"""Force garbage collection and clear cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def validate_input_texts(texts: List[str]) -> List[str]:
"""
Validate and clean input texts
Args:
texts: List of input texts
Returns:
Cleaned texts
"""
cleaned_texts = []
for text in texts:
# Remove excess whitespace
text = ' '.join(text.split())
# Skip empty texts
if text:
cleaned_texts.append(text)
if not cleaned_texts:
raise ValueError("No valid texts provided after cleaning")
return cleaned_texts
def get_model_info(model_name: str) -> Dict:
"""
Get detailed information about a model
Args:
model_name: Model identifier
Returns:
Dictionary with model information
"""
model_info = {
'jina': {
'full_name': 'jinaai/jina-embeddings-v2-base-es',
'dimensions': 768,
'max_length': 8192,
'pooling': 'mean',
'languages': ['Spanish', 'English']
},
'robertalex': {
'full_name': 'PlanTL-GOB-ES/RoBERTalex',
'dimensions': 768,
'max_length': 512,
'pooling': 'cls',
'languages': ['Spanish']
},
'jina-v3': {
'full_name': 'jinaai/jina-embeddings-v3',
'dimensions': 1024,
'max_length': 8192,
'pooling': 'mean',
'languages': ['Multilingual']
},
'legal-bert': {
'full_name': 'nlpaueb/legal-bert-base-uncased',
'dimensions': 768,
'max_length': 512,
'pooling': 'cls',
'languages': ['English']
},
'roberta-ca': {
'full_name': 'projecte-aina/roberta-large-ca-v2',
'dimensions': 1024,
'max_length': 512,
'pooling': 'cls',
'languages': ['Catalan']
}
}
return model_info.get(model_name, {})