|
from fastapi import FastAPI, Query |
|
from pydantic import BaseModel |
|
from typing import List |
|
from simcse import SimCSE |
|
import os |
|
|
|
app = FastAPI() |
|
|
|
|
|
sentence_path = os.path.join("./static/", "model_names.txt") |
|
embedder0 = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased", device="cpu") |
|
embedder1 = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased", device="cpu") |
|
|
|
embedder0.build_index(sentence_path, 0) |
|
embedder1.build_index(sentence_path, 1) |
|
|
|
|
|
class SearchResult(BaseModel): |
|
sentence: str |
|
score: float |
|
|
|
@app.get("/search", response_model=List[SearchResult]) |
|
def search(prompt: str = Query(..., description="Input text prompt")): |
|
results0 = embedder0.search(prompt, top_k=5, threshold=0.6) |
|
results1 = embedder1.search(prompt, top_k=5, threshold=0.6) |
|
|
|
|
|
combined = results0 + results1 |
|
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
seen = set() |
|
unique_sorted = [] |
|
for sentence, score in sorted_combined: |
|
if sentence not in seen: |
|
seen.add(sentence) |
|
unique_sorted.append({"sentence": sentence, "score": score}) |
|
|
|
return unique_sorted |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("demo:app", host="0.0.0.0", port=10001) |