|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
|
|
def accelerate_parser(): |
|
parser = argparse.ArgumentParser(add_help=False) |
|
|
|
|
|
parser.add_argument("-cpu", "--use_cpu", action="store_true", |
|
help="Whether or not disable cuda") |
|
|
|
|
|
parser.add_argument("-cumgard", "--gradient-accumulate-step", |
|
type=int, default=1) |
|
parser.add_argument("--split-batches", action="store_true", |
|
help="Whether or not the accelerator should split the batches " |
|
"yielded by the dataloaders across the devices.") |
|
|
|
|
|
parser.add_argument("-mprec", "--mixed-precision", |
|
type=str, default='no', choices=['no', 'fp16', 'bf16'], |
|
help="Whether to use mixed precision. Choose" |
|
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." |
|
"and an Nvidia Ampere GPU.") |
|
parser.add_argument("--init-scale", |
|
type=float, default=65536.0, |
|
help="Default value: `2.**16 = 65536.0` ," |
|
"For ImageNet experiments, '2.**20 = 1048576.0' was a good default value." |
|
"the others: `2.**17 = 131072.0` ") |
|
parser.add_argument("--growth-factor", type=float, default=2.0) |
|
parser.add_argument("--backoff-factor", type=float, default=0.5) |
|
parser.add_argument("--growth-interval", type=int, default=2000) |
|
|
|
|
|
parser.add_argument("-gard_norm", "--max_grad_norm", type=float, default=-1) |
|
|
|
|
|
parser.add_argument("--use-wandb", action="store_true") |
|
parser.add_argument("--project-name", type=str, default="SketchGeneration") |
|
parser.add_argument("--entity", type=str, default="ximinng") |
|
parser.add_argument("--tensorboard", action="store_true") |
|
|
|
|
|
parser.add_argument("-log_step", "--log_step", default=1000, type=int, |
|
help="can be use to control log.") |
|
parser.add_argument("-eval_step", "--eval_step", default=10, type=int, |
|
help="can be use to calculate some metrics.") |
|
parser.add_argument("-save_step", "--save_step", default=10, type=int, |
|
help="can be use to control saving checkpoint.") |
|
|
|
|
|
|
|
parser.add_argument("-update", |
|
type=str, default="sds.warmup=1000", |
|
help="modified hyper-parameters of config file. ") |
|
return parser |
|
|
|
|
|
def ema_parser(): |
|
parser = argparse.ArgumentParser(add_help=False) |
|
parser.add_argument('--ema', action='store_true', help='enable EMA model') |
|
parser.add_argument("--ema_decay", type=float, default=0.9999) |
|
parser.add_argument("--ema_update_after_step", type=int, default=100) |
|
parser.add_argument("--ema_update_every", type=int, default=10) |
|
return parser |
|
|
|
|
|
def base_data_parser(): |
|
parser = argparse.ArgumentParser(add_help=False) |
|
parser.add_argument("-spl", "--split", |
|
default='test', type=str, |
|
choices=['train', 'val', 'test', 'all'], |
|
help="which part of the data set, 'all' means combine training and test sets.") |
|
parser.add_argument("-j", "--num_workers", |
|
default=6, type=int, |
|
help="how many subprocesses to use for data loading.") |
|
parser.add_argument("--shuffle", |
|
action='store_true', |
|
help="how many subprocesses to use for data loading.") |
|
parser.add_argument("--drop_last", |
|
action='store_true', |
|
help="how many subprocesses to use for data loading.") |
|
return parser |
|
|
|
|
|
def base_training_parser(): |
|
parser = argparse.ArgumentParser(add_help=False) |
|
parser.add_argument("-tbz", "--train_batch_size", |
|
default=32, type=int, |
|
help="how many images to sample during training.") |
|
parser.add_argument("-wd", "--weight_decay", default=0, type=float) |
|
return parser |
|
|
|
|
|
def base_sampling_parser(): |
|
parser = argparse.ArgumentParser(add_help=False) |
|
parser.add_argument("-vbz", "--valid_batch_size", |
|
default=1, type=int, |
|
help="how many images to sample during evaluation") |
|
parser.add_argument("-ts", "--total_samples", |
|
default=2000, type=int, |
|
help="the total number of samples, can be used to calculate FID.") |
|
parser.add_argument("-ns", "--num_samples", |
|
default=4, type=int, |
|
help="number of samples taken at a time, " |
|
"can be used to repeatedly induce samples from a generation model " |
|
"from a fixed guided information, " |
|
"eg: `one latent to ns samples` (1 latent to 5 photo generation) ") |
|
return parser |
|
|