File size: 938 Bytes
8a0c27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from datasets import load_dataset
from core.search_engine import PromptSearchEngine


class PromptDataset:
    """
    TODO
    """

    def __init__(self, dataset_name: str):
        """
        TODO
        """

        self.dataset_name = dataset_name
        self.dataset = None

    def load(self):
        """
        TODO
        """

        self.dataset = load_dataset(self.dataset_name)

        return self.dataset

    def get_prompts(self):
        """
        TODO
        """

        if self.dataset is None:
            raise ValueError("Dataset not loaded. Call the load() method first.")

        return [item['Prompt'] for item in self.dataset['test']]


# if __name__ == "__main__":
#     dataset = PromptDataset("Gustavosta/Stable-Diffusion-Prompts")
#     dataset.load()
#     prompts = dataset.get_prompts()
#     engine = PromptSearchEngine(prompts)
#     result = engine.most_similar("dark")
#     print(result)