Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from contextlib import asynccontextmanager | |
from typing import List, Optional, Literal, Dict, Any | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, ConfigDict | |
from sentence_transformers import SparseEncoder | |
from transformers import AutoTokenizer | |
# -------------------------------------------------------------------------------------- | |
# Logging | |
# -------------------------------------------------------------------------------------- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("main") | |
# -------------------------------------------------------------------------------------- | |
# Device selection — intentionally NEVER choose MPS for SPLADE due to sparse-op gaps | |
# -------------------------------------------------------------------------------------- | |
def choose_device() -> str: | |
if torch.cuda.is_available(): | |
return "cuda" | |
# Avoid MPS for SPLADE (missing sparse ops). Default to CPU instead. | |
return "cpu" | |
DEVICE = choose_device() | |
logger.info(f"Selected device: {DEVICE}") | |
# -------------------------------------------------------------------------------------- | |
# Model loading | |
# -------------------------------------------------------------------------------------- | |
MODEL_ID = "sparse-encoder/splade-robbert-dutch-base-v1" | |
def load_sparse_encoder(model_id: str, device: str) -> SparseEncoder: | |
"""Load SparseEncoder. Prefer safetensors when available, but fall back to .bin. | |
Torch >= 2.6 is required by Transformers to load .bin safely. | |
""" | |
# Do NOT force safetensors globally; some repos only publish .bin | |
os.environ.pop("TRANSFORMERS_USE_SAFETENSORS", None) | |
try: | |
logger.info(f"Loading Dutch SPLADE model on {device}...") | |
m = SparseEncoder(model_id, device=device, model_kwargs={"use_safetensors": True}) | |
return m | |
except OSError as e: | |
msg = str(e) | |
if "does not appear to have a file named model.safetensors" in msg: | |
logger.info("No safetensors in repo; retrying with .bin weights.") | |
return SparseEncoder(model_id, device=device) | |
raise | |
model: Optional[SparseEncoder] = None | |
# Tokenizer for mapping vocab ids -> readable tokens in explanations | |
tokenizer: Optional[AutoTokenizer] = None | |
async def lifespan(app: FastAPI): | |
global model, tokenizer | |
try: | |
model = load_sparse_encoder(MODEL_ID, DEVICE) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
logger.info("Model & tokenizer loaded.") | |
yield | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
raise | |
finally: | |
# Allow GC to clean up if server stops | |
pass | |
app = FastAPI(title="Sparse Embedding API", lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# -------------------------------------------------------------------------------------- | |
# Schemas | |
# -------------------------------------------------------------------------------------- | |
class HealthResponse(BaseModel): | |
# Pydantic v2 warns about names starting with model_; allow them explicitly | |
model_config = ConfigDict(protected_namespaces=()) | |
model_loaded: bool | |
model_name: str | |
device: str | |
class EmbeddingsRequest(BaseModel): | |
texts: List[str] | |
mode: Literal["query", "document"] = "query" | |
normalize: bool = True | |
# Keep payloads light; 0/None means no cap | |
max_active_dims: Optional[int] = 0 | |
class EmbeddingRow(BaseModel): | |
indices: List[int] | |
weights: List[float] | |
class EmbeddingsResponse(BaseModel): | |
data: List[EmbeddingRow] | |
dim: int | |
info: Dict[str, Any] | |
# --- Similarity API --- | |
class SimilarityRequest(BaseModel): | |
queries: List[str] | |
documents: List[str] | |
normalize: bool = True | |
max_active_dims: Optional[int] = 0 | |
top_k: Optional[int] = 5 | |
class SimilarityHit(BaseModel): | |
doc_index: int | |
score: float | |
text: str | |
class SimilarityResponse(BaseModel): | |
results: List[List[SimilarityHit]] # one list per query | |
info: Dict[str, Any] | |
# --- Explain API --- | |
class TokenContribution(BaseModel): | |
token_id: int | |
token: str | |
query_weight: float | |
doc_weight: float | |
contribution: float | |
class ExplainRequest(BaseModel): | |
query: str | |
document: str | |
normalize: bool = True | |
max_active_dims: Optional[int] = 0 | |
top_k_tokens: int = 15 | |
class ExplainResponse(BaseModel): | |
score: float | |
top_tokens: List[TokenContribution] | |
info: Dict[str, Any] | |
# -------------------------------------------------------------------------------------- | |
# Helpers | |
# -------------------------------------------------------------------------------------- | |
def torch_sparse_batch_to_rows(t: torch.Tensor) -> List[Dict[str, Any]]: | |
"""Convert a 2D torch sparse tensor [batch, dim] to list of {indices, weights} per row.""" | |
if not isinstance(t, torch.Tensor): | |
raise TypeError("Expected a torch.Tensor from SparseEncoder") | |
if not t.is_sparse: | |
# Dense fallback (shouldn't happen with SparseEncoder). Convert per-row. | |
t = t.to("cpu") | |
rows = [] | |
for r in t: | |
nz = torch.nonzero(r, as_tuple=True)[0] | |
rows.append({"indices": nz.tolist(), "weights": r[nz].tolist()}) | |
return rows | |
# COO expected; coalesce and split by row | |
t = t.coalesce() # merge duplicates | |
idx = t.indices() # [2, nnz] | |
vals = t.values() # [nnz] | |
batch_size = t.size(0) | |
rows_out: List[Dict[str, Any]] = [] | |
row_ids = idx[0] | |
col_ids = idx[1] | |
# For each row, mask and gather its entries | |
for i in range(batch_size): | |
m = row_ids == i | |
if torch.count_nonzero(m) == 0: | |
rows_out.append({"indices": [], "weights": []}) | |
continue | |
cols_i = col_ids[m].to("cpu") | |
vals_i = vals[m].to("cpu") | |
rows_out.append({"indices": cols_i.tolist(), "weights": vals_i.tolist()}) | |
return rows_out | |
def top_token_contributions(q_row: Dict[str, Any], d_row: Dict[str, Any], k: int) -> List[Dict[str, Any]]: | |
"""Intersect query/doc indices and score tokens by product of weights.""" | |
q_map = {int(i): float(w) for i, w in zip(q_row.get("indices", []), q_row.get("weights", []))} | |
contribs = [] | |
for i, dw in zip(d_row.get("indices", []), d_row.get("weights", [])): | |
i = int(i) | |
dw = float(dw) | |
qw = q_map.get(i) | |
if qw is not None: | |
contribs.append((i, qw, dw, qw * dw)) | |
contribs.sort(key=lambda t: t[3], reverse=True) | |
top = contribs[: max(k, 0) or 15] | |
out: List[Dict[str, Any]] = [] | |
for tok_id, qw, dw, c in top: | |
try: | |
# RobBERT uses RoBERTa/BPE-style tokens (Ġ denotes a leading space) | |
tok = tokenizer.convert_ids_to_tokens([tok_id])[0] | |
pretty = tok.replace("Ġ", " ").replace("▁", " ") | |
except Exception: | |
tok = pretty = str(tok_id) | |
out.append({ | |
"token_id": tok_id, | |
"token": pretty, | |
"query_weight": qw, | |
"doc_weight": dw, | |
"contribution": c, | |
}) | |
return out | |
# -------------------------------------------------------------------------------------- | |
# Routes | |
# -------------------------------------------------------------------------------------- | |
async def health() -> HealthResponse: | |
return HealthResponse( | |
model_loaded=model is not None, | |
model_name=MODEL_ID, | |
device=DEVICE, | |
) | |
async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse: | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not req.texts: | |
raise HTTPException(status_code=400, detail="'texts' must be a non-empty list") | |
prompt_name = "query" if req.mode == "query" else "document" | |
max_k = req.max_active_dims or None | |
logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode") | |
try: | |
if req.mode == "query": | |
embs = model.encode_query( | |
req.texts, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
else: | |
embs = model.encode_document( | |
req.texts, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
rows = torch_sparse_batch_to_rows(embs) | |
# Model card states ~50k dims; we can read the 2nd dimension from the tensor | |
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
return EmbeddingsResponse( | |
data=[EmbeddingRow(**r) for r in rows], | |
dim=dim, | |
info={ | |
"mode": req.mode, | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": DEVICE, | |
}, | |
) | |
except RuntimeError as e: | |
# If anything MPS-related sneaks in, hard-move to CPU and retry once | |
msg = str(e) | |
if "MPS" in msg or "to_sparse" in msg: | |
logger.warning("Encountered MPS/sparse op issue; retrying on CPU.") | |
try: | |
model.to("cpu") | |
if req.mode == "query": | |
embs = model.encode_query( | |
req.texts, | |
convert_to_tensor=True, | |
device="cpu", | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
else: | |
embs = model.encode_document( | |
req.texts, | |
convert_to_tensor=True, | |
device="cpu", | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
rows = torch_sparse_batch_to_rows(embs) | |
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
return EmbeddingsResponse( | |
data=[EmbeddingRow(**r) for r in rows], | |
dim=dim, | |
info={ | |
"mode": req.mode, | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": "cpu", | |
"retry": True, | |
}, | |
) | |
except Exception: | |
logger.exception("CPU retry failed") | |
raise HTTPException(status_code=500, detail=msg) | |
# Unknown runtime error | |
logger.exception("Error generating embeddings") | |
raise HTTPException(status_code=500, detail=msg) | |
except Exception as e: | |
logger.exception("Error generating embeddings") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def similarity(req: SimilarityRequest) -> SimilarityResponse: | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not req.queries: | |
raise HTTPException(status_code=400, detail="'queries' must be a non-empty list") | |
if not req.documents: | |
raise HTTPException(status_code=400, detail="'documents' must be a non-empty list") | |
max_k = req.max_active_dims or None | |
try: | |
q = model.encode_query( | |
req.queries, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
d = model.encode_document( | |
req.documents, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs] | |
results: List[List[SimilarityHit]] = [] | |
k = min(req.top_k or 5, len(req.documents)) | |
for i in range(scores.size(0)): | |
vals, idxs = torch.topk(scores[i], k=k) | |
q_hits: List[SimilarityHit] = [] | |
for v, j in zip(vals.tolist(), idxs.tolist()): | |
q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j])) | |
results.append(q_hits) | |
return SimilarityResponse( | |
results=results, | |
info={ | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": DEVICE, | |
}, | |
) | |
except Exception as e: | |
logger.exception("Error computing similarity") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# -------------------------------------------------------------------------------------- | |
# Routes | |
# -------------------------------------------------------------------------------------- | |
async def health() -> HealthResponse: | |
return HealthResponse( | |
model_loaded=model is not None, | |
model_name=MODEL_ID, | |
device=DEVICE, | |
) | |
async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse: | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not req.texts: | |
raise HTTPException(status_code=400, detail="'texts' must be a non-empty list") | |
prompt_name = "query" if req.mode == "query" else "document" | |
max_k = req.max_active_dims or None | |
logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode") | |
try: | |
if req.mode == "query": | |
embs = model.encode_query( | |
req.texts, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
else: | |
embs = model.encode_document( | |
req.texts, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
rows = torch_sparse_batch_to_rows(embs) | |
# Model card states ~50k dims; we can read the 2nd dimension from the tensor | |
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
return EmbeddingsResponse( | |
data=[EmbeddingRow(**r) for r in rows], | |
dim=dim, | |
info={ | |
"mode": req.mode, | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": DEVICE, | |
}, | |
) | |
except RuntimeError as e: | |
# If anything MPS-related sneaks in, hard-move to CPU and retry once | |
msg = str(e) | |
if "MPS" in msg or "to_sparse" in msg: | |
logger.warning("Encountered MPS/sparse op issue; retrying on CPU.") | |
try: | |
model.to("cpu") | |
if req.mode == "query": | |
embs = model.encode_query( | |
req.texts, | |
convert_to_tensor=True, | |
device="cpu", | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
else: | |
embs = model.encode_document( | |
req.texts, | |
convert_to_tensor=True, | |
device="cpu", | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
rows = torch_sparse_batch_to_rows(embs) | |
dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
return EmbeddingsResponse( | |
data=[EmbeddingRow(**r) for r in rows], | |
dim=dim, | |
info={ | |
"mode": req.mode, | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": "cpu", | |
"retry": True, | |
}, | |
) | |
except Exception: | |
logger.exception("CPU retry failed") | |
raise HTTPException(status_code=500, detail=msg) | |
# Unknown runtime error | |
logger.exception("Error generating embeddings") | |
raise HTTPException(status_code=500, detail=msg) | |
except Exception as e: | |
logger.exception("Error generating embeddings") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def similarity(req: SimilarityRequest) -> SimilarityResponse: | |
if model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not req.queries: | |
raise HTTPException(status_code=400, detail="'queries' must be a non-empty list") | |
if not req.documents: | |
raise HTTPException(status_code=400, detail="'documents' must be a non-empty list") | |
max_k = req.max_active_dims or None | |
try: | |
q = model.encode_query( | |
req.queries, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
d = model.encode_document( | |
req.documents, | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs] | |
results: List[List[SimilarityHit]] = [] | |
k = min(req.top_k or 5, len(req.documents)) | |
for i in range(scores.size(0)): | |
vals, idxs = torch.topk(scores[i], k=k) | |
q_hits: List[SimilarityHit] = [] | |
for v, j in zip(vals.tolist(), idxs.tolist()): | |
q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j])) | |
results.append(q_hits) | |
return SimilarityResponse( | |
results=results, | |
info={ | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": DEVICE, | |
}, | |
) | |
except Exception as e: | |
logger.exception("Error computing similarity") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def explain(req: ExplainRequest) -> ExplainResponse: | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=503, detail="Model/tokenizer not loaded") | |
max_k = req.max_active_dims or None | |
try: | |
q = model.encode_query( | |
[req.query], | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
d = model.encode_document( | |
[req.document], | |
convert_to_tensor=True, | |
device=DEVICE, | |
normalize=req.normalize, | |
max_active_dims=max_k, | |
) | |
score = float(model.similarity(q, d)[0, 0].item()) | |
q_row = torch_sparse_batch_to_rows(q)[0] | |
d_row = torch_sparse_batch_to_rows(d)[0] | |
tokens = top_token_contributions(q_row, d_row, req.top_k_tokens) | |
return ExplainResponse( | |
score=score, | |
top_tokens=[TokenContribution(**t) for t in tokens], | |
info={ | |
"normalize": req.normalize, | |
"max_active_dims": max_k, | |
"device": DEVICE, | |
}, | |
) | |
except Exception as e: | |
logger.exception("Error explaining match") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# -------------------------------------------------------------------------------------- | |
# Local dev runner | |
# -------------------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
"main:app", | |
host="0.0.0.0", | |
port=8000, | |
reload=True, | |
log_level="info", | |
) | |