Spaces:
Sleeping
Sleeping
Add app
Browse files- Dockerfile +22 -0
- api/__init__.py +1 -0
- api/requirements.txt +3 -0
- api/service_manager.py +13 -0
- api/web_server.py +26 -0
- core/__init__.py +1 -0
- core/__pycache__/__init__.cpython-39.pyc +0 -0
- core/__pycache__/search_engine.cpython-39.pyc +0 -0
- core/__pycache__/vectorizer.cpython-39.pyc +0 -0
- core/data/__init__.py +1 -0
- core/data/__pycache__/__init__.cpython-39.pyc +0 -0
- core/data/__pycache__/dataset.cpython-39.pyc +0 -0
- core/data/dataset.py +44 -0
- core/initialization.py +21 -0
- core/requirements.txt +3 -0
- core/search_engine.py +36 -0
- core/vectorizer.py +34 -0
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9 as core
|
2 |
+
|
3 |
+
COPY ./core/requirements.txt ./requirements.txt
|
4 |
+
RUN pip install -f requirements.txt
|
5 |
+
|
6 |
+
WORKDIR /app
|
7 |
+
|
8 |
+
COPY ./core .
|
9 |
+
RUN python ./initialization.py
|
10 |
+
|
11 |
+
FROM python:3.9
|
12 |
+
|
13 |
+
COPY ./api/requirements.txt ./requirements.txt
|
14 |
+
RUN pip install -f requirements.txt
|
15 |
+
|
16 |
+
WORKDIR /app
|
17 |
+
COPY ./api .
|
18 |
+
|
19 |
+
COPY --from=core ./engine.pickle ./engine.pickle
|
20 |
+
|
21 |
+
EXPOSE 9999
|
22 |
+
ENTRYPOINT ["python", "service_manager.py"]
|
api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# This file is intentionally left empty
|
api/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.111.1
|
2 |
+
pydantic==2.8.2
|
3 |
+
uvicorn==0.30.3
|
api/service_manager.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from web_server import app
|
2 |
+
import uvicorn
|
3 |
+
|
4 |
+
|
5 |
+
def run():
|
6 |
+
"""
|
7 |
+
TODO
|
8 |
+
"""
|
9 |
+
|
10 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
11 |
+
|
12 |
+
run()
|
13 |
+
|
api/web_server.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dill
|
2 |
+
from fastapi import FastAPI, HTTPException
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from core.search_engine import PromptSearchEngine
|
5 |
+
|
6 |
+
|
7 |
+
class Query(BaseModel):
|
8 |
+
prompt: str
|
9 |
+
n: int = 5
|
10 |
+
|
11 |
+
|
12 |
+
app = FastAPI()
|
13 |
+
|
14 |
+
with open('../core/engine.pickle', 'rb') as file:
|
15 |
+
serialized_engine = file.read()
|
16 |
+
|
17 |
+
prompt_search_engine = dill.loads(serialized_engine)
|
18 |
+
|
19 |
+
|
20 |
+
@app.post("/search/")
|
21 |
+
async def search(query: Query):
|
22 |
+
try:
|
23 |
+
results = prompt_search_engine.most_similar(query.prompt, query.n)
|
24 |
+
return {"results": results}
|
25 |
+
except Exception as e:
|
26 |
+
raise HTTPException(status_code=500, detail=str(e))
|
core/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# This file is intentionally left empty
|
core/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (151 Bytes). View file
|
|
core/__pycache__/search_engine.cpython-39.pyc
ADDED
Binary file (1.56 kB). View file
|
|
core/__pycache__/vectorizer.cpython-39.pyc
ADDED
Binary file (1.48 kB). View file
|
|
core/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# This file is intentionally left empty
|
core/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (156 Bytes). View file
|
|
core/data/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (1.21 kB). View file
|
|
core/data/dataset.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from core.search_engine import PromptSearchEngine
|
3 |
+
|
4 |
+
|
5 |
+
class PromptDataset:
|
6 |
+
"""
|
7 |
+
TODO
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, dataset_name: str):
|
11 |
+
"""
|
12 |
+
TODO
|
13 |
+
"""
|
14 |
+
|
15 |
+
self.dataset_name = dataset_name
|
16 |
+
self.dataset = None
|
17 |
+
|
18 |
+
def load(self):
|
19 |
+
"""
|
20 |
+
TODO
|
21 |
+
"""
|
22 |
+
|
23 |
+
self.dataset = load_dataset(self.dataset_name)
|
24 |
+
|
25 |
+
return self.dataset
|
26 |
+
|
27 |
+
def get_prompts(self):
|
28 |
+
"""
|
29 |
+
TODO
|
30 |
+
"""
|
31 |
+
|
32 |
+
if self.dataset is None:
|
33 |
+
raise ValueError("Dataset not loaded. Call the load() method first.")
|
34 |
+
|
35 |
+
return [item['Prompt'] for item in self.dataset['test']]
|
36 |
+
|
37 |
+
|
38 |
+
# if __name__ == "__main__":
|
39 |
+
# dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
|
40 |
+
# dataset.load()
|
41 |
+
# prompts = dataset.get_prompts()
|
42 |
+
# engine = PromptSearchEngine(prompts)
|
43 |
+
# result = engine.most_similar("dark")
|
44 |
+
# print(result)
|
core/initialization.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dill
|
2 |
+
from data.dataset import PromptDataset
|
3 |
+
from core.search_engine import PromptSearchEngine
|
4 |
+
|
5 |
+
|
6 |
+
def run():
|
7 |
+
"""
|
8 |
+
TODO
|
9 |
+
"""
|
10 |
+
|
11 |
+
prompt_dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
|
12 |
+
prompt_dataset.load()
|
13 |
+
prompts = prompt_dataset.get_prompts()
|
14 |
+
engine = PromptSearchEngine(prompts)
|
15 |
+
|
16 |
+
serialized_engine = dill.dumps(engine)
|
17 |
+
|
18 |
+
with open("engine.pickle", "wb") as file:
|
19 |
+
file.write(serialized_engine)
|
20 |
+
|
21 |
+
run()
|
core/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.20.0
|
2 |
+
faiss_cpu==1.8.0.post1
|
3 |
+
sentence_transformers==3.0.1
|
core/search_engine.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Sequence, Tuple
|
2 |
+
import numpy as np
|
3 |
+
import faiss
|
4 |
+
from core.vectorizer import Vectorizer
|
5 |
+
|
6 |
+
|
7 |
+
class PromptSearchEngine(object):
|
8 |
+
"""
|
9 |
+
TODO
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, prompts: Sequence[str]) -> None:
|
13 |
+
"""
|
14 |
+
TODO
|
15 |
+
"""
|
16 |
+
|
17 |
+
self.vectorizer = Vectorizer()
|
18 |
+
self.corpus_vectors = self.vectorizer.transform(prompts)
|
19 |
+
self.corpus = prompts
|
20 |
+
|
21 |
+
self.corpus_vectors = self.corpus_vectors / np.linalg.norm(self.corpus_vectors, axis=1, keepdims=True)
|
22 |
+
|
23 |
+
d = self.corpus_vectors.shape[1]
|
24 |
+
self.index = faiss.IndexFlatIP(d)
|
25 |
+
self.index.add(self.corpus_vectors.astype('float32'))
|
26 |
+
|
27 |
+
def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
|
28 |
+
"""
|
29 |
+
TODO
|
30 |
+
"""
|
31 |
+
|
32 |
+
query_vector = self.vectorizer.transform([query]).astype('float32')
|
33 |
+
query_vector = query_vector / np.linalg.norm(query_vector)
|
34 |
+
distances, indices = self.index.search(query_vector, n)
|
35 |
+
|
36 |
+
return [(distances[0][i], self.corpus[indices[0][i]]) for i in range(n)]
|
core/vectorizer.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
import numpy as np
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
|
5 |
+
|
6 |
+
class Vectorizer(object):
|
7 |
+
"""
|
8 |
+
TODO
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, model_name: str = 'all-MiniLM-L6-v2') -> None:
|
12 |
+
"""
|
13 |
+
Initialize the vectorizer with a pre-trained embedding model.
|
14 |
+
"""
|
15 |
+
|
16 |
+
self.model = SentenceTransformer(model_name)
|
17 |
+
|
18 |
+
def transform(self, prompts: Sequence[str]) -> np.ndarray:
|
19 |
+
"""
|
20 |
+
Transform texts into numerical vectors using the specified model.
|
21 |
+
"""
|
22 |
+
|
23 |
+
return self.model.encode(list(prompts))
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
|
27 |
+
"""
|
28 |
+
Calculate cosine similarity between prompt vectors.
|
29 |
+
"""
|
30 |
+
|
31 |
+
query_norm = query_vector / np.linalg.norm(query_vector)
|
32 |
+
corpus_norms = corpus_vectors / np.linalg.norm(corpus_vectors, axis=1, keepdims=True)
|
33 |
+
|
34 |
+
return np.dot(corpus_norms, query_norm.T).flatten()
|