Spaces:
Runtime error
Runtime error
import json | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from promptSearchEngine import PromptSearchEngine | |
from vectorizer import Vectorizer | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
DATASET = "Gustavosta/Stable-Diffusion-Prompts" | |
model = SentenceTransformer(EMBEDDING_MODEL) | |
dataset = load_dataset(DATASET , split="train[:1%]") | |
promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model) | |
class SearchRequest(BaseModel): | |
query: str | |
n: int | None = 5 | |
app = FastAPI() | |
async def root(): | |
return {"message": 'GET /docs'} | |
async def search(q: str, n: int = 5): | |
results = [] | |
if q.isspace() or q =="": | |
return {"message": "Enter query"} | |
else: | |
results = promptSearchEngine.most_similar(q, n) | |
if not results: | |
raise HTTPException(status_code=404, detail="No prompts found.") | |
return promptSearchEngine.stringify_prompts(results) | |
async def searchPost(request: SearchRequest): | |
results = promptSearchEngine.most_similar(request.query, request.n) | |
if not results: | |
raise HTTPException(status_code=404, detail="No prompts found.") | |
formatted_results = [{"similarity": float(similarity), "prompt": prompt } for similarity, prompt in results] | |
return { "data" : formatted_results } | |