Spaces:
Build error
Build error
import torch | |
from torch.utils.data import Dataset | |
class ImagesDataset(Dataset): | |
def __init__(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor]): | |
if isinstance(images, list): | |
images = dict.fromkeys(images) | |
self.images = list(images) | |
self.names = list(images.values()) | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, index): | |
image = self.images[index] | |
if image.dtype is torch.uint8: | |
image = image / 255 | |
names = self.names[index] | |
return image, names | |
def image_collate(batch): | |
images = torch.stack([item[0] for item in batch]) | |
names = [item[1] for item in batch] | |
return images, names | |