sasagema commited on
Commit
85c1145
·
1 Parent(s): 302a52d

Added files

Browse files
Files changed (7) hide show
  1. Dockerfile +16 -0
  2. promptSearchEngine.py +83 -0
  3. requirements.txt +0 -0
  4. run.py +46 -0
  5. run_local_ui.py +67 -0
  6. run_ui.py +33 -0
  7. vectorizer.py +24 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "run:app", "--host", "0.0.0.0", "--port", "7860"]
promptSearchEngine.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Sequence, Tuple
2
+ import numpy as np
3
+ from vectorizer import Vectorizer
4
+
5
+ def cosine_similarity(
6
+ query_vector: np.ndarray,
7
+ corpus_vectors: np.ndarray
8
+ )-> np.ndarray:
9
+
10
+ """Calculate cosine similarity between prompt vectors.
11
+ Args:
12
+ query_vector: Vectorized prompt query of shape (1, D).
13
+ corpus_vectors: Vectorized prompt corpus of shape (N, D).
14
+ Returns: The vector of shape (N,) with values in range [-1, 1] where 1
15
+ is max similarity i.e., two vectors are the same.
16
+ """
17
+ dot_product = np.dot( corpus_vectors, query_vector)
18
+ magnitude_A = np.linalg.norm(corpus_vectors, axis=1)
19
+ magnitude_B = np.linalg.norm(query_vector)
20
+
21
+ cosine_sim = dot_product / (magnitude_A * magnitude_B)
22
+ return np.around(cosine_sim, 4)
23
+ # return np.format_float_positional(cosine_sim, precision = 4)
24
+
25
+
26
+ class PromptSearchEngine:
27
+ def __init__(self, prompts: Sequence[str], model) -> None:
28
+ """Initialize search engine by vectorizing prompt corpus.
29
+ Vectorized prompt corpus should be used to find the top n most
30
+ similar prompts w.r.t. user’s input prompt.
31
+ Args:
32
+ prompts: The sequence of raw prompts from the dataset.
33
+ """
34
+ self.prompts = prompts
35
+ self.vectorizer = Vectorizer(model)
36
+ self.corpus_embeddings = self.vectorizer.transform(prompts)
37
+ def most_similar(
38
+ self,
39
+ query: str,
40
+ n: int = 5
41
+ ) -> List[Tuple[float, str]]:
42
+ """Return top n most similar prompts from corpus.
43
+ Input query prompt should be vectorized with chosen Vectorizer.
44
+ After
45
+ that, use the cosine_similarity function to get the top n most
46
+ similar
47
+ prompts from the corpus.
48
+ Args:
49
+ query: The raw query prompt input from the user.
50
+ n: The number of similar prompts returned from the corpus.
51
+ Returns:
52
+ The list of top n most similar prompts from the corpus along
53
+ with similarity scores. Note that returned prompts are
54
+ verbatim.
55
+ """
56
+ most_similar_prompts = []
57
+ prompt_embedding = self.vectorizer.transform([query]).flatten()
58
+ corpus_embeddings = self.corpus_embeddings
59
+
60
+ result = cosine_similarity(prompt_embedding, corpus_embeddings)
61
+
62
+ for i in range(len(self.prompts)):
63
+ most_similar_prompts.append((result[i], self.prompts[i]))
64
+
65
+ prompt_score_sorted = sorted(most_similar_prompts, key=lambda x: x[0], reverse=True)
66
+
67
+ return prompt_score_sorted[0:n]
68
+ def display_prompts(self, prompts):
69
+ """Display the list of prompts with their similarity scores."""
70
+ if prompts:
71
+ for i, (score, prompt) in enumerate(prompts, 1):
72
+ print(f"{i}. {prompt} (Similarity: {score:.4f})")
73
+ else:
74
+ print("No prompts found.")
75
+ def stringify_prompts(self, prompts):
76
+ """Save the list of prompts with their similarity scores."""
77
+ strings = []
78
+ if prompts:
79
+ for i, (score, prompt) in enumerate(prompts, 1):
80
+ strings.append(f"{i}. {prompt} (Similarity: {score:.4f})")
81
+ return strings
82
+ else:
83
+ return []
requirements.txt ADDED
Binary file (260 Bytes). View file
 
run.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from promptSearchEngine import PromptSearchEngine
5
+ from vectorizer import Vectorizer
6
+ from datasets import load_dataset
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
10
+ DATASET = "Gustavosta/Stable-Diffusion-Prompts"
11
+
12
+
13
+
14
+ model = SentenceTransformer(EMBEDDING_MODEL)
15
+ dataset = load_dataset(DATASET , split="test[:1%]")
16
+ promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model)
17
+
18
+ class SearchRequest(BaseModel):
19
+ query: str
20
+ n: int | None = 5
21
+
22
+ app = FastAPI()
23
+
24
+ @app.get("/")
25
+ async def root():
26
+ return {"message": 'GET /docs'}
27
+
28
+ @app.get("/search")
29
+ async def search(q: str, n: int = 5):
30
+ results = []
31
+ if q.isspace() or q =="":
32
+ return {"message": "Enter query"}
33
+ else:
34
+ results = promptSearchEngine.most_similar(q, n)
35
+ if not results:
36
+ raise HTTPException(status_code=404, detail="No prompts found.")
37
+ return promptSearchEngine.stringify_prompts(results)
38
+
39
+
40
+ @app.post("/search")
41
+ async def searchPost(request: SearchRequest):
42
+ results = promptSearchEngine.most_similar(request.query, request.n)
43
+ if not results:
44
+ raise HTTPException(status_code=404, detail="No prompts found.")
45
+ formatted_results = [{"similarity": float(similarity), "prompt": prompt } for similarity, prompt in results]
46
+ return { "data" : formatted_results }
run_local_ui.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from promptSearchEngine import PromptSearchEngine
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ import streamlit as st
6
+
7
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
8
+ DATASET = "Gustavosta/Stable-Diffusion-Prompts"
9
+
10
+ class SearchRequest(BaseModel):
11
+ query: str
12
+ n: int | None = 5
13
+
14
+ # model = SentenceTransformer("all-MiniLM-L6-v2")
15
+ # dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts" , split="test[:1%]")
16
+ # promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model)
17
+
18
+ @st.cache_resource
19
+ def load_model():
20
+ """Initialize pretrained model for vectorizing.
21
+ @st.cache_resource anotation enables caching for Streamlit.
22
+ """
23
+ return SentenceTransformer(EMBEDDING_MODEL)
24
+
25
+ @st.cache_resource
26
+ def load_dataSet():
27
+ """Initialize pretrained model for vectorizing.
28
+ @st.cache_resource anotation enables caching for Streamlit.
29
+ """
30
+ return load_dataset(DATASET , split="test[:1%]")
31
+
32
+ @st.cache_resource
33
+ def load_searchEngine(prompts, _model):
34
+ """Initialize search engine and vectorize raw propmpts from dataset.
35
+ @st.cache_resource anotation enables caching for Streamlit.
36
+ Args:
37
+ prompts: The sequence of raw prompts from the dataset.
38
+ model: The model for vectorizing.
39
+ """
40
+ return PromptSearchEngine(prompts, _model)
41
+
42
+ model = load_model()
43
+ dataset = load_dataSet()
44
+ promptSearchEngine = load_searchEngine(dataset["Prompt"], model)
45
+
46
+
47
+ with st.form("search_form"):
48
+ st.write("Prompt Search Engine")
49
+ query = st.text_area("Prompt to search")
50
+ number = st.number_input("Number of similar prompts", value = 5, min_value=0, max_value=100)
51
+ submitted = st.form_submit_button("Submit")
52
+ if submitted:
53
+ result = promptSearchEngine.most_similar(query, number)
54
+ st.dataframe(
55
+ result,
56
+ use_container_width=True,
57
+ column_config={
58
+ 1: st.column_config.NumberColumn(
59
+ "Similarity",
60
+ help="Range in [-1, 1] where 1 is max similarity, means that prompts are identical.",
61
+ format= "%.4f"
62
+ ),
63
+ 2: st.column_config.TextColumn("Prompts", help="The simlar prompts"),
64
+ },
65
+ )
66
+
67
+
run_ui.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from promptSearchEngine import PromptSearchEngine
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ import streamlit as st
6
+ import requests
7
+ import json
8
+
9
+ st.title('Prompt Search Engine')
10
+
11
+ with st.form("search_form"):
12
+ st.write("Prompt Search Engine")
13
+ query = st.text_area("Prompt to search")
14
+ number = st.number_input("Number of similar prompts", value = 5, min_value=0, max_value=100)
15
+ submitted = st.form_submit_button("Submit")
16
+ if submitted:
17
+ inputs = {"query": query, "n": number}
18
+ result = requests.post(url = "http://localhost:8000/search", data = json.dumps(inputs))
19
+ result = result.json()
20
+ st.dataframe(
21
+ result["data"],
22
+ use_container_width=True,
23
+ column_config={
24
+ "similarity": st.column_config.NumberColumn(
25
+ "Similarity",
26
+ help="Range in [-1, 1] where 1 is max similarity, means that prompts are identical.",
27
+ format= "%.4f"
28
+ ),
29
+ "prompt": st.column_config.TextColumn("Prompts", help="The simlar prompts"),
30
+ },
31
+ )
32
+
33
+
vectorizer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import numpy as np
3
+
4
+
5
+ class Vectorizer:
6
+ def __init__(self, model) -> None:
7
+ """Initialize the vectorizer with a pre-trained embedding model.
8
+ Args:
9
+ model: The pre-trained embedding model to use for transforming
10
+ prompts.
11
+ """
12
+
13
+ self.model = model
14
+
15
+ def transform(self, prompts: Sequence[str]) -> np.ndarray:
16
+
17
+ """Transform texts into numerical vectors using the specified
18
+ model.
19
+ Args:
20
+ prompts: The sequence of raw corpus prompts. Returns:
21
+ Vectorized
22
+ prompts as a numpy array."""
23
+ vectorized = self.model.encode(prompts, show_progress_bar=True)
24
+ return vectorized