|
import argparse |
|
|
|
HPARAMS_REGISTRY = {} |
|
|
|
|
|
class Hparams: |
|
def update(self, dict): |
|
for k, v in dict.items(): |
|
setattr(self, k, v) |
|
|
|
brset = Hparams() |
|
brset.lr = 1e-3 |
|
brset.bs = 16 |
|
brset.wd = 0.01 |
|
brset.z_dim = 16 |
|
brset.input_res = 384 |
|
brset.pad = 9 |
|
brset.hflip = 0.5 |
|
|
|
brset.input_channels = 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
brset.enc_arch = "384b1d4,96b3d4,24b11d2,12b7d2,6b3d6,1b2" |
|
brset.dec_arch = "1b2,6b4,12b8,24b12,96b4,384b2" |
|
brset.widths = [32, 64, 128, 160, 192, 512] |
|
|
|
|
|
|
|
brset.bias_max_res = 64 |
|
brset.bottleneck = 4 |
|
brset.parents_x = ['patient_age', 'patient_sex', 'DR_ICDR'] |
|
brset.context_norm = "[-1,1]" |
|
brset.context_dim = 7 |
|
brset.n_classes = 5 |
|
brset.concat_pa = True |
|
HPARAMS_REGISTRY["brset"] = brset |
|
|
|
|
|
morphomnist = Hparams() |
|
morphomnist.lr = 1e-3 |
|
morphomnist.bs = 32 |
|
morphomnist.wd = 0.01 |
|
morphomnist.z_dim = 16 |
|
morphomnist.input_res = 32 |
|
morphomnist.pad = 4 |
|
morphomnist.enc_arch = "32b3d2,16b3d2,8b3d2,4b3d4,1b4" |
|
morphomnist.dec_arch = "1b4,4b4,8b4,16b4,32b4" |
|
morphomnist.widths = [16, 32, 64, 128, 256] |
|
morphomnist.parents_x = ["thickness", "intensity", "digit"] |
|
morphomnist.concat_pa = True |
|
morphomnist.context_norm = "[-1,1]" |
|
morphomnist.context_dim = 12 |
|
HPARAMS_REGISTRY["morphomnist"] = morphomnist |
|
|
|
|
|
cmnist = Hparams() |
|
cmnist.lr = 1e-3 |
|
cmnist.bs = 32 |
|
cmnist.wd = 0.01 |
|
cmnist.z_dim = 16 |
|
cmnist.input_res = 32 |
|
cmnist.input_channels = 3 |
|
cmnist.pad = 4 |
|
cmnist.enc_arch = "32b3d2,16b3d2,8b3d2,4b3d4,1b4" |
|
cmnist.dec_arch = "1b4,4b4,8b4,16b4,32b4" |
|
cmnist.widths = [16, 32, 64, 128, 256] |
|
cmnist.parents_x = ["digit", "colour"] |
|
cmnist.context_dim = 20 |
|
HPARAMS_REGISTRY["cmnist"] = cmnist |
|
|
|
|
|
ukbb64 = Hparams() |
|
ukbb64.lr = 1e-3 |
|
ukbb64.bs = 32 |
|
ukbb64.wd = 0.1 |
|
ukbb64.z_dim = 16 |
|
ukbb64.input_res = 64 |
|
ukbb64.pad = 3 |
|
ukbb64.enc_arch = "64b3d2,32b31d2,16b15d2,8b7d2,4b3d4,1b2" |
|
ukbb64.dec_arch = "1b2,4b4,8b8,16b16,32b32,64b4" |
|
ukbb64.widths = [32, 64, 128, 256, 512, 1024] |
|
HPARAMS_REGISTRY["ukbb64"] = ukbb64 |
|
|
|
|
|
ukbb192 = Hparams() |
|
ukbb192.update(ukbb64.__dict__) |
|
ukbb192.input_res = 384 |
|
ukbb192.pad = 9 |
|
ukbb192.enc_arch = "384b2d2,192b2d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2" |
|
ukbb192.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2,384b2" |
|
ukbb192.widths = [32, 64, 96, 128, 160, 192, 512, 1024] |
|
HPARAMS_REGISTRY["ukbb192"] = ukbb192 |
|
|
|
|
|
mimic192 = Hparams() |
|
mimic192.lr = 1e-3 |
|
mimic192.bs = 16 |
|
mimic192.wd = 0.1 |
|
mimic192.z_dim = 16 |
|
mimic192.input_res = 192 |
|
mimic192.pad = 9 |
|
mimic192.enc_arch = "192b1d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2" |
|
mimic192.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2" |
|
mimic192.widths = [32, 64, 96, 128, 160, 192, 512] |
|
HPARAMS_REGISTRY["mimic192"] = mimic192 |
|
|
|
mimic384 = Hparams() |
|
mimic384.lr = 1e-3 |
|
mimic384.bs = 16 |
|
mimic384.wd = 0.1 |
|
mimic384.z_dim = 16 |
|
mimic384.input_res = 384 |
|
mimic384.pad = 9 |
|
mimic384.enc_arch = "384b1d2,192b1d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2" |
|
mimic384.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2,384b2" |
|
mimic384.widths = [32, 64, 96, 128, 160, 192, 512,1024] |
|
HPARAMS_REGISTRY["mimic384"] = mimic384 |
|
|
|
def setup_hparams(parser: argparse.ArgumentParser) -> Hparams: |
|
hparams = Hparams() |
|
args = parser.parse_known_args()[0] |
|
valid_args = set(args.__dict__.keys()) |
|
hparams_dict = HPARAMS_REGISTRY[args.hps].__dict__ |
|
for k in hparams_dict.keys(): |
|
if k not in valid_args: |
|
raise ValueError(f"{k} not in default args") |
|
parser.set_defaults(**hparams_dict) |
|
hparams.update(parser.parse_known_args()[0].__dict__) |
|
return hparams |
|
|
|
|
|
def add_arguments(parser: argparse.ArgumentParser): |
|
parser.add_argument("--exp_name", help="Experiment name.", type=str, default="") |
|
parser.add_argument( |
|
"--data_dir", help="Data directory to load form.", type=str, default="" |
|
) |
|
parser.add_argument("--hps", help="hyperparam set.", type=str, default="ukbb64") |
|
parser.add_argument( |
|
"--resume", help="Path to load checkpoint.", type=str, default="" |
|
) |
|
parser.add_argument("--seed", help="Set random seed.", type=int, default=7) |
|
parser.add_argument( |
|
"--deterministic", |
|
help="Toggle cudNN determinism.", |
|
action="store_true", |
|
default=False, |
|
) |
|
|
|
parser.add_argument("--epochs", help="Training epochs.", type=int, default=5000) |
|
parser.add_argument("--bs", help="Batch size.", type=int, default=32) |
|
parser.add_argument("--lr", help="Learning rate.", type=float, default=1e-3) |
|
parser.add_argument( |
|
"--lr_warmup_steps", help="lr warmup steps.", type=int, default=100 |
|
) |
|
parser.add_argument("--wd", help="Weight decay penalty.", type=float, default=0.01) |
|
parser.add_argument( |
|
"--betas", |
|
help="Adam beta parameters.", |
|
nargs="+", |
|
type=float, |
|
default=[0.9, 0.9], |
|
) |
|
parser.add_argument( |
|
"--ema_rate", help="Exp. moving avg. model rate.", type=float, default=0.999 |
|
) |
|
parser.add_argument( |
|
"--input_res", help="Input image crop resolution.", type=int, default=64 |
|
) |
|
parser.add_argument( |
|
"--input_channels", help="Input image num channels.", type=int, default=1 |
|
) |
|
parser.add_argument("--pad", help="Input padding.", type=int, default=3) |
|
parser.add_argument( |
|
"--hflip", help="Horizontal flip prob.", type=float, default=0.5 |
|
) |
|
parser.add_argument( |
|
"--grad_clip", help="Gradient clipping value.", type=float, default=350 |
|
) |
|
parser.add_argument( |
|
"--grad_skip", help="Skip update grad norm threshold.", type=float, default=500 |
|
) |
|
parser.add_argument( |
|
"--accu_steps", help="Gradient accumulation steps.", type=int, default=1 |
|
) |
|
parser.add_argument( |
|
"--beta", help="Max KL beta penalty weight.", type=float, default=1.0 |
|
) |
|
parser.add_argument( |
|
"--beta_warmup_steps", help="KL beta penalty warmup steps.", type=int, default=0 |
|
) |
|
parser.add_argument( |
|
"--kl_free_bits", help="KL min free bits constraint.", type=float, default=0.0 |
|
) |
|
parser.add_argument( |
|
"--viz_freq", help="Steps per visualisation.", type=int, default=10000 |
|
) |
|
parser.add_argument( |
|
"--eval_freq", help="Train epochs per validation.", type=int, default=5 |
|
) |
|
parser.add_argument( |
|
"--n_classes", help="Number of classes for DR ICDR.", type=int, default=10 |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--vae", |
|
help="VAE model: simple/hierarchical.", |
|
type=str, |
|
default="hierarchical", |
|
) |
|
parser.add_argument( |
|
"--enc_arch", |
|
help="Encoder architecture config.", |
|
type=str, |
|
default="64b1d2,32b1d2,16b1d2,8b1d8,1b2", |
|
) |
|
parser.add_argument( |
|
"--dec_arch", |
|
help="Decoder architecture config.", |
|
type=str, |
|
default="1b2,8b2,16b2,32b2,64b2", |
|
) |
|
parser.add_argument( |
|
"--cond_prior", |
|
help="Use a conditional prior.", |
|
action="store_true", |
|
default=False, |
|
) |
|
parser.add_argument( |
|
"--widths", |
|
help="Number of channels.", |
|
nargs="+", |
|
type=int, |
|
default=[16, 32, 48, 64, 128], |
|
) |
|
parser.add_argument( |
|
"--bottleneck", help="Bottleneck width factor.", type=int, default=4 |
|
) |
|
parser.add_argument( |
|
"--z_dim", help="Numver of latent channel dims.", type=int, default=16 |
|
) |
|
parser.add_argument( |
|
"--z_max_res", |
|
help="Max resolution of stochastic z layers.", |
|
type=int, |
|
default=192, |
|
) |
|
parser.add_argument( |
|
"--bias_max_res", |
|
help="Learned bias param max resolution.", |
|
type=int, |
|
default=64, |
|
) |
|
parser.add_argument( |
|
"--x_like", |
|
help="x likelihood: {fixed/shared/diag}_{gauss/dgauss}.", |
|
type=str, |
|
default="diag_dgauss", |
|
) |
|
parser.add_argument( |
|
"--std_init", |
|
help="Initial std for x scale. 0 is random.", |
|
type=float, |
|
default=0.0, |
|
) |
|
parser.add_argument( |
|
"--parents_x", |
|
help="Parents of x to condition on.", |
|
nargs="+", |
|
default=["mri_seq", "brain_volume", "ventricle_volume", "sex"], |
|
) |
|
parser.add_argument( |
|
"--concat_pa", |
|
help="Whether to concatenate parents_x.", |
|
action="store_true", |
|
default=False, |
|
) |
|
parser.add_argument( |
|
"--context_dim", |
|
help="Num context variables conditioned on.", |
|
type=int, |
|
default=4, |
|
) |
|
parser.add_argument( |
|
"--context_norm", |
|
help='Conditioning normalisation {"[-1,1]"/"[0,1]"/log_standard}.', |
|
type=str, |
|
default="log_standard", |
|
) |
|
parser.add_argument( |
|
"--q_correction", |
|
help="Use posterior correction.", |
|
action="store_true", |
|
default=False, |
|
) |
|
return parser |
|
|