MarkMoHR's picture
added code
7aefe45
import argparse
#################################################################################
# practical argparse utils #
#################################################################################
def accelerate_parser():
parser = argparse.ArgumentParser(add_help=False)
# Device
parser.add_argument("-cpu", "--use_cpu", action="store_true",
help="Whether or not disable cuda")
# Gradient Accumulation
parser.add_argument("-cumgard", "--gradient-accumulate-step",
type=int, default=1)
parser.add_argument("--split-batches", action="store_true",
help="Whether or not the accelerator should split the batches "
"yielded by the dataloaders across the devices.")
# Nvidia-Apex and GradScaler
parser.add_argument("-mprec", "--mixed-precision",
type=str, default='no', choices=['no', 'fp16', 'bf16'],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.")
parser.add_argument("--init-scale",
type=float, default=65536.0,
help="Default value: `2.**16 = 65536.0` ,"
"For ImageNet experiments, '2.**20 = 1048576.0' was a good default value."
"the others: `2.**17 = 131072.0` ")
parser.add_argument("--growth-factor", type=float, default=2.0)
parser.add_argument("--backoff-factor", type=float, default=0.5)
parser.add_argument("--growth-interval", type=int, default=2000)
# Gradient Normalization
parser.add_argument("-gard_norm", "--max_grad_norm", type=float, default=-1)
# Trackers
parser.add_argument("--use-wandb", action="store_true")
parser.add_argument("--project-name", type=str, default="SketchGeneration")
parser.add_argument("--entity", type=str, default="ximinng")
parser.add_argument("--tensorboard", action="store_true")
# timing
parser.add_argument("-log_step", "--log_step", default=1000, type=int,
help="can be use to control log.")
parser.add_argument("-eval_step", "--eval_step", default=10, type=int,
help="can be use to calculate some metrics.")
parser.add_argument("-save_step", "--save_step", default=10, type=int,
help="can be use to control saving checkpoint.")
# update configuration interface
# example: python main.py -c main.yaml -update "nnet.depth=16 batch_size=16"
parser.add_argument("-update",
type=str, default="sds.warmup=1000",
help="modified hyper-parameters of config file. ")
return parser
def ema_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--ema', action='store_true', help='enable EMA model')
parser.add_argument("--ema_decay", type=float, default=0.9999)
parser.add_argument("--ema_update_after_step", type=int, default=100)
parser.add_argument("--ema_update_every", type=int, default=10)
return parser
def base_data_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-spl", "--split",
default='test', type=str,
choices=['train', 'val', 'test', 'all'],
help="which part of the data set, 'all' means combine training and test sets.")
parser.add_argument("-j", "--num_workers",
default=6, type=int,
help="how many subprocesses to use for data loading.")
parser.add_argument("--shuffle",
action='store_true',
help="how many subprocesses to use for data loading.")
parser.add_argument("--drop_last",
action='store_true',
help="how many subprocesses to use for data loading.")
return parser
def base_training_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-tbz", "--train_batch_size",
default=32, type=int,
help="how many images to sample during training.")
parser.add_argument("-wd", "--weight_decay", default=0, type=float)
return parser
def base_sampling_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-vbz", "--valid_batch_size",
default=1, type=int,
help="how many images to sample during evaluation")
parser.add_argument("-ts", "--total_samples",
default=2000, type=int,
help="the total number of samples, can be used to calculate FID.")
parser.add_argument("-ns", "--num_samples",
default=4, type=int,
help="number of samples taken at a time, "
"can be used to repeatedly induce samples from a generation model "
"from a fixed guided information, "
"eg: `one latent to ns samples` (1 latent to 5 photo generation) ")
return parser