krstakis commited on
Commit
8a0c27f
·
1 Parent(s): a4bc2b4
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9 as core
2
+
3
+ COPY ./core/requirements.txt ./requirements.txt
4
+ RUN pip install -f requirements.txt
5
+
6
+ WORKDIR /app
7
+
8
+ COPY ./core .
9
+ RUN python ./initialization.py
10
+
11
+ FROM python:3.9
12
+
13
+ COPY ./api/requirements.txt ./requirements.txt
14
+ RUN pip install -f requirements.txt
15
+
16
+ WORKDIR /app
17
+ COPY ./api .
18
+
19
+ COPY --from=core ./engine.pickle ./engine.pickle
20
+
21
+ EXPOSE 9999
22
+ ENTRYPOINT ["python", "service_manager.py"]
api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is intentionally left empty
api/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi==0.111.1
2
+ pydantic==2.8.2
3
+ uvicorn==0.30.3
api/service_manager.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from web_server import app
2
+ import uvicorn
3
+
4
+
5
+ def run():
6
+ """
7
+ TODO
8
+ """
9
+
10
+ uvicorn.run(app, host="0.0.0.0", port=8000)
11
+
12
+ run()
13
+
api/web_server.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dill
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from core.search_engine import PromptSearchEngine
5
+
6
+
7
+ class Query(BaseModel):
8
+ prompt: str
9
+ n: int = 5
10
+
11
+
12
+ app = FastAPI()
13
+
14
+ with open('../core/engine.pickle', 'rb') as file:
15
+ serialized_engine = file.read()
16
+
17
+ prompt_search_engine = dill.loads(serialized_engine)
18
+
19
+
20
+ @app.post("/search/")
21
+ async def search(query: Query):
22
+ try:
23
+ results = prompt_search_engine.most_similar(query.prompt, query.n)
24
+ return {"results": results}
25
+ except Exception as e:
26
+ raise HTTPException(status_code=500, detail=str(e))
core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is intentionally left empty
core/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (151 Bytes). View file
 
core/__pycache__/search_engine.cpython-39.pyc ADDED
Binary file (1.56 kB). View file
 
core/__pycache__/vectorizer.cpython-39.pyc ADDED
Binary file (1.48 kB). View file
 
core/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is intentionally left empty
core/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (156 Bytes). View file
 
core/data/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (1.21 kB). View file
 
core/data/dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from core.search_engine import PromptSearchEngine
3
+
4
+
5
+ class PromptDataset:
6
+ """
7
+ TODO
8
+ """
9
+
10
+ def __init__(self, dataset_name: str):
11
+ """
12
+ TODO
13
+ """
14
+
15
+ self.dataset_name = dataset_name
16
+ self.dataset = None
17
+
18
+ def load(self):
19
+ """
20
+ TODO
21
+ """
22
+
23
+ self.dataset = load_dataset(self.dataset_name)
24
+
25
+ return self.dataset
26
+
27
+ def get_prompts(self):
28
+ """
29
+ TODO
30
+ """
31
+
32
+ if self.dataset is None:
33
+ raise ValueError("Dataset not loaded. Call the load() method first.")
34
+
35
+ return [item['Prompt'] for item in self.dataset['test']]
36
+
37
+
38
+ # if __name__ == "__main__":
39
+ # dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
40
+ # dataset.load()
41
+ # prompts = dataset.get_prompts()
42
+ # engine = PromptSearchEngine(prompts)
43
+ # result = engine.most_similar("dark")
44
+ # print(result)
core/initialization.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dill
2
+ from data.dataset import PromptDataset
3
+ from core.search_engine import PromptSearchEngine
4
+
5
+
6
+ def run():
7
+ """
8
+ TODO
9
+ """
10
+
11
+ prompt_dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
12
+ prompt_dataset.load()
13
+ prompts = prompt_dataset.get_prompts()
14
+ engine = PromptSearchEngine(prompts)
15
+
16
+ serialized_engine = dill.dumps(engine)
17
+
18
+ with open("engine.pickle", "wb") as file:
19
+ file.write(serialized_engine)
20
+
21
+ run()
core/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets==2.20.0
2
+ faiss_cpu==1.8.0.post1
3
+ sentence_transformers==3.0.1
core/search_engine.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Sequence, Tuple
2
+ import numpy as np
3
+ import faiss
4
+ from core.vectorizer import Vectorizer
5
+
6
+
7
+ class PromptSearchEngine(object):
8
+ """
9
+ TODO
10
+ """
11
+
12
+ def __init__(self, prompts: Sequence[str]) -> None:
13
+ """
14
+ TODO
15
+ """
16
+
17
+ self.vectorizer = Vectorizer()
18
+ self.corpus_vectors = self.vectorizer.transform(prompts)
19
+ self.corpus = prompts
20
+
21
+ self.corpus_vectors = self.corpus_vectors / np.linalg.norm(self.corpus_vectors, axis=1, keepdims=True)
22
+
23
+ d = self.corpus_vectors.shape[1]
24
+ self.index = faiss.IndexFlatIP(d)
25
+ self.index.add(self.corpus_vectors.astype('float32'))
26
+
27
+ def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
28
+ """
29
+ TODO
30
+ """
31
+
32
+ query_vector = self.vectorizer.transform([query]).astype('float32')
33
+ query_vector = query_vector / np.linalg.norm(query_vector)
34
+ distances, indices = self.index.search(query_vector, n)
35
+
36
+ return [(distances[0][i], self.corpus[indices[0][i]]) for i in range(n)]
core/vectorizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+
6
+ class Vectorizer(object):
7
+ """
8
+ TODO
9
+ """
10
+
11
+ def __init__(self, model_name: str = 'all-MiniLM-L6-v2') -> None:
12
+ """
13
+ Initialize the vectorizer with a pre-trained embedding model.
14
+ """
15
+
16
+ self.model = SentenceTransformer(model_name)
17
+
18
+ def transform(self, prompts: Sequence[str]) -> np.ndarray:
19
+ """
20
+ Transform texts into numerical vectors using the specified model.
21
+ """
22
+
23
+ return self.model.encode(list(prompts))
24
+
25
+ @staticmethod
26
+ def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
27
+ """
28
+ Calculate cosine similarity between prompt vectors.
29
+ """
30
+
31
+ query_norm = query_vector / np.linalg.norm(query_vector)
32
+ corpus_norms = corpus_vectors / np.linalg.norm(corpus_vectors, axis=1, keepdims=True)
33
+
34
+ return np.dot(corpus_norms, query_norm.T).flatten()