Spaces:
Build error
Build error
File size: 733 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from torch.utils.data import Dataset
class ImagesDataset(Dataset):
def __init__(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor]):
if isinstance(images, list):
images = dict.fromkeys(images)
self.images = list(images)
self.names = list(images.values())
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image = self.images[index]
if image.dtype is torch.uint8:
image = image / 255
names = self.names[index]
return image, names
def image_collate(batch):
images = torch.stack([item[0] for item in batch])
names = [item[1] for item in batch]
return images, names
|