File size: 733 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch.utils.data import Dataset


class ImagesDataset(Dataset):
    def __init__(self, images: dict[torch.Tensor, list[str]] | list[torch.Tensor]):
        if isinstance(images, list):
            images = dict.fromkeys(images)

        self.images = list(images)
        self.names = list(images.values())

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]

        if image.dtype is torch.uint8:
            image = image / 255

        names = self.names[index]
        return image, names


def image_collate(batch):
    images = torch.stack([item[0] for item in batch])
    names = [item[1] for item in batch]
    return images, names