# utils/helpers.py """Helper functions for model loading and embedding generation""" import torch import torch.nn.functional as F from transformers import ( AutoTokenizer, AutoModel, RobertaTokenizer, RobertaModel, BertTokenizer, BertModel ) from typing import List, Dict, Optional import gc import os def load_models(model_names: List[str] = None) -> Dict: """ Load specific embedding models with memory optimization Args: model_names: List of model names to load. If None, loads all models. Returns: Dict containing loaded models and tokenizers """ models_cache = {} # Default to all models if none specified if model_names is None: model_names = ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"] # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: # Load Jina v2 Spanish model if "jina" in model_names: print("Loading Jina embeddings v2 Spanish model...") jina_tokenizer = AutoTokenizer.from_pretrained( 'jinaai/jina-embeddings-v2-base-es', trust_remote_code=True ) jina_model = AutoModel.from_pretrained( 'jinaai/jina-embeddings-v2-base-es', trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) jina_model.eval() models_cache['jina'] = { 'tokenizer': jina_tokenizer, 'model': jina_model, 'device': device, 'pooling': 'mean' } # Load RoBERTalex model if "robertalex" in model_names: print("Loading RoBERTalex model...") robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex') robertalex_model = RobertaModel.from_pretrained( 'PlanTL-GOB-ES/RoBERTalex', torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) robertalex_model.eval() models_cache['robertalex'] = { 'tokenizer': robertalex_tokenizer, 'model': robertalex_model, 'device': device, 'pooling': 'cls' } # Load Jina v3 model if "jina-v3" in model_names: print("Loading Jina embeddings v3 model...") jina_v3_tokenizer = AutoTokenizer.from_pretrained( 'jinaai/jina-embeddings-v3', trust_remote_code=True ) jina_v3_model = AutoModel.from_pretrained( 'jinaai/jina-embeddings-v3', trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) jina_v3_model.eval() models_cache['jina-v3'] = { 'tokenizer': jina_v3_tokenizer, 'model': jina_v3_model, 'device': device, 'pooling': 'mean' } # Load Legal BERT model if "legal-bert" in model_names: print("Loading Legal BERT model...") legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased') legal_bert_model = BertModel.from_pretrained( 'nlpaueb/legal-bert-base-uncased', torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) legal_bert_model.eval() models_cache['legal-bert'] = { 'tokenizer': legal_bert_tokenizer, 'model': legal_bert_model, 'device': device, 'pooling': 'cls' } # Load Catalan RoBERTa model if "roberta-ca" in model_names: print("Loading Catalan RoBERTa-large model...") roberta_ca_tokenizer = AutoTokenizer.from_pretrained('projecte-aina/roberta-large-ca-v2') roberta_ca_model = AutoModel.from_pretrained( 'projecte-aina/roberta-large-ca-v2', torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) roberta_ca_model.eval() models_cache['roberta-ca'] = { 'tokenizer': roberta_ca_tokenizer, 'model': roberta_ca_model, 'device': device, 'pooling': 'cls' } # Force garbage collection after loading gc.collect() return models_cache except Exception as e: print(f"Error loading models: {str(e)}") raise def mean_pooling(model_output, attention_mask): """ Apply mean pooling to get sentence embeddings Args: model_output: Model output containing token embeddings attention_mask: Attention mask for valid tokens Returns: Pooled embeddings """ token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def get_embeddings( texts: List[str], model_name: str, models_cache: Dict, normalize: bool = True, max_length: Optional[int] = None ) -> List[List[float]]: """ Generate embeddings for texts using specified model Args: texts: List of texts to embed model_name: Name of model to use models_cache: Dictionary containing loaded models normalize: Whether to normalize embeddings max_length: Maximum sequence length Returns: List of embedding vectors """ if model_name not in models_cache: raise ValueError(f"Model {model_name} not available. Choose from: {list(models_cache.keys())}") tokenizer = models_cache[model_name]['tokenizer'] model = models_cache[model_name]['model'] device = models_cache[model_name]['device'] pooling_strategy = models_cache[model_name]['pooling'] # Set max length based on model capabilities if max_length is None: if model_name in ['jina', 'jina-v3']: max_length = 8192 else: # robertalex, legal-bert, roberta-ca max_length = 512 # Process in batches for memory efficiency # Reduce batch size for large models if model_name in ['jina-v3', 'roberta-ca']: batch_size = 4 if len(texts) > 4 else len(texts) else: batch_size = 8 if len(texts) > 8 else len(texts) all_embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # Tokenize inputs encoded_input = tokenizer( batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt' ).to(device) # Generate embeddings with torch.no_grad(): model_output = model(**encoded_input) if pooling_strategy == 'mean': # Mean pooling for Jina models embeddings = mean_pooling(model_output, encoded_input['attention_mask']) else: # CLS token for BERT-based models embeddings = model_output.last_hidden_state[:, 0, :] # Normalize if requested if normalize: embeddings = F.normalize(embeddings, p=2, dim=1) # Convert to CPU and list batch_embeddings = embeddings.cpu().numpy().tolist() all_embeddings.extend(batch_embeddings) return all_embeddings def cleanup_memory(): """Force garbage collection and clear cache""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def validate_input_texts(texts: List[str]) -> List[str]: """ Validate and clean input texts Args: texts: List of input texts Returns: Cleaned texts """ cleaned_texts = [] for text in texts: # Remove excess whitespace text = ' '.join(text.split()) # Skip empty texts if text: cleaned_texts.append(text) if not cleaned_texts: raise ValueError("No valid texts provided after cleaning") return cleaned_texts def get_model_info(model_name: str) -> Dict: """ Get detailed information about a model Args: model_name: Model identifier Returns: Dictionary with model information """ model_info = { 'jina': { 'full_name': 'jinaai/jina-embeddings-v2-base-es', 'dimensions': 768, 'max_length': 8192, 'pooling': 'mean', 'languages': ['Spanish', 'English'] }, 'robertalex': { 'full_name': 'PlanTL-GOB-ES/RoBERTalex', 'dimensions': 768, 'max_length': 512, 'pooling': 'cls', 'languages': ['Spanish'] }, 'jina-v3': { 'full_name': 'jinaai/jina-embeddings-v3', 'dimensions': 1024, 'max_length': 8192, 'pooling': 'mean', 'languages': ['Multilingual'] }, 'legal-bert': { 'full_name': 'nlpaueb/legal-bert-base-uncased', 'dimensions': 768, 'max_length': 512, 'pooling': 'cls', 'languages': ['English'] }, 'roberta-ca': { 'full_name': 'projecte-aina/roberta-large-ca-v2', 'dimensions': 1024, 'max_length': 512, 'pooling': 'cls', 'languages': ['Catalan'] } } return model_info.get(model_name, {})