krstakis commited on
Commit
b652e4e
·
1 Parent(s): f526df9

fixing everything.....

Browse files
Dockerfile CHANGED
@@ -1,63 +1,92 @@
1
- # First stage: Build and load the dataset
2
- FROM python:3.9 as core
 
3
 
4
- # Set the working directory
5
- WORKDIR /app
6
 
7
- # Copy and install core requirements
8
  COPY ./core/requirements.txt ./core/requirements.txt
9
- RUN pip install -r ./core/requirements.txt
10
-
11
- # Copy the core files
12
- COPY ./core ./core
13
 
14
- # Set the PYTHONPATH to include the /app directory
15
- ENV PYTHONPATH="/app"
 
16
 
17
- # Set the HF_HOME to a writable directory
18
- ENV HF_HOME="/app/cache"
19
 
20
- # Create the cache directory and set correct permissions
21
- RUN mkdir -p /app/cache && chmod -R 777 /app/cache
22
-
23
- # Run the initialization script to load and serialize the dataset
24
- RUN python ./core/initialization.py
25
-
26
- # Second stage: Set up the API
27
  FROM python:3.9
 
28
 
29
- # Set the working directory
30
- WORKDIR /app
31
-
32
- # Copy and install core and API requirements
33
- COPY ./core/requirements.txt ./core/requirements.txt
34
- COPY ./api/requirements.txt ./api/requirements.txt
35
- RUN pip install -r ./core/requirements.txt
36
- RUN pip install -r ./api/requirements.txt
37
-
38
- # Copy the API files
39
- COPY ./api ./api
40
-
41
- # Copy the core files to the second stage to ensure search_engine is available
42
  COPY ./core ./core
 
43
 
44
- # Copy the serialized engine from the first stage to the API directory
45
- COPY --from=core /app/core/engine.pickle /app/api/engine.pickle
46
-
47
- # Set the PYTHONPATH to include the /app directory
48
- ENV PYTHONPATH="/app"
49
-
50
- # Set the HF_HOME to a writable directory
51
- ENV HF_HOME="/app/cache"
52
-
53
- # Create the cache directory and set correct permissions
54
- RUN mkdir -p /app/cache && chmod -R 777 /app/cache
55
 
56
- # Expose the API port
57
  EXPOSE 7860
58
-
59
- # Run the service manager
60
  ENTRYPOINT ["python", "api/service_manager.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # FROM python:3.9 as core
62
  #
63
  # COPY ./core/requirements.txt ./requirements.txt
 
1
+ FROM python:3.9 AS install
2
+ RUN apt-get update
3
+ RUN apt-get install -y --no-install-recommends build-essential gcc
4
 
5
+ COPY ./api/requirements.txt ./api/requirements.txt
6
+ RUN pip install --user -r ./api/requirements.txt
7
 
 
8
  COPY ./core/requirements.txt ./core/requirements.txt
9
+ RUN pip install --user -r ./core/requirements.txt
 
 
 
10
 
11
+ ##################################################################
12
+ FROM python:3.9 AS setup
13
+ COPY --from=install /root/.local /root/.local
14
 
15
+ COPY ./core .
16
+ RUN python ./initialization.py
17
 
18
+ ##################################################################
 
 
 
 
 
 
19
  FROM python:3.9
20
+ COPY --from=install /root/.local /root/.local
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  COPY ./core ./core
23
+ COPY ./api ./api
24
 
25
+ COPY --from=setup /engine.pickle /engine.pickle
 
 
 
 
 
 
 
 
 
 
26
 
 
27
  EXPOSE 7860
 
 
28
  ENTRYPOINT ["python", "api/service_manager.py"]
29
+
30
+ # # First stage: Build and load the dataset
31
+ # FROM python:3.9 as core
32
+ #
33
+ # # Set the working directory
34
+ # WORKDIR /app
35
+ #
36
+ # # Copy and install core requirements
37
+ # COPY ./core/requirements.txt ./core/requirements.txt
38
+ # RUN pip install -r ./core/requirements.txt
39
+ #
40
+ # # Copy the core files
41
+ # COPY ./core ./core
42
+ #
43
+ # # Set the PYTHONPATH to include the /app directory
44
+ # ENV PYTHONPATH="/app"
45
+ #
46
+ # # Set the HF_HOME to a writable directory
47
+ # ENV HF_HOME="/app/cache"
48
+ #
49
+ # # Create the cache directory and set correct permissions
50
+ # RUN mkdir -p /app/cache && chmod -R 777 /app/cache
51
+ #
52
+ # # Run the initialization script to load and serialize the dataset
53
+ # RUN python ./core/initialization.py
54
+ #
55
+ # # Second stage: Set up the API
56
+ # FROM python:3.9
57
+ #
58
+ # # Set the working directory
59
+ # WORKDIR /app
60
+ #
61
+ # # Copy and install core and API requirements
62
+ # COPY ./core/requirements.txt ./core/requirements.txt
63
+ # COPY ./api/requirements.txt ./api/requirements.txt
64
+ # RUN pip install -r ./core/requirements.txt
65
+ # RUN pip install -r ./api/requirements.txt
66
+ #
67
+ # # Copy the API files
68
+ # COPY ./api ./api
69
+ #
70
+ # # Copy the core files to the second stage to ensure search_engine is available
71
+ # COPY ./core ./core
72
+ #
73
+ # # Copy the serialized engine from the first stage to the API directory
74
+ # COPY --from=core /app/core/engine.pickle /app/api/engine.pickle
75
+ #
76
+ # # Set the PYTHONPATH to include the /app directory
77
+ # ENV PYTHONPATH="/app"
78
+ #
79
+ # # Set the HF_HOME to a writable directory
80
+ # ENV HF_HOME="/app/cache"
81
+ #
82
+ # # Create the cache directory and set correct permissions
83
+ # RUN mkdir -p /app/cache && chmod -R 777 /app/cache
84
+ #
85
+ # # Expose the API port
86
+ # EXPOSE 7860
87
+ #
88
+ # # Run the service manager
89
+ # ENTRYPOINT ["python", "api/service_manager.py"]
90
  # FROM python:3.9 as core
91
  #
92
  # COPY ./core/requirements.txt ./requirements.txt
api/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (175 Bytes). View file
 
api/__pycache__/service_manager.cpython-39.pyc ADDED
Binary file (414 Bytes). View file
 
api/__pycache__/web_server.cpython-39.pyc ADDED
Binary file (1.33 kB). View file
 
api/service_manager.py CHANGED
@@ -1,12 +1,11 @@
1
- from api.web_server import app
2
  import uvicorn
3
 
 
 
4
 
5
  def run():
6
  """
7
- TODO
8
  """
9
 
10
- uvicorn.run(app, host="0.0.0.0", port=8000)
11
-
12
-
 
 
1
  import uvicorn
2
 
3
+ from .web_server import app
4
+
5
 
6
  def run():
7
  """
8
+ Start the FastAPI web server using Uvicorn.
9
  """
10
 
11
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
api/web_server.py CHANGED
@@ -1,7 +1,8 @@
1
  import dill
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- from core.search_engine import PromptSearchEngine
 
5
 
6
 
7
  class Query(BaseModel):
@@ -11,7 +12,7 @@ class Query(BaseModel):
11
 
12
  app = FastAPI()
13
 
14
- with open('api/engine.pickle', 'rb') as file:
15
  serialized_engine = file.read()
16
 
17
  prompt_search_engine = dill.loads(serialized_engine)
@@ -20,7 +21,22 @@ prompt_search_engine = dill.loads(serialized_engine)
20
  @app.post("/search/")
21
  async def search(query: Query):
22
  """
23
- TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  """
25
 
26
  try:
@@ -31,7 +47,9 @@ async def search(query: Query):
31
  raise ValueError("Prompt must be a string")
32
 
33
  results = prompt_search_engine.most_similar(query.prompt, query.n)
34
- formatted_results = [{"score": float(score), "description": desc} for score, desc in results]
 
 
35
 
36
  return formatted_results
37
 
 
1
  import dill
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
+
5
+ # from ..core.search_engine import PromptSearchEngine
6
 
7
 
8
  class Query(BaseModel):
 
12
 
13
  app = FastAPI()
14
 
15
+ with open("./engine.pickle", "rb") as file:
16
  serialized_engine = file.read()
17
 
18
  prompt_search_engine = dill.loads(serialized_engine)
 
21
  @app.post("/search/")
22
  async def search(query: Query):
23
  """
24
+ Find the most similar prompts to a given query prompt using the pre-trained PromptSearchEngine.
25
+
26
+ This endpoint accepts a query prompt and returns a specified number of the most similar prompts
27
+ from the corpus. It performs the following steps:
28
+ 1. Validates the input types.
29
+ 2. Uses the pre-loaded PromptSearchEngine to find the most similar prompts.
30
+ 3. Formats the results into a list of dictionaries containing the similarity score and prompt text.
31
+
32
+ Args:
33
+ query (Query): The query model containing the prompt text and the number of similar prompts to return.
34
+
35
+ Returns:
36
+ List[Dict[str, Union[float, str]]]: A list of dictionaries where each dictionary contains the similarity score and the corresponding prompt.
37
+
38
+ Raises:
39
+ HTTPException: If an error occurs during the processing of the query, an HTTP 500 error is raised with the error details.
40
  """
41
 
42
  try:
 
47
  raise ValueError("Prompt must be a string")
48
 
49
  results = prompt_search_engine.most_similar(query.prompt, query.n)
50
+ formatted_results = [
51
+ {"score": float(score), "description": desc} for score, desc in results
52
+ ]
53
 
54
  return formatted_results
55
 
core/initialization.py CHANGED
@@ -1,11 +1,19 @@
1
  import dill
2
- from core.search_engine import PromptSearchEngine
3
- from core.data.dataset import PromptDataset
 
4
 
5
 
6
  def run():
7
  """
8
- TODO
 
 
 
 
 
 
 
9
  """
10
 
11
  prompt_dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
@@ -15,7 +23,5 @@ def run():
15
 
16
  serialized_engine = dill.dumps(engine)
17
 
18
- with open("core/engine.pickle", "wb") as file:
19
  file.write(serialized_engine)
20
-
21
- run()
 
1
  import dill
2
+
3
+ from .data.dataset import PromptDataset
4
+ from .search_engine import PromptSearchEngine
5
 
6
 
7
  def run():
8
  """
9
+ Initialize the PromptSearchEngine with prompts from the specified dataset,
10
+ serialize the engine, and save it to a file.
11
+
12
+ This function performs the following steps:
13
+ 1. Loads a dataset of prompts using the PromptDataset class.
14
+ 2. Initializes the PromptSearchEngine with the loaded prompts.
15
+ 3. Serializes the PromptSearchEngine instance using dill.
16
+ 4. Saves the serialized engine to a file named 'engine.pickle'.
17
  """
18
 
19
  prompt_dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
 
23
 
24
  serialized_engine = dill.dumps(engine)
25
 
26
+ with open("engine.pickle", "wb") as file:
27
  file.write(serialized_engine)
 
 
core/search_engine.py CHANGED
@@ -1,35 +1,50 @@
1
  from typing import List, Sequence, Tuple
2
- import numpy as np
3
  import faiss
 
 
4
  from .vectorizer import Vectorizer
5
 
6
 
7
- class PromptSearchEngine(object):
8
  """
9
- TODO
 
10
  """
11
 
12
  def __init__(self, prompts: Sequence[str]) -> None:
13
  """
14
- TODO
 
 
 
15
  """
16
 
17
  self.vectorizer = Vectorizer()
18
  self.corpus_vectors = self.vectorizer.transform(prompts)
19
  self.corpus = prompts
20
 
21
- self.corpus_vectors = self.corpus_vectors / np.linalg.norm(self.corpus_vectors, axis=1, keepdims=True)
 
 
22
 
23
  d = self.corpus_vectors.shape[1]
24
  self.index = faiss.IndexFlatIP(d)
25
- self.index.add(self.corpus_vectors.astype('float32'))
26
 
27
  def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
28
  """
29
- TODO
 
 
 
 
 
 
 
30
  """
31
 
32
- query_vector = self.vectorizer.transform([query]).astype('float32')
33
  query_vector = query_vector / np.linalg.norm(query_vector)
34
  distances, indices = self.index.search(query_vector, n)
35
 
 
1
  from typing import List, Sequence, Tuple
2
+
3
  import faiss
4
+ import numpy as np
5
+
6
  from .vectorizer import Vectorizer
7
 
8
 
9
+ class PromptSearchEngine:
10
  """
11
+ The PromptSearchEngine is responsible for finding the most similar prompts to a given query
12
+ by leveraging vectorized representations of the prompts and a similarity search index.
13
  """
14
 
15
  def __init__(self, prompts: Sequence[str]) -> None:
16
  """
17
+ Initialize the PromptSearchEngine with a list of prompts.
18
+
19
+ Args:
20
+ prompts (Sequence[str]): The sequence of raw corpus prompts to be indexed for similarity search.
21
  """
22
 
23
  self.vectorizer = Vectorizer()
24
  self.corpus_vectors = self.vectorizer.transform(prompts)
25
  self.corpus = prompts
26
 
27
+ self.corpus_vectors = self.corpus_vectors / np.linalg.norm(
28
+ self.corpus_vectors, axis=1, keepdims=True
29
+ )
30
 
31
  d = self.corpus_vectors.shape[1]
32
  self.index = faiss.IndexFlatIP(d)
33
+ self.index.add(self.corpus_vectors.astype("float32"))
34
 
35
  def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
36
  """
37
+ Find the most similar prompts to a given query.
38
+
39
+ Args:
40
+ query (str): The query prompt to search for similar prompts.
41
+ n (int, optional): The number of similar prompts to retrieve. Defaults to 5.
42
+
43
+ Returns:
44
+ List[Tuple[float, str]]: A list of tuples containing the similarity score and the corresponding prompt.
45
  """
46
 
47
+ query_vector = self.vectorizer.transform([query]).astype("float32")
48
  query_vector = query_vector / np.linalg.norm(query_vector)
49
  distances, indices = self.index.search(query_vector, n)
50
 
core/vectorizer.py CHANGED
@@ -1,16 +1,23 @@
1
  from typing import Sequence
 
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
 
5
 
6
- class Vectorizer(object):
7
  """
8
- TODO
 
 
9
  """
10
 
11
- def __init__(self, model_name: str = 'all-MiniLM-L6-v2') -> None:
12
  """
13
  Initialize the vectorizer with a pre-trained embedding model.
 
 
 
 
14
  """
15
 
16
  self.model = SentenceTransformer(model_name)
@@ -18,17 +25,37 @@ class Vectorizer(object):
18
  def transform(self, prompts: Sequence[str]) -> np.ndarray:
19
  """
20
  Transform texts into numerical vectors using the specified model.
 
 
 
 
 
 
 
21
  """
22
 
23
  return self.model.encode(list(prompts))
24
 
25
  @staticmethod
26
- def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
 
 
27
  """
28
- Calculate cosine similarity between prompt vectors.
 
 
 
 
 
 
 
 
 
29
  """
30
-
31
  query_norm = query_vector / np.linalg.norm(query_vector)
32
- corpus_norms = corpus_vectors / np.linalg.norm(corpus_vectors, axis=1, keepdims=True)
 
 
33
 
34
  return np.dot(corpus_norms, query_norm.T).flatten()
 
1
  from typing import Sequence
2
+
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
 
6
 
7
+ class Vectorizer:
8
  """
9
+ The Vectorizers role is to transform textual prompts into numerical vectors that can be
10
+ compared in a high-dimensional space. This transformation allows the system to quantify the
11
+ similarity between different prompts effectively.
12
  """
13
 
14
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2") -> None:
15
  """
16
  Initialize the vectorizer with a pre-trained embedding model.
17
+
18
+ Args:
19
+ model_name (str): The pre-trained embedding model to use for transforming prompts.
20
+ This can be any model that provides a method to convert texts into numerical vectors.
21
  """
22
 
23
  self.model = SentenceTransformer(model_name)
 
25
  def transform(self, prompts: Sequence[str]) -> np.ndarray:
26
  """
27
  Transform texts into numerical vectors using the specified model.
28
+
29
+ Args:
30
+ prompts (Sequence[str]): The sequence of raw corpus prompts to be transformed.
31
+
32
+ Returns:
33
+ np.ndarray: A numpy array containing the vectorized prompts. Each row corresponds to the
34
+ vector representation of a prompt.
35
  """
36
 
37
  return self.model.encode(list(prompts))
38
 
39
  @staticmethod
40
+ def cosine_similarity(
41
+ query_vector: np.ndarray, corpus_vectors: np.ndarray
42
+ ) -> np.ndarray:
43
  """
44
+ Calculate cosine similarity between a query vector and a set of corpus vectors.
45
+
46
+ Args:
47
+ query_vector (np.ndarray): A numpy array representing the vector of the query prompt.
48
+ corpus_vectors (np.ndarray): A numpy array representing the vectors of the corpus prompts.
49
+ Each row corresponds to the vector representation of a corpus prompt.
50
+
51
+ Returns:
52
+ np.ndarray: A numpy array containing the cosine similarity scores between the query vector and each
53
+ of the corpus vectors.
54
  """
55
+
56
  query_norm = query_vector / np.linalg.norm(query_vector)
57
+ corpus_norms = corpus_vectors / np.linalg.norm(
58
+ corpus_vectors, axis=1, keepdims=True
59
+ )
60
 
61
  return np.dot(corpus_norms, query_norm.T).flatten()
requirements.txt CHANGED
@@ -1,8 +1,10 @@
 
1
  datasets==2.20.0
2
  dill==0.3.8
3
  faiss_cpu==1.8.0.post1
4
  fastapi==0.111.1
 
5
  pydantic==2.8.2
6
  sentence_transformers==3.0.1
 
7
  uvicorn==0.30.3
8
- numpy==1.23.5
 
1
+ Requests==2.32.3
2
  datasets==2.20.0
3
  dill==0.3.8
4
  faiss_cpu==1.8.0.post1
5
  fastapi==0.111.1
6
+ numpy==1.23.5
7
  pydantic==2.8.2
8
  sentence_transformers==3.0.1
9
+ streamlit==1.36.0
10
  uvicorn==0.30.3
 
ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # This file is intentionally left empty
ui/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Requests==2.32.3
2
+ streamlit==1.36.0
ui/streamlit_app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This Streamlit application interfaces with a FastAPI backend to search for prompts based on user input.
3
+
4
+ The user can enter a prompt and specify the number of similar results they want to retrieve.
5
+ The application then sends a request to the FastAPI endpoint and displays the results.
6
+ """
7
+
8
+ import requests
9
+ import streamlit as st
10
+
11
+ API_URL = "http://localhost:8000/search/"
12
+
13
+ st.title("Prompt Search Engine")
14
+
15
+ prompt = st.text_input("Enter a prompt")
16
+ n = st.slider("Number of results", 1, 20, 5)
17
+
18
+ if st.button("Search"):
19
+ response = requests.post(API_URL, json={"prompt": prompt, "n": n})
20
+ if response.status_code == 200:
21
+ results = response.json()
22
+ for result in results:
23
+ score = result["score"]
24
+ result_prompt = result["description"]
25
+ st.write(f"Score: {score:.4f}, Prompt: {result_prompt}")
26
+ else:
27
+ st.error("Error: Could not retrieve results")