File size: 1,851 Bytes
8a0c27f
 
 
19b2dc7
51a3a8c
 
8a0c27f
 
 
 
 
 
 
b652e4e
8a0c27f
 
 
 
 
 
 
63a3ec4
b652e4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a3ec4
 
8a0c27f
63a3ec4
 
 
 
 
 
8a0c27f
b652e4e
 
 
63a3ec4
 
 
8a0c27f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import dill
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from core.search_engine import PromptSearchEngine


class Query(BaseModel):
    prompt: str
    n: int = 5


app = FastAPI()

with open("./engine.pickle", "rb") as file:
    serialized_engine = file.read()

prompt_search_engine = dill.loads(serialized_engine)


@app.post("/search/")
async def search(query: Query):
    """
    Find the most similar prompts to a given query prompt using the pre-trained PromptSearchEngine.

    This endpoint accepts a query prompt and returns a specified number of the most similar prompts
    from the corpus. It performs the following steps:
    1. Validates the input types.
    2. Uses the pre-loaded PromptSearchEngine to find the most similar prompts.
    3. Formats the results into a list of dictionaries containing the similarity score and prompt text.

    Args:
        query (Query): The query model containing the prompt text and the number of similar prompts to return.

    Returns:
        List[Dict[str, Union[float, str]]]: A list of dictionaries where each dictionary contains the similarity score and the corresponding prompt.

    Raises:
        HTTPException: If an error occurs during the processing of the query, an HTTP 500 error is raised with the error details.
    """

    try:
        if not isinstance(query.prompt, str):
            raise ValueError("Prompt must be a string")

        if not isinstance(query.n, int):
            raise ValueError("Prompt must be a string")

        results = prompt_search_engine.most_similar(query.prompt, query.n)
        formatted_results = [
            {"score": float(score), "description": desc} for score, desc in results
        ]

        return formatted_results

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))