|
import os |
|
|
|
from .dataset import IterableImageDataset, ImageDataset |
|
|
|
|
|
def _search_split(root, split): |
|
|
|
split_name = split.split('[')[0] |
|
try_root = os.path.join(root, split_name) |
|
if os.path.exists(try_root): |
|
return try_root |
|
if split_name == 'validation': |
|
try_root = os.path.join(root, 'val') |
|
if os.path.exists(try_root): |
|
return try_root |
|
return root |
|
|
|
|
|
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): |
|
name = name.lower() |
|
if name.startswith('tfds'): |
|
ds = IterableImageDataset( |
|
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) |
|
else: |
|
|
|
kwargs.pop('repeats', 0) |
|
if search_split and os.path.isdir(root): |
|
root = _search_split(root, split) |
|
ds = ImageDataset(root, parser=name, **kwargs) |
|
return ds |
|
|