sasagema's picture
Change dataset size
4e223f2
import json
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from promptSearchEngine import PromptSearchEngine
from vectorizer import Vectorizer
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
DATASET = "Gustavosta/Stable-Diffusion-Prompts"
model = SentenceTransformer(EMBEDDING_MODEL)
dataset = load_dataset(DATASET , split="train[:1%]")
promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model)
class SearchRequest(BaseModel):
query: str
n: int | None = 5
app = FastAPI()
@app.get("/")
async def root():
return {"message": 'GET /docs'}
@app.get("/search")
async def search(q: str, n: int = 5):
results = []
if q.isspace() or q =="":
return {"message": "Enter query"}
else:
results = promptSearchEngine.most_similar(q, n)
if not results:
raise HTTPException(status_code=404, detail="No prompts found.")
return promptSearchEngine.stringify_prompts(results)
@app.post("/search")
async def searchPost(request: SearchRequest):
results = promptSearchEngine.most_similar(request.query, request.n)
if not results:
raise HTTPException(status_code=404, detail="No prompts found.")
formatted_results = [{"similarity": float(similarity), "prompt": prompt } for similarity, prompt in results]
return { "data" : formatted_results }