import os import logging from contextlib import asynccontextmanager from typing import List, Optional, Literal, Dict, Any import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, ConfigDict from sentence_transformers import SparseEncoder from transformers import AutoTokenizer # -------------------------------------------------------------------------------------- # Logging # -------------------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger("main") # -------------------------------------------------------------------------------------- # Device selection — intentionally NEVER choose MPS for SPLADE due to sparse-op gaps # -------------------------------------------------------------------------------------- def choose_device() -> str: if torch.cuda.is_available(): return "cuda" # Avoid MPS for SPLADE (missing sparse ops). Default to CPU instead. return "cpu" DEVICE = choose_device() logger.info(f"Selected device: {DEVICE}") # -------------------------------------------------------------------------------------- # Model loading # -------------------------------------------------------------------------------------- MODEL_ID = "sparse-encoder/splade-robbert-dutch-base-v1" def load_sparse_encoder(model_id: str, device: str) -> SparseEncoder: """Load SparseEncoder. Prefer safetensors when available, but fall back to .bin. Torch >= 2.6 is required by Transformers to load .bin safely. """ # Do NOT force safetensors globally; some repos only publish .bin os.environ.pop("TRANSFORMERS_USE_SAFETENSORS", None) try: logger.info(f"Loading Dutch SPLADE model on {device}...") m = SparseEncoder(model_id, device=device, model_kwargs={"use_safetensors": True}) return m except OSError as e: msg = str(e) if "does not appear to have a file named model.safetensors" in msg: logger.info("No safetensors in repo; retrying with .bin weights.") return SparseEncoder(model_id, device=device) raise model: Optional[SparseEncoder] = None # Tokenizer for mapping vocab ids -> readable tokens in explanations tokenizer: Optional[AutoTokenizer] = None @asynccontextmanager async def lifespan(app: FastAPI): global model, tokenizer try: model = load_sparse_encoder(MODEL_ID, DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) logger.info("Model & tokenizer loaded.") yield except Exception as e: logger.error(f"Failed to load model: {e}") raise finally: # Allow GC to clean up if server stops pass app = FastAPI(title="Sparse Embedding API", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # -------------------------------------------------------------------------------------- # Schemas # -------------------------------------------------------------------------------------- class HealthResponse(BaseModel): # Pydantic v2 warns about names starting with model_; allow them explicitly model_config = ConfigDict(protected_namespaces=()) model_loaded: bool model_name: str device: str class EmbeddingsRequest(BaseModel): texts: List[str] mode: Literal["query", "document"] = "query" normalize: bool = True # Keep payloads light; 0/None means no cap max_active_dims: Optional[int] = 0 class EmbeddingRow(BaseModel): indices: List[int] weights: List[float] class EmbeddingsResponse(BaseModel): data: List[EmbeddingRow] dim: int info: Dict[str, Any] # --- Similarity API --- class SimilarityRequest(BaseModel): queries: List[str] documents: List[str] normalize: bool = True max_active_dims: Optional[int] = 0 top_k: Optional[int] = 5 class SimilarityHit(BaseModel): doc_index: int score: float text: str class SimilarityResponse(BaseModel): results: List[List[SimilarityHit]] # one list per query info: Dict[str, Any] # --- Explain API --- class TokenContribution(BaseModel): token_id: int token: str query_weight: float doc_weight: float contribution: float class ExplainRequest(BaseModel): query: str document: str normalize: bool = True max_active_dims: Optional[int] = 0 top_k_tokens: int = 15 class ExplainResponse(BaseModel): score: float top_tokens: List[TokenContribution] info: Dict[str, Any] # -------------------------------------------------------------------------------------- # Helpers # -------------------------------------------------------------------------------------- def torch_sparse_batch_to_rows(t: torch.Tensor) -> List[Dict[str, Any]]: """Convert a 2D torch sparse tensor [batch, dim] to list of {indices, weights} per row.""" if not isinstance(t, torch.Tensor): raise TypeError("Expected a torch.Tensor from SparseEncoder") if not t.is_sparse: # Dense fallback (shouldn't happen with SparseEncoder). Convert per-row. t = t.to("cpu") rows = [] for r in t: nz = torch.nonzero(r, as_tuple=True)[0] rows.append({"indices": nz.tolist(), "weights": r[nz].tolist()}) return rows # COO expected; coalesce and split by row t = t.coalesce() # merge duplicates idx = t.indices() # [2, nnz] vals = t.values() # [nnz] batch_size = t.size(0) rows_out: List[Dict[str, Any]] = [] row_ids = idx[0] col_ids = idx[1] # For each row, mask and gather its entries for i in range(batch_size): m = row_ids == i if torch.count_nonzero(m) == 0: rows_out.append({"indices": [], "weights": []}) continue cols_i = col_ids[m].to("cpu") vals_i = vals[m].to("cpu") rows_out.append({"indices": cols_i.tolist(), "weights": vals_i.tolist()}) return rows_out def top_token_contributions(q_row: Dict[str, Any], d_row: Dict[str, Any], k: int) -> List[Dict[str, Any]]: """Intersect query/doc indices and score tokens by product of weights.""" q_map = {int(i): float(w) for i, w in zip(q_row.get("indices", []), q_row.get("weights", []))} contribs = [] for i, dw in zip(d_row.get("indices", []), d_row.get("weights", [])): i = int(i) dw = float(dw) qw = q_map.get(i) if qw is not None: contribs.append((i, qw, dw, qw * dw)) contribs.sort(key=lambda t: t[3], reverse=True) top = contribs[: max(k, 0) or 15] out: List[Dict[str, Any]] = [] for tok_id, qw, dw, c in top: try: # RobBERT uses RoBERTa/BPE-style tokens (Ġ denotes a leading space) tok = tokenizer.convert_ids_to_tokens([tok_id])[0] pretty = tok.replace("Ġ", " ").replace("▁", " ") except Exception: tok = pretty = str(tok_id) out.append({ "token_id": tok_id, "token": pretty, "query_weight": qw, "doc_weight": dw, "contribution": c, }) return out # -------------------------------------------------------------------------------------- # Routes # -------------------------------------------------------------------------------------- @app.get("/") async def root(): return { "message": "Dutch SPLADE Embedding API", "docs": "https://moimobrian-py-api.hf.space/docs", "health": "https://moimobrian-py-api.hf.space/health" } @app.get("/health", response_model=HealthResponse) async def health() -> HealthResponse: return HealthResponse( model_loaded=model is not None, model_name=MODEL_ID, device=DEVICE, ) @app.post("/embeddings", response_model=EmbeddingsResponse) async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse: if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if not req.texts: raise HTTPException(status_code=400, detail="'texts' must be a non-empty list") prompt_name = "query" if req.mode == "query" else "document" max_k = req.max_active_dims or None logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode") try: if req.mode == "query": embs = model.encode_query( req.texts, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) else: embs = model.encode_document( req.texts, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) rows = torch_sparse_batch_to_rows(embs) # Model card states ~50k dims; we can read the 2nd dimension from the tensor dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 return EmbeddingsResponse( data=[EmbeddingRow(**r) for r in rows], dim=dim, info={ "mode": req.mode, "normalize": req.normalize, "max_active_dims": max_k, "device": DEVICE, }, ) except RuntimeError as e: # If anything MPS-related sneaks in, hard-move to CPU and retry once msg = str(e) if "MPS" in msg or "to_sparse" in msg: logger.warning("Encountered MPS/sparse op issue; retrying on CPU.") try: model.to("cpu") if req.mode == "query": embs = model.encode_query( req.texts, convert_to_tensor=True, device="cpu", normalize=req.normalize, max_active_dims=max_k, ) else: embs = model.encode_document( req.texts, convert_to_tensor=True, device="cpu", normalize=req.normalize, max_active_dims=max_k, ) rows = torch_sparse_batch_to_rows(embs) dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 return EmbeddingsResponse( data=[EmbeddingRow(**r) for r in rows], dim=dim, info={ "mode": req.mode, "normalize": req.normalize, "max_active_dims": max_k, "device": "cpu", "retry": True, }, ) except Exception: logger.exception("CPU retry failed") raise HTTPException(status_code=500, detail=msg) # Unknown runtime error logger.exception("Error generating embeddings") raise HTTPException(status_code=500, detail=msg) except Exception as e: logger.exception("Error generating embeddings") raise HTTPException(status_code=500, detail=str(e)) @app.post("/similarity", response_model=SimilarityResponse) async def similarity(req: SimilarityRequest) -> SimilarityResponse: if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if not req.queries: raise HTTPException(status_code=400, detail="'queries' must be a non-empty list") if not req.documents: raise HTTPException(status_code=400, detail="'documents' must be a non-empty list") max_k = req.max_active_dims or None try: q = model.encode_query( req.queries, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) d = model.encode_document( req.documents, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs] results: List[List[SimilarityHit]] = [] k = min(req.top_k or 5, len(req.documents)) for i in range(scores.size(0)): vals, idxs = torch.topk(scores[i], k=k) q_hits: List[SimilarityHit] = [] for v, j in zip(vals.tolist(), idxs.tolist()): q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j])) results.append(q_hits) return SimilarityResponse( results=results, info={ "normalize": req.normalize, "max_active_dims": max_k, "device": DEVICE, }, ) except Exception as e: logger.exception("Error computing similarity") raise HTTPException(status_code=500, detail=str(e)) # -------------------------------------------------------------------------------------- # Routes # -------------------------------------------------------------------------------------- @app.get("/health", response_model=HealthResponse) async def health() -> HealthResponse: return HealthResponse( model_loaded=model is not None, model_name=MODEL_ID, device=DEVICE, ) @app.post("/embeddings", response_model=EmbeddingsResponse) async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse: if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if not req.texts: raise HTTPException(status_code=400, detail="'texts' must be a non-empty list") prompt_name = "query" if req.mode == "query" else "document" max_k = req.max_active_dims or None logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode") try: if req.mode == "query": embs = model.encode_query( req.texts, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) else: embs = model.encode_document( req.texts, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) rows = torch_sparse_batch_to_rows(embs) # Model card states ~50k dims; we can read the 2nd dimension from the tensor dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 return EmbeddingsResponse( data=[EmbeddingRow(**r) for r in rows], dim=dim, info={ "mode": req.mode, "normalize": req.normalize, "max_active_dims": max_k, "device": DEVICE, }, ) except RuntimeError as e: # If anything MPS-related sneaks in, hard-move to CPU and retry once msg = str(e) if "MPS" in msg or "to_sparse" in msg: logger.warning("Encountered MPS/sparse op issue; retrying on CPU.") try: model.to("cpu") if req.mode == "query": embs = model.encode_query( req.texts, convert_to_tensor=True, device="cpu", normalize=req.normalize, max_active_dims=max_k, ) else: embs = model.encode_document( req.texts, convert_to_tensor=True, device="cpu", normalize=req.normalize, max_active_dims=max_k, ) rows = torch_sparse_batch_to_rows(embs) dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 return EmbeddingsResponse( data=[EmbeddingRow(**r) for r in rows], dim=dim, info={ "mode": req.mode, "normalize": req.normalize, "max_active_dims": max_k, "device": "cpu", "retry": True, }, ) except Exception: logger.exception("CPU retry failed") raise HTTPException(status_code=500, detail=msg) # Unknown runtime error logger.exception("Error generating embeddings") raise HTTPException(status_code=500, detail=msg) except Exception as e: logger.exception("Error generating embeddings") raise HTTPException(status_code=500, detail=str(e)) @app.post("/similarity", response_model=SimilarityResponse) async def similarity(req: SimilarityRequest) -> SimilarityResponse: if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if not req.queries: raise HTTPException(status_code=400, detail="'queries' must be a non-empty list") if not req.documents: raise HTTPException(status_code=400, detail="'documents' must be a non-empty list") max_k = req.max_active_dims or None try: q = model.encode_query( req.queries, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) d = model.encode_document( req.documents, convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs] results: List[List[SimilarityHit]] = [] k = min(req.top_k or 5, len(req.documents)) for i in range(scores.size(0)): vals, idxs = torch.topk(scores[i], k=k) q_hits: List[SimilarityHit] = [] for v, j in zip(vals.tolist(), idxs.tolist()): q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j])) results.append(q_hits) return SimilarityResponse( results=results, info={ "normalize": req.normalize, "max_active_dims": max_k, "device": DEVICE, }, ) except Exception as e: logger.exception("Error computing similarity") raise HTTPException(status_code=500, detail=str(e)) @app.post("/explain", response_model=ExplainResponse) async def explain(req: ExplainRequest) -> ExplainResponse: if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model/tokenizer not loaded") max_k = req.max_active_dims or None try: q = model.encode_query( [req.query], convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) d = model.encode_document( [req.document], convert_to_tensor=True, device=DEVICE, normalize=req.normalize, max_active_dims=max_k, ) score = float(model.similarity(q, d)[0, 0].item()) q_row = torch_sparse_batch_to_rows(q)[0] d_row = torch_sparse_batch_to_rows(d)[0] tokens = top_token_contributions(q_row, d_row, req.top_k_tokens) return ExplainResponse( score=score, top_tokens=[TokenContribution(**t) for t in tokens], info={ "normalize": req.normalize, "max_active_dims": max_k, "device": DEVICE, }, ) except Exception as e: logger.exception("Error explaining match") raise HTTPException(status_code=500, detail=str(e)) # -------------------------------------------------------------------------------------- # Local dev runner # -------------------------------------------------------------------------------------- if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", port=8000, reload=True, log_level="info", )