File size: 624 Bytes
8a0c27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from datasets import load_dataset


class PromptDataset:
    """
    TODO
    """

    def __init__(self, dataset_name: str):
        """
        TODO
        """

        self.dataset_name = dataset_name
        self.dataset = None

    def load(self):
        """
        TODO
        """

        self.dataset = load_dataset(self.dataset_name)

        return self.dataset

    def get_prompts(self):
        """
        TODO
        """

        if self.dataset is None:
            raise ValueError("Dataset not loaded. Call the load() method first.")

        return [item['Prompt'] for item in self.dataset['test']]