|
import os |
|
|
|
from .parser_image_folder import ParserImageFolder |
|
from .parser_image_tar import ParserImageTar |
|
from .parser_image_in_tar import ParserImageInTar |
|
|
|
|
|
def create_parser(name, root, split='train', **kwargs): |
|
name = name.lower() |
|
name = name.split('/', 2) |
|
prefix = '' |
|
if len(name) > 1: |
|
prefix = name[0] |
|
name = name[-1] |
|
|
|
|
|
|
|
if prefix == 'tfds': |
|
from .parser_tfds import ParserTfds |
|
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) |
|
else: |
|
assert os.path.exists(root) |
|
|
|
|
|
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': |
|
parser = ParserImageInTar(root, **kwargs) |
|
else: |
|
parser = ParserImageFolder(root, **kwargs) |
|
return parser |
|
|