py-api / main.py
moimobrian's picture
Added main page route (/)
2764934
import os
import logging
from contextlib import asynccontextmanager
from typing import List, Optional, Literal, Dict, Any
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, ConfigDict
from sentence_transformers import SparseEncoder
from transformers import AutoTokenizer
# --------------------------------------------------------------------------------------
# Logging
# --------------------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("main")
# --------------------------------------------------------------------------------------
# Device selection — intentionally NEVER choose MPS for SPLADE due to sparse-op gaps
# --------------------------------------------------------------------------------------
def choose_device() -> str:
if torch.cuda.is_available():
return "cuda"
# Avoid MPS for SPLADE (missing sparse ops). Default to CPU instead.
return "cpu"
DEVICE = choose_device()
logger.info(f"Selected device: {DEVICE}")
# --------------------------------------------------------------------------------------
# Model loading
# --------------------------------------------------------------------------------------
MODEL_ID = "sparse-encoder/splade-robbert-dutch-base-v1"
def load_sparse_encoder(model_id: str, device: str) -> SparseEncoder:
"""Load SparseEncoder. Prefer safetensors when available, but fall back to .bin.
Torch >= 2.6 is required by Transformers to load .bin safely.
"""
# Do NOT force safetensors globally; some repos only publish .bin
os.environ.pop("TRANSFORMERS_USE_SAFETENSORS", None)
try:
logger.info(f"Loading Dutch SPLADE model on {device}...")
m = SparseEncoder(model_id, device=device, model_kwargs={"use_safetensors": True})
return m
except OSError as e:
msg = str(e)
if "does not appear to have a file named model.safetensors" in msg:
logger.info("No safetensors in repo; retrying with .bin weights.")
return SparseEncoder(model_id, device=device)
raise
model: Optional[SparseEncoder] = None
# Tokenizer for mapping vocab ids -> readable tokens in explanations
tokenizer: Optional[AutoTokenizer] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
try:
model = load_sparse_encoder(MODEL_ID, DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
logger.info("Model & tokenizer loaded.")
yield
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
finally:
# Allow GC to clean up if server stops
pass
app = FastAPI(title="Sparse Embedding API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --------------------------------------------------------------------------------------
# Schemas
# --------------------------------------------------------------------------------------
class HealthResponse(BaseModel):
# Pydantic v2 warns about names starting with model_; allow them explicitly
model_config = ConfigDict(protected_namespaces=())
model_loaded: bool
model_name: str
device: str
class EmbeddingsRequest(BaseModel):
texts: List[str]
mode: Literal["query", "document"] = "query"
normalize: bool = True
# Keep payloads light; 0/None means no cap
max_active_dims: Optional[int] = 0
class EmbeddingRow(BaseModel):
indices: List[int]
weights: List[float]
class EmbeddingsResponse(BaseModel):
data: List[EmbeddingRow]
dim: int
info: Dict[str, Any]
# --- Similarity API ---
class SimilarityRequest(BaseModel):
queries: List[str]
documents: List[str]
normalize: bool = True
max_active_dims: Optional[int] = 0
top_k: Optional[int] = 5
class SimilarityHit(BaseModel):
doc_index: int
score: float
text: str
class SimilarityResponse(BaseModel):
results: List[List[SimilarityHit]] # one list per query
info: Dict[str, Any]
# --- Explain API ---
class TokenContribution(BaseModel):
token_id: int
token: str
query_weight: float
doc_weight: float
contribution: float
class ExplainRequest(BaseModel):
query: str
document: str
normalize: bool = True
max_active_dims: Optional[int] = 0
top_k_tokens: int = 15
class ExplainResponse(BaseModel):
score: float
top_tokens: List[TokenContribution]
info: Dict[str, Any]
# --------------------------------------------------------------------------------------
# Helpers
# --------------------------------------------------------------------------------------
def torch_sparse_batch_to_rows(t: torch.Tensor) -> List[Dict[str, Any]]:
"""Convert a 2D torch sparse tensor [batch, dim] to list of {indices, weights} per row."""
if not isinstance(t, torch.Tensor):
raise TypeError("Expected a torch.Tensor from SparseEncoder")
if not t.is_sparse:
# Dense fallback (shouldn't happen with SparseEncoder). Convert per-row.
t = t.to("cpu")
rows = []
for r in t:
nz = torch.nonzero(r, as_tuple=True)[0]
rows.append({"indices": nz.tolist(), "weights": r[nz].tolist()})
return rows
# COO expected; coalesce and split by row
t = t.coalesce() # merge duplicates
idx = t.indices() # [2, nnz]
vals = t.values() # [nnz]
batch_size = t.size(0)
rows_out: List[Dict[str, Any]] = []
row_ids = idx[0]
col_ids = idx[1]
# For each row, mask and gather its entries
for i in range(batch_size):
m = row_ids == i
if torch.count_nonzero(m) == 0:
rows_out.append({"indices": [], "weights": []})
continue
cols_i = col_ids[m].to("cpu")
vals_i = vals[m].to("cpu")
rows_out.append({"indices": cols_i.tolist(), "weights": vals_i.tolist()})
return rows_out
def top_token_contributions(q_row: Dict[str, Any], d_row: Dict[str, Any], k: int) -> List[Dict[str, Any]]:
"""Intersect query/doc indices and score tokens by product of weights."""
q_map = {int(i): float(w) for i, w in zip(q_row.get("indices", []), q_row.get("weights", []))}
contribs = []
for i, dw in zip(d_row.get("indices", []), d_row.get("weights", [])):
i = int(i)
dw = float(dw)
qw = q_map.get(i)
if qw is not None:
contribs.append((i, qw, dw, qw * dw))
contribs.sort(key=lambda t: t[3], reverse=True)
top = contribs[: max(k, 0) or 15]
out: List[Dict[str, Any]] = []
for tok_id, qw, dw, c in top:
try:
# RobBERT uses RoBERTa/BPE-style tokens (Ġ denotes a leading space)
tok = tokenizer.convert_ids_to_tokens([tok_id])[0]
pretty = tok.replace("Ġ", " ").replace("▁", " ")
except Exception:
tok = pretty = str(tok_id)
out.append({
"token_id": tok_id,
"token": pretty,
"query_weight": qw,
"doc_weight": dw,
"contribution": c,
})
return out
# --------------------------------------------------------------------------------------
# Routes
# --------------------------------------------------------------------------------------
@app.get("/")
async def root():
return {
"message": "Dutch SPLADE Embedding API",
"docs": "https://moimobrian-py-api.hf.space/docs",
"health": "https://moimobrian-py-api.hf.space/health"
}
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(
model_loaded=model is not None,
model_name=MODEL_ID,
device=DEVICE,
)
@app.post("/embeddings", response_model=EmbeddingsResponse)
async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse:
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not req.texts:
raise HTTPException(status_code=400, detail="'texts' must be a non-empty list")
prompt_name = "query" if req.mode == "query" else "document"
max_k = req.max_active_dims or None
logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode")
try:
if req.mode == "query":
embs = model.encode_query(
req.texts,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
else:
embs = model.encode_document(
req.texts,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
rows = torch_sparse_batch_to_rows(embs)
# Model card states ~50k dims; we can read the 2nd dimension from the tensor
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
return EmbeddingsResponse(
data=[EmbeddingRow(**r) for r in rows],
dim=dim,
info={
"mode": req.mode,
"normalize": req.normalize,
"max_active_dims": max_k,
"device": DEVICE,
},
)
except RuntimeError as e:
# If anything MPS-related sneaks in, hard-move to CPU and retry once
msg = str(e)
if "MPS" in msg or "to_sparse" in msg:
logger.warning("Encountered MPS/sparse op issue; retrying on CPU.")
try:
model.to("cpu")
if req.mode == "query":
embs = model.encode_query(
req.texts,
convert_to_tensor=True,
device="cpu",
normalize=req.normalize,
max_active_dims=max_k,
)
else:
embs = model.encode_document(
req.texts,
convert_to_tensor=True,
device="cpu",
normalize=req.normalize,
max_active_dims=max_k,
)
rows = torch_sparse_batch_to_rows(embs)
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
return EmbeddingsResponse(
data=[EmbeddingRow(**r) for r in rows],
dim=dim,
info={
"mode": req.mode,
"normalize": req.normalize,
"max_active_dims": max_k,
"device": "cpu",
"retry": True,
},
)
except Exception:
logger.exception("CPU retry failed")
raise HTTPException(status_code=500, detail=msg)
# Unknown runtime error
logger.exception("Error generating embeddings")
raise HTTPException(status_code=500, detail=msg)
except Exception as e:
logger.exception("Error generating embeddings")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/similarity", response_model=SimilarityResponse)
async def similarity(req: SimilarityRequest) -> SimilarityResponse:
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not req.queries:
raise HTTPException(status_code=400, detail="'queries' must be a non-empty list")
if not req.documents:
raise HTTPException(status_code=400, detail="'documents' must be a non-empty list")
max_k = req.max_active_dims or None
try:
q = model.encode_query(
req.queries,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
d = model.encode_document(
req.documents,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs]
results: List[List[SimilarityHit]] = []
k = min(req.top_k or 5, len(req.documents))
for i in range(scores.size(0)):
vals, idxs = torch.topk(scores[i], k=k)
q_hits: List[SimilarityHit] = []
for v, j in zip(vals.tolist(), idxs.tolist()):
q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j]))
results.append(q_hits)
return SimilarityResponse(
results=results,
info={
"normalize": req.normalize,
"max_active_dims": max_k,
"device": DEVICE,
},
)
except Exception as e:
logger.exception("Error computing similarity")
raise HTTPException(status_code=500, detail=str(e))
# --------------------------------------------------------------------------------------
# Routes
# --------------------------------------------------------------------------------------
@app.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
return HealthResponse(
model_loaded=model is not None,
model_name=MODEL_ID,
device=DEVICE,
)
@app.post("/embeddings", response_model=EmbeddingsResponse)
async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse:
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not req.texts:
raise HTTPException(status_code=400, detail="'texts' must be a non-empty list")
prompt_name = "query" if req.mode == "query" else "document"
max_k = req.max_active_dims or None
logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode")
try:
if req.mode == "query":
embs = model.encode_query(
req.texts,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
else:
embs = model.encode_document(
req.texts,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
rows = torch_sparse_batch_to_rows(embs)
# Model card states ~50k dims; we can read the 2nd dimension from the tensor
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
return EmbeddingsResponse(
data=[EmbeddingRow(**r) for r in rows],
dim=dim,
info={
"mode": req.mode,
"normalize": req.normalize,
"max_active_dims": max_k,
"device": DEVICE,
},
)
except RuntimeError as e:
# If anything MPS-related sneaks in, hard-move to CPU and retry once
msg = str(e)
if "MPS" in msg or "to_sparse" in msg:
logger.warning("Encountered MPS/sparse op issue; retrying on CPU.")
try:
model.to("cpu")
if req.mode == "query":
embs = model.encode_query(
req.texts,
convert_to_tensor=True,
device="cpu",
normalize=req.normalize,
max_active_dims=max_k,
)
else:
embs = model.encode_document(
req.texts,
convert_to_tensor=True,
device="cpu",
normalize=req.normalize,
max_active_dims=max_k,
)
rows = torch_sparse_batch_to_rows(embs)
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
return EmbeddingsResponse(
data=[EmbeddingRow(**r) for r in rows],
dim=dim,
info={
"mode": req.mode,
"normalize": req.normalize,
"max_active_dims": max_k,
"device": "cpu",
"retry": True,
},
)
except Exception:
logger.exception("CPU retry failed")
raise HTTPException(status_code=500, detail=msg)
# Unknown runtime error
logger.exception("Error generating embeddings")
raise HTTPException(status_code=500, detail=msg)
except Exception as e:
logger.exception("Error generating embeddings")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/similarity", response_model=SimilarityResponse)
async def similarity(req: SimilarityRequest) -> SimilarityResponse:
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not req.queries:
raise HTTPException(status_code=400, detail="'queries' must be a non-empty list")
if not req.documents:
raise HTTPException(status_code=400, detail="'documents' must be a non-empty list")
max_k = req.max_active_dims or None
try:
q = model.encode_query(
req.queries,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
d = model.encode_document(
req.documents,
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs]
results: List[List[SimilarityHit]] = []
k = min(req.top_k or 5, len(req.documents))
for i in range(scores.size(0)):
vals, idxs = torch.topk(scores[i], k=k)
q_hits: List[SimilarityHit] = []
for v, j in zip(vals.tolist(), idxs.tolist()):
q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j]))
results.append(q_hits)
return SimilarityResponse(
results=results,
info={
"normalize": req.normalize,
"max_active_dims": max_k,
"device": DEVICE,
},
)
except Exception as e:
logger.exception("Error computing similarity")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/explain", response_model=ExplainResponse)
async def explain(req: ExplainRequest) -> ExplainResponse:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model/tokenizer not loaded")
max_k = req.max_active_dims or None
try:
q = model.encode_query(
[req.query],
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
d = model.encode_document(
[req.document],
convert_to_tensor=True,
device=DEVICE,
normalize=req.normalize,
max_active_dims=max_k,
)
score = float(model.similarity(q, d)[0, 0].item())
q_row = torch_sparse_batch_to_rows(q)[0]
d_row = torch_sparse_batch_to_rows(d)[0]
tokens = top_token_contributions(q_row, d_row, req.top_k_tokens)
return ExplainResponse(
score=score,
top_tokens=[TokenContribution(**t) for t in tokens],
info={
"normalize": req.normalize,
"max_active_dims": max_k,
"device": DEVICE,
},
)
except Exception as e:
logger.exception("Error explaining match")
raise HTTPException(status_code=500, detail=str(e))
# --------------------------------------------------------------------------------------
# Local dev runner
# --------------------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info",
)