Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import torch | |
from torch.utils.data import Dataset | |
from utils.hparams import hparams | |
from utils.indexed_datasets import IndexedDataset | |
class BaseDataset(Dataset): | |
""" | |
Base class for datasets. | |
1. *sizes*: | |
clipped length if "max_frames" is set; | |
2. *num_frames*: | |
unclipped length. | |
Subclasses should define: | |
1. *collate*: | |
take the longest data, pad other data to the same length; | |
2. *__getitem__*: | |
the index function. | |
""" | |
def __init__(self, prefix, size_key='lengths', preload=False): | |
super().__init__() | |
self.prefix = prefix | |
self.data_dir = hparams['binary_data_dir'] | |
with open(os.path.join(self.data_dir, f'{self.prefix}.meta'), 'rb') as f: | |
self.metadata = pickle.load(f) | |
self.sizes = self.metadata[size_key] | |
self._indexed_ds = IndexedDataset(self.data_dir, self.prefix) | |
if preload: | |
self.indexed_ds = [self._indexed_ds[i] for i in range(len(self._indexed_ds))] | |
del self._indexed_ds | |
else: | |
self.indexed_ds = self._indexed_ds | |
def __getitem__(self, index): | |
return {'_idx': index, **self.indexed_ds[index]} | |
def __len__(self): | |
return len(self.sizes) | |
def num_frames(self, index): | |
return self.sizes[index] | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
return self.sizes[index] | |
def collater(self, samples): | |
return { | |
'size': len(samples), | |
'indices': torch.LongTensor([s['_idx'] for s in samples]) | |
} | |