sushil3125 commited on
Commit
32c0b8d
Β·
1 Parent(s): b6b7d0e
Files changed (1) hide show
  1. app.py +77 -9
app.py CHANGED
@@ -1,26 +1,94 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
 
4
  from typing import List
 
 
 
5
 
6
- # Load the pre-trained sentence transformer model
 
 
 
 
 
 
 
 
 
7
  bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu")
 
 
 
8
  all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu")
 
9
 
10
- # Initialize FastAPI app
11
- app = FastAPI()
 
 
 
 
12
 
13
- # Request body model
14
  class TextInput(BaseModel):
15
- text: List[str] # List of sentences or text data
16
  model_name: str
17
 
18
- # Route to calculate embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @app.post("/get-embedding/")
20
  async def get_embedding(input: TextInput):
21
- # Generate embeddings using the sentence transformer model
22
- if input.model_name == "BM":
23
- embeddings = all_mp_net_model.encode(input.text)
 
 
 
 
 
 
 
 
24
  else:
25
  embeddings = bge_small_model.encode(input.text)
26
  return {"embeddings": embeddings.tolist()}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
  from typing import List
6
+ import torch
7
+ from functools import lru_cache
8
+ import logging
9
 
10
+ # πŸ”§ Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # πŸš€ Initialize FastAPI app
15
+ app = FastAPI()
16
+ logger.info("Starting FastAPI application")
17
+
18
+ # πŸ”Œ Load SentenceTransformer models
19
+ logger.info("Loading BGE small model...")
20
  bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu")
21
+ logger.info("Loaded BGE small model")
22
+
23
+ logger.info("Loading All-MPNet model...")
24
  all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu")
25
+ logger.info("Loaded All-MPNet model")
26
 
27
+ # πŸ”Œ Load SPLADE model
28
+ logger.info("Loading SPLADE model...")
29
+ SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True)
30
+ SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
31
+ SPLADE_MODEL.eval()
32
+ logger.info("Loaded SPLADE model")
33
 
34
+ # πŸ“¦ Request and response models
35
  class TextInput(BaseModel):
36
+ text: List[str]
37
  model_name: str
38
 
39
+ class SparseVector(BaseModel):
40
+ indices: List[int]
41
+ values: List[float]
42
+
43
+ # 🧠 LRU cacheable versions
44
+ @lru_cache(maxsize=1000)
45
+ def encode_dense_cached(model_name: str, text: str):
46
+ logger.info(f"Encoding dense text with model {model_name}: {text}")
47
+ if model_name == "BM":
48
+ embedding = all_mp_net_model.encode([text])[0].tolist()
49
+ else:
50
+ embedding = bge_small_model.encode([text])[0].tolist()
51
+ logger.info(f"Finished encoding dense text")
52
+ return embedding
53
+
54
+ @lru_cache(maxsize=1000)
55
+ def encode_splade_cached(text: str) -> SparseVector:
56
+ logger.info(f"Encoding SPLADE sparse vector: {text}")
57
+ inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True)
58
+ with torch.no_grad():
59
+ outputs = SPLADE_MODEL(**inputs)
60
+
61
+ logits = outputs.logits[0]
62
+ relu_log = torch.log1p(torch.relu(logits))
63
+ nonzero = relu_log.nonzero(as_tuple=False)
64
+
65
+ if nonzero.shape[0] == 0:
66
+ logger.info("No non-zero values found in SPLADE output")
67
+ return SparseVector(indices=[], values=[])
68
+
69
+ vocab_indices = nonzero[:, 1]
70
+ values = relu_log[nonzero[:, 0], nonzero[:, 1]]
71
+
72
+ logger.info(f"SPLADE encoding complete with {len(vocab_indices)} dimensions")
73
+ return SparseVector(
74
+ indices=vocab_indices.cpu().numpy().tolist(),
75
+ values=values.cpu().numpy().tolist()
76
+ )
77
+
78
+ # πŸš€ Main endpoint
79
  @app.post("/get-embedding/")
80
  async def get_embedding(input: TextInput):
81
+ logger.info(f"Received request with model: {input.model_name}, texts: {input.text}")
82
+
83
+ model_key = input.model_name.upper()
84
+ if model_key in {"BM", "BG"}:
85
+ embeddings = [encode_dense_cached(model_key, t) for t in input.text]
86
+ logger.info(f"Returning dense embeddings for {len(embeddings)} texts")
87
+ return {"type": "dense", "embeddings": embeddings}
88
+ elif model_key == "SPLADE":
89
+ sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text]
90
+ logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts")
91
+ return {"type": "sparse", "embeddings": sparse_vecs}
92
  else:
93
  embeddings = bge_small_model.encode(input.text)
94
  return {"embeddings": embeddings.tolist()}