English
medical
brain-data
mri
jesseab commited on
Commit
20a693a
·
1 Parent(s): 8e78bf8

Added autoencoder model main files

Browse files
__pycache__/brlp_lite.cpython-310.pyc ADDED
Binary file (18.2 kB). View file
 
autoencoder-ep-4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d83a2a0ed04a16f4908e91c2c8aab3b20b4f9a763dd838600baba07e694c6b94
3
+ size 55126081
brlp_lite.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Union
3
+ import pandas as pd
4
+ import argparse
5
+ import numpy as np
6
+ import warnings
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from torch.optim.optimizer import Optimizer
11
+ from torch.nn import L1Loss
12
+ from torch.utils.data import DataLoader
13
+ from torch.cuda.amp import autocast
14
+ from torch.amp import GradScaler
15
+
16
+ from generative.networks.nets import (
17
+ AutoencoderKL,
18
+ PatchDiscriminator,
19
+ )
20
+ from generative.losses import PerceptualLoss, PatchAdversarialLoss
21
+ from monai.data import Dataset, PersistentDataset
22
+ from monai.transforms.transform import Transform
23
+ from monai import transforms
24
+ from monai.utils import set_determinism
25
+ from monai.data.meta_tensor import MetaTensor
26
+
27
+ from tqdm import tqdm
28
+ import matplotlib.pyplot as plt
29
+
30
+ from torch.utils.tensorboard import SummaryWriter
31
+
32
+ # choosen resolution
33
+ RESOLUTION = 1.5
34
+
35
+ # shape of the MNI152 (1mm^3) template
36
+ INPUT_SHAPE_1mm = (182, 218, 182)
37
+
38
+ # resampling the MNI152 to (1.5mm^3)
39
+ INPUT_SHAPE_1p5mm = (122, 146, 122)
40
+
41
+ # Adjusting the dimensions to be divisible by 8 (2^3 where 3 are the downsampling layers of the AE)
42
+ INPUT_SHAPE_AE = (120, 144, 120)
43
+
44
+ # Latent shape of the autoencoder
45
+ LATENT_SHAPE_AE = (3, 15, 18, 15)
46
+
47
+
48
+ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
49
+ """
50
+ Load pretrained weights if available.
51
+
52
+ Args:
53
+ checkpoints_path (Optional[str]): path of the checkpoints
54
+ network (nn.Module): the neural network to initialize
55
+
56
+ Returns:
57
+ nn.Module: the initialized neural network
58
+ """
59
+ if checkpoints_path is not None:
60
+ assert os.path.exists(checkpoints_path), 'Invalid path'
61
+ # Using context manager to allow MetaTensor
62
+ with torch.serialization.safe_globals([MetaTensor]):
63
+ #network.load_state_dict(torch.load(checkpoints_path))
64
+ network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
65
+ return network
66
+
67
+
68
+ def init_autoencoder(checkpoints_path: Optional[str] = None) -> nn.Module:
69
+ """
70
+ Load the KL autoencoder (pretrained if `checkpoints_path` points to previous params).
71
+
72
+ Args:
73
+ checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
74
+
75
+ Returns:
76
+ nn.Module: the KL autoencoder
77
+ """
78
+ autoencoder = AutoencoderKL(spatial_dims=3,
79
+ in_channels=1,
80
+ out_channels=1,
81
+ latent_channels=3,
82
+ num_channels=(64, 128, 128, 128),
83
+ num_res_blocks=2,
84
+ norm_num_groups=32,
85
+ norm_eps=1e-06,
86
+ attention_levels=(False, False, False, False),
87
+ with_decoder_nonlocal_attn=False,
88
+ with_encoder_nonlocal_attn=False)
89
+ return load_if(checkpoints_path, autoencoder)
90
+
91
+
92
+ def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Module:
93
+ """
94
+ Load the patch discriminator (pretrained if `checkpoints_path` points to previous params).
95
+
96
+ Args:
97
+ checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
98
+
99
+ Returns:
100
+ nn.Module: the parch discriminator
101
+ """
102
+ patch_discriminator = PatchDiscriminator(spatial_dims=3,
103
+ num_layers_d=3,
104
+ num_channels=32,
105
+ in_channels=1,
106
+ out_channels=1)
107
+ return load_if(checkpoints_path, patch_discriminator)
108
+
109
+
110
+ class KLDivergenceLoss:
111
+ """
112
+ A class for computing the Kullback-Leibler divergence loss.
113
+ """
114
+
115
+ def __call__(self, z_mu: Tensor, z_sigma: Tensor) -> Tensor:
116
+ """
117
+ Computes the KL divergence loss for the given parameters.
118
+
119
+ Args:
120
+ z_mu (Tensor): The mean of the distribution.
121
+ z_sigma (Tensor): The standard deviation of the distribution.
122
+
123
+ Returns:
124
+ Tensor: The computed KL divergence loss, averaged over the batch size.
125
+ """
126
+
127
+ kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
128
+ return torch.sum(kl_loss) / kl_loss.shape[0]
129
+
130
+
131
+ class GradientAccumulation:
132
+ """
133
+ Implements gradient accumulation to facilitate training with larger
134
+ effective batch sizes than what can be physically accommodated in memory.
135
+ """
136
+
137
+ def __init__(self,
138
+ actual_batch_size: int,
139
+ expect_batch_size: int,
140
+ loader_len: int,
141
+ optimizer: Optimizer,
142
+ grad_scaler: Optional[GradScaler] = None) -> None:
143
+ """
144
+ Initializes the GradientAccumulation instance with the necessary parameters for
145
+ managing gradient accumulation.
146
+
147
+ Args:
148
+ actual_batch_size (int): The size of the mini-batches actually used in training.
149
+ expect_batch_size (int): The desired (effective) batch size to simulate through gradient accumulation.
150
+ loader_len (int): The length of the data loader, representing the total number of mini-batches.
151
+ optimizer (Optimizer): The optimizer used for performing optimization steps.
152
+ grad_scaler (Optional[GradScaler], optional): A GradScaler for mixed precision training. Defaults to None.
153
+
154
+ Raises:
155
+ AssertionError: If `expect_batch_size` is not divisible by `actual_batch_size`.
156
+ """
157
+
158
+ assert expect_batch_size % actual_batch_size == 0, \
159
+ 'expect_batch_size must be divisible by actual_batch_size'
160
+ self.actual_batch_size = actual_batch_size
161
+ self.expect_batch_size = expect_batch_size
162
+ self.loader_len = loader_len
163
+ self.optimizer = optimizer
164
+ self.grad_scaler = grad_scaler
165
+
166
+ # if the expected batch size is N=KM, and the actual batch size
167
+ # is M, then we need to accumulate gradient from N / M = K optimization steps.
168
+ self.steps_until_update = expect_batch_size / actual_batch_size
169
+
170
+ def step(self, loss: Tensor, step: int) -> None:
171
+ """
172
+ Performs a backward pass for the given loss and potentially executes an optimization
173
+ step if the conditions for gradient accumulation are met. The optimization step is taken
174
+ only after a specified number of steps (defined by the expected batch size) or at the end
175
+ of the dataset.
176
+
177
+ Args:
178
+ loss (Tensor): The loss value for the current forward pass.
179
+ step (int): The current step (mini-batch index) within the epoch.
180
+ """
181
+ loss = loss / self.expect_batch_size
182
+
183
+ if self.grad_scaler is not None:
184
+ self.grad_scaler.scale(loss).backward()
185
+ else:
186
+ loss.backward()
187
+ if (step + 1) % self.steps_until_update == 0 or (step + 1) == self.loader_len:
188
+ if self.grad_scaler is not None:
189
+ self.grad_scaler.step(self.optimizer)
190
+ self.grad_scaler.update()
191
+ else:
192
+ self.optimizer.step()
193
+ self.optimizer.zero_grad(set_to_none=True)
194
+
195
+
196
+ class AverageLoss:
197
+ """
198
+ Utility class to track losses
199
+ and metrics during training.
200
+ """
201
+
202
+ def __init__(self):
203
+ self.losses_accumulator = {}
204
+
205
+ def put(self, loss_key:str, loss_value:Union[int,float]) -> None:
206
+ """
207
+ Store value
208
+
209
+ Args:
210
+ loss_key (str): Metric name
211
+ loss_value (int | float): Metric value to store
212
+ """
213
+ if loss_key not in self.losses_accumulator:
214
+ self.losses_accumulator[loss_key] = []
215
+ self.losses_accumulator[loss_key].append(loss_value)
216
+
217
+ def pop_avg(self, loss_key:str) -> float:
218
+ """
219
+ Average the stored values of a given metric
220
+
221
+ Args:
222
+ loss_key (str): Metric name
223
+
224
+ Returns:
225
+ float: average of the stored values
226
+ """
227
+ if loss_key not in self.losses_accumulator:
228
+ return None
229
+ losses = self.losses_accumulator[loss_key]
230
+ self.losses_accumulator[loss_key] = []
231
+ return sum(losses) / len(losses)
232
+
233
+ def to_tensorboard(self, writer: SummaryWriter, step: int):
234
+ """
235
+ Logs the average value of all the metrics stored
236
+ into Tensorboard.
237
+
238
+ Args:
239
+ writer (SummaryWriter): Tensorboard writer
240
+ step (int): Tensorboard logging global step
241
+ """
242
+ for metric_key in self.losses_accumulator.keys():
243
+ writer.add_scalar(metric_key, self.pop_avg(metric_key), step)
244
+
245
+
246
+ def get_dataset_from_pd(df: pd.DataFrame, transforms_fn: Transform, cache_dir: Optional[str]) -> Union[Dataset,PersistentDataset]:
247
+ """
248
+ If `cache_dir` is defined, returns a `monai.data.PersistenDataset`.
249
+ Otherwise, returns a simple `monai.data.Dataset`.
250
+
251
+ Args:
252
+ df (pd.DataFrame): Dataframe describing each image in the longitudinal dataset.
253
+ transforms_fn (Transform): Set of transformations
254
+ cache_dir (Optional[str]): Cache directory (ensure enough storage is available)
255
+
256
+ Returns:
257
+ Dataset|PersistentDataset: The dataset
258
+ """
259
+ assert cache_dir is None or os.path.exists(cache_dir), 'Invalid cache directory path'
260
+ data = df.to_dict(orient='records')
261
+ return Dataset(data=data, transform=transforms_fn) if cache_dir is None \
262
+ else PersistentDataset(data=data, transform=transforms_fn, cache_dir=cache_dir)
263
+
264
+
265
+ def tb_display_reconstruction(writer, step, image, recon):
266
+ """
267
+ Display reconstruction in TensorBoard during AE training.
268
+ """
269
+ plt.style.use('dark_background')
270
+ _, ax = plt.subplots(ncols=3, nrows=2, figsize=(7, 5))
271
+ for _ax in ax.flatten(): _ax.set_axis_off()
272
+
273
+ if len(image.shape) == 4: image = image.squeeze(0)
274
+ if len(recon.shape) == 4: recon = recon.squeeze(0)
275
+
276
+ ax[0, 0].set_title('original image', color='cyan')
277
+ ax[0, 0].imshow(image[image.shape[0] // 2, :, :], cmap='gray')
278
+ ax[0, 1].imshow(image[:, image.shape[1] // 2, :], cmap='gray')
279
+ ax[0, 2].imshow(image[:, :, image.shape[2] // 2], cmap='gray')
280
+
281
+ ax[1, 0].set_title('reconstructed image', color='magenta')
282
+ ax[1, 0].imshow(recon[recon.shape[0] // 2, :, :], cmap='gray')
283
+ ax[1, 1].imshow(recon[:, recon.shape[1] // 2, :], cmap='gray')
284
+ ax[1, 2].imshow(recon[:, :, recon.shape[2] // 2], cmap='gray')
285
+
286
+ plt.tight_layout()
287
+ writer.add_figure('Reconstruction', plt.gcf(), global_step=step)
288
+
289
+
290
+ def set_environment(seed: int = 0) -> None:
291
+ """
292
+ Set deterministic behavior for reproducibility.
293
+
294
+ Args:
295
+ seed (int, optional): Seed value. Defaults to 0.
296
+ """
297
+ set_determinism(seed)
298
+
299
+
300
+ def train(
301
+ dataset_csv: str,
302
+ cache_dir: str,
303
+ output_dir: str,
304
+ aekl_ckpt: Optional[str] = None,
305
+ disc_ckpt: Optional[str] = None,
306
+ num_workers: int = 8,
307
+ n_epochs: int = 5,
308
+ max_batch_size: int = 2,
309
+ batch_size: int = 16,
310
+ lr: float = 1e-4,
311
+ aug_p: float = 0.8,
312
+ device: str = ('cuda' if torch.cuda.is_available() else
313
+ 'cpu'),
314
+ ) -> None:
315
+ """
316
+ Train the autoencoder and discriminator models.
317
+
318
+ Args:
319
+ dataset_csv (str): Path to the dataset CSV file.
320
+ cache_dir (str): Directory for caching data.
321
+ output_dir (str): Directory to save model checkpoints.
322
+ aekl_ckpt (Optional[str], optional): Path to the autoencoder checkpoint. Defaults to None.
323
+ disc_ckpt (Optional[str], optional): Path to the discriminator checkpoint. Defaults to None.
324
+ num_workers (int, optional): Number of data loader workers. Defaults to 8.
325
+ n_epochs (int, optional): Number of training epochs. Defaults to 5.
326
+ max_batch_size (int, optional): Actual batch size per iteration. Defaults to 2.
327
+ batch_size (int, optional): Expected (effective) batch size. Defaults to 16.
328
+ lr (float, optional): Learning rate. Defaults to 1e-4.
329
+ aug_p (float, optional): Augmentation probability. Defaults to 0.8.
330
+ device (str, optional): Device to run the training on. Defaults to 'cuda' if available.
331
+ """
332
+ set_environment(0)
333
+
334
+ transforms_fn = transforms.Compose([
335
+ transforms.CopyItemsD(keys={'image_path'}, names=['image']),
336
+ transforms.LoadImageD(image_only=True, keys=['image']),
337
+ transforms.EnsureChannelFirstD(keys=['image']),
338
+ transforms.SpacingD(pixdim=2, keys=['image']),
339
+ transforms.ResizeWithPadOrCropD(spatial_size=(80, 96, 80), mode='minimum', keys=['image']),
340
+ transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
341
+ ])
342
+
343
+ dataset_df = pd.read_csv(dataset_csv)
344
+ train_df = dataset_df[dataset_df.split == 'train']
345
+ trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
346
+
347
+ train_loader = DataLoader(
348
+ dataset=trainset,
349
+ num_workers=num_workers,
350
+ batch_size=max_batch_size,
351
+ shuffle=True,
352
+ persistent_workers=True,
353
+ pin_memory=True
354
+ )
355
+
356
+ print('Device is %s' %(device))
357
+ autoencoder = init_autoencoder(aekl_ckpt).to(device)
358
+ discriminator = init_patch_discriminator(disc_ckpt).to(device)
359
+
360
+ # Loss Weights
361
+ adv_weight = 0.025
362
+ perceptual_weight = 0.001
363
+ kl_weight = 1e-7
364
+
365
+ # Loss Functions
366
+ l1_loss_fn = L1Loss()
367
+ kl_loss_fn = KLDivergenceLoss()
368
+ adv_loss_fn = PatchAdversarialLoss(criterion="least_squares")
369
+
370
+ with warnings.catch_warnings():
371
+ warnings.simplefilter("ignore")
372
+ perc_loss_fn = PerceptualLoss(
373
+ spatial_dims=3,
374
+ network_type="squeeze",
375
+ is_fake_3d=True,
376
+ fake_3d_ratio=0.2
377
+ ).to(device)
378
+
379
+ # Optimizers
380
+ optimizer_g = torch.optim.Adam(autoencoder.parameters(), lr=lr)
381
+ optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr)
382
+
383
+ # Gradient Accumulation
384
+ gradacc_g = GradientAccumulation(
385
+ actual_batch_size=max_batch_size,
386
+ expect_batch_size=batch_size,
387
+ loader_len=len(train_loader),
388
+ optimizer=optimizer_g,
389
+ grad_scaler=GradScaler()
390
+ )
391
+
392
+ gradacc_d = GradientAccumulation(
393
+ actual_batch_size=max_batch_size,
394
+ expect_batch_size=batch_size,
395
+ loader_len=len(train_loader),
396
+ optimizer=optimizer_d,
397
+ grad_scaler=GradScaler()
398
+ )
399
+
400
+ # Logging
401
+ avgloss = AverageLoss()
402
+ writer = SummaryWriter()
403
+ total_counter = 0
404
+
405
+ for epoch in range(n_epochs):
406
+ autoencoder.train()
407
+ progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
408
+ progress_bar.set_description(f'Epoch {epoch + 1}/{n_epochs}')
409
+
410
+ for step, batch in progress_bar:
411
+ # Generator Training
412
+ with autocast(enabled=True):
413
+ images = batch["image"].to(device)
414
+ reconstruction, z_mu, z_sigma = autoencoder(images)
415
+
416
+ logits_fake = discriminator(reconstruction.contiguous().float())[-1]
417
+
418
+ rec_loss = l1_loss_fn(reconstruction.float(), images.float())
419
+ kl_loss = kl_weight * kl_loss_fn(z_mu, z_sigma)
420
+ per_loss = perceptual_weight * perc_loss_fn(reconstruction.float(), images.float())
421
+ gen_loss = adv_weight * adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False)
422
+
423
+ loss_g = rec_loss + kl_loss + per_loss + gen_loss
424
+
425
+ gradacc_g.step(loss_g, step)
426
+
427
+ # Discriminator Training
428
+ with autocast(enabled=True):
429
+ logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
430
+ d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
431
+ logits_real = discriminator(images.contiguous().detach())[-1]
432
+ d_loss_real = adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True)
433
+ discriminator_loss = (d_loss_fake + d_loss_real) * 0.5
434
+ loss_d = adv_weight * discriminator_loss
435
+
436
+ gradacc_d.step(loss_d, step)
437
+
438
+ # Logging
439
+ avgloss.put('Generator/reconstruction_loss', rec_loss.item())
440
+ avgloss.put('Generator/perceptual_loss', per_loss.item())
441
+ avgloss.put('Generator/adversarial_loss', gen_loss.item())
442
+ avgloss.put('Generator/kl_regularization', kl_loss.item())
443
+ avgloss.put('Discriminator/adversarial_loss', loss_d.item())
444
+
445
+ if total_counter % 10 == 0:
446
+ step_log = total_counter // 10
447
+ avgloss.to_tensorboard(writer, step_log)
448
+ tb_display_reconstruction(
449
+ writer,
450
+ step_log,
451
+ images[0].detach().cpu(),
452
+ reconstruction[0].detach().cpu()
453
+ )
454
+
455
+ total_counter += 1
456
+
457
+ # Save the model after each epoch.
458
+ os.makedirs(output_dir, exist_ok=True)
459
+ torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch + 1}.pth'))
460
+ torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch + 1}.pth'))
461
+
462
+ writer.close()
463
+ print("Training completed and models saved.")
464
+
465
+
466
+ def inference(
467
+ dataset_csv: str,
468
+ aekl_ckpt: str,
469
+ output_dir: str,
470
+ device: str = ('cuda' if torch.cuda.is_available() else
471
+ 'cpu'),
472
+ ) -> None:
473
+ """
474
+ Perform inference to encode images into latent space.
475
+
476
+ Args:
477
+ dataset_csv (str): Path to the dataset CSV file.
478
+ aekl_ckpt (str): Path to the autoencoder checkpoint.
479
+ output_dir (str): Directory to save latent representations.
480
+ device (str, optional): Device to run the inference on. Defaults to 'cuda' if available.
481
+ """
482
+ DEVICE = device
483
+
484
+ autoencoder = init_autoencoder(aekl_ckpt).to(DEVICE).eval()
485
+
486
+ transforms_fn = transforms.Compose([
487
+ transforms.CopyItemsD(keys={'image_path'}, names=['image']),
488
+ transforms.LoadImageD(image_only=True, keys=['image']),
489
+ transforms.EnsureChannelFirstD(keys=['image']),
490
+ transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
491
+ transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
492
+ transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
493
+ ])
494
+
495
+ df = pd.read_csv(dataset_csv)
496
+
497
+ os.makedirs(output_dir, exist_ok=True)
498
+
499
+ with torch.no_grad():
500
+ for image_path in tqdm(df.image_path, total=len(df)):
501
+ destpath = os.path.join(
502
+ output_dir,
503
+ os.path.basename(image_path).replace('.nii.gz', '_latent.npz').replace('.nii', '_latent.npz')
504
+ )
505
+ if os.path.exists(destpath):
506
+ continue
507
+ mri_tensor = transforms_fn({'image_path': image_path})['image'].to(DEVICE)
508
+ mri_latent, _ = autoencoder.encode(mri_tensor.unsqueeze(0))
509
+ mri_latent = mri_latent.cpu().squeeze(0).numpy()
510
+ np.savez_compressed(destpath, data=mri_latent)
511
+
512
+ print("Inference completed and latent representations saved.")
513
+
514
+
515
+ def main():
516
+ """
517
+ Main function to parse command-line arguments and execute training or inference.
518
+ """
519
+ parser = argparse.ArgumentParser(description="BRLP Lite Training and Inference Script")
520
+
521
+ subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
522
+
523
+ # Training Subparser
524
+ train_parser = subparsers.add_parser('train', help='Train the models.')
525
+ train_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
526
+ train_parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
527
+ train_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
528
+ train_parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
529
+ train_parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
530
+ train_parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
531
+ train_parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
532
+ train_parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
533
+ train_parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
534
+ train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
535
+ train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
536
+
537
+ # Inference Subparser
538
+ infer_parser = subparsers.add_parser('infererence', help='Run inference to encode images.')
539
+ infer_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
540
+ infer_parser.add_argument('--aekl_ckpt', type=str, required=True, help='Path to the autoencoder checkpoint.')
541
+ infer_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save latent representations.')
542
+
543
+ args = parser.parse_args()
544
+
545
+ if args.command == 'train':
546
+ train(
547
+ dataset_csv=args.dataset_csv,
548
+ cache_dir=args.cache_dir,
549
+ output_dir=args.output_dir,
550
+ aekl_ckpt=args.aekl_ckpt,
551
+ disc_ckpt=args.disc_ckpt,
552
+ num_workers=args.num_workers,
553
+ n_epochs=args.n_epochs,
554
+ max_batch_size=args.max_batch_size,
555
+ batch_size=args.batch_size,
556
+ lr=args.lr,
557
+ aug_p=args.aug_p,
558
+ )
559
+ elif args.command == 'infer':
560
+ inference(
561
+ dataset_csv=args.dataset_csv,
562
+ aekl_ckpt=args.aekl_ckpt,
563
+ output_dir=args.output_dir,
564
+ )
565
+ else:
566
+ parser.print_help()
567
+
568
+
569
+ if __name__ == '__main__':
570
+ main()
discriminator-ep-4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83d59c14472ce1cc582798762f7980361e0239e9f524b6b8b6861dab43fd664e
3
+ size 11098603
inputs_local.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+
3
+ # PyTorch (CUDA or CPU version). For GPU install, see PyTorch docs for the correct wheel.
4
+ torch>=1.12
5
+
6
+ # MONAI v1.2+ has the 'generative' subpackage with AutoencoderKL, PatchDiscriminator, etc.
7
+ monai>=1.2.0
8
+ monai-generative
9
+
10
+ # For perceptual losses in MONAI's generative module.
11
+ lpips
12
+
13
+ # Common Python libraries
14
+ pandas
15
+ numpy
16
+ tqdm
17
+ tensorboard
18
+ matplotlib