File size: 941 Bytes
8a0c27f
 
 
63a3ec4
8a0c27f
 
 
 
 
 
 
 
 
63a3ec4
8a0c27f
 
 
 
 
 
 
63a3ec4
 
 
 
8a0c27f
63a3ec4
 
 
 
 
 
8a0c27f
63a3ec4
 
 
 
8a0c27f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import dill
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# from ..core.search_engine import PromptSearchEngine


class Query(BaseModel):
    prompt: str
    n: int = 5


app = FastAPI()

with open('./engine.pickle', 'rb') as file:
    serialized_engine = file.read()

prompt_search_engine = dill.loads(serialized_engine)


@app.post("/search/")
async def search(query: Query):
    """
    TODO
    """

    try:
        if not isinstance(query.prompt, str):
            raise ValueError("Prompt must be a string")

        if not isinstance(query.n, int):
            raise ValueError("Prompt must be a string")

        results = prompt_search_engine.most_similar(query.prompt, query.n)
        formatted_results = [{"score": float(score), "description": desc} for score, desc in results]

        return formatted_results

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))