|
|
|
"""Helper functions for model loading and embedding generation""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import ( |
|
AutoTokenizer, AutoModel, |
|
RobertaTokenizer, RobertaModel, |
|
BertTokenizer, BertModel |
|
) |
|
from typing import List, Dict, Optional |
|
import gc |
|
import os |
|
|
|
def load_models(model_names: List[str] = None) -> Dict: |
|
""" |
|
Load specific embedding models with memory optimization |
|
|
|
Args: |
|
model_names: List of model names to load. If None, loads all models. |
|
|
|
Returns: |
|
Dict containing loaded models and tokenizers |
|
""" |
|
models_cache = {} |
|
|
|
|
|
if model_names is None: |
|
model_names = ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
try: |
|
|
|
if "jina" in model_names: |
|
print("Loading Jina embeddings v2 Spanish model...") |
|
jina_tokenizer = AutoTokenizer.from_pretrained( |
|
'jinaai/jina-embeddings-v2-base-es', |
|
trust_remote_code=True |
|
) |
|
jina_model = AutoModel.from_pretrained( |
|
'jinaai/jina-embeddings-v2-base-es', |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
jina_model.eval() |
|
|
|
models_cache['jina'] = { |
|
'tokenizer': jina_tokenizer, |
|
'model': jina_model, |
|
'device': device, |
|
'pooling': 'mean' |
|
} |
|
|
|
|
|
if "robertalex" in model_names: |
|
print("Loading RoBERTalex model...") |
|
robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex') |
|
robertalex_model = RobertaModel.from_pretrained( |
|
'PlanTL-GOB-ES/RoBERTalex', |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
robertalex_model.eval() |
|
|
|
models_cache['robertalex'] = { |
|
'tokenizer': robertalex_tokenizer, |
|
'model': robertalex_model, |
|
'device': device, |
|
'pooling': 'cls' |
|
} |
|
|
|
|
|
if "jina-v3" in model_names: |
|
print("Loading Jina embeddings v3 model...") |
|
jina_v3_tokenizer = AutoTokenizer.from_pretrained( |
|
'jinaai/jina-embeddings-v3', |
|
trust_remote_code=True |
|
) |
|
jina_v3_model = AutoModel.from_pretrained( |
|
'jinaai/jina-embeddings-v3', |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
jina_v3_model.eval() |
|
|
|
models_cache['jina-v3'] = { |
|
'tokenizer': jina_v3_tokenizer, |
|
'model': jina_v3_model, |
|
'device': device, |
|
'pooling': 'mean' |
|
} |
|
|
|
|
|
if "legal-bert" in model_names: |
|
print("Loading Legal BERT model...") |
|
legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased') |
|
legal_bert_model = BertModel.from_pretrained( |
|
'nlpaueb/legal-bert-base-uncased', |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
legal_bert_model.eval() |
|
|
|
models_cache['legal-bert'] = { |
|
'tokenizer': legal_bert_tokenizer, |
|
'model': legal_bert_model, |
|
'device': device, |
|
'pooling': 'cls' |
|
} |
|
|
|
|
|
if "roberta-ca" in model_names: |
|
print("Loading Catalan RoBERTa-large model...") |
|
roberta_ca_tokenizer = AutoTokenizer.from_pretrained('projecte-aina/roberta-large-ca-v2') |
|
roberta_ca_model = AutoModel.from_pretrained( |
|
'projecte-aina/roberta-large-ca-v2', |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(device) |
|
roberta_ca_model.eval() |
|
|
|
models_cache['roberta-ca'] = { |
|
'tokenizer': roberta_ca_tokenizer, |
|
'model': roberta_ca_model, |
|
'device': device, |
|
'pooling': 'cls' |
|
} |
|
|
|
|
|
gc.collect() |
|
|
|
return models_cache |
|
|
|
except Exception as e: |
|
print(f"Error loading models: {str(e)}") |
|
raise |
|
|
|
def mean_pooling(model_output, attention_mask): |
|
""" |
|
Apply mean pooling to get sentence embeddings |
|
|
|
Args: |
|
model_output: Model output containing token embeddings |
|
attention_mask: Attention mask for valid tokens |
|
|
|
Returns: |
|
Pooled embeddings |
|
""" |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def get_embeddings( |
|
texts: List[str], |
|
model_name: str, |
|
models_cache: Dict, |
|
normalize: bool = True, |
|
max_length: Optional[int] = None |
|
) -> List[List[float]]: |
|
""" |
|
Generate embeddings for texts using specified model |
|
|
|
Args: |
|
texts: List of texts to embed |
|
model_name: Name of model to use |
|
models_cache: Dictionary containing loaded models |
|
normalize: Whether to normalize embeddings |
|
max_length: Maximum sequence length |
|
|
|
Returns: |
|
List of embedding vectors |
|
""" |
|
if model_name not in models_cache: |
|
raise ValueError(f"Model {model_name} not available. Choose from: {list(models_cache.keys())}") |
|
|
|
tokenizer = models_cache[model_name]['tokenizer'] |
|
model = models_cache[model_name]['model'] |
|
device = models_cache[model_name]['device'] |
|
pooling_strategy = models_cache[model_name]['pooling'] |
|
|
|
|
|
if max_length is None: |
|
if model_name in ['jina', 'jina-v3']: |
|
max_length = 8192 |
|
else: |
|
max_length = 512 |
|
|
|
|
|
|
|
if model_name in ['jina-v3', 'roberta-ca']: |
|
batch_size = 4 if len(texts) > 4 else len(texts) |
|
else: |
|
batch_size = 8 if len(texts) > 8 else len(texts) |
|
|
|
all_embeddings = [] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i + batch_size] |
|
|
|
|
|
encoded_input = tokenizer( |
|
batch_texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=max_length, |
|
return_tensors='pt' |
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
model_output = model(**encoded_input) |
|
|
|
if pooling_strategy == 'mean': |
|
|
|
embeddings = mean_pooling(model_output, encoded_input['attention_mask']) |
|
else: |
|
|
|
embeddings = model_output.last_hidden_state[:, 0, :] |
|
|
|
|
|
if normalize: |
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
batch_embeddings = embeddings.cpu().numpy().tolist() |
|
all_embeddings.extend(batch_embeddings) |
|
|
|
return all_embeddings |
|
|
|
def cleanup_memory(): |
|
"""Force garbage collection and clear cache""" |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
def validate_input_texts(texts: List[str]) -> List[str]: |
|
""" |
|
Validate and clean input texts |
|
|
|
Args: |
|
texts: List of input texts |
|
|
|
Returns: |
|
Cleaned texts |
|
""" |
|
cleaned_texts = [] |
|
for text in texts: |
|
|
|
text = ' '.join(text.split()) |
|
|
|
if text: |
|
cleaned_texts.append(text) |
|
|
|
if not cleaned_texts: |
|
raise ValueError("No valid texts provided after cleaning") |
|
|
|
return cleaned_texts |
|
|
|
def get_model_info(model_name: str) -> Dict: |
|
""" |
|
Get detailed information about a model |
|
|
|
Args: |
|
model_name: Model identifier |
|
|
|
Returns: |
|
Dictionary with model information |
|
""" |
|
model_info = { |
|
'jina': { |
|
'full_name': 'jinaai/jina-embeddings-v2-base-es', |
|
'dimensions': 768, |
|
'max_length': 8192, |
|
'pooling': 'mean', |
|
'languages': ['Spanish', 'English'] |
|
}, |
|
'robertalex': { |
|
'full_name': 'PlanTL-GOB-ES/RoBERTalex', |
|
'dimensions': 768, |
|
'max_length': 512, |
|
'pooling': 'cls', |
|
'languages': ['Spanish'] |
|
}, |
|
'jina-v3': { |
|
'full_name': 'jinaai/jina-embeddings-v3', |
|
'dimensions': 1024, |
|
'max_length': 8192, |
|
'pooling': 'mean', |
|
'languages': ['Multilingual'] |
|
}, |
|
'legal-bert': { |
|
'full_name': 'nlpaueb/legal-bert-base-uncased', |
|
'dimensions': 768, |
|
'max_length': 512, |
|
'pooling': 'cls', |
|
'languages': ['English'] |
|
}, |
|
'roberta-ca': { |
|
'full_name': 'projecte-aina/roberta-large-ca-v2', |
|
'dimensions': 1024, |
|
'max_length': 512, |
|
'pooling': 'cls', |
|
'languages': ['Catalan'] |
|
} |
|
} |
|
|
|
return model_info.get(model_name, {}) |