Jordi Catafal commited on
Commit
c3aef13
·
1 Parent(s): f26f739

initial deployment

Browse files
embeddings_api/Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV TRANSFORMERS_CACHE=/app/cache
6
+ ENV HF_HOME=/app/cache
7
+ ENV PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:128
8
+
9
+ # Create non-root user
10
+ RUN useradd -m -u 1000 user
11
+ USER user
12
+ ENV HOME=/home/user \
13
+ PATH=/home/user/.local/bin:$PATH
14
+
15
+ # Set working directory
16
+ WORKDIR /app
17
+
18
+ # Copy requirements and install dependencies
19
+ COPY --chown=user requirements.txt .
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy application code
24
+ COPY --chown=user . .
25
+
26
+ # Create cache directory
27
+ RUN mkdir -p /app/cache
28
+
29
+ # Expose port
30
+ EXPOSE 7860
31
+
32
+ # Run the application
33
+ CMD ["python", "app.py"]
embeddings_api/app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from typing import List
3
+ import torch
4
+ import uvicorn
5
+ import gc
6
+ import os
7
+
8
+ from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
9
+ from utils.helpers import load_models, get_embeddings, cleanup_memory
10
+
11
+ app = FastAPI(
12
+ title="Spanish Embedding API",
13
+ description="Dual Spanish embedding models API",
14
+ version="1.0.0"
15
+ )
16
+
17
+ # Global model cache
18
+ models_cache = {}
19
+
20
+ @app.on_event("startup")
21
+ async def startup_event():
22
+ """Load models on startup"""
23
+ global models_cache
24
+ models_cache = load_models()
25
+ print("Models loaded successfully!")
26
+
27
+ @app.get("/")
28
+ async def root():
29
+ return {
30
+ "message": "Spanish Embedding API",
31
+ "models": ["jina", "robertalex"],
32
+ "status": "running",
33
+ "docs": "/docs"
34
+ }
35
+
36
+ @app.post("/embed", response_model=EmbeddingResponse)
37
+ async def create_embeddings(request: EmbeddingRequest):
38
+ """Generate embeddings for input texts"""
39
+ try:
40
+ if not request.texts:
41
+ raise HTTPException(status_code=400, detail="No texts provided")
42
+
43
+ if len(request.texts) > 50: # Rate limiting
44
+ raise HTTPException(status_code=400, detail="Maximum 50 texts per request")
45
+
46
+ embeddings = get_embeddings(
47
+ request.texts,
48
+ request.model,
49
+ models_cache,
50
+ request.normalize,
51
+ request.max_length
52
+ )
53
+
54
+ # Cleanup memory after large batches
55
+ if len(request.texts) > 20:
56
+ cleanup_memory()
57
+
58
+ return EmbeddingResponse(
59
+ embeddings=embeddings,
60
+ model_used=request.model,
61
+ dimensions=len(embeddings[0]) if embeddings else 0,
62
+ num_texts=len(request.texts)
63
+ )
64
+
65
+ except ValueError as e:
66
+ raise HTTPException(status_code=400, detail=str(e))
67
+ except Exception as e:
68
+ raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
69
+
70
+ @app.get("/models", response_model=List[ModelInfo])
71
+ async def list_models():
72
+ """List available models and their specifications"""
73
+ return [
74
+ ModelInfo(
75
+ model_id="jina",
76
+ name="jinaai/jina-embeddings-v2-base-es",
77
+ dimensions=768,
78
+ max_sequence_length=8192,
79
+ languages=["Spanish", "English"],
80
+ model_type="bilingual",
81
+ description="Bilingual Spanish-English embeddings with long context support"
82
+ ),
83
+ ModelInfo(
84
+ model_id="robertalex",
85
+ name="PlanTL-GOB-ES/RoBERTalex",
86
+ dimensions=768,
87
+ max_sequence_length=512,
88
+ languages=["Spanish"],
89
+ model_type="legal domain",
90
+ description="Spanish legal domain specialized embeddings"
91
+ )
92
+ ]
93
+
94
+ @app.get("/health")
95
+ async def health_check():
96
+ """Health check endpoint"""
97
+ return {
98
+ "status": "healthy",
99
+ "models_loaded": len(models_cache) == 2,
100
+ "available_models": list(models_cache.keys())
101
+ }
102
+
103
+ if __name__ == "__main__":
104
+ # Set multi-threading for CPU
105
+ torch.set_num_threads(8)
106
+ torch.set_num_interop_threads(1)
107
+
108
+ uvicorn.run(app, host="0.0.0.0", port=7860)
embeddings_api/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # models/__init__.py
2
+ """Models package for embedding API schemas and configurations"""
3
+
4
+ from .schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
5
+
6
+ __all__ = ['EmbeddingRequest', 'EmbeddingResponse', 'ModelInfo']
embeddings_api/models/schemas.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/schemas.py
2
+ """Pydantic models for request/response validation"""
3
+
4
+ from pydantic import BaseModel, Field, validator
5
+ from typing import List, Optional, Literal
6
+
7
+ class EmbeddingRequest(BaseModel):
8
+ """Request model for embedding generation"""
9
+ texts: List[str] = Field(
10
+ ...,
11
+ description="List of texts to embed",
12
+ example=["Hola mundo", "¿Cómo estás?"]
13
+ )
14
+ model: Literal["jina", "robertalex"] = Field(
15
+ default="jina",
16
+ description="Model to use for embeddings"
17
+ )
18
+ normalize: bool = Field(
19
+ default=True,
20
+ description="Whether to normalize embeddings to unit length"
21
+ )
22
+ max_length: Optional[int] = Field(
23
+ default=None,
24
+ description="Maximum sequence length (uses model default if not specified)"
25
+ )
26
+
27
+ @validator('texts')
28
+ def validate_texts(cls, v):
29
+ if not v:
30
+ raise ValueError("At least one text must be provided")
31
+ if len(v) > 50:
32
+ raise ValueError("Maximum 50 texts per request")
33
+ # Check for empty strings
34
+ if any(not text.strip() for text in v):
35
+ raise ValueError("Empty texts are not allowed")
36
+ return v
37
+
38
+ @validator('max_length')
39
+ def validate_max_length(cls, v, values):
40
+ if v is not None:
41
+ model = values.get('model', 'jina')
42
+ if model == 'jina' and v > 8192:
43
+ raise ValueError("Max length for Jina model is 8192")
44
+ elif model == 'robertalex' and v > 512:
45
+ raise ValueError("Max length for RoBERTalex model is 512")
46
+ if v < 1:
47
+ raise ValueError("Max length must be positive")
48
+ return v
49
+
50
+ class EmbeddingResponse(BaseModel):
51
+ """Response model for embedding generation"""
52
+ embeddings: List[List[float]] = Field(
53
+ ...,
54
+ description="List of embedding vectors"
55
+ )
56
+ model_used: str = Field(
57
+ ...,
58
+ description="Model that was used"
59
+ )
60
+ dimensions: int = Field(
61
+ ...,
62
+ description="Dimension of embedding vectors"
63
+ )
64
+ num_texts: int = Field(
65
+ ...,
66
+ description="Number of texts processed"
67
+ )
68
+
69
+ class ModelInfo(BaseModel):
70
+ """Information about available models"""
71
+ model_id: str = Field(
72
+ ...,
73
+ description="Model identifier for API calls"
74
+ )
75
+ name: str = Field(
76
+ ...,
77
+ description="Full Hugging Face model name"
78
+ )
79
+ dimensions: int = Field(
80
+ ...,
81
+ description="Output embedding dimensions"
82
+ )
83
+ max_sequence_length: int = Field(
84
+ ...,
85
+ description="Maximum input sequence length"
86
+ )
87
+ languages: List[str] = Field(
88
+ ...,
89
+ description="Supported languages"
90
+ )
91
+ model_type: str = Field(
92
+ ...,
93
+ description="Type/domain of model"
94
+ )
95
+ description: str = Field(
96
+ ...,
97
+ description="Model description"
98
+ )
99
+
100
+ class ErrorResponse(BaseModel):
101
+ """Error response model"""
102
+ detail: str = Field(
103
+ ...,
104
+ description="Error message"
105
+ )
106
+ error_type: Optional[str] = Field(
107
+ default=None,
108
+ description="Type of error"
109
+ )
embeddings_api/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ transformers==4.36.0
4
+ torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu
5
+ sentence-transformers==2.2.2
6
+ numpy<2.0.0
7
+ scikit-learn==1.3.2
8
+ pydantic==2.5.0
9
+ huggingface-hub==0.19.4
10
+ python-multipart==0.0.6
embeddings_api/utils/helpers.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/helpers.py
2
+ """Helper functions for model loading and embedding generation"""
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, RobertaModel
7
+ from typing import List, Dict, Optional
8
+ import gc
9
+ import os
10
+
11
+ def load_models() -> Dict:
12
+ """
13
+ Load both embedding models with memory optimization
14
+
15
+ Returns:
16
+ Dict containing loaded models and tokenizers
17
+ """
18
+ models_cache = {}
19
+
20
+ # Set device
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ try:
24
+ # Load Jina model
25
+ print("Loading Jina embeddings model...")
26
+ jina_tokenizer = AutoTokenizer.from_pretrained(
27
+ 'jinaai/jina-embeddings-v2-base-es',
28
+ trust_remote_code=True
29
+ )
30
+ jina_model = AutoModel.from_pretrained(
31
+ 'jinaai/jina-embeddings-v2-base-es',
32
+ trust_remote_code=True,
33
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
34
+ ).to(device)
35
+ jina_model.eval()
36
+
37
+ # Load RoBERTalex model
38
+ print("Loading RoBERTalex model...")
39
+ robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex')
40
+ robertalex_model = RobertaModel.from_pretrained(
41
+ 'PlanTL-GOB-ES/RoBERTalex',
42
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
43
+ ).to(device)
44
+ robertalex_model.eval()
45
+
46
+ models_cache = {
47
+ 'jina': {
48
+ 'tokenizer': jina_tokenizer,
49
+ 'model': jina_model,
50
+ 'device': device
51
+ },
52
+ 'robertalex': {
53
+ 'tokenizer': robertalex_tokenizer,
54
+ 'model': robertalex_model,
55
+ 'device': device
56
+ }
57
+ }
58
+
59
+ # Force garbage collection after loading
60
+ gc.collect()
61
+
62
+ return models_cache
63
+
64
+ except Exception as e:
65
+ print(f"Error loading models: {str(e)}")
66
+ raise
67
+
68
+ def mean_pooling(model_output, attention_mask):
69
+ """
70
+ Apply mean pooling to get sentence embeddings
71
+
72
+ Args:
73
+ model_output: Model output containing token embeddings
74
+ attention_mask: Attention mask for valid tokens
75
+
76
+ Returns:
77
+ Pooled embeddings
78
+ """
79
+ token_embeddings = model_output[0]
80
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
81
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
82
+
83
+ def get_embeddings(
84
+ texts: List[str],
85
+ model_name: str,
86
+ models_cache: Dict,
87
+ normalize: bool = True,
88
+ max_length: Optional[int] = None
89
+ ) -> List[List[float]]:
90
+ """
91
+ Generate embeddings for texts using specified model
92
+
93
+ Args:
94
+ texts: List of texts to embed
95
+ model_name: Name of model to use ('jina' or 'robertalex')
96
+ models_cache: Dictionary containing loaded models
97
+ normalize: Whether to normalize embeddings
98
+ max_length: Maximum sequence length
99
+
100
+ Returns:
101
+ List of embedding vectors
102
+ """
103
+ if model_name not in models_cache:
104
+ raise ValueError(f"Model {model_name} not available. Choose 'jina' or 'robertalex'")
105
+
106
+ tokenizer = models_cache[model_name]['tokenizer']
107
+ model = models_cache[model_name]['model']
108
+ device = models_cache[model_name]['device']
109
+
110
+ # Set max length based on model capabilities
111
+ if max_length is None:
112
+ max_length = 8192 if model_name == 'jina' else 512
113
+
114
+ # Process in batches for memory efficiency
115
+ batch_size = 8 if len(texts) > 8 else len(texts)
116
+ all_embeddings = []
117
+
118
+ for i in range(0, len(texts), batch_size):
119
+ batch_texts = texts[i:i + batch_size]
120
+
121
+ # Tokenize inputs
122
+ encoded_input = tokenizer(
123
+ batch_texts,
124
+ padding=True,
125
+ truncation=True,
126
+ max_length=max_length,
127
+ return_tensors='pt'
128
+ ).to(device)
129
+
130
+ # Generate embeddings
131
+ with torch.no_grad():
132
+ model_output = model(**encoded_input)
133
+
134
+ if model_name == 'jina':
135
+ # Jina models require mean pooling
136
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
137
+ else:
138
+ # RoBERTalex: use [CLS] token embedding
139
+ embeddings = model_output.last_hidden_state[:, 0, :]
140
+
141
+ # Normalize if requested
142
+ if normalize:
143
+ embeddings = F.normalize(embeddings, p=2, dim=1)
144
+
145
+ # Convert to CPU and list
146
+ batch_embeddings = embeddings.cpu().numpy().tolist()
147
+ all_embeddings.extend(batch_embeddings)
148
+
149
+ return all_embeddings
150
+
151
+ def cleanup_memory():
152
+ """Force garbage collection and clear cache"""
153
+ gc.collect()
154
+ if torch.cuda.is_available():
155
+ torch.cuda.empty_cache()
156
+
157
+ def validate_input_texts(texts: List[str]) -> List[str]:
158
+ """
159
+ Validate and clean input texts
160
+
161
+ Args:
162
+ texts: List of input texts
163
+
164
+ Returns:
165
+ Cleaned texts
166
+ """
167
+ cleaned_texts = []
168
+ for text in texts:
169
+ # Remove excess whitespace
170
+ text = ' '.join(text.split())
171
+ # Skip empty texts
172
+ if text:
173
+ cleaned_texts.append(text)
174
+
175
+ if not cleaned_texts:
176
+ raise ValueError("No valid texts provided after cleaning")
177
+
178
+ return cleaned_texts
179
+
180
+ def get_model_info(model_name: str) -> Dict:
181
+ """
182
+ Get detailed information about a model
183
+
184
+ Args:
185
+ model_name: Model identifier
186
+
187
+ Returns:
188
+ Dictionary with model information
189
+ """
190
+ model_info = {
191
+ 'jina': {
192
+ 'full_name': 'jinaai/jina-embeddings-v2-base-es',
193
+ 'dimensions': 768,
194
+ 'max_length': 8192,
195
+ 'pooling': 'mean',
196
+ 'languages': ['Spanish', 'English']
197
+ },
198
+ 'robertalex': {
199
+ 'full_name': 'PlanTL-GOB-ES/RoBERTalex',
200
+ 'dimensions': 768,
201
+ 'max_length': 512,
202
+ 'pooling': 'cls',
203
+ 'languages': ['Spanish']
204
+ }
205
+ }
206
+
207
+ return model_info.get(model_name, {})