Yeefei commited on
Commit
5022bc3
·
verified ·
1 Parent(s): 5ebd65a

Upload hps.py

Browse files
Files changed (1) hide show
  1. hps.py +300 -0
hps.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ HPARAMS_REGISTRY = {}
4
+
5
+
6
+ class Hparams:
7
+ def update(self, dict):
8
+ for k, v in dict.items():
9
+ setattr(self, k, v)
10
+
11
+ brset = Hparams()
12
+ brset.lr = 1e-3
13
+ brset.bs = 16
14
+ brset.wd = 0.01
15
+ brset.z_dim = 16
16
+ brset.input_res = 384 #192
17
+ brset.pad = 9
18
+ brset.hflip = 0.5
19
+
20
+ brset.input_channels = 3
21
+ # the first number is never used, it is just a placeholder to know the expected dimension of the output
22
+ # b is the number of convolutional blocks, so for example 32b3d2 means 3 convolutional blocks
23
+ # d is used to create a downsampling layer (represented as projection layer, and a 2D average pooling layer), so 32b3d2 means that we will add a 2D average pooling layer block with a stride and and kernel size of 2, at the end of the 3 convolutional blocks
24
+ # The widths are the number of channels of each convolutional block
25
+ #brset.enc_arch = "384b1d4,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2" # Also for 384 but requires more memory
26
+ #brset.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,384b2" # Also for 384 but requires more memory
27
+ brset.enc_arch = "384b1d4,96b3d4,24b11d2,12b7d2,6b3d6,1b2" # for 384
28
+ brset.dec_arch = "1b2,6b4,12b8,24b12,96b4,384b2" # for 384
29
+ brset.widths = [32, 64, 128, 160, 192, 512] # for 384
30
+ #brset.enc_arch = "192b1d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2" # for 192
31
+ #brset.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2" # for 192
32
+ #brset.widths = [32, 64, 96, 128, 160, 192, 512] # for 192
33
+ brset.bias_max_res = 64 # Used for the max resolution of the bias parameter
34
+ brset.bottleneck = 4 # Used for the number of channels of the bottleneck layer in the block = width/bottleneck
35
+ brset.parents_x = ['patient_age', 'patient_sex', 'DR_ICDR']
36
+ brset.context_norm = "[-1,1]"
37
+ brset.context_dim = 7 # Number of context variables. In our case it is 7 because we have age (1 - Continuous), sex (1 - Binary) and DR_ICDR (5 one-hot encoded)
38
+ brset.n_classes = 5
39
+ brset.concat_pa = True
40
+ HPARAMS_REGISTRY["brset"] = brset
41
+
42
+
43
+ morphomnist = Hparams()
44
+ morphomnist.lr = 1e-3
45
+ morphomnist.bs = 32
46
+ morphomnist.wd = 0.01
47
+ morphomnist.z_dim = 16
48
+ morphomnist.input_res = 32
49
+ morphomnist.pad = 4
50
+ morphomnist.enc_arch = "32b3d2,16b3d2,8b3d2,4b3d4,1b4"
51
+ morphomnist.dec_arch = "1b4,4b4,8b4,16b4,32b4"
52
+ morphomnist.widths = [16, 32, 64, 128, 256]
53
+ morphomnist.parents_x = ["thickness", "intensity", "digit"]
54
+ morphomnist.concat_pa = True
55
+ morphomnist.context_norm = "[-1,1]"
56
+ morphomnist.context_dim = 12
57
+ HPARAMS_REGISTRY["morphomnist"] = morphomnist
58
+
59
+
60
+ cmnist = Hparams()
61
+ cmnist.lr = 1e-3
62
+ cmnist.bs = 32
63
+ cmnist.wd = 0.01
64
+ cmnist.z_dim = 16
65
+ cmnist.input_res = 32
66
+ cmnist.input_channels = 3
67
+ cmnist.pad = 4
68
+ cmnist.enc_arch = "32b3d2,16b3d2,8b3d2,4b3d4,1b4"
69
+ cmnist.dec_arch = "1b4,4b4,8b4,16b4,32b4"
70
+ cmnist.widths = [16, 32, 64, 128, 256]
71
+ cmnist.parents_x = ["digit", "colour"]
72
+ cmnist.context_dim = 20
73
+ HPARAMS_REGISTRY["cmnist"] = cmnist
74
+
75
+
76
+ ukbb64 = Hparams()
77
+ ukbb64.lr = 1e-3
78
+ ukbb64.bs = 32
79
+ ukbb64.wd = 0.1
80
+ ukbb64.z_dim = 16
81
+ ukbb64.input_res = 64
82
+ ukbb64.pad = 3
83
+ ukbb64.enc_arch = "64b3d2,32b31d2,16b15d2,8b7d2,4b3d4,1b2"
84
+ ukbb64.dec_arch = "1b2,4b4,8b8,16b16,32b32,64b4"
85
+ ukbb64.widths = [32, 64, 128, 256, 512, 1024]
86
+ HPARAMS_REGISTRY["ukbb64"] = ukbb64
87
+
88
+
89
+ ukbb192 = Hparams()
90
+ ukbb192.update(ukbb64.__dict__)
91
+ ukbb192.input_res = 384
92
+ ukbb192.pad = 9
93
+ ukbb192.enc_arch = "384b2d2,192b2d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2"
94
+ ukbb192.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2,384b2"
95
+ ukbb192.widths = [32, 64, 96, 128, 160, 192, 512, 1024]
96
+ HPARAMS_REGISTRY["ukbb192"] = ukbb192
97
+
98
+
99
+ mimic192 = Hparams()
100
+ mimic192.lr = 1e-3
101
+ mimic192.bs = 16
102
+ mimic192.wd = 0.1
103
+ mimic192.z_dim = 16
104
+ mimic192.input_res = 192
105
+ mimic192.pad = 9
106
+ mimic192.enc_arch = "192b1d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2"
107
+ mimic192.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2"
108
+ mimic192.widths = [32, 64, 96, 128, 160, 192, 512]
109
+ HPARAMS_REGISTRY["mimic192"] = mimic192
110
+
111
+ mimic384 = Hparams()
112
+ mimic384.lr = 1e-3
113
+ mimic384.bs = 16
114
+ mimic384.wd = 0.1
115
+ mimic384.z_dim = 16
116
+ mimic384.input_res = 384
117
+ mimic384.pad = 9
118
+ mimic384.enc_arch = "384b1d2,192b1d2,96b3d2,48b7d2,24b11d2,12b7d2,6b3d6,1b2"
119
+ mimic384.dec_arch = "1b2,6b4,12b8,24b12,48b8,96b4,192b2,384b2"
120
+ mimic384.widths = [32, 64, 96, 128, 160, 192, 512,1024]
121
+ HPARAMS_REGISTRY["mimic384"] = mimic384
122
+
123
+ def setup_hparams(parser: argparse.ArgumentParser) -> Hparams:
124
+ hparams = Hparams()
125
+ args = parser.parse_known_args()[0]
126
+ valid_args = set(args.__dict__.keys())
127
+ hparams_dict = HPARAMS_REGISTRY[args.hps].__dict__
128
+ for k in hparams_dict.keys():
129
+ if k not in valid_args:
130
+ raise ValueError(f"{k} not in default args")
131
+ parser.set_defaults(**hparams_dict)
132
+ hparams.update(parser.parse_known_args()[0].__dict__)
133
+ return hparams
134
+
135
+
136
+ def add_arguments(parser: argparse.ArgumentParser):
137
+ parser.add_argument("--exp_name", help="Experiment name.", type=str, default="")
138
+ parser.add_argument(
139
+ "--data_dir", help="Data directory to load form.", type=str, default=""
140
+ )
141
+ parser.add_argument("--hps", help="hyperparam set.", type=str, default="ukbb64")
142
+ parser.add_argument(
143
+ "--resume", help="Path to load checkpoint.", type=str, default=""
144
+ )
145
+ parser.add_argument("--seed", help="Set random seed.", type=int, default=7)
146
+ parser.add_argument(
147
+ "--deterministic",
148
+ help="Toggle cudNN determinism.",
149
+ action="store_true",
150
+ default=False,
151
+ )
152
+ # training
153
+ parser.add_argument("--epochs", help="Training epochs.", type=int, default=5000)
154
+ parser.add_argument("--bs", help="Batch size.", type=int, default=32)
155
+ parser.add_argument("--lr", help="Learning rate.", type=float, default=1e-3)
156
+ parser.add_argument(
157
+ "--lr_warmup_steps", help="lr warmup steps.", type=int, default=100
158
+ )
159
+ parser.add_argument("--wd", help="Weight decay penalty.", type=float, default=0.01)
160
+ parser.add_argument(
161
+ "--betas",
162
+ help="Adam beta parameters.",
163
+ nargs="+",
164
+ type=float,
165
+ default=[0.9, 0.9],
166
+ )
167
+ parser.add_argument(
168
+ "--ema_rate", help="Exp. moving avg. model rate.", type=float, default=0.999
169
+ )
170
+ parser.add_argument(
171
+ "--input_res", help="Input image crop resolution.", type=int, default=64
172
+ )
173
+ parser.add_argument(
174
+ "--input_channels", help="Input image num channels.", type=int, default=1
175
+ )
176
+ parser.add_argument("--pad", help="Input padding.", type=int, default=3)
177
+ parser.add_argument(
178
+ "--hflip", help="Horizontal flip prob.", type=float, default=0.5
179
+ )
180
+ parser.add_argument(
181
+ "--grad_clip", help="Gradient clipping value.", type=float, default=350
182
+ )
183
+ parser.add_argument(
184
+ "--grad_skip", help="Skip update grad norm threshold.", type=float, default=500
185
+ )
186
+ parser.add_argument(
187
+ "--accu_steps", help="Gradient accumulation steps.", type=int, default=1
188
+ )
189
+ parser.add_argument(
190
+ "--beta", help="Max KL beta penalty weight.", type=float, default=1.0
191
+ )
192
+ parser.add_argument(
193
+ "--beta_warmup_steps", help="KL beta penalty warmup steps.", type=int, default=0
194
+ )
195
+ parser.add_argument(
196
+ "--kl_free_bits", help="KL min free bits constraint.", type=float, default=0.0
197
+ )
198
+ parser.add_argument(
199
+ "--viz_freq", help="Steps per visualisation.", type=int, default=10000
200
+ )
201
+ parser.add_argument(
202
+ "--eval_freq", help="Train epochs per validation.", type=int, default=5
203
+ )
204
+ parser.add_argument(
205
+ "--n_classes", help="Number of classes for DR ICDR.", type=int, default=10
206
+ )
207
+
208
+ # model
209
+ parser.add_argument(
210
+ "--vae",
211
+ help="VAE model: simple/hierarchical.",
212
+ type=str,
213
+ default="hierarchical",
214
+ )
215
+ parser.add_argument(
216
+ "--enc_arch",
217
+ help="Encoder architecture config.",
218
+ type=str,
219
+ default="64b1d2,32b1d2,16b1d2,8b1d8,1b2",
220
+ )
221
+ parser.add_argument(
222
+ "--dec_arch",
223
+ help="Decoder architecture config.",
224
+ type=str,
225
+ default="1b2,8b2,16b2,32b2,64b2",
226
+ )
227
+ parser.add_argument(
228
+ "--cond_prior",
229
+ help="Use a conditional prior.",
230
+ action="store_true",
231
+ default=False,
232
+ )
233
+ parser.add_argument(
234
+ "--widths",
235
+ help="Number of channels.",
236
+ nargs="+",
237
+ type=int,
238
+ default=[16, 32, 48, 64, 128],
239
+ )
240
+ parser.add_argument(
241
+ "--bottleneck", help="Bottleneck width factor.", type=int, default=4
242
+ )
243
+ parser.add_argument(
244
+ "--z_dim", help="Numver of latent channel dims.", type=int, default=16
245
+ )
246
+ parser.add_argument(
247
+ "--z_max_res",
248
+ help="Max resolution of stochastic z layers.",
249
+ type=int,
250
+ default=192,
251
+ )
252
+ parser.add_argument(
253
+ "--bias_max_res",
254
+ help="Learned bias param max resolution.",
255
+ type=int,
256
+ default=64,
257
+ )
258
+ parser.add_argument(
259
+ "--x_like",
260
+ help="x likelihood: {fixed/shared/diag}_{gauss/dgauss}.",
261
+ type=str,
262
+ default="diag_dgauss",
263
+ )
264
+ parser.add_argument(
265
+ "--std_init",
266
+ help="Initial std for x scale. 0 is random.",
267
+ type=float,
268
+ default=0.0,
269
+ )
270
+ parser.add_argument(
271
+ "--parents_x",
272
+ help="Parents of x to condition on.",
273
+ nargs="+",
274
+ default=["mri_seq", "brain_volume", "ventricle_volume", "sex"],
275
+ )
276
+ parser.add_argument(
277
+ "--concat_pa",
278
+ help="Whether to concatenate parents_x.",
279
+ action="store_true",
280
+ default=False,
281
+ )
282
+ parser.add_argument(
283
+ "--context_dim",
284
+ help="Num context variables conditioned on.",
285
+ type=int,
286
+ default=4,
287
+ )
288
+ parser.add_argument(
289
+ "--context_norm",
290
+ help='Conditioning normalisation {"[-1,1]"/"[0,1]"/log_standard}.',
291
+ type=str,
292
+ default="log_standard",
293
+ )
294
+ parser.add_argument(
295
+ "--q_correction",
296
+ help="Use posterior correction.",
297
+ action="store_true",
298
+ default=False,
299
+ )
300
+ return parser