Jordi Catafal commited on
Commit
5861022
·
1 Parent(s): 0610fdd

trying hibrid approach

Browse files
__pycache__/app.cpython-311.pyc CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
 
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from typing import List
4
  import torch
5
  import uvicorn
@@ -7,10 +8,53 @@ import uvicorn
7
  from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
8
  from utils.helpers import load_models, get_embeddings, cleanup_memory
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  app = FastAPI(
11
  title="Multilingual & Legal Embedding API",
12
  description="Multi-model embedding API for Spanish, Catalan, English and Legal texts",
13
- version="3.0.0"
 
14
  )
15
 
16
  # Add CORS middleware to allow cross-origin requests
@@ -22,21 +66,6 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Global model cache - loaded on demand
26
- models_cache = {}
27
-
28
- def ensure_models_loaded():
29
- """Load models on first request if not already loaded"""
30
- global models_cache
31
- if not models_cache:
32
- try:
33
- print("Loading models on demand...")
34
- models_cache = load_models()
35
- print("All models loaded successfully!")
36
- except Exception as e:
37
- print(f"Failed to load models: {str(e)}")
38
- raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
39
-
40
  @app.get("/")
41
  async def root():
42
  return {
@@ -51,8 +80,8 @@ async def root():
51
  async def create_embeddings(request: EmbeddingRequest):
52
  """Generate embeddings for input texts"""
53
  try:
54
- # Load models on first request
55
- ensure_models_loaded()
56
 
57
  if not request.texts:
58
  raise HTTPException(status_code=400, detail="No texts provided")
@@ -138,14 +167,18 @@ async def list_models():
138
  @app.get("/health")
139
  async def health_check():
140
  """Health check endpoint"""
141
- models_loaded = len(models_cache) == 5
 
 
142
  return {
143
- "status": "healthy" if models_loaded else "ready",
144
- "models_loaded": models_loaded,
 
145
  "available_models": list(models_cache.keys()),
146
- "expected_models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"],
 
147
  "models_count": len(models_cache),
148
- "note": "Models load on first embedding request" if not models_loaded else "All models ready"
149
  }
150
 
151
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from contextlib import asynccontextmanager
4
  from typing import List
5
  import torch
6
  import uvicorn
 
8
  from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
9
  from utils.helpers import load_models, get_embeddings, cleanup_memory
10
 
11
+ # Global model cache
12
+ models_cache = {}
13
+
14
+ # Models to load at startup (most frequently used)
15
+ STARTUP_MODELS = ["jina-v3", "roberta-ca"]
16
+ # Models to load on demand
17
+ ON_DEMAND_MODELS = ["jina", "robertalex", "legal-bert"]
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
21
+ """Application lifespan handler for startup and shutdown"""
22
+ # Startup - load priority models
23
+ try:
24
+ global models_cache
25
+ print(f"Loading startup models: {STARTUP_MODELS}...")
26
+ models_cache = load_models(STARTUP_MODELS)
27
+ print(f"Startup models loaded successfully: {list(models_cache.keys())}")
28
+ yield
29
+ except Exception as e:
30
+ print(f"Failed to load startup models: {str(e)}")
31
+ # Continue anyway - models can be loaded on demand
32
+ yield
33
+ finally:
34
+ # Shutdown - cleanup resources
35
+ cleanup_memory()
36
+
37
+ def ensure_model_loaded(model_name: str):
38
+ """Load a specific model on demand if not already loaded"""
39
+ global models_cache
40
+ if model_name not in models_cache:
41
+ if model_name in ON_DEMAND_MODELS:
42
+ try:
43
+ print(f"Loading model on demand: {model_name}...")
44
+ new_models = load_models([model_name])
45
+ models_cache.update(new_models)
46
+ print(f"Model {model_name} loaded successfully!")
47
+ except Exception as e:
48
+ print(f"Failed to load model {model_name}: {str(e)}")
49
+ raise HTTPException(status_code=500, detail=f"Model {model_name} loading failed: {str(e)}")
50
+ else:
51
+ raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
52
+
53
  app = FastAPI(
54
  title="Multilingual & Legal Embedding API",
55
  description="Multi-model embedding API for Spanish, Catalan, English and Legal texts",
56
+ version="3.0.0",
57
+ lifespan=lifespan
58
  )
59
 
60
  # Add CORS middleware to allow cross-origin requests
 
66
  allow_headers=["*"],
67
  )
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  @app.get("/")
70
  async def root():
71
  return {
 
80
  async def create_embeddings(request: EmbeddingRequest):
81
  """Generate embeddings for input texts"""
82
  try:
83
+ # Load specific model on demand if needed
84
+ ensure_model_loaded(request.model)
85
 
86
  if not request.texts:
87
  raise HTTPException(status_code=400, detail="No texts provided")
 
167
  @app.get("/health")
168
  async def health_check():
169
  """Health check endpoint"""
170
+ startup_models_loaded = all(model in models_cache for model in STARTUP_MODELS)
171
+ all_models_loaded = len(models_cache) == 5
172
+
173
  return {
174
+ "status": "healthy" if startup_models_loaded else "partial",
175
+ "startup_models_loaded": startup_models_loaded,
176
+ "all_models_loaded": all_models_loaded,
177
  "available_models": list(models_cache.keys()),
178
+ "startup_models": STARTUP_MODELS,
179
+ "on_demand_models": ON_DEMAND_MODELS,
180
  "models_count": len(models_cache),
181
+ "note": f"Startup models: {STARTUP_MODELS} | On-demand: {ON_DEMAND_MODELS}"
182
  }
183
 
184
  if __name__ == "__main__":
test_hybrid.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for hybrid model loading
4
+ """
5
+
6
+ import requests
7
+ import json
8
+ import time
9
+
10
+ def test_hybrid_api(base_url="https://aurasystems-spanish-embeddings-api.hf.space"):
11
+ """Test the hybrid API"""
12
+
13
+ print(f"Testing hybrid API at {base_url}")
14
+
15
+ # Test health endpoint first
16
+ try:
17
+ response = requests.get(f"{base_url}/health")
18
+ print(f"✓ Health endpoint: {response.status_code}")
19
+ if response.status_code == 200:
20
+ health_data = response.json()
21
+ print(f" Startup models loaded: {health_data.get('startup_models_loaded', False)}")
22
+ print(f" Available models: {health_data.get('available_models', [])}")
23
+ print(f" Note: {health_data.get('note', 'N/A')}")
24
+ else:
25
+ print(f" Error: {response.text}")
26
+ except Exception as e:
27
+ print(f"✗ Health endpoint failed: {e}")
28
+ return False
29
+
30
+ # Test startup model (jina-v3)
31
+ try:
32
+ payload = {
33
+ "texts": ["Hola mundo", "Bonjour le monde"],
34
+ "model": "jina-v3",
35
+ "normalize": True
36
+ }
37
+ response = requests.post(f"{base_url}/embed", json=payload)
38
+ print(f"✓ Startup model (jina-v3): {response.status_code}")
39
+ if response.status_code == 200:
40
+ data = response.json()
41
+ print(f" Generated {data.get('num_texts', 0)} embeddings")
42
+ print(f" Dimensions: {data.get('dimensions', 0)}")
43
+ else:
44
+ print(f" Error: {response.text}")
45
+ except Exception as e:
46
+ print(f"✗ Startup model test failed: {e}")
47
+
48
+ # Test startup model (roberta-ca)
49
+ try:
50
+ payload = {
51
+ "texts": ["Bon dia", "Com estàs?"],
52
+ "model": "roberta-ca",
53
+ "normalize": True
54
+ }
55
+ response = requests.post(f"{base_url}/embed", json=payload)
56
+ print(f"✓ Startup model (roberta-ca): {response.status_code}")
57
+ if response.status_code == 200:
58
+ data = response.json()
59
+ print(f" Generated {data.get('num_texts', 0)} embeddings")
60
+ print(f" Dimensions: {data.get('dimensions', 0)}")
61
+ else:
62
+ print(f" Error: {response.text}")
63
+ except Exception as e:
64
+ print(f"✗ Startup model test failed: {e}")
65
+
66
+ # Test on-demand model (jina)
67
+ try:
68
+ payload = {
69
+ "texts": ["Texto en español"],
70
+ "model": "jina",
71
+ "normalize": True
72
+ }
73
+ response = requests.post(f"{base_url}/embed", json=payload)
74
+ print(f"✓ On-demand model (jina): {response.status_code}")
75
+ if response.status_code == 200:
76
+ data = response.json()
77
+ print(f" Generated {data.get('num_texts', 0)} embeddings")
78
+ print(f" Dimensions: {data.get('dimensions', 0)}")
79
+ else:
80
+ print(f" Error: {response.text}")
81
+ except Exception as e:
82
+ print(f"✗ On-demand model test failed: {e}")
83
+
84
+ # Check health again to see all models
85
+ try:
86
+ response = requests.get(f"{base_url}/health")
87
+ if response.status_code == 200:
88
+ health_data = response.json()
89
+ print(f"✓ Final health check:")
90
+ print(f" All models loaded: {health_data.get('all_models_loaded', False)}")
91
+ print(f" Available models: {health_data.get('available_models', [])}")
92
+ except Exception as e:
93
+ print(f"✗ Final health check failed: {e}")
94
+
95
+ return True
96
+
97
+ if __name__ == "__main__":
98
+ test_hybrid_api()
utils/__pycache__/helpers.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/helpers.cpython-311.pyc and b/utils/__pycache__/helpers.cpython-311.pyc differ
 
utils/helpers.py CHANGED
@@ -12,104 +12,118 @@ from typing import List, Dict, Optional
12
  import gc
13
  import os
14
 
15
- def load_models() -> Dict:
16
  """
17
- Load all embedding models with memory optimization
 
 
 
18
 
19
  Returns:
20
  Dict containing loaded models and tokenizers
21
  """
22
  models_cache = {}
23
 
 
 
 
 
24
  # Set device
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
  try:
28
  # Load Jina v2 Spanish model
29
- print("Loading Jina embeddings v2 Spanish model...")
30
- jina_tokenizer = AutoTokenizer.from_pretrained(
31
- 'jinaai/jina-embeddings-v2-base-es',
32
- trust_remote_code=True
33
- )
34
- jina_model = AutoModel.from_pretrained(
35
- 'jinaai/jina-embeddings-v2-base-es',
36
- trust_remote_code=True,
37
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
38
- ).to(device)
39
- jina_model.eval()
40
-
41
- # Load RoBERTalex model
42
- print("Loading RoBERTalex model...")
43
- robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex')
44
- robertalex_model = RobertaModel.from_pretrained(
45
- 'PlanTL-GOB-ES/RoBERTalex',
46
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
47
- ).to(device)
48
- robertalex_model.eval()
49
-
50
- # Load Jina v3 model
51
- print("Loading Jina embeddings v3 model...")
52
- jina_v3_tokenizer = AutoTokenizer.from_pretrained(
53
- 'jinaai/jina-embeddings-v3',
54
- trust_remote_code=True
55
- )
56
- jina_v3_model = AutoModel.from_pretrained(
57
- 'jinaai/jina-embeddings-v3',
58
- trust_remote_code=True,
59
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
60
- ).to(device)
61
- jina_v3_model.eval()
62
-
63
- # Load Legal BERT model
64
- print("Loading Legal BERT model...")
65
- legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
66
- legal_bert_model = BertModel.from_pretrained(
67
- 'nlpaueb/legal-bert-base-uncased',
68
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
69
- ).to(device)
70
- legal_bert_model.eval()
71
-
72
- # Load Catalan RoBERTa model
73
- print("Loading Catalan RoBERTa-large model...")
74
- roberta_ca_tokenizer = AutoTokenizer.from_pretrained('projecte-aina/roberta-large-ca-v2')
75
- roberta_ca_model = AutoModel.from_pretrained(
76
- 'projecte-aina/roberta-large-ca-v2',
77
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
78
- ).to(device)
79
- roberta_ca_model.eval()
80
-
81
- models_cache = {
82
- 'jina': {
83
  'tokenizer': jina_tokenizer,
84
  'model': jina_model,
85
  'device': device,
86
  'pooling': 'mean'
87
- },
88
- 'robertalex': {
 
 
 
 
 
 
 
 
 
 
 
89
  'tokenizer': robertalex_tokenizer,
90
  'model': robertalex_model,
91
  'device': device,
92
  'pooling': 'cls'
93
- },
94
- 'jina-v3': {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  'tokenizer': jina_v3_tokenizer,
96
  'model': jina_v3_model,
97
  'device': device,
98
  'pooling': 'mean'
99
- },
100
- 'legal-bert': {
 
 
 
 
 
 
 
 
 
 
 
101
  'tokenizer': legal_bert_tokenizer,
102
  'model': legal_bert_model,
103
  'device': device,
104
  'pooling': 'cls'
105
- },
106
- 'roberta-ca': {
 
 
 
 
 
 
 
 
 
 
 
107
  'tokenizer': roberta_ca_tokenizer,
108
  'model': roberta_ca_model,
109
  'device': device,
110
  'pooling': 'cls'
111
  }
112
- }
113
 
114
  # Force garbage collection after loading
115
  gc.collect()
 
12
  import gc
13
  import os
14
 
15
+ def load_models(model_names: List[str] = None) -> Dict:
16
  """
17
+ Load specific embedding models with memory optimization
18
+
19
+ Args:
20
+ model_names: List of model names to load. If None, loads all models.
21
 
22
  Returns:
23
  Dict containing loaded models and tokenizers
24
  """
25
  models_cache = {}
26
 
27
+ # Default to all models if none specified
28
+ if model_names is None:
29
+ model_names = ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"]
30
+
31
  # Set device
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
  try:
35
  # Load Jina v2 Spanish model
36
+ if "jina" in model_names:
37
+ print("Loading Jina embeddings v2 Spanish model...")
38
+ jina_tokenizer = AutoTokenizer.from_pretrained(
39
+ 'jinaai/jina-embeddings-v2-base-es',
40
+ trust_remote_code=True
41
+ )
42
+ jina_model = AutoModel.from_pretrained(
43
+ 'jinaai/jina-embeddings-v2-base-es',
44
+ trust_remote_code=True,
45
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
46
+ ).to(device)
47
+ jina_model.eval()
48
+
49
+ models_cache['jina'] = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  'tokenizer': jina_tokenizer,
51
  'model': jina_model,
52
  'device': device,
53
  'pooling': 'mean'
54
+ }
55
+
56
+ # Load RoBERTalex model
57
+ if "robertalex" in model_names:
58
+ print("Loading RoBERTalex model...")
59
+ robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex')
60
+ robertalex_model = RobertaModel.from_pretrained(
61
+ 'PlanTL-GOB-ES/RoBERTalex',
62
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
63
+ ).to(device)
64
+ robertalex_model.eval()
65
+
66
+ models_cache['robertalex'] = {
67
  'tokenizer': robertalex_tokenizer,
68
  'model': robertalex_model,
69
  'device': device,
70
  'pooling': 'cls'
71
+ }
72
+
73
+ # Load Jina v3 model
74
+ if "jina-v3" in model_names:
75
+ print("Loading Jina embeddings v3 model...")
76
+ jina_v3_tokenizer = AutoTokenizer.from_pretrained(
77
+ 'jinaai/jina-embeddings-v3',
78
+ trust_remote_code=True
79
+ )
80
+ jina_v3_model = AutoModel.from_pretrained(
81
+ 'jinaai/jina-embeddings-v3',
82
+ trust_remote_code=True,
83
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
84
+ ).to(device)
85
+ jina_v3_model.eval()
86
+
87
+ models_cache['jina-v3'] = {
88
  'tokenizer': jina_v3_tokenizer,
89
  'model': jina_v3_model,
90
  'device': device,
91
  'pooling': 'mean'
92
+ }
93
+
94
+ # Load Legal BERT model
95
+ if "legal-bert" in model_names:
96
+ print("Loading Legal BERT model...")
97
+ legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
98
+ legal_bert_model = BertModel.from_pretrained(
99
+ 'nlpaueb/legal-bert-base-uncased',
100
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
101
+ ).to(device)
102
+ legal_bert_model.eval()
103
+
104
+ models_cache['legal-bert'] = {
105
  'tokenizer': legal_bert_tokenizer,
106
  'model': legal_bert_model,
107
  'device': device,
108
  'pooling': 'cls'
109
+ }
110
+
111
+ # Load Catalan RoBERTa model
112
+ if "roberta-ca" in model_names:
113
+ print("Loading Catalan RoBERTa-large model...")
114
+ roberta_ca_tokenizer = AutoTokenizer.from_pretrained('projecte-aina/roberta-large-ca-v2')
115
+ roberta_ca_model = AutoModel.from_pretrained(
116
+ 'projecte-aina/roberta-large-ca-v2',
117
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
118
+ ).to(device)
119
+ roberta_ca_model.eval()
120
+
121
+ models_cache['roberta-ca'] = {
122
  'tokenizer': roberta_ca_tokenizer,
123
  'model': roberta_ca_model,
124
  'device': device,
125
  'pooling': 'cls'
126
  }
 
127
 
128
  # Force garbage collection after loading
129
  gc.collect()