File size: 3,076 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pathlib
import multiprocessing
from collections import deque

import h5py
import torch
import numpy as np


class IndexedDataset:
    def __init__(self, path, prefix, num_cache=0):
        super().__init__()
        self.path = pathlib.Path(path) / f'{prefix}.data'
        if not self.path.exists():
            raise FileNotFoundError(f'IndexedDataset not found: {self.path}')
        self.dset = None
        self.cache = deque(maxlen=num_cache)
        self.num_cache = num_cache

    def check_index(self, i):
        if i < 0 or i >= len(self.dset):
            raise IndexError('index out of range')

    def __del__(self):
        if self.dset:
            self.dset.close()

    def __getitem__(self, i):
        if self.dset is None:
            self.dset = h5py.File(self.path, 'r')
        self.check_index(i)
        if self.num_cache > 0:
            for c in self.cache:
                if c[0] == i:
                    return c[1]
        item = {k: v[()].item() if v.shape == () else torch.from_numpy(v[()]) for k, v in self.dset[str(i)].items()}
        if self.num_cache > 0:
            self.cache.appendleft((i, item))
        return item

    def __len__(self):
        if self.dset is None:
            self.dset = h5py.File(self.path, 'r')
        return len(self.dset)


class IndexedDatasetBuilder:
    def __init__(self, path, prefix, allowed_attr=None, auto_increment=True):
        self.path = pathlib.Path(path) / f'{prefix}.data'
        self.prefix = prefix
        self.dset = h5py.File(self.path, 'w')
        self.counter = 0
        self.auto_increment = auto_increment
        if allowed_attr is not None:
            self.allowed_attr = set(allowed_attr)
        else:
            self.allowed_attr = None

    def add_item(self, item, item_no=None):
        if self.auto_increment and item_no is not None or not self.auto_increment and item_no is None:
            raise ValueError('auto_increment and provided item_no are mutually exclusive')
        if self.allowed_attr is not None:
            item = {
                k: item[k]
                for k in self.allowed_attr
                if k in item
            }
        if self.auto_increment:
            item_no = self.counter
            self.counter += 1
        for k, v in item.items():
            if v is None:
                continue
            self.dset.create_dataset(f'{item_no}/{k}', data=v)
        return item_no

    def finalize(self):
        self.dset.close()


if __name__ == "__main__":
    import random
    from tqdm import tqdm

    ds_path = './checkpoints/indexed_ds_example'
    size = 100
    items = [{"a": np.random.normal(size=[10000, 10]),
              "b": np.random.normal(size=[10000, 10])} for i in range(size)]
    builder = IndexedDatasetBuilder(ds_path, 'example')
    for i in tqdm(range(size)):
        builder.add_item(items[i])
    builder.finalize()
    ds = IndexedDataset(ds_path, 'example')
    for i in tqdm(range(10000)):
        idx = random.randint(0, size - 1)
        assert (ds[idx]['a'] == items[idx]['a']).all()