moimobrian commited on
Commit
bbcbb55
·
1 Parent(s): d3910a8

Initial deployment

Browse files
Files changed (3) hide show
  1. Dockerfile +17 -0
  2. main.py +597 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt
10
+
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY main.py
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+ from typing import List, Optional, Literal, Dict, Any
5
+
6
+ import torch
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, ConfigDict
10
+ from sentence_transformers import SparseEncoder
11
+ from transformers import AutoTokenizer
12
+
13
+ # --------------------------------------------------------------------------------------
14
+ # Logging
15
+ # --------------------------------------------------------------------------------------
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger("main")
18
+
19
+ # --------------------------------------------------------------------------------------
20
+ # Device selection — intentionally NEVER choose MPS for SPLADE due to sparse-op gaps
21
+ # --------------------------------------------------------------------------------------
22
+
23
+ def choose_device() -> str:
24
+ if torch.cuda.is_available():
25
+ return "cuda"
26
+ # Avoid MPS for SPLADE (missing sparse ops). Default to CPU instead.
27
+ return "cpu"
28
+
29
+ DEVICE = choose_device()
30
+ logger.info(f"Selected device: {DEVICE}")
31
+
32
+ # --------------------------------------------------------------------------------------
33
+ # Model loading
34
+ # --------------------------------------------------------------------------------------
35
+ MODEL_ID = "sparse-encoder/splade-robbert-dutch-base-v1"
36
+
37
+
38
+ def load_sparse_encoder(model_id: str, device: str) -> SparseEncoder:
39
+ """Load SparseEncoder. Prefer safetensors when available, but fall back to .bin.
40
+ Torch >= 2.6 is required by Transformers to load .bin safely.
41
+ """
42
+ # Do NOT force safetensors globally; some repos only publish .bin
43
+ os.environ.pop("TRANSFORMERS_USE_SAFETENSORS", None)
44
+ try:
45
+ logger.info(f"Loading Dutch SPLADE model on {device}...")
46
+ m = SparseEncoder(model_id, device=device, model_kwargs={"use_safetensors": True})
47
+ return m
48
+ except OSError as e:
49
+ msg = str(e)
50
+ if "does not appear to have a file named model.safetensors" in msg:
51
+ logger.info("No safetensors in repo; retrying with .bin weights.")
52
+ return SparseEncoder(model_id, device=device)
53
+ raise
54
+
55
+
56
+ model: Optional[SparseEncoder] = None
57
+ # Tokenizer for mapping vocab ids -> readable tokens in explanations
58
+ tokenizer: Optional[AutoTokenizer] = None
59
+
60
+
61
+ @asynccontextmanager
62
+ async def lifespan(app: FastAPI):
63
+ global model, tokenizer
64
+ try:
65
+ model = load_sparse_encoder(MODEL_ID, DEVICE)
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
67
+ logger.info("Model & tokenizer loaded.")
68
+ yield
69
+ except Exception as e:
70
+ logger.error(f"Failed to load model: {e}")
71
+ raise
72
+ finally:
73
+ # Allow GC to clean up if server stops
74
+ pass
75
+
76
+
77
+ app = FastAPI(title="Sparse Embedding API", lifespan=lifespan)
78
+ app.add_middleware(
79
+ CORSMiddleware,
80
+ allow_origins=["*"],
81
+ allow_credentials=True,
82
+ allow_methods=["*"],
83
+ allow_headers=["*"],
84
+ )
85
+
86
+ # --------------------------------------------------------------------------------------
87
+ # Schemas
88
+ # --------------------------------------------------------------------------------------
89
+
90
+
91
+ class HealthResponse(BaseModel):
92
+ # Pydantic v2 warns about names starting with model_; allow them explicitly
93
+ model_config = ConfigDict(protected_namespaces=())
94
+
95
+ model_loaded: bool
96
+ model_name: str
97
+ device: str
98
+
99
+
100
+ class EmbeddingsRequest(BaseModel):
101
+ texts: List[str]
102
+ mode: Literal["query", "document"] = "query"
103
+ normalize: bool = True
104
+ # Keep payloads light; 0/None means no cap
105
+ max_active_dims: Optional[int] = 0
106
+
107
+
108
+ class EmbeddingRow(BaseModel):
109
+ indices: List[int]
110
+ weights: List[float]
111
+
112
+
113
+ class EmbeddingsResponse(BaseModel):
114
+ data: List[EmbeddingRow]
115
+ dim: int
116
+ info: Dict[str, Any]
117
+
118
+
119
+ # --- Similarity API ---
120
+ class SimilarityRequest(BaseModel):
121
+ queries: List[str]
122
+ documents: List[str]
123
+ normalize: bool = True
124
+ max_active_dims: Optional[int] = 0
125
+ top_k: Optional[int] = 5
126
+
127
+
128
+ class SimilarityHit(BaseModel):
129
+ doc_index: int
130
+ score: float
131
+ text: str
132
+
133
+
134
+ class SimilarityResponse(BaseModel):
135
+ results: List[List[SimilarityHit]] # one list per query
136
+ info: Dict[str, Any]
137
+
138
+
139
+ # --- Explain API ---
140
+ class TokenContribution(BaseModel):
141
+ token_id: int
142
+ token: str
143
+ query_weight: float
144
+ doc_weight: float
145
+ contribution: float
146
+
147
+
148
+ class ExplainRequest(BaseModel):
149
+ query: str
150
+ document: str
151
+ normalize: bool = True
152
+ max_active_dims: Optional[int] = 0
153
+ top_k_tokens: int = 15
154
+
155
+
156
+ class ExplainResponse(BaseModel):
157
+ score: float
158
+ top_tokens: List[TokenContribution]
159
+ info: Dict[str, Any]
160
+
161
+
162
+ # --------------------------------------------------------------------------------------
163
+ # Helpers
164
+ # --------------------------------------------------------------------------------------
165
+
166
+
167
+ def torch_sparse_batch_to_rows(t: torch.Tensor) -> List[Dict[str, Any]]:
168
+ """Convert a 2D torch sparse tensor [batch, dim] to list of {indices, weights} per row."""
169
+ if not isinstance(t, torch.Tensor):
170
+ raise TypeError("Expected a torch.Tensor from SparseEncoder")
171
+ if not t.is_sparse:
172
+ # Dense fallback (shouldn't happen with SparseEncoder). Convert per-row.
173
+ t = t.to("cpu")
174
+ rows = []
175
+ for r in t:
176
+ nz = torch.nonzero(r, as_tuple=True)[0]
177
+ rows.append({"indices": nz.tolist(), "weights": r[nz].tolist()})
178
+ return rows
179
+
180
+ # COO expected; coalesce and split by row
181
+ t = t.coalesce() # merge duplicates
182
+ idx = t.indices() # [2, nnz]
183
+ vals = t.values() # [nnz]
184
+ batch_size = t.size(0)
185
+
186
+ rows_out: List[Dict[str, Any]] = []
187
+ row_ids = idx[0]
188
+ col_ids = idx[1]
189
+
190
+ # For each row, mask and gather its entries
191
+ for i in range(batch_size):
192
+ m = row_ids == i
193
+ if torch.count_nonzero(m) == 0:
194
+ rows_out.append({"indices": [], "weights": []})
195
+ continue
196
+ cols_i = col_ids[m].to("cpu")
197
+ vals_i = vals[m].to("cpu")
198
+ rows_out.append({"indices": cols_i.tolist(), "weights": vals_i.tolist()})
199
+ return rows_out
200
+
201
+
202
+ def top_token_contributions(q_row: Dict[str, Any], d_row: Dict[str, Any], k: int) -> List[Dict[str, Any]]:
203
+ """Intersect query/doc indices and score tokens by product of weights."""
204
+ q_map = {int(i): float(w) for i, w in zip(q_row.get("indices", []), q_row.get("weights", []))}
205
+ contribs = []
206
+ for i, dw in zip(d_row.get("indices", []), d_row.get("weights", [])):
207
+ i = int(i)
208
+ dw = float(dw)
209
+ qw = q_map.get(i)
210
+ if qw is not None:
211
+ contribs.append((i, qw, dw, qw * dw))
212
+ contribs.sort(key=lambda t: t[3], reverse=True)
213
+ top = contribs[: max(k, 0) or 15]
214
+ out: List[Dict[str, Any]] = []
215
+ for tok_id, qw, dw, c in top:
216
+ try:
217
+ # RobBERT uses RoBERTa/BPE-style tokens (Ġ denotes a leading space)
218
+ tok = tokenizer.convert_ids_to_tokens([tok_id])[0]
219
+ pretty = tok.replace("Ġ", " ").replace("▁", " ")
220
+ except Exception:
221
+ tok = pretty = str(tok_id)
222
+ out.append({
223
+ "token_id": tok_id,
224
+ "token": pretty,
225
+ "query_weight": qw,
226
+ "doc_weight": dw,
227
+ "contribution": c,
228
+ })
229
+ return out
230
+
231
+
232
+ # --------------------------------------------------------------------------------------
233
+ # Routes
234
+ # --------------------------------------------------------------------------------------
235
+
236
+
237
+ @app.get("/health", response_model=HealthResponse)
238
+ async def health() -> HealthResponse:
239
+ return HealthResponse(
240
+ model_loaded=model is not None,
241
+ model_name=MODEL_ID,
242
+ device=DEVICE,
243
+ )
244
+
245
+
246
+ @app.post("/embeddings", response_model=EmbeddingsResponse)
247
+ async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse:
248
+ if model is None:
249
+ raise HTTPException(status_code=503, detail="Model not loaded")
250
+ if not req.texts:
251
+ raise HTTPException(status_code=400, detail="'texts' must be a non-empty list")
252
+
253
+ prompt_name = "query" if req.mode == "query" else "document"
254
+ max_k = req.max_active_dims or None
255
+
256
+ logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode")
257
+
258
+ try:
259
+ if req.mode == "query":
260
+ embs = model.encode_query(
261
+ req.texts,
262
+ convert_to_tensor=True,
263
+ device=DEVICE,
264
+ normalize=req.normalize,
265
+ max_active_dims=max_k,
266
+ )
267
+ else:
268
+ embs = model.encode_document(
269
+ req.texts,
270
+ convert_to_tensor=True,
271
+ device=DEVICE,
272
+ normalize=req.normalize,
273
+ max_active_dims=max_k,
274
+ )
275
+
276
+ rows = torch_sparse_batch_to_rows(embs)
277
+ # Model card states ~50k dims; we can read the 2nd dimension from the tensor
278
+ dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
279
+
280
+ return EmbeddingsResponse(
281
+ data=[EmbeddingRow(**r) for r in rows],
282
+ dim=dim,
283
+ info={
284
+ "mode": req.mode,
285
+ "normalize": req.normalize,
286
+ "max_active_dims": max_k,
287
+ "device": DEVICE,
288
+ },
289
+ )
290
+ except RuntimeError as e:
291
+ # If anything MPS-related sneaks in, hard-move to CPU and retry once
292
+ msg = str(e)
293
+ if "MPS" in msg or "to_sparse" in msg:
294
+ logger.warning("Encountered MPS/sparse op issue; retrying on CPU.")
295
+ try:
296
+ model.to("cpu")
297
+ if req.mode == "query":
298
+ embs = model.encode_query(
299
+ req.texts,
300
+ convert_to_tensor=True,
301
+ device="cpu",
302
+ normalize=req.normalize,
303
+ max_active_dims=max_k,
304
+ )
305
+ else:
306
+ embs = model.encode_document(
307
+ req.texts,
308
+ convert_to_tensor=True,
309
+ device="cpu",
310
+ normalize=req.normalize,
311
+ max_active_dims=max_k,
312
+ )
313
+ rows = torch_sparse_batch_to_rows(embs)
314
+ dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
315
+ return EmbeddingsResponse(
316
+ data=[EmbeddingRow(**r) for r in rows],
317
+ dim=dim,
318
+ info={
319
+ "mode": req.mode,
320
+ "normalize": req.normalize,
321
+ "max_active_dims": max_k,
322
+ "device": "cpu",
323
+ "retry": True,
324
+ },
325
+ )
326
+ except Exception:
327
+ logger.exception("CPU retry failed")
328
+ raise HTTPException(status_code=500, detail=msg)
329
+ # Unknown runtime error
330
+ logger.exception("Error generating embeddings")
331
+ raise HTTPException(status_code=500, detail=msg)
332
+ except Exception as e:
333
+ logger.exception("Error generating embeddings")
334
+ raise HTTPException(status_code=500, detail=str(e))
335
+
336
+
337
+ @app.post("/similarity", response_model=SimilarityResponse)
338
+ async def similarity(req: SimilarityRequest) -> SimilarityResponse:
339
+ if model is None:
340
+ raise HTTPException(status_code=503, detail="Model not loaded")
341
+ if not req.queries:
342
+ raise HTTPException(status_code=400, detail="'queries' must be a non-empty list")
343
+ if not req.documents:
344
+ raise HTTPException(status_code=400, detail="'documents' must be a non-empty list")
345
+
346
+ max_k = req.max_active_dims or None
347
+
348
+ try:
349
+ q = model.encode_query(
350
+ req.queries,
351
+ convert_to_tensor=True,
352
+ device=DEVICE,
353
+ normalize=req.normalize,
354
+ max_active_dims=max_k,
355
+ )
356
+ d = model.encode_document(
357
+ req.documents,
358
+ convert_to_tensor=True,
359
+ device=DEVICE,
360
+ normalize=req.normalize,
361
+ max_active_dims=max_k,
362
+ )
363
+ scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs]
364
+
365
+ results: List[List[SimilarityHit]] = []
366
+ k = min(req.top_k or 5, len(req.documents))
367
+ for i in range(scores.size(0)):
368
+ vals, idxs = torch.topk(scores[i], k=k)
369
+ q_hits: List[SimilarityHit] = []
370
+ for v, j in zip(vals.tolist(), idxs.tolist()):
371
+ q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j]))
372
+ results.append(q_hits)
373
+
374
+ return SimilarityResponse(
375
+ results=results,
376
+ info={
377
+ "normalize": req.normalize,
378
+ "max_active_dims": max_k,
379
+ "device": DEVICE,
380
+ },
381
+ )
382
+ except Exception as e:
383
+ logger.exception("Error computing similarity")
384
+ raise HTTPException(status_code=500, detail=str(e))
385
+
386
+
387
+ # --------------------------------------------------------------------------------------
388
+ # Routes
389
+ # --------------------------------------------------------------------------------------
390
+
391
+
392
+ @app.get("/health", response_model=HealthResponse)
393
+ async def health() -> HealthResponse:
394
+ return HealthResponse(
395
+ model_loaded=model is not None,
396
+ model_name=MODEL_ID,
397
+ device=DEVICE,
398
+ )
399
+
400
+
401
+ @app.post("/embeddings", response_model=EmbeddingsResponse)
402
+ async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse:
403
+ if model is None:
404
+ raise HTTPException(status_code=503, detail="Model not loaded")
405
+ if not req.texts:
406
+ raise HTTPException(status_code=400, detail="'texts' must be a non-empty list")
407
+
408
+ prompt_name = "query" if req.mode == "query" else "document"
409
+ max_k = req.max_active_dims or None
410
+
411
+ logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode")
412
+
413
+ try:
414
+ if req.mode == "query":
415
+ embs = model.encode_query(
416
+ req.texts,
417
+ convert_to_tensor=True,
418
+ device=DEVICE,
419
+ normalize=req.normalize,
420
+ max_active_dims=max_k,
421
+ )
422
+ else:
423
+ embs = model.encode_document(
424
+ req.texts,
425
+ convert_to_tensor=True,
426
+ device=DEVICE,
427
+ normalize=req.normalize,
428
+ max_active_dims=max_k,
429
+ )
430
+
431
+ rows = torch_sparse_batch_to_rows(embs)
432
+ # Model card states ~50k dims; we can read the 2nd dimension from the tensor
433
+ dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
434
+
435
+ return EmbeddingsResponse(
436
+ data=[EmbeddingRow(**r) for r in rows],
437
+ dim=dim,
438
+ info={
439
+ "mode": req.mode,
440
+ "normalize": req.normalize,
441
+ "max_active_dims": max_k,
442
+ "device": DEVICE,
443
+ },
444
+ )
445
+ except RuntimeError as e:
446
+ # If anything MPS-related sneaks in, hard-move to CPU and retry once
447
+ msg = str(e)
448
+ if "MPS" in msg or "to_sparse" in msg:
449
+ logger.warning("Encountered MPS/sparse op issue; retrying on CPU.")
450
+ try:
451
+ model.to("cpu")
452
+ if req.mode == "query":
453
+ embs = model.encode_query(
454
+ req.texts,
455
+ convert_to_tensor=True,
456
+ device="cpu",
457
+ normalize=req.normalize,
458
+ max_active_dims=max_k,
459
+ )
460
+ else:
461
+ embs = model.encode_document(
462
+ req.texts,
463
+ convert_to_tensor=True,
464
+ device="cpu",
465
+ normalize=req.normalize,
466
+ max_active_dims=max_k,
467
+ )
468
+ rows = torch_sparse_batch_to_rows(embs)
469
+ dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0
470
+ return EmbeddingsResponse(
471
+ data=[EmbeddingRow(**r) for r in rows],
472
+ dim=dim,
473
+ info={
474
+ "mode": req.mode,
475
+ "normalize": req.normalize,
476
+ "max_active_dims": max_k,
477
+ "device": "cpu",
478
+ "retry": True,
479
+ },
480
+ )
481
+ except Exception:
482
+ logger.exception("CPU retry failed")
483
+ raise HTTPException(status_code=500, detail=msg)
484
+ # Unknown runtime error
485
+ logger.exception("Error generating embeddings")
486
+ raise HTTPException(status_code=500, detail=msg)
487
+ except Exception as e:
488
+ logger.exception("Error generating embeddings")
489
+ raise HTTPException(status_code=500, detail=str(e))
490
+
491
+
492
+ @app.post("/similarity", response_model=SimilarityResponse)
493
+ async def similarity(req: SimilarityRequest) -> SimilarityResponse:
494
+ if model is None:
495
+ raise HTTPException(status_code=503, detail="Model not loaded")
496
+ if not req.queries:
497
+ raise HTTPException(status_code=400, detail="'queries' must be a non-empty list")
498
+ if not req.documents:
499
+ raise HTTPException(status_code=400, detail="'documents' must be a non-empty list")
500
+
501
+ max_k = req.max_active_dims or None
502
+
503
+ try:
504
+ q = model.encode_query(
505
+ req.queries,
506
+ convert_to_tensor=True,
507
+ device=DEVICE,
508
+ normalize=req.normalize,
509
+ max_active_dims=max_k,
510
+ )
511
+ d = model.encode_document(
512
+ req.documents,
513
+ convert_to_tensor=True,
514
+ device=DEVICE,
515
+ normalize=req.normalize,
516
+ max_active_dims=max_k,
517
+ )
518
+ scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs]
519
+
520
+ results: List[List[SimilarityHit]] = []
521
+ k = min(req.top_k or 5, len(req.documents))
522
+ for i in range(scores.size(0)):
523
+ vals, idxs = torch.topk(scores[i], k=k)
524
+ q_hits: List[SimilarityHit] = []
525
+ for v, j in zip(vals.tolist(), idxs.tolist()):
526
+ q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j]))
527
+ results.append(q_hits)
528
+
529
+ return SimilarityResponse(
530
+ results=results,
531
+ info={
532
+ "normalize": req.normalize,
533
+ "max_active_dims": max_k,
534
+ "device": DEVICE,
535
+ },
536
+ )
537
+ except Exception as e:
538
+ logger.exception("Error computing similarity")
539
+ raise HTTPException(status_code=500, detail=str(e))
540
+
541
+
542
+ @app.post("/explain", response_model=ExplainResponse)
543
+ async def explain(req: ExplainRequest) -> ExplainResponse:
544
+ if model is None or tokenizer is None:
545
+ raise HTTPException(status_code=503, detail="Model/tokenizer not loaded")
546
+
547
+ max_k = req.max_active_dims or None
548
+
549
+ try:
550
+ q = model.encode_query(
551
+ [req.query],
552
+ convert_to_tensor=True,
553
+ device=DEVICE,
554
+ normalize=req.normalize,
555
+ max_active_dims=max_k,
556
+ )
557
+ d = model.encode_document(
558
+ [req.document],
559
+ convert_to_tensor=True,
560
+ device=DEVICE,
561
+ normalize=req.normalize,
562
+ max_active_dims=max_k,
563
+ )
564
+ score = float(model.similarity(q, d)[0, 0].item())
565
+
566
+ q_row = torch_sparse_batch_to_rows(q)[0]
567
+ d_row = torch_sparse_batch_to_rows(d)[0]
568
+ tokens = top_token_contributions(q_row, d_row, req.top_k_tokens)
569
+
570
+ return ExplainResponse(
571
+ score=score,
572
+ top_tokens=[TokenContribution(**t) for t in tokens],
573
+ info={
574
+ "normalize": req.normalize,
575
+ "max_active_dims": max_k,
576
+ "device": DEVICE,
577
+ },
578
+ )
579
+ except Exception as e:
580
+ logger.exception("Error explaining match")
581
+ raise HTTPException(status_code=500, detail=str(e))
582
+
583
+
584
+ # --------------------------------------------------------------------------------------
585
+ # Local dev runner
586
+ # --------------------------------------------------------------------------------------
587
+
588
+ if __name__ == "__main__":
589
+ import uvicorn
590
+
591
+ uvicorn.run(
592
+ "main:app",
593
+ host="0.0.0.0",
594
+ port=8000,
595
+ reload=True,
596
+ log_level="info",
597
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ fastapi==0.115.0
3
+ uvicorn[standard]==0.32.0
4
+ sentence-transformers==5.0.0
5
+ torch>=2.6.0
6
+ scipy==1.13.1
7
+ pydantic==2.9.2
8
+ python-multipart==0.0.9