import pathlib import multiprocessing from collections import deque import h5py import torch import numpy as np class IndexedDataset: def __init__(self, path, prefix, num_cache=0): super().__init__() self.path = pathlib.Path(path) / f'{prefix}.data' if not self.path.exists(): raise FileNotFoundError(f'IndexedDataset not found: {self.path}') self.dset = None self.cache = deque(maxlen=num_cache) self.num_cache = num_cache def check_index(self, i): if i < 0 or i >= len(self.dset): raise IndexError('index out of range') def __del__(self): if self.dset: self.dset.close() def __getitem__(self, i): if self.dset is None: self.dset = h5py.File(self.path, 'r') self.check_index(i) if self.num_cache > 0: for c in self.cache: if c[0] == i: return c[1] item = {k: v[()].item() if v.shape == () else torch.from_numpy(v[()]) for k, v in self.dset[str(i)].items()} if self.num_cache > 0: self.cache.appendleft((i, item)) return item def __len__(self): if self.dset is None: self.dset = h5py.File(self.path, 'r') return len(self.dset) class IndexedDatasetBuilder: def __init__(self, path, prefix, allowed_attr=None, auto_increment=True): self.path = pathlib.Path(path) / f'{prefix}.data' self.prefix = prefix self.dset = h5py.File(self.path, 'w') self.counter = 0 self.auto_increment = auto_increment if allowed_attr is not None: self.allowed_attr = set(allowed_attr) else: self.allowed_attr = None def add_item(self, item, item_no=None): if self.auto_increment and item_no is not None or not self.auto_increment and item_no is None: raise ValueError('auto_increment and provided item_no are mutually exclusive') if self.allowed_attr is not None: item = { k: item[k] for k in self.allowed_attr if k in item } if self.auto_increment: item_no = self.counter self.counter += 1 for k, v in item.items(): if v is None: continue self.dset.create_dataset(f'{item_no}/{k}', data=v) return item_no def finalize(self): self.dset.close() if __name__ == "__main__": import random from tqdm import tqdm ds_path = './checkpoints/indexed_ds_example' size = 100 items = [{"a": np.random.normal(size=[10000, 10]), "b": np.random.normal(size=[10000, 10])} for i in range(size)] builder = IndexedDatasetBuilder(ds_path, 'example') for i in tqdm(range(size)): builder.add_item(items[i]) builder.finalize() ds = IndexedDataset(ds_path, 'example') for i in tqdm(range(10000)): idx = random.randint(0, size - 1) assert (ds[idx]['a'] == items[idx]['a']).all()