File size: 2,292 Bytes
85c1145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from pydantic import BaseModel
from promptSearchEngine import PromptSearchEngine
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import streamlit as st

EMBEDDING_MODEL = "all-MiniLM-L6-v2"
DATASET = "Gustavosta/Stable-Diffusion-Prompts"

class SearchRequest(BaseModel):
    query: str 
    n: int | None = 5
    
# model = SentenceTransformer("all-MiniLM-L6-v2")
# dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts" , split="test[:1%]")
# promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model)

@st.cache_resource  
def load_model():
    """Initialize pretrained model for vectorizing.
        @st.cache_resource anotation enables caching for Streamlit.
    """
    return SentenceTransformer(EMBEDDING_MODEL)

@st.cache_resource  
def load_dataSet():
    """Initialize pretrained model for vectorizing.
        @st.cache_resource anotation enables caching for Streamlit. 
    """
    return load_dataset(DATASET , split="test[:1%]")

@st.cache_resource  
def load_searchEngine(prompts, _model):
    """Initialize search engine and vectorize raw propmpts from dataset.
        @st.cache_resource anotation enables caching for Streamlit.
        Args:
        prompts: The sequence of raw prompts from the dataset.
        model: The model for vectorizing.
    """
    return PromptSearchEngine(prompts, _model)

model = load_model()
dataset = load_dataSet()
promptSearchEngine = load_searchEngine(dataset["Prompt"], model)


with st.form("search_form"):
    st.write("Prompt Search Engine")
    query = st.text_area("Prompt to search")
    number = st.number_input("Number of similar prompts", value = 5, min_value=0, max_value=100)
    submitted = st.form_submit_button("Submit")
    if submitted:
        result = promptSearchEngine.most_similar(query, number)
        st.dataframe(
            result, 
            use_container_width=True,
            column_config={
                1: st.column_config.NumberColumn(
                    "Similarity", 
                    help="Range in [-1, 1] where 1 is max similarity, means that prompts are identical.",
                    format= "%.4f"
                ),
                2: st.column_config.TextColumn("Prompts", help="The simlar prompts"),
            },
        )