File size: 3,648 Bytes
693e699
 
 
32c0b8d
693e699
32c0b8d
 
 
b3e42d1
693e699
32c0b8d
 
 
 
 
 
 
 
 
 
693e699
32c0b8d
 
 
693e699
32c0b8d
693e699
32c0b8d
 
 
 
 
 
693e699
32c0b8d
693e699
32c0b8d
693e699
 
32c0b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693e699
 
32c0b8d
 
 
 
 
 
 
 
 
 
 
693e699
 
 
91fb3b7
 
 
 
b3e42d1
f4ba8ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List
import torch
from functools import lru_cache
import logging
from datetime import datetime

# πŸ”§ Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# πŸš€ Initialize FastAPI app
app = FastAPI()
logger.info("Starting FastAPI application")

# πŸ”Œ Load SentenceTransformer models
logger.info("Loading BGE small model...")
bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu")
logger.info("Loaded BGE small model")

logger.info("Loading All-MPNet model...")
all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu")
logger.info("Loaded All-MPNet model")

# πŸ”Œ Load SPLADE model
logger.info("Loading SPLADE model...")
SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True)
SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
SPLADE_MODEL.eval()
logger.info("Loaded SPLADE model")

# πŸ“¦ Request and response models
class TextInput(BaseModel):
    text: List[str]
    model_name: str

class SparseVector(BaseModel):
    indices: List[int]
    values: List[float]

# 🧠 LRU cacheable versions
@lru_cache(maxsize=1000)
def encode_dense_cached(model_name: str, text: str):
    logger.info(f"Encoding dense text with model {model_name}: {text}")
    if model_name == "BM":
        embedding = all_mp_net_model.encode([text])[0].tolist()
    else:
        embedding = bge_small_model.encode([text])[0].tolist()
    logger.info(f"Finished encoding dense text")
    return embedding

@lru_cache(maxsize=1000)
def encode_splade_cached(text: str) -> SparseVector:
    logger.info(f"Encoding SPLADE sparse vector: {text}")
    inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True)
    with torch.no_grad():
        outputs = SPLADE_MODEL(**inputs)

    logits = outputs.logits[0]
    relu_log = torch.log1p(torch.relu(logits))
    nonzero = relu_log.nonzero(as_tuple=False)

    if nonzero.shape[0] == 0:
        logger.info("No non-zero values found in SPLADE output")
        return SparseVector(indices=[], values=[])

    vocab_indices = nonzero[:, 1]
    values = relu_log[nonzero[:, 0], nonzero[:, 1]]

    logger.info(f"SPLADE encoding complete with {len(vocab_indices)} dimensions")
    return SparseVector(
        indices=vocab_indices.cpu().numpy().tolist(),
        values=values.cpu().numpy().tolist()
    )

# πŸš€ Main endpoint
@app.post("/get-embedding/")
async def get_embedding(input: TextInput):
    logger.info(f"Received request with model: {input.model_name}, texts: {input.text}")

    model_key = input.model_name.upper()
    if model_key in {"BM", "BG"}:
        embeddings = [encode_dense_cached(model_key, t) for t in input.text]
        logger.info(f"Returning dense embeddings for {len(embeddings)} texts")
        return {"type": "dense", "embeddings": embeddings}
    elif model_key == "SPLADE":
        sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text]
        logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts")
        return {"type": "sparse", "embeddings": sparse_vecs}
    else:
        embeddings = bge_small_model.encode(input.text)
    return {"embeddings": embeddings.tolist()}


@app.get("/status")
async def status():
    logger.info(f"Status API: Server is up and running at {datetime.now()}")
    return {"status": "Server is up and running"}