File size: 5,390 Bytes
7aefe45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse


#################################################################################
#                            practical argparse utils                           #
#################################################################################

def accelerate_parser():
    parser = argparse.ArgumentParser(add_help=False)

    # Device
    parser.add_argument("-cpu", "--use_cpu", action="store_true",
                        help="Whether or not disable cuda")

    # Gradient Accumulation
    parser.add_argument("-cumgard", "--gradient-accumulate-step",
                        type=int, default=1)
    parser.add_argument("--split-batches", action="store_true",
                        help="Whether or not the accelerator should split the batches "
                             "yielded by the dataloaders across the devices.")

    # Nvidia-Apex and GradScaler
    parser.add_argument("-mprec", "--mixed-precision",
                        type=str, default='no', choices=['no', 'fp16', 'bf16'],
                        help="Whether to use mixed precision. Choose"
                             "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
                             "and an Nvidia Ampere GPU.")
    parser.add_argument("--init-scale",
                        type=float, default=65536.0,
                        help="Default value: `2.**16 = 65536.0` ,"
                             "For ImageNet experiments, '2.**20 = 1048576.0' was a good default value."
                             "the others: `2.**17 = 131072.0` ")
    parser.add_argument("--growth-factor", type=float, default=2.0)
    parser.add_argument("--backoff-factor", type=float, default=0.5)
    parser.add_argument("--growth-interval", type=int, default=2000)

    # Gradient Normalization
    parser.add_argument("-gard_norm", "--max_grad_norm", type=float, default=-1)

    # Trackers
    parser.add_argument("--use-wandb", action="store_true")
    parser.add_argument("--project-name", type=str, default="SketchGeneration")
    parser.add_argument("--entity", type=str, default="ximinng")
    parser.add_argument("--tensorboard", action="store_true")

    # timing
    parser.add_argument("-log_step", "--log_step", default=1000, type=int,
                        help="can be use to control log.")
    parser.add_argument("-eval_step", "--eval_step", default=10, type=int,
                        help="can be use to calculate some metrics.")
    parser.add_argument("-save_step", "--save_step", default=10, type=int,
                        help="can be use to control saving checkpoint.")

    # update configuration interface
    # example: python main.py -c main.yaml -update "nnet.depth=16 batch_size=16"
    parser.add_argument("-update",
                        type=str, default="sds.warmup=1000",
                        help="modified hyper-parameters of config file. ")
    return parser


def ema_parser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--ema', action='store_true', help='enable EMA model')
    parser.add_argument("--ema_decay", type=float, default=0.9999)
    parser.add_argument("--ema_update_after_step", type=int, default=100)
    parser.add_argument("--ema_update_every", type=int, default=10)
    return parser


def base_data_parser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("-spl", "--split",
                        default='test', type=str,
                        choices=['train', 'val', 'test', 'all'],
                        help="which part of the data set, 'all' means combine training and test sets.")
    parser.add_argument("-j", "--num_workers",
                        default=6, type=int,
                        help="how many subprocesses to use for data loading.")
    parser.add_argument("--shuffle",
                        action='store_true',
                        help="how many subprocesses to use for data loading.")
    parser.add_argument("--drop_last",
                        action='store_true',
                        help="how many subprocesses to use for data loading.")
    return parser


def base_training_parser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("-tbz", "--train_batch_size",
                        default=32, type=int,
                        help="how many images to sample during training.")
    parser.add_argument("-wd", "--weight_decay", default=0, type=float)
    return parser


def base_sampling_parser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("-vbz", "--valid_batch_size",
                        default=1, type=int,
                        help="how many images to sample during evaluation")
    parser.add_argument("-ts", "--total_samples",
                        default=2000, type=int,
                        help="the total number of samples, can be used to calculate FID.")
    parser.add_argument("-ns", "--num_samples",
                        default=4, type=int,
                        help="number of samples taken at a time, "
                             "can be used to repeatedly induce samples from a generation model "
                             "from a fixed guided information, "
                             "eg: `one latent to ns samples` (1 latent to 5 photo generation) ")
    return parser