Jordi Catafal commited on
Commit
0a6cb95
·
1 Parent(s): 8c3e1fb

Add Jina v3 and Legal-BERT models - total 4 models

Browse files
Files changed (8) hide show
  1. Dockerfile +10 -1
  2. README.md +110 -25
  3. app.py +25 -7
  4. models/__init__.py +1 -0
  5. models/schemas.py +5 -5
  6. requirements.txt +2 -1
  7. utils/__init__.py +7 -0
  8. utils/helpers.py +70 -12
Dockerfile CHANGED
@@ -5,6 +5,14 @@ ENV PYTHONUNBUFFERED=1
5
  ENV TRANSFORMERS_CACHE=/app/cache
6
  ENV HF_HOME=/app/cache
7
  ENV PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:128
 
 
 
 
 
 
 
 
8
 
9
  # Create non-root user
10
  RUN useradd -m -u 1000 user
@@ -18,7 +26,8 @@ WORKDIR /app
18
  # Copy requirements and install dependencies
19
  COPY --chown=user requirements.txt .
20
  RUN pip install --no-cache-dir --upgrade pip && \
21
- pip install --no-cache-dir -r requirements.txt
 
22
 
23
  # Copy application code
24
  COPY --chown=user . .
 
5
  ENV TRANSFORMERS_CACHE=/app/cache
6
  ENV HF_HOME=/app/cache
7
  ENV PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:128
8
+ # Add this to handle the larger models
9
+ ENV TRANSFORMERS_OFFLINE=0
10
+ ENV HF_HUB_ENABLE_HF_TRANSFER=1
11
+
12
+ # Install system dependencies for better performance
13
+ RUN apt-get update && apt-get install -y \
14
+ build-essential \
15
+ && rm -rf /var/lib/apt/lists/*
16
 
17
  # Create non-root user
18
  RUN useradd -m -u 1000 user
 
26
  # Copy requirements and install dependencies
27
  COPY --chown=user requirements.txt .
28
  RUN pip install --no-cache-dir --upgrade pip && \
29
+ pip install --no-cache-dir -r requirements.txt && \
30
+ pip install --no-cache-dir hf_transfer
31
 
32
  # Copy application code
33
  COPY --chown=user . .
README.md CHANGED
@@ -11,9 +11,9 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
11
 
12
  --------------------------------
13
 
14
- # Spanish Embeddings API
15
 
16
- A high-performance API for generating embeddings from Spanish text using state-of-the-art models. This API provides access to two specialized models optimized for different use cases.
17
 
18
  ## 🚀 Quick Start
19
 
@@ -26,7 +26,9 @@ A high-performance API for generating embeddings from Spanish text using state-o
26
  | Model | Max Tokens | Languages | Dimensions | Best Use Case |
27
  |-------|------------|-----------|------------|---------------|
28
  | **jina** | 8,192 | Spanish, English | 768 | General purpose, long documents, cross-lingual tasks |
29
- | **robertalex** | 512 | Spanish | 768 | Legal documents, formal Spanish, domain-specific text |
 
 
30
 
31
  ## 🔗 API Endpoints
32
 
@@ -64,7 +66,7 @@ import numpy as np
64
 
65
  API_URL = "https://aurasystems-spanish-embeddings-api.hf.space"
66
 
67
- # Example 1: Basic usage
68
  response = requests.post(
69
  f"{API_URL}/embed",
70
  json={
@@ -78,13 +80,24 @@ result = response.json()
78
  embeddings = result["embeddings"]
79
  print(f"Generated {len(embeddings)} embeddings of {result['dimensions']} dimensions")
80
 
81
- # Example 2: Using with numpy for similarity
82
- embeddings_array = np.array(embeddings)
83
- similarity = np.dot(embeddings_array[0], embeddings_array[1])
84
- print(f"Cosine similarity: {similarity:.4f}")
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Example 3: Legal text with RoBERTalex
87
- legal_response = requests.post(
88
  f"{API_URL}/embed",
89
  json={
90
  "texts": [
@@ -95,12 +108,38 @@ legal_response = requests.post(
95
  "normalize": True
96
  }
97
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ```
99
 
100
  ### cURL
101
 
102
  ```bash
103
- # Basic embedding generation
104
  curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
105
  -H "Content-Type: application/json" \
106
  -d '{
@@ -109,17 +148,35 @@ curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
109
  "normalize": true
110
  }'
111
 
112
- # With custom max length
113
  curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
114
  -H "Content-Type: application/json" \
115
  -d '{
116
- "texts": ["Documento muy largo..."],
117
- "model": "jina",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  "normalize": true,
119
- "max_length": 2048
120
  }'
121
 
122
- # Get model information
123
  curl "https://aurasystems-spanish-embeddings-api.hf.space/models"
124
  ```
125
 
@@ -169,10 +226,16 @@ from langchain.embeddings.base import Embeddings
169
  from typing import List
170
  import requests
171
 
172
- class SpanishEmbeddings(Embeddings):
173
- """Custom LangChain embeddings class for Spanish text"""
174
 
175
- def __init__(self, model: str = "jina"):
 
 
 
 
 
 
176
  self.api_url = "https://aurasystems-spanish-embeddings-api.hf.space/embed"
177
  self.model = model
178
 
@@ -191,13 +254,35 @@ class SpanishEmbeddings(Embeddings):
191
  def embed_query(self, text: str) -> List[float]:
192
  return self.embed_documents([text])[0]
193
 
194
- # Usage with LangChain
195
- embeddings = SpanishEmbeddings(model="jina")
196
- doc_embeddings = embeddings.embed_documents([
197
- "Primer documento",
198
- "Segundo documento"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  ])
200
- query_embedding = embeddings.embed_query("consulta de búsqueda")
201
  ```
202
 
203
  ## 📋 Request/Response Formats
 
11
 
12
  --------------------------------
13
 
14
+ # Spanish & Legal Embeddings API
15
 
16
+ A high-performance API for generating embeddings from Spanish, English, and multilingual text using state-of-the-art models. This API provides access to four specialized models optimized for different use cases and languages.
17
 
18
  ## 🚀 Quick Start
19
 
 
26
  | Model | Max Tokens | Languages | Dimensions | Best Use Case |
27
  |-------|------------|-----------|------------|---------------|
28
  | **jina** | 8,192 | Spanish, English | 768 | General purpose, long documents, cross-lingual tasks |
29
+ | **robertalex** | 512 | Spanish | 768 | Spanish legal documents, formal Spanish |
30
+ | **jina-v3** | 8,192 | Multilingual (30+ languages) | 1,024 | Superior multilingual embeddings, long context |
31
+ | **legal-bert** | 512 | English | 768 | English legal documents, contracts, law texts |
32
 
33
  ## 🔗 API Endpoints
34
 
 
66
 
67
  API_URL = "https://aurasystems-spanish-embeddings-api.hf.space"
68
 
69
+ # Example 1: Basic usage with Jina v2 Spanish
70
  response = requests.post(
71
  f"{API_URL}/embed",
72
  json={
 
80
  embeddings = result["embeddings"]
81
  print(f"Generated {len(embeddings)} embeddings of {result['dimensions']} dimensions")
82
 
83
+ # Example 2: Using Jina v3 for multilingual texts
84
+ multilingual_response = requests.post(
85
+ f"{API_URL}/embed",
86
+ json={
87
+ "texts": [
88
+ "Hello world", # English
89
+ "Hola mundo", # Spanish
90
+ "Bonjour le monde", # French
91
+ "Hallo Welt" # German
92
+ ],
93
+ "model": "jina-v3",
94
+ "normalize": True
95
+ }
96
+ )
97
+ print(f"Jina v3 dimensions: {multilingual_response.json()['dimensions']}") # 1024 dims
98
 
99
+ # Example 3: Legal text with RoBERTalex (Spanish)
100
+ spanish_legal_response = requests.post(
101
  f"{API_URL}/embed",
102
  json={
103
  "texts": [
 
108
  "normalize": True
109
  }
110
  )
111
+
112
+ # Example 4: Legal text with Legal-BERT (English)
113
+ english_legal_response = requests.post(
114
+ f"{API_URL}/embed",
115
+ json={
116
+ "texts": [
117
+ "The contract shall be valid from the date of signature",
118
+ "This agreement is governed by the laws of the state"
119
+ ],
120
+ "model": "legal-bert",
121
+ "normalize": True
122
+ }
123
+ )
124
+
125
+ # Example 5: Compare similarity across models
126
+ text = "artificial intelligence and law"
127
+ models_comparison = {}
128
+
129
+ for model in ["jina", "jina-v3", "legal-bert"]:
130
+ resp = requests.post(
131
+ f"{API_URL}/embed",
132
+ json={"texts": [text], "model": model, "normalize": True}
133
+ )
134
+ models_comparison[model] = resp.json()["dimensions"]
135
+
136
+ print("Embedding dimensions by model:", models_comparison)
137
  ```
138
 
139
  ### cURL
140
 
141
  ```bash
142
+ # Basic embedding generation with Jina v2 Spanish
143
  curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
144
  -H "Content-Type: application/json" \
145
  -d '{
 
148
  "normalize": true
149
  }'
150
 
151
+ # Using Jina v3 for multilingual embeddings
152
  curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
153
  -H "Content-Type: application/json" \
154
  -d '{
155
+ "texts": ["Hello world", "Hola mundo", "Bonjour le monde"],
156
+ "model": "jina-v3",
157
+ "normalize": true
158
+ }'
159
+
160
+ # English legal text with Legal-BERT
161
+ curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
162
+ -H "Content-Type: application/json" \
163
+ -d '{
164
+ "texts": ["This agreement is legally binding"],
165
+ "model": "legal-bert",
166
+ "normalize": true
167
+ }'
168
+
169
+ # Spanish legal text with RoBERTalex
170
+ curl -X POST "https://aurasystems-spanish-embeddings-api.hf.space/embed" \
171
+ -H "Content-Type: application/json" \
172
+ -d '{
173
+ "texts": ["Artículo primero de la constitución"],
174
+ "model": "robertalex",
175
  "normalize": true,
176
+ "max_length": 512
177
  }'
178
 
179
+ # Get all model information
180
  curl "https://aurasystems-spanish-embeddings-api.hf.space/models"
181
  ```
182
 
 
226
  from typing import List
227
  import requests
228
 
229
+ class MultilingualEmbeddings(Embeddings):
230
+ """Custom LangChain embeddings class for multilingual text"""
231
 
232
+ def __init__(self, model: str = "jina-v3"):
233
+ """
234
+ Initialize embeddings
235
+
236
+ Args:
237
+ model: One of "jina", "robertalex", "jina-v3", "legal-bert"
238
+ """
239
  self.api_url = "https://aurasystems-spanish-embeddings-api.hf.space/embed"
240
  self.model = model
241
 
 
254
  def embed_query(self, text: str) -> List[float]:
255
  return self.embed_documents([text])[0]
256
 
257
+ # Usage examples with different models
258
+ # Spanish embeddings
259
+ spanish_embeddings = MultilingualEmbeddings(model="jina")
260
+ spanish_docs = spanish_embeddings.embed_documents([
261
+ "Primer documento en español",
262
+ "Segundo documento en español"
263
+ ])
264
+
265
+ # Multilingual embeddings with Jina v3
266
+ multilingual_embeddings = MultilingualEmbeddings(model="jina-v3")
267
+ mixed_docs = multilingual_embeddings.embed_documents([
268
+ "English document",
269
+ "Documento en español",
270
+ "Document en français"
271
+ ])
272
+
273
+ # Legal embeddings for English
274
+ legal_embeddings = MultilingualEmbeddings(model="legal-bert")
275
+ legal_docs = legal_embeddings.embed_documents([
276
+ "This contract is governed by English law",
277
+ "The party shall indemnify and hold harmless"
278
+ ])
279
+
280
+ # Spanish legal embeddings
281
+ spanish_legal_embeddings = MultilingualEmbeddings(model="robertalex")
282
+ spanish_legal_docs = spanish_legal_embeddings.embed_documents([
283
+ "Artículo 1: De los derechos fundamentales",
284
+ "La presente ley entrará en vigor"
285
  ])
 
286
  ```
287
 
288
  ## 📋 Request/Response Formats
app.py CHANGED
@@ -9,9 +9,9 @@ from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
9
  from utils.helpers import load_models, get_embeddings, cleanup_memory
10
 
11
  app = FastAPI(
12
- title="Spanish Embedding API",
13
- description="Dual Spanish embedding models API",
14
- version="1.0.0"
15
  )
16
 
17
  # Global model cache
@@ -22,13 +22,13 @@ async def startup_event():
22
  """Load models on startup"""
23
  global models_cache
24
  models_cache = load_models()
25
- print("Models loaded successfully!")
26
 
27
  @app.get("/")
28
  async def root():
29
  return {
30
- "message": "Spanish Embedding API",
31
- "models": ["jina", "robertalex"],
32
  "status": "running",
33
  "docs": "/docs"
34
  }
@@ -88,6 +88,24 @@ async def list_models():
88
  languages=["Spanish"],
89
  model_type="legal domain",
90
  description="Spanish legal domain specialized embeddings"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
  ]
93
 
@@ -96,7 +114,7 @@ async def health_check():
96
  """Health check endpoint"""
97
  return {
98
  "status": "healthy",
99
- "models_loaded": len(models_cache) == 2,
100
  "available_models": list(models_cache.keys())
101
  }
102
 
 
9
  from utils.helpers import load_models, get_embeddings, cleanup_memory
10
 
11
  app = FastAPI(
12
+ title="Spanish & Legal Embedding API",
13
+ description="Multi-model embedding API for Spanish and Legal texts",
14
+ version="2.0.0"
15
  )
16
 
17
  # Global model cache
 
22
  """Load models on startup"""
23
  global models_cache
24
  models_cache = load_models()
25
+ print("All models loaded successfully!")
26
 
27
  @app.get("/")
28
  async def root():
29
  return {
30
+ "message": "Spanish & Legal Embedding API",
31
+ "models": ["jina", "robertalex", "jina-v3", "legal-bert"],
32
  "status": "running",
33
  "docs": "/docs"
34
  }
 
88
  languages=["Spanish"],
89
  model_type="legal domain",
90
  description="Spanish legal domain specialized embeddings"
91
+ ),
92
+ ModelInfo(
93
+ model_id="jina-v3",
94
+ name="jinaai/jina-embeddings-v3",
95
+ dimensions=1024,
96
+ max_sequence_length=8192,
97
+ languages=["Multilingual"],
98
+ model_type="multilingual",
99
+ description="Latest Jina v3 with superior multilingual performance"
100
+ ),
101
+ ModelInfo(
102
+ model_id="legal-bert",
103
+ name="nlpaueb/legal-bert-base-uncased",
104
+ dimensions=768,
105
+ max_sequence_length=512,
106
+ languages=["English"],
107
+ model_type="legal domain",
108
+ description="English legal domain BERT model"
109
  )
110
  ]
111
 
 
114
  """Health check endpoint"""
115
  return {
116
  "status": "healthy",
117
+ "models_loaded": len(models_cache) == 4,
118
  "available_models": list(models_cache.keys())
119
  }
120
 
models/__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # models/__init__.py
2
  """Models package for embedding API schemas and configurations"""
3
 
 
1
+
2
  # models/__init__.py
3
  """Models package for embedding API schemas and configurations"""
4
 
models/schemas.py CHANGED
@@ -11,7 +11,7 @@ class EmbeddingRequest(BaseModel):
11
  description="List of texts to embed",
12
  example=["Hola mundo", "¿Cómo estás?"]
13
  )
14
- model: Literal["jina", "robertalex"] = Field(
15
  default="jina",
16
  description="Model to use for embeddings"
17
  )
@@ -39,10 +39,10 @@ class EmbeddingRequest(BaseModel):
39
  def validate_max_length(cls, v, values):
40
  if v is not None:
41
  model = values.get('model', 'jina')
42
- if model == 'jina' and v > 8192:
43
- raise ValueError("Max length for Jina model is 8192")
44
- elif model == 'robertalex' and v > 512:
45
- raise ValueError("Max length for RoBERTalex model is 512")
46
  if v < 1:
47
  raise ValueError("Max length must be positive")
48
  return v
 
11
  description="List of texts to embed",
12
  example=["Hola mundo", "¿Cómo estás?"]
13
  )
14
+ model: Literal["jina", "robertalex", "jina-v3", "legal-bert"] = Field(
15
  default="jina",
16
  description="Model to use for embeddings"
17
  )
 
39
  def validate_max_length(cls, v, values):
40
  if v is not None:
41
  model = values.get('model', 'jina')
42
+ if model in ['jina', 'jina-v3'] and v > 8192:
43
+ raise ValueError(f"Max length for {model} model is 8192")
44
+ elif model in ['robertalex', 'legal-bert'] and v > 512:
45
+ raise ValueError(f"Max length for {model} model is 512")
46
  if v < 1:
47
  raise ValueError("Max length must be positive")
48
  return v
requirements.txt CHANGED
@@ -7,4 +7,5 @@ numpy<2.0.0
7
  scikit-learn==1.3.2
8
  pydantic==2.5.0
9
  huggingface-hub==0.19.4
10
- python-multipart==0.0.6
 
 
7
  scikit-learn==1.3.2
8
  pydantic==2.5.0
9
  huggingface-hub==0.19.4
10
+ python-multipart==0.0.6
11
+ protobuf>=3.20.0
utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ # utils/__init__.py
3
+ """Utils package for helper functions"""
4
+
5
+ from .helpers import load_models, get_embeddings, cleanup_memory, validate_input_texts, get_model_info
6
+
7
+ __all__ = ['load_models', 'get_embeddings', 'cleanup_memory', 'validate_input_texts', 'get_model_info']
utils/helpers.py CHANGED
@@ -3,14 +3,18 @@
3
 
4
  import torch
5
  import torch.nn.functional as F
6
- from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, RobertaModel
 
 
 
 
7
  from typing import List, Dict, Optional
8
  import gc
9
  import os
10
 
11
  def load_models() -> Dict:
12
  """
13
- Load both embedding models with memory optimization
14
 
15
  Returns:
16
  Dict containing loaded models and tokenizers
@@ -21,8 +25,8 @@ def load_models() -> Dict:
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
  try:
24
- # Load Jina model
25
- print("Loading Jina embeddings model...")
26
  jina_tokenizer = AutoTokenizer.from_pretrained(
27
  'jinaai/jina-embeddings-v2-base-es',
28
  trust_remote_code=True
@@ -43,16 +47,52 @@ def load_models() -> Dict:
43
  ).to(device)
44
  robertalex_model.eval()
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  models_cache = {
47
  'jina': {
48
  'tokenizer': jina_tokenizer,
49
  'model': jina_model,
50
- 'device': device
 
51
  },
52
  'robertalex': {
53
  'tokenizer': robertalex_tokenizer,
54
  'model': robertalex_model,
55
- 'device': device
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
  }
58
 
@@ -92,7 +132,7 @@ def get_embeddings(
92
 
93
  Args:
94
  texts: List of texts to embed
95
- model_name: Name of model to use ('jina' or 'robertalex')
96
  models_cache: Dictionary containing loaded models
97
  normalize: Whether to normalize embeddings
98
  max_length: Maximum sequence length
@@ -101,15 +141,19 @@ def get_embeddings(
101
  List of embedding vectors
102
  """
103
  if model_name not in models_cache:
104
- raise ValueError(f"Model {model_name} not available. Choose 'jina' or 'robertalex'")
105
 
106
  tokenizer = models_cache[model_name]['tokenizer']
107
  model = models_cache[model_name]['model']
108
  device = models_cache[model_name]['device']
 
109
 
110
  # Set max length based on model capabilities
111
  if max_length is None:
112
- max_length = 8192 if model_name == 'jina' else 512
 
 
 
113
 
114
  # Process in batches for memory efficiency
115
  batch_size = 8 if len(texts) > 8 else len(texts)
@@ -131,11 +175,11 @@ def get_embeddings(
131
  with torch.no_grad():
132
  model_output = model(**encoded_input)
133
 
134
- if model_name == 'jina':
135
- # Jina models require mean pooling
136
  embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
137
  else:
138
- # RoBERTalex: use [CLS] token embedding
139
  embeddings = model_output.last_hidden_state[:, 0, :]
140
 
141
  # Normalize if requested
@@ -201,6 +245,20 @@ def get_model_info(model_name: str) -> Dict:
201
  'max_length': 512,
202
  'pooling': 'cls',
203
  'languages': ['Spanish']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  }
205
  }
206
 
 
3
 
4
  import torch
5
  import torch.nn.functional as F
6
+ from transformers import (
7
+ AutoTokenizer, AutoModel,
8
+ RobertaTokenizer, RobertaModel,
9
+ BertTokenizer, BertModel
10
+ )
11
  from typing import List, Dict, Optional
12
  import gc
13
  import os
14
 
15
  def load_models() -> Dict:
16
  """
17
+ Load all embedding models with memory optimization
18
 
19
  Returns:
20
  Dict containing loaded models and tokenizers
 
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
  try:
28
+ # Load Jina v2 Spanish model
29
+ print("Loading Jina embeddings v2 Spanish model...")
30
  jina_tokenizer = AutoTokenizer.from_pretrained(
31
  'jinaai/jina-embeddings-v2-base-es',
32
  trust_remote_code=True
 
47
  ).to(device)
48
  robertalex_model.eval()
49
 
50
+ # Load Jina v3 model
51
+ print("Loading Jina embeddings v3 model...")
52
+ jina_v3_tokenizer = AutoTokenizer.from_pretrained(
53
+ 'jinaai/jina-embeddings-v3',
54
+ trust_remote_code=True
55
+ )
56
+ jina_v3_model = AutoModel.from_pretrained(
57
+ 'jinaai/jina-embeddings-v3',
58
+ trust_remote_code=True,
59
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
60
+ ).to(device)
61
+ jina_v3_model.eval()
62
+
63
+ # Load Legal BERT model
64
+ print("Loading Legal BERT model...")
65
+ legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
66
+ legal_bert_model = BertModel.from_pretrained(
67
+ 'nlpaueb/legal-bert-base-uncased',
68
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
69
+ ).to(device)
70
+ legal_bert_model.eval()
71
+
72
  models_cache = {
73
  'jina': {
74
  'tokenizer': jina_tokenizer,
75
  'model': jina_model,
76
+ 'device': device,
77
+ 'pooling': 'mean'
78
  },
79
  'robertalex': {
80
  'tokenizer': robertalex_tokenizer,
81
  'model': robertalex_model,
82
+ 'device': device,
83
+ 'pooling': 'cls'
84
+ },
85
+ 'jina-v3': {
86
+ 'tokenizer': jina_v3_tokenizer,
87
+ 'model': jina_v3_model,
88
+ 'device': device,
89
+ 'pooling': 'mean'
90
+ },
91
+ 'legal-bert': {
92
+ 'tokenizer': legal_bert_tokenizer,
93
+ 'model': legal_bert_model,
94
+ 'device': device,
95
+ 'pooling': 'cls'
96
  }
97
  }
98
 
 
132
 
133
  Args:
134
  texts: List of texts to embed
135
+ model_name: Name of model to use
136
  models_cache: Dictionary containing loaded models
137
  normalize: Whether to normalize embeddings
138
  max_length: Maximum sequence length
 
141
  List of embedding vectors
142
  """
143
  if model_name not in models_cache:
144
+ raise ValueError(f"Model {model_name} not available. Choose from: {list(models_cache.keys())}")
145
 
146
  tokenizer = models_cache[model_name]['tokenizer']
147
  model = models_cache[model_name]['model']
148
  device = models_cache[model_name]['device']
149
+ pooling_strategy = models_cache[model_name]['pooling']
150
 
151
  # Set max length based on model capabilities
152
  if max_length is None:
153
+ if model_name in ['jina', 'jina-v3']:
154
+ max_length = 8192
155
+ else: # robertalex, legal-bert
156
+ max_length = 512
157
 
158
  # Process in batches for memory efficiency
159
  batch_size = 8 if len(texts) > 8 else len(texts)
 
175
  with torch.no_grad():
176
  model_output = model(**encoded_input)
177
 
178
+ if pooling_strategy == 'mean':
179
+ # Mean pooling for Jina models
180
  embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
181
  else:
182
+ # CLS token for BERT-based models
183
  embeddings = model_output.last_hidden_state[:, 0, :]
184
 
185
  # Normalize if requested
 
245
  'max_length': 512,
246
  'pooling': 'cls',
247
  'languages': ['Spanish']
248
+ },
249
+ 'jina-v3': {
250
+ 'full_name': 'jinaai/jina-embeddings-v3',
251
+ 'dimensions': 1024,
252
+ 'max_length': 8192,
253
+ 'pooling': 'mean',
254
+ 'languages': ['Multilingual']
255
+ },
256
+ 'legal-bert': {
257
+ 'full_name': 'nlpaueb/legal-bert-base-uncased',
258
+ 'dimensions': 768,
259
+ 'max_length': 512,
260
+ 'pooling': 'cls',
261
+ 'languages': ['English']
262
  }
263
  }
264