Spaces:
Running
Running
import argparse | |
from dataclasses import ( | |
asdict, | |
dataclass, | |
) | |
import functools | |
import random | |
from textwrap import dedent, indent | |
import json | |
from pathlib import Path | |
# from toolz import curry | |
from typing import Dict, List, Optional, Sequence, Tuple, Union | |
import toml | |
import voluptuous | |
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema | |
from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class BaseDatasetParams: | |
resolution: Tuple[int, int] = (960, 544) | |
enable_bucket: bool = False | |
bucket_no_upscale: bool = False | |
caption_extension: Optional[str] = None | |
batch_size: int = 1 | |
num_repeats: int = 1 | |
cache_directory: Optional[str] = None | |
debug_dataset: bool = False | |
class ImageDatasetParams(BaseDatasetParams): | |
image_directory: Optional[str] = None | |
image_jsonl_file: Optional[str] = None | |
class VideoDatasetParams(BaseDatasetParams): | |
video_directory: Optional[str] = None | |
video_jsonl_file: Optional[str] = None | |
target_frames: Sequence[int] = (1,) | |
frame_extraction: Optional[str] = "head" | |
frame_stride: Optional[int] = 1 | |
frame_sample: Optional[int] = 1 | |
class DatasetBlueprint: | |
is_image_dataset: bool | |
params: Union[ImageDatasetParams, VideoDatasetParams] | |
class DatasetGroupBlueprint: | |
datasets: Sequence[DatasetBlueprint] | |
class Blueprint: | |
dataset_group: DatasetGroupBlueprint | |
class ConfigSanitizer: | |
# @curry | |
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: | |
Schema(ExactSequence([klass, klass]))(value) | |
return tuple(value) | |
# @curry | |
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: | |
Schema(Any(klass, ExactSequence([klass, klass])))(value) | |
try: | |
Schema(klass)(value) | |
return (value, value) | |
except: | |
return ConfigSanitizer.__validate_and_convert_twodim(klass, value) | |
# datasets schema | |
DATASET_ASCENDABLE_SCHEMA = { | |
"caption_extension": str, | |
"batch_size": int, | |
"num_repeats": int, | |
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), | |
"enable_bucket": bool, | |
"bucket_no_upscale": bool, | |
} | |
IMAGE_DATASET_DISTINCT_SCHEMA = { | |
"image_directory": str, | |
"image_jsonl_file": str, | |
"cache_directory": str, | |
} | |
VIDEO_DATASET_DISTINCT_SCHEMA = { | |
"video_directory": str, | |
"video_jsonl_file": str, | |
"target_frames": [int], | |
"frame_extraction": str, | |
"frame_stride": int, | |
"frame_sample": int, | |
"cache_directory": str, | |
} | |
# options handled by argparse but not handled by user config | |
ARGPARSE_SPECIFIC_SCHEMA = { | |
"debug_dataset": bool, | |
} | |
def __init__(self) -> None: | |
self.image_dataset_schema = self.__merge_dict( | |
self.DATASET_ASCENDABLE_SCHEMA, | |
self.IMAGE_DATASET_DISTINCT_SCHEMA, | |
) | |
self.video_dataset_schema = self.__merge_dict( | |
self.DATASET_ASCENDABLE_SCHEMA, | |
self.VIDEO_DATASET_DISTINCT_SCHEMA, | |
) | |
def validate_flex_dataset(dataset_config: dict): | |
if "target_frames" in dataset_config: | |
return Schema(self.video_dataset_schema)(dataset_config) | |
else: | |
return Schema(self.image_dataset_schema)(dataset_config) | |
self.dataset_schema = validate_flex_dataset | |
self.general_schema = self.__merge_dict( | |
self.DATASET_ASCENDABLE_SCHEMA, | |
) | |
self.user_config_validator = Schema( | |
{ | |
"general": self.general_schema, | |
"datasets": [self.dataset_schema], | |
} | |
) | |
self.argparse_schema = self.__merge_dict( | |
self.ARGPARSE_SPECIFIC_SCHEMA, | |
) | |
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) | |
def sanitize_user_config(self, user_config: dict) -> dict: | |
try: | |
return self.user_config_validator(user_config) | |
except MultipleInvalid: | |
# TODO: clarify the error message | |
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") | |
raise | |
# NOTE: In nature, argument parser result is not needed to be sanitize | |
# However this will help us to detect program bug | |
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: | |
try: | |
return self.argparse_config_validator(argparse_namespace) | |
except MultipleInvalid: | |
# XXX: this should be a bug | |
logger.error( | |
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" | |
) | |
raise | |
# NOTE: value would be overwritten by latter dict if there is already the same key | |
def __merge_dict(*dict_list: dict) -> dict: | |
merged = {} | |
for schema in dict_list: | |
# merged |= schema | |
for k, v in schema.items(): | |
merged[k] = v | |
return merged | |
class BlueprintGenerator: | |
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} | |
def __init__(self, sanitizer: ConfigSanitizer): | |
self.sanitizer = sanitizer | |
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer | |
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: | |
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) | |
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) | |
argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None} | |
general_config = sanitized_user_config.get("general", {}) | |
dataset_blueprints = [] | |
for dataset_config in sanitized_user_config.get("datasets", []): | |
is_image_dataset = "target_frames" not in dataset_config | |
if is_image_dataset: | |
dataset_params_klass = ImageDatasetParams | |
else: | |
dataset_params_klass = VideoDatasetParams | |
params = self.generate_params_by_fallbacks( | |
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] | |
) | |
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params)) | |
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) | |
return Blueprint(dataset_group_blueprint) | |
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): | |
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME | |
search_value = BlueprintGenerator.search_value | |
default_params = asdict(param_klass()) | |
param_names = default_params.keys() | |
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} | |
return param_klass(**params) | |
def search_value(key: str, fallbacks: Sequence[dict], default_value=None): | |
for cand in fallbacks: | |
value = cand.get(key) | |
if value is not None: | |
return value | |
return default_value | |
# if training is True, it will return a dataset group for training, otherwise for caching | |
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup: | |
datasets: List[Union[ImageDataset, VideoDataset]] = [] | |
for dataset_blueprint in dataset_group_blueprint.datasets: | |
if dataset_blueprint.is_image_dataset: | |
dataset_klass = ImageDataset | |
else: | |
dataset_klass = VideoDataset | |
dataset = dataset_klass(**asdict(dataset_blueprint.params)) | |
datasets.append(dataset) | |
# assertion | |
cache_directories = [dataset.cache_directory for dataset in datasets] | |
num_of_unique_cache_directories = len(set(cache_directories)) | |
if num_of_unique_cache_directories != len(cache_directories): | |
raise ValueError( | |
"cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)" | |
+ " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)" | |
) | |
# print info | |
info = "" | |
for i, dataset in enumerate(datasets): | |
is_image_dataset = isinstance(dataset, ImageDataset) | |
info += dedent( | |
f"""\ | |
[Dataset {i}] | |
is_image_dataset: {is_image_dataset} | |
resolution: {dataset.resolution} | |
batch_size: {dataset.batch_size} | |
num_repeats: {dataset.num_repeats} | |
caption_extension: "{dataset.caption_extension}" | |
enable_bucket: {dataset.enable_bucket} | |
bucket_no_upscale: {dataset.bucket_no_upscale} | |
cache_directory: "{dataset.cache_directory}" | |
debug_dataset: {dataset.debug_dataset} | |
""" | |
) | |
if is_image_dataset: | |
info += indent( | |
dedent( | |
f"""\ | |
image_directory: "{dataset.image_directory}" | |
image_jsonl_file: "{dataset.image_jsonl_file}" | |
\n""" | |
), | |
" ", | |
) | |
else: | |
info += indent( | |
dedent( | |
f"""\ | |
video_directory: "{dataset.video_directory}" | |
video_jsonl_file: "{dataset.video_jsonl_file}" | |
target_frames: {dataset.target_frames} | |
frame_extraction: {dataset.frame_extraction} | |
frame_stride: {dataset.frame_stride} | |
frame_sample: {dataset.frame_sample} | |
\n""" | |
), | |
" ", | |
) | |
logger.info(f"{info}") | |
# make buckets first because it determines the length of dataset | |
# and set the same seed for all datasets | |
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no | |
for i, dataset in enumerate(datasets): | |
# logger.info(f"[Dataset {i}]") | |
dataset.set_seed(seed) | |
if training: | |
dataset.prepare_for_training() | |
return DatasetGroup(datasets) | |
def load_user_config(file: str) -> dict: | |
file: Path = Path(file) | |
if not file.is_file(): | |
raise ValueError(f"file not found / ファイルが見つかりません: {file}") | |
if file.name.lower().endswith(".json"): | |
try: | |
with open(file, "r", encoding="utf-8") as f: | |
config = json.load(f) | |
except Exception: | |
logger.error( | |
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" | |
) | |
raise | |
elif file.name.lower().endswith(".toml"): | |
try: | |
config = toml.load(file) | |
except Exception: | |
logger.error( | |
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" | |
) | |
raise | |
else: | |
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") | |
return config | |
# for config test | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("dataset_config") | |
config_args, remain = parser.parse_known_args() | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--debug_dataset", action="store_true") | |
argparse_namespace = parser.parse_args(remain) | |
logger.info("[argparse_namespace]") | |
logger.info(f"{vars(argparse_namespace)}") | |
user_config = load_user_config(config_args.dataset_config) | |
logger.info("") | |
logger.info("[user_config]") | |
logger.info(f"{user_config}") | |
sanitizer = ConfigSanitizer() | |
sanitized_user_config = sanitizer.sanitize_user_config(user_config) | |
logger.info("") | |
logger.info("[sanitized_user_config]") | |
logger.info(f"{sanitized_user_config}") | |
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) | |
logger.info("") | |
logger.info("[blueprint]") | |
logger.info(f"{blueprint}") | |
dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group) | |