|
HPARAMS_REGISTRY = {} |
|
|
|
|
|
class Hyperparams(dict): |
|
def __getattr__(self, attr): |
|
try: |
|
return self[attr] |
|
except KeyError: |
|
return None |
|
|
|
def __setattr__(self, attr, value): |
|
self[attr] = value |
|
|
|
|
|
cifar10 = Hyperparams() |
|
cifar10.width = 384 |
|
cifar10.lr = 0.0002 |
|
cifar10.zdim = 16 |
|
cifar10.wd = 0.01 |
|
cifar10.dec_blocks = "1x1,4m1,4x2,8m4,8x5,16m8,16x10,32m16,32x21" |
|
cifar10.enc_blocks = "32x11,32d2,16x6,16d2,8x6,8d2,4x3,4d4,1x3" |
|
cifar10.warmup_iters = 100 |
|
cifar10.dataset = 'cifar10' |
|
cifar10.n_batch = 16 |
|
cifar10.ema_rate = 0.9999 |
|
HPARAMS_REGISTRY['cifar10'] = cifar10 |
|
|
|
|
|
i32 = Hyperparams() |
|
i32.update(cifar10) |
|
i32.dataset = 'imagenet32' |
|
i32.ema_rate = 0.999 |
|
i32.dec_blocks = "1x2,4m1,4x4,8m4,8x9,16m8,16x19,32m16,32x40" |
|
i32.enc_blocks = "32x15,32d2,16x9,16d2,8x8,8d2,4x6,4d4,1x6" |
|
i32.width = 512 |
|
i32.n_batch = 8 |
|
i32.lr = 0.00015 |
|
i32.grad_clip = 200. |
|
i32.skip_threshold = 300. |
|
i32.epochs_per_eval = 1 |
|
i32.epochs_per_eval_save = 1 |
|
HPARAMS_REGISTRY['imagenet32'] = i32 |
|
|
|
i64 = Hyperparams() |
|
i64.update(i32) |
|
i64.n_batch = 4 |
|
i64.grad_clip = 220.0 |
|
i64.skip_threshold = 380.0 |
|
i64.dataset = 'imagenet64' |
|
i64.dec_blocks = "1x2,4m1,4x3,8m4,8x7,16m8,16x15,32m16,32x31,64m32,64x12" |
|
i64.enc_blocks = "64x11,64d2,32x20,32d2,16x9,16d2,8x8,8d2,4x7,4d4,1x5" |
|
HPARAMS_REGISTRY['imagenet64'] = i64 |
|
|
|
ffhq_256 = Hyperparams() |
|
ffhq_256.update(i64) |
|
ffhq_256.n_batch = 1 |
|
ffhq_256.lr = 0.00015 |
|
ffhq_256.dataset = 'ffhq_256' |
|
ffhq_256.epochs_per_eval = 1 |
|
ffhq_256.epochs_per_eval_save = 1 |
|
ffhq_256.num_images_visualize = 2 |
|
ffhq_256.num_variables_visualize = 3 |
|
ffhq_256.num_temperatures_visualize = 1 |
|
ffhq_256.dec_blocks = "1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x21,64m32,64x13,128m64,128x7,256m128" |
|
ffhq_256.enc_blocks = "256x3,256d2,128x8,128d2,64x12,64d2,32x17,32d2,16x7,16d2,8x5,8d2,4x5,4d4,1x4" |
|
ffhq_256.no_bias_above = 64 |
|
ffhq_256.grad_clip = 130. |
|
ffhq_256.skip_threshold = 180. |
|
HPARAMS_REGISTRY['ffhq256'] = ffhq_256 |
|
|
|
|
|
ffhq1024 = Hyperparams() |
|
ffhq1024.update(ffhq_256) |
|
ffhq1024.dataset = 'ffhq_1024' |
|
ffhq1024.data_root = './ffhq_images1024x1024' |
|
ffhq1024.epochs_per_eval = 1 |
|
ffhq1024.epochs_per_eval_save = 1 |
|
ffhq1024.num_images_visualize = 1 |
|
ffhq1024.iters_per_images = 25000 |
|
ffhq1024.num_variables_visualize = 0 |
|
ffhq1024.num_temperatures_visualize = 4 |
|
ffhq1024.grad_clip = 360. |
|
ffhq1024.skip_threshold = 500. |
|
ffhq1024.num_mixtures = 2 |
|
ffhq1024.width = 16 |
|
ffhq1024.lr = 0.00007 |
|
ffhq1024.dec_blocks = "1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x20,64m32,64x14,128m64,128x7,256m128,256x2,512m256,1024m512" |
|
ffhq1024.enc_blocks = "1024x1,1024d2,512x3,512d2,256x5,256d2,128x7,128d2,64x10,64d2,32x14,32d2,16x7,16d2,8x5,8d2,4x5,4d4,1x4" |
|
ffhq1024.custom_width_str = "512:32,256:64,128:512,64:512,32:512,16:512,8:512,4:512,1:512" |
|
HPARAMS_REGISTRY['ffhq1024'] = ffhq1024 |
|
|
|
|
|
def parse_args_and_update_hparams(H, parser, s=None): |
|
args = parser.parse_args(s) |
|
valid_args = set(args.__dict__.keys()) |
|
hparam_sets = [x for x in args.hparam_sets.split(',') if x] |
|
for hp_set in hparam_sets: |
|
hps = HPARAMS_REGISTRY[hp_set] |
|
for k in hps: |
|
if k not in valid_args: |
|
raise ValueError(f"{k} not in default args") |
|
parser.set_defaults(**hps) |
|
H.update(parser.parse_args(s).__dict__) |
|
|
|
|
|
def add_vae_arguments(parser): |
|
parser.add_argument('--seed', type=int, default=0) |
|
parser.add_argument('--port', type=int, default=29500) |
|
parser.add_argument('--save_dir', type=str, default='./saved_models') |
|
parser.add_argument('--data_root', type=str, default='./') |
|
|
|
parser.add_argument('--desc', type=str, default='test') |
|
parser.add_argument('--hparam_sets', '--hps', type=str) |
|
parser.add_argument('--restore_path', type=str, default=None) |
|
parser.add_argument('--restore_ema_path', type=str, default=None) |
|
parser.add_argument('--restore_log_path', type=str, default=None) |
|
parser.add_argument('--restore_optimizer_path', type=str, default=None) |
|
parser.add_argument('--dataset', type=str, default='cifar10') |
|
|
|
parser.add_argument('--ema_rate', type=float, default=0.999) |
|
|
|
parser.add_argument('--enc_blocks', type=str, default=None) |
|
parser.add_argument('--dec_blocks', type=str, default=None) |
|
parser.add_argument('--zdim', type=int, default=16) |
|
parser.add_argument('--width', type=int, default=512) |
|
parser.add_argument('--custom_width_str', type=str, default='') |
|
parser.add_argument('--bottleneck_multiple', type=float, default=0.25) |
|
|
|
parser.add_argument('--no_bias_above', type=int, default=64) |
|
parser.add_argument('--scale_encblock', action="store_true") |
|
|
|
parser.add_argument('--test_eval', action="store_true") |
|
parser.add_argument('--warmup_iters', type=float, default=0) |
|
|
|
parser.add_argument('--num_mixtures', type=int, default=10) |
|
parser.add_argument('--grad_clip', type=float, default=200.0) |
|
parser.add_argument('--skip_threshold', type=float, default=400.0) |
|
parser.add_argument('--lr', type=float, default=0.00015) |
|
parser.add_argument('--lr_prior', type=float, default=0.00015) |
|
parser.add_argument('--wd', type=float, default=0.0) |
|
parser.add_argument('--wd_prior', type=float, default=0.0) |
|
parser.add_argument('--num_epochs', type=int, default=10000) |
|
parser.add_argument('--n_batch', type=int, default=32) |
|
parser.add_argument('--adam_beta1', type=float, default=0.9) |
|
parser.add_argument('--adam_beta2', type=float, default=0.9) |
|
|
|
parser.add_argument('--temperature', type=float, default=1.0) |
|
|
|
parser.add_argument('--iters_per_ckpt', type=int, default=25000) |
|
parser.add_argument('--iters_per_print', type=int, default=1000) |
|
parser.add_argument('--iters_per_save', type=int, default=10000) |
|
parser.add_argument('--iters_per_images', type=int, default=10000) |
|
parser.add_argument('--epochs_per_eval', type=int, default=10) |
|
parser.add_argument('--epochs_per_probe', type=int, default=None) |
|
parser.add_argument('--epochs_per_eval_save', type=int, default=20) |
|
parser.add_argument('--num_images_visualize', type=int, default=8) |
|
parser.add_argument('--num_variables_visualize', type=int, default=6) |
|
parser.add_argument('--num_temperatures_visualize', type=int, default=3) |
|
return parser |
|
|