Spaces:
Build error
Build error
""" | |
Copyright (C) 2019 NVIDIA Corporation. All rights reserved. | |
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). | |
""" | |
import importlib | |
import torch.utils.data | |
from .base_dataset import BaseDataset | |
def find_dataset_using_name(dataset_name): | |
# Given the option --dataset [datasetname], | |
# the file "datasets/datasetname_dataset.py" | |
# will be imported. | |
dataset_filename = "sean_codes.data." + dataset_name + "_dataset" | |
datasetlib = importlib.import_module(dataset_filename) | |
# In the file, the class called DatasetNameDataset() will | |
# be instantiated. It has to be a subclass of BaseDataset, | |
# and it is case-insensitive. | |
dataset = None | |
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | |
for name, cls in datasetlib.__dict__.items(): | |
if name.lower() == target_dataset_name.lower() \ | |
and issubclass(cls, BaseDataset): | |
dataset = cls | |
if dataset is None: | |
raise ValueError("In %s.py, there should be a subclass of BaseDataset " | |
"with class name that matches %s in lowercase." % | |
(dataset_filename, target_dataset_name)) | |
return dataset | |
def get_option_setter(dataset_name): | |
dataset_class = find_dataset_using_name(dataset_name) | |
return dataset_class.modify_commandline_options | |
def create_dataloader(opt): | |
dataset = find_dataset_using_name(opt.dataset_mode) | |
instance = dataset() | |
instance.initialize(opt) | |
print("dataset [%s] of size %d was created" % | |
(type(instance).__name__, len(instance))) | |
dataloader = torch.utils.data.DataLoader( | |
instance, | |
batch_size=opt.batchSize, | |
shuffle=not opt.serial_batches, | |
num_workers=min(opt.batchSize, int(opt.nThreads)), | |
drop_last=opt.isTrain, | |
pin_memory=True | |
) | |
return dataloader | |