File size: 1,425 Bytes
85c1145
 
 
 
 
 
 
 
 
 
 
 
 
 
4e223f2
85c1145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from promptSearchEngine import PromptSearchEngine
from vectorizer import Vectorizer
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

EMBEDDING_MODEL = "all-MiniLM-L6-v2"
DATASET = "Gustavosta/Stable-Diffusion-Prompts"


    
model = SentenceTransformer(EMBEDDING_MODEL)
dataset = load_dataset(DATASET , split="train[:1%]")
promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model)

class SearchRequest(BaseModel):
    query: str 
    n: int | None = 5

app = FastAPI()

@app.get("/")
async def root():
    return {"message": 'GET /docs'}

@app.get("/search")
async def search(q: str, n: int = 5):
    results = []
    if q.isspace() or q =="":
        return {"message": "Enter query"}
    else:
        results = promptSearchEngine.most_similar(q, n)
    if not results:
        raise HTTPException(status_code=404, detail="No prompts found.")
    return promptSearchEngine.stringify_prompts(results)


@app.post("/search")
async def searchPost(request: SearchRequest):
    results = promptSearchEngine.most_similar(request.query, request.n)
    if not results:
        raise HTTPException(status_code=404, detail="No prompts found.")
    formatted_results = [{"similarity": float(similarity), "prompt": prompt } for similarity, prompt in results]
    return { "data" : formatted_results }