|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from contextlib import asynccontextmanager |
|
from typing import List |
|
import torch |
|
import uvicorn |
|
|
|
from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo |
|
from utils.helpers import load_models, get_embeddings, cleanup_memory |
|
|
|
|
|
models_cache = {} |
|
|
|
|
|
STARTUP_MODEL = "jina-v3" |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""Application lifespan handler for startup and shutdown""" |
|
|
|
try: |
|
global models_cache |
|
print(f"Loading startup model: {STARTUP_MODEL}...") |
|
models_cache = load_models([STARTUP_MODEL]) |
|
print(f"Startup model loaded successfully: {list(models_cache.keys())}") |
|
yield |
|
except Exception as e: |
|
print(f"Failed to load startup model: {str(e)}") |
|
|
|
yield |
|
finally: |
|
|
|
cleanup_memory() |
|
|
|
def ensure_model_loaded(model_name: str, max_length_limit: int): |
|
"""Load a specific model on demand if not already loaded""" |
|
global models_cache |
|
if model_name not in models_cache: |
|
try: |
|
print(f"Loading model on demand: {model_name}...") |
|
new_models = load_models([model_name]) |
|
models_cache.update(new_models) |
|
print(f"Model {model_name} loaded successfully!") |
|
except Exception as e: |
|
print(f"Failed to load model {model_name}: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Model {model_name} loading failed: {str(e)}") |
|
|
|
def validate_request_for_model(request: EmbeddingRequest, model_name: str, max_length_limit: int): |
|
"""Validate request parameters for specific model""" |
|
if not request.texts: |
|
raise HTTPException(status_code=400, detail="No texts provided") |
|
|
|
if len(request.texts) > 50: |
|
raise HTTPException(status_code=400, detail="Maximum 50 texts per request") |
|
|
|
if request.max_length is not None and request.max_length > max_length_limit: |
|
raise HTTPException(status_code=400, detail=f"Max length for {model_name} is {max_length_limit}") |
|
|
|
app = FastAPI( |
|
title="Multilingual & Legal Embedding API", |
|
description="Multi-model embedding API with dedicated endpoints per model", |
|
version="4.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/") |
|
async def root(): |
|
return { |
|
"message": "Multilingual & Legal Embedding API - Endpoint Per Model", |
|
"version": "4.0.0", |
|
"status": "running", |
|
"docs": "/docs", |
|
"startup_model": STARTUP_MODEL, |
|
"available_endpoints": { |
|
"jina-v3": "/embed/jina-v3", |
|
"roberta-ca": "/embed/roberta-ca", |
|
"jina": "/embed/jina", |
|
"robertalex": "/embed/robertalex", |
|
"legal-bert": "/embed/legal-bert" |
|
} |
|
} |
|
|
|
|
|
@app.post("/embed/jina-v3", response_model=EmbeddingResponse) |
|
async def embed_jina_v3(request: EmbeddingRequest): |
|
"""Generate embeddings using Jina v3 model (multilingual)""" |
|
try: |
|
ensure_model_loaded("jina-v3", 8192) |
|
validate_request_for_model(request, "jina-v3", 8192) |
|
|
|
embeddings = get_embeddings( |
|
request.texts, |
|
"jina-v3", |
|
models_cache, |
|
request.normalize, |
|
request.max_length |
|
) |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
model_used="jina-v3", |
|
dimensions=len(embeddings[0]) if embeddings else 0, |
|
num_texts=len(request.texts) |
|
) |
|
|
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
|
|
@app.post("/embed/roberta-ca", response_model=EmbeddingResponse) |
|
async def embed_roberta_ca(request: EmbeddingRequest): |
|
"""Generate embeddings using Catalan RoBERTa model""" |
|
try: |
|
ensure_model_loaded("roberta-ca", 512) |
|
validate_request_for_model(request, "roberta-ca", 512) |
|
|
|
embeddings = get_embeddings( |
|
request.texts, |
|
"roberta-ca", |
|
models_cache, |
|
request.normalize, |
|
request.max_length |
|
) |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
model_used="roberta-ca", |
|
dimensions=len(embeddings[0]) if embeddings else 0, |
|
num_texts=len(request.texts) |
|
) |
|
|
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
|
|
@app.post("/embed/jina", response_model=EmbeddingResponse) |
|
async def embed_jina(request: EmbeddingRequest): |
|
"""Generate embeddings using Jina v2 Spanish/English model""" |
|
try: |
|
ensure_model_loaded("jina", 8192) |
|
validate_request_for_model(request, "jina", 8192) |
|
|
|
embeddings = get_embeddings( |
|
request.texts, |
|
"jina", |
|
models_cache, |
|
request.normalize, |
|
request.max_length |
|
) |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
model_used="jina", |
|
dimensions=len(embeddings[0]) if embeddings else 0, |
|
num_texts=len(request.texts) |
|
) |
|
|
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
|
|
@app.post("/embed/robertalex", response_model=EmbeddingResponse) |
|
async def embed_robertalex(request: EmbeddingRequest): |
|
"""Generate embeddings using RoBERTalex Spanish legal model""" |
|
try: |
|
ensure_model_loaded("robertalex", 512) |
|
validate_request_for_model(request, "robertalex", 512) |
|
|
|
embeddings = get_embeddings( |
|
request.texts, |
|
"robertalex", |
|
models_cache, |
|
request.normalize, |
|
request.max_length |
|
) |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
model_used="robertalex", |
|
dimensions=len(embeddings[0]) if embeddings else 0, |
|
num_texts=len(request.texts) |
|
) |
|
|
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
|
|
@app.post("/embed/legal-bert", response_model=EmbeddingResponse) |
|
async def embed_legal_bert(request: EmbeddingRequest): |
|
"""Generate embeddings using Legal BERT English model""" |
|
try: |
|
ensure_model_loaded("legal-bert", 512) |
|
validate_request_for_model(request, "legal-bert", 512) |
|
|
|
embeddings = get_embeddings( |
|
request.texts, |
|
"legal-bert", |
|
models_cache, |
|
request.normalize, |
|
request.max_length |
|
) |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
model_used="legal-bert", |
|
dimensions=len(embeddings[0]) if embeddings else 0, |
|
num_texts=len(request.texts) |
|
) |
|
|
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") |
|
|
|
@app.get("/models", response_model=List[ModelInfo]) |
|
async def list_models(): |
|
"""List available models and their specifications""" |
|
return [ |
|
ModelInfo( |
|
model_id="jina-v3", |
|
name="jinaai/jina-embeddings-v3", |
|
dimensions=1024, |
|
max_sequence_length=8192, |
|
languages=["Multilingual"], |
|
model_type="multilingual", |
|
description="Latest Jina v3 with superior multilingual performance - loaded at startup" |
|
), |
|
ModelInfo( |
|
model_id="roberta-ca", |
|
name="projecte-aina/roberta-large-ca-v2", |
|
dimensions=1024, |
|
max_sequence_length=512, |
|
languages=["Catalan"], |
|
model_type="general", |
|
description="Catalan RoBERTa-large model trained on large corpus" |
|
), |
|
ModelInfo( |
|
model_id="jina", |
|
name="jinaai/jina-embeddings-v2-base-es", |
|
dimensions=768, |
|
max_sequence_length=8192, |
|
languages=["Spanish", "English"], |
|
model_type="bilingual", |
|
description="Bilingual Spanish-English embeddings with long context support" |
|
), |
|
ModelInfo( |
|
model_id="robertalex", |
|
name="PlanTL-GOB-ES/RoBERTalex", |
|
dimensions=768, |
|
max_sequence_length=512, |
|
languages=["Spanish"], |
|
model_type="legal domain", |
|
description="Spanish legal domain specialized embeddings" |
|
), |
|
ModelInfo( |
|
model_id="legal-bert", |
|
name="nlpaueb/legal-bert-base-uncased", |
|
dimensions=768, |
|
max_sequence_length=512, |
|
languages=["English"], |
|
model_type="legal domain", |
|
description="English legal domain BERT model" |
|
) |
|
] |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
startup_loaded = STARTUP_MODEL in models_cache |
|
|
|
return { |
|
"status": "healthy" if startup_loaded else "partial", |
|
"startup_model": STARTUP_MODEL, |
|
"startup_model_loaded": startup_loaded, |
|
"available_models": list(models_cache.keys()), |
|
"models_count": len(models_cache), |
|
"endpoints": { |
|
"jina-v3": f"/embed/jina-v3 {'(ready)' if 'jina-v3' in models_cache else '(loads on demand)'}", |
|
"roberta-ca": f"/embed/roberta-ca {'(ready)' if 'roberta-ca' in models_cache else '(loads on demand)'}", |
|
"jina": f"/embed/jina {'(ready)' if 'jina' in models_cache else '(loads on demand)'}", |
|
"robertalex": f"/embed/robertalex {'(ready)' if 'robertalex' in models_cache else '(loads on demand)'}", |
|
"legal-bert": f"/embed/legal-bert {'(ready)' if 'legal-bert' in models_cache else '(loads on demand)'}" |
|
} |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
torch.set_num_threads(8) |
|
torch.set_num_interop_threads(1) |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |