Spaces:
Sleeping
Sleeping
File size: 1,747 Bytes
c42fe7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
import pickle
import torch
from torch.utils.data import Dataset
from utils.hparams import hparams
from utils.indexed_datasets import IndexedDataset
class BaseDataset(Dataset):
"""
Base class for datasets.
1. *sizes*:
clipped length if "max_frames" is set;
2. *num_frames*:
unclipped length.
Subclasses should define:
1. *collate*:
take the longest data, pad other data to the same length;
2. *__getitem__*:
the index function.
"""
def __init__(self, prefix, size_key='lengths', preload=False):
super().__init__()
self.prefix = prefix
self.data_dir = hparams['binary_data_dir']
with open(os.path.join(self.data_dir, f'{self.prefix}.meta'), 'rb') as f:
self.metadata = pickle.load(f)
self.sizes = self.metadata[size_key]
self._indexed_ds = IndexedDataset(self.data_dir, self.prefix)
if preload:
self.indexed_ds = [self._indexed_ds[i] for i in range(len(self._indexed_ds))]
del self._indexed_ds
else:
self.indexed_ds = self._indexed_ds
def __getitem__(self, index):
return {'_idx': index, **self.indexed_ds[index]}
def __len__(self):
return len(self.sizes)
def num_frames(self, index):
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def collater(self, samples):
return {
'size': len(samples),
'indices': torch.LongTensor([s['_idx'] for s in samples])
}
|