Added autoencoder model main files
Browse files- __pycache__/brlp_lite.cpython-310.pyc +0 -0
- autoencoder-ep-4.pth +3 -0
- brlp_lite.py +570 -0
- discriminator-ep-4.pth +3 -0
- inputs_local.csv +0 -0
- requirements.txt +18 -0
__pycache__/brlp_lite.cpython-310.pyc
ADDED
Binary file (18.2 kB). View file
|
|
autoencoder-ep-4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d83a2a0ed04a16f4908e91c2c8aab3b20b4f9a763dd838600baba07e694c6b94
|
3 |
+
size 55126081
|
brlp_lite.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Union
|
3 |
+
import pandas as pd
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import warnings
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.optim.optimizer import Optimizer
|
11 |
+
from torch.nn import L1Loss
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from torch.amp import GradScaler
|
15 |
+
|
16 |
+
from generative.networks.nets import (
|
17 |
+
AutoencoderKL,
|
18 |
+
PatchDiscriminator,
|
19 |
+
)
|
20 |
+
from generative.losses import PerceptualLoss, PatchAdversarialLoss
|
21 |
+
from monai.data import Dataset, PersistentDataset
|
22 |
+
from monai.transforms.transform import Transform
|
23 |
+
from monai import transforms
|
24 |
+
from monai.utils import set_determinism
|
25 |
+
from monai.data.meta_tensor import MetaTensor
|
26 |
+
|
27 |
+
from tqdm import tqdm
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
|
30 |
+
from torch.utils.tensorboard import SummaryWriter
|
31 |
+
|
32 |
+
# choosen resolution
|
33 |
+
RESOLUTION = 1.5
|
34 |
+
|
35 |
+
# shape of the MNI152 (1mm^3) template
|
36 |
+
INPUT_SHAPE_1mm = (182, 218, 182)
|
37 |
+
|
38 |
+
# resampling the MNI152 to (1.5mm^3)
|
39 |
+
INPUT_SHAPE_1p5mm = (122, 146, 122)
|
40 |
+
|
41 |
+
# Adjusting the dimensions to be divisible by 8 (2^3 where 3 are the downsampling layers of the AE)
|
42 |
+
INPUT_SHAPE_AE = (120, 144, 120)
|
43 |
+
|
44 |
+
# Latent shape of the autoencoder
|
45 |
+
LATENT_SHAPE_AE = (3, 15, 18, 15)
|
46 |
+
|
47 |
+
|
48 |
+
def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
|
49 |
+
"""
|
50 |
+
Load pretrained weights if available.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
checkpoints_path (Optional[str]): path of the checkpoints
|
54 |
+
network (nn.Module): the neural network to initialize
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
nn.Module: the initialized neural network
|
58 |
+
"""
|
59 |
+
if checkpoints_path is not None:
|
60 |
+
assert os.path.exists(checkpoints_path), 'Invalid path'
|
61 |
+
# Using context manager to allow MetaTensor
|
62 |
+
with torch.serialization.safe_globals([MetaTensor]):
|
63 |
+
#network.load_state_dict(torch.load(checkpoints_path))
|
64 |
+
network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
|
65 |
+
return network
|
66 |
+
|
67 |
+
|
68 |
+
def init_autoencoder(checkpoints_path: Optional[str] = None) -> nn.Module:
|
69 |
+
"""
|
70 |
+
Load the KL autoencoder (pretrained if `checkpoints_path` points to previous params).
|
71 |
+
|
72 |
+
Args:
|
73 |
+
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
nn.Module: the KL autoencoder
|
77 |
+
"""
|
78 |
+
autoencoder = AutoencoderKL(spatial_dims=3,
|
79 |
+
in_channels=1,
|
80 |
+
out_channels=1,
|
81 |
+
latent_channels=3,
|
82 |
+
num_channels=(64, 128, 128, 128),
|
83 |
+
num_res_blocks=2,
|
84 |
+
norm_num_groups=32,
|
85 |
+
norm_eps=1e-06,
|
86 |
+
attention_levels=(False, False, False, False),
|
87 |
+
with_decoder_nonlocal_attn=False,
|
88 |
+
with_encoder_nonlocal_attn=False)
|
89 |
+
return load_if(checkpoints_path, autoencoder)
|
90 |
+
|
91 |
+
|
92 |
+
def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Module:
|
93 |
+
"""
|
94 |
+
Load the patch discriminator (pretrained if `checkpoints_path` points to previous params).
|
95 |
+
|
96 |
+
Args:
|
97 |
+
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
nn.Module: the parch discriminator
|
101 |
+
"""
|
102 |
+
patch_discriminator = PatchDiscriminator(spatial_dims=3,
|
103 |
+
num_layers_d=3,
|
104 |
+
num_channels=32,
|
105 |
+
in_channels=1,
|
106 |
+
out_channels=1)
|
107 |
+
return load_if(checkpoints_path, patch_discriminator)
|
108 |
+
|
109 |
+
|
110 |
+
class KLDivergenceLoss:
|
111 |
+
"""
|
112 |
+
A class for computing the Kullback-Leibler divergence loss.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __call__(self, z_mu: Tensor, z_sigma: Tensor) -> Tensor:
|
116 |
+
"""
|
117 |
+
Computes the KL divergence loss for the given parameters.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
z_mu (Tensor): The mean of the distribution.
|
121 |
+
z_sigma (Tensor): The standard deviation of the distribution.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Tensor: The computed KL divergence loss, averaged over the batch size.
|
125 |
+
"""
|
126 |
+
|
127 |
+
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
|
128 |
+
return torch.sum(kl_loss) / kl_loss.shape[0]
|
129 |
+
|
130 |
+
|
131 |
+
class GradientAccumulation:
|
132 |
+
"""
|
133 |
+
Implements gradient accumulation to facilitate training with larger
|
134 |
+
effective batch sizes than what can be physically accommodated in memory.
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self,
|
138 |
+
actual_batch_size: int,
|
139 |
+
expect_batch_size: int,
|
140 |
+
loader_len: int,
|
141 |
+
optimizer: Optimizer,
|
142 |
+
grad_scaler: Optional[GradScaler] = None) -> None:
|
143 |
+
"""
|
144 |
+
Initializes the GradientAccumulation instance with the necessary parameters for
|
145 |
+
managing gradient accumulation.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
actual_batch_size (int): The size of the mini-batches actually used in training.
|
149 |
+
expect_batch_size (int): The desired (effective) batch size to simulate through gradient accumulation.
|
150 |
+
loader_len (int): The length of the data loader, representing the total number of mini-batches.
|
151 |
+
optimizer (Optimizer): The optimizer used for performing optimization steps.
|
152 |
+
grad_scaler (Optional[GradScaler], optional): A GradScaler for mixed precision training. Defaults to None.
|
153 |
+
|
154 |
+
Raises:
|
155 |
+
AssertionError: If `expect_batch_size` is not divisible by `actual_batch_size`.
|
156 |
+
"""
|
157 |
+
|
158 |
+
assert expect_batch_size % actual_batch_size == 0, \
|
159 |
+
'expect_batch_size must be divisible by actual_batch_size'
|
160 |
+
self.actual_batch_size = actual_batch_size
|
161 |
+
self.expect_batch_size = expect_batch_size
|
162 |
+
self.loader_len = loader_len
|
163 |
+
self.optimizer = optimizer
|
164 |
+
self.grad_scaler = grad_scaler
|
165 |
+
|
166 |
+
# if the expected batch size is N=KM, and the actual batch size
|
167 |
+
# is M, then we need to accumulate gradient from N / M = K optimization steps.
|
168 |
+
self.steps_until_update = expect_batch_size / actual_batch_size
|
169 |
+
|
170 |
+
def step(self, loss: Tensor, step: int) -> None:
|
171 |
+
"""
|
172 |
+
Performs a backward pass for the given loss and potentially executes an optimization
|
173 |
+
step if the conditions for gradient accumulation are met. The optimization step is taken
|
174 |
+
only after a specified number of steps (defined by the expected batch size) or at the end
|
175 |
+
of the dataset.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
loss (Tensor): The loss value for the current forward pass.
|
179 |
+
step (int): The current step (mini-batch index) within the epoch.
|
180 |
+
"""
|
181 |
+
loss = loss / self.expect_batch_size
|
182 |
+
|
183 |
+
if self.grad_scaler is not None:
|
184 |
+
self.grad_scaler.scale(loss).backward()
|
185 |
+
else:
|
186 |
+
loss.backward()
|
187 |
+
if (step + 1) % self.steps_until_update == 0 or (step + 1) == self.loader_len:
|
188 |
+
if self.grad_scaler is not None:
|
189 |
+
self.grad_scaler.step(self.optimizer)
|
190 |
+
self.grad_scaler.update()
|
191 |
+
else:
|
192 |
+
self.optimizer.step()
|
193 |
+
self.optimizer.zero_grad(set_to_none=True)
|
194 |
+
|
195 |
+
|
196 |
+
class AverageLoss:
|
197 |
+
"""
|
198 |
+
Utility class to track losses
|
199 |
+
and metrics during training.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self):
|
203 |
+
self.losses_accumulator = {}
|
204 |
+
|
205 |
+
def put(self, loss_key:str, loss_value:Union[int,float]) -> None:
|
206 |
+
"""
|
207 |
+
Store value
|
208 |
+
|
209 |
+
Args:
|
210 |
+
loss_key (str): Metric name
|
211 |
+
loss_value (int | float): Metric value to store
|
212 |
+
"""
|
213 |
+
if loss_key not in self.losses_accumulator:
|
214 |
+
self.losses_accumulator[loss_key] = []
|
215 |
+
self.losses_accumulator[loss_key].append(loss_value)
|
216 |
+
|
217 |
+
def pop_avg(self, loss_key:str) -> float:
|
218 |
+
"""
|
219 |
+
Average the stored values of a given metric
|
220 |
+
|
221 |
+
Args:
|
222 |
+
loss_key (str): Metric name
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
float: average of the stored values
|
226 |
+
"""
|
227 |
+
if loss_key not in self.losses_accumulator:
|
228 |
+
return None
|
229 |
+
losses = self.losses_accumulator[loss_key]
|
230 |
+
self.losses_accumulator[loss_key] = []
|
231 |
+
return sum(losses) / len(losses)
|
232 |
+
|
233 |
+
def to_tensorboard(self, writer: SummaryWriter, step: int):
|
234 |
+
"""
|
235 |
+
Logs the average value of all the metrics stored
|
236 |
+
into Tensorboard.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
writer (SummaryWriter): Tensorboard writer
|
240 |
+
step (int): Tensorboard logging global step
|
241 |
+
"""
|
242 |
+
for metric_key in self.losses_accumulator.keys():
|
243 |
+
writer.add_scalar(metric_key, self.pop_avg(metric_key), step)
|
244 |
+
|
245 |
+
|
246 |
+
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn: Transform, cache_dir: Optional[str]) -> Union[Dataset,PersistentDataset]:
|
247 |
+
"""
|
248 |
+
If `cache_dir` is defined, returns a `monai.data.PersistenDataset`.
|
249 |
+
Otherwise, returns a simple `monai.data.Dataset`.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
df (pd.DataFrame): Dataframe describing each image in the longitudinal dataset.
|
253 |
+
transforms_fn (Transform): Set of transformations
|
254 |
+
cache_dir (Optional[str]): Cache directory (ensure enough storage is available)
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
Dataset|PersistentDataset: The dataset
|
258 |
+
"""
|
259 |
+
assert cache_dir is None or os.path.exists(cache_dir), 'Invalid cache directory path'
|
260 |
+
data = df.to_dict(orient='records')
|
261 |
+
return Dataset(data=data, transform=transforms_fn) if cache_dir is None \
|
262 |
+
else PersistentDataset(data=data, transform=transforms_fn, cache_dir=cache_dir)
|
263 |
+
|
264 |
+
|
265 |
+
def tb_display_reconstruction(writer, step, image, recon):
|
266 |
+
"""
|
267 |
+
Display reconstruction in TensorBoard during AE training.
|
268 |
+
"""
|
269 |
+
plt.style.use('dark_background')
|
270 |
+
_, ax = plt.subplots(ncols=3, nrows=2, figsize=(7, 5))
|
271 |
+
for _ax in ax.flatten(): _ax.set_axis_off()
|
272 |
+
|
273 |
+
if len(image.shape) == 4: image = image.squeeze(0)
|
274 |
+
if len(recon.shape) == 4: recon = recon.squeeze(0)
|
275 |
+
|
276 |
+
ax[0, 0].set_title('original image', color='cyan')
|
277 |
+
ax[0, 0].imshow(image[image.shape[0] // 2, :, :], cmap='gray')
|
278 |
+
ax[0, 1].imshow(image[:, image.shape[1] // 2, :], cmap='gray')
|
279 |
+
ax[0, 2].imshow(image[:, :, image.shape[2] // 2], cmap='gray')
|
280 |
+
|
281 |
+
ax[1, 0].set_title('reconstructed image', color='magenta')
|
282 |
+
ax[1, 0].imshow(recon[recon.shape[0] // 2, :, :], cmap='gray')
|
283 |
+
ax[1, 1].imshow(recon[:, recon.shape[1] // 2, :], cmap='gray')
|
284 |
+
ax[1, 2].imshow(recon[:, :, recon.shape[2] // 2], cmap='gray')
|
285 |
+
|
286 |
+
plt.tight_layout()
|
287 |
+
writer.add_figure('Reconstruction', plt.gcf(), global_step=step)
|
288 |
+
|
289 |
+
|
290 |
+
def set_environment(seed: int = 0) -> None:
|
291 |
+
"""
|
292 |
+
Set deterministic behavior for reproducibility.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
seed (int, optional): Seed value. Defaults to 0.
|
296 |
+
"""
|
297 |
+
set_determinism(seed)
|
298 |
+
|
299 |
+
|
300 |
+
def train(
|
301 |
+
dataset_csv: str,
|
302 |
+
cache_dir: str,
|
303 |
+
output_dir: str,
|
304 |
+
aekl_ckpt: Optional[str] = None,
|
305 |
+
disc_ckpt: Optional[str] = None,
|
306 |
+
num_workers: int = 8,
|
307 |
+
n_epochs: int = 5,
|
308 |
+
max_batch_size: int = 2,
|
309 |
+
batch_size: int = 16,
|
310 |
+
lr: float = 1e-4,
|
311 |
+
aug_p: float = 0.8,
|
312 |
+
device: str = ('cuda' if torch.cuda.is_available() else
|
313 |
+
'cpu'),
|
314 |
+
) -> None:
|
315 |
+
"""
|
316 |
+
Train the autoencoder and discriminator models.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
dataset_csv (str): Path to the dataset CSV file.
|
320 |
+
cache_dir (str): Directory for caching data.
|
321 |
+
output_dir (str): Directory to save model checkpoints.
|
322 |
+
aekl_ckpt (Optional[str], optional): Path to the autoencoder checkpoint. Defaults to None.
|
323 |
+
disc_ckpt (Optional[str], optional): Path to the discriminator checkpoint. Defaults to None.
|
324 |
+
num_workers (int, optional): Number of data loader workers. Defaults to 8.
|
325 |
+
n_epochs (int, optional): Number of training epochs. Defaults to 5.
|
326 |
+
max_batch_size (int, optional): Actual batch size per iteration. Defaults to 2.
|
327 |
+
batch_size (int, optional): Expected (effective) batch size. Defaults to 16.
|
328 |
+
lr (float, optional): Learning rate. Defaults to 1e-4.
|
329 |
+
aug_p (float, optional): Augmentation probability. Defaults to 0.8.
|
330 |
+
device (str, optional): Device to run the training on. Defaults to 'cuda' if available.
|
331 |
+
"""
|
332 |
+
set_environment(0)
|
333 |
+
|
334 |
+
transforms_fn = transforms.Compose([
|
335 |
+
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
336 |
+
transforms.LoadImageD(image_only=True, keys=['image']),
|
337 |
+
transforms.EnsureChannelFirstD(keys=['image']),
|
338 |
+
transforms.SpacingD(pixdim=2, keys=['image']),
|
339 |
+
transforms.ResizeWithPadOrCropD(spatial_size=(80, 96, 80), mode='minimum', keys=['image']),
|
340 |
+
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
|
341 |
+
])
|
342 |
+
|
343 |
+
dataset_df = pd.read_csv(dataset_csv)
|
344 |
+
train_df = dataset_df[dataset_df.split == 'train']
|
345 |
+
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
346 |
+
|
347 |
+
train_loader = DataLoader(
|
348 |
+
dataset=trainset,
|
349 |
+
num_workers=num_workers,
|
350 |
+
batch_size=max_batch_size,
|
351 |
+
shuffle=True,
|
352 |
+
persistent_workers=True,
|
353 |
+
pin_memory=True
|
354 |
+
)
|
355 |
+
|
356 |
+
print('Device is %s' %(device))
|
357 |
+
autoencoder = init_autoencoder(aekl_ckpt).to(device)
|
358 |
+
discriminator = init_patch_discriminator(disc_ckpt).to(device)
|
359 |
+
|
360 |
+
# Loss Weights
|
361 |
+
adv_weight = 0.025
|
362 |
+
perceptual_weight = 0.001
|
363 |
+
kl_weight = 1e-7
|
364 |
+
|
365 |
+
# Loss Functions
|
366 |
+
l1_loss_fn = L1Loss()
|
367 |
+
kl_loss_fn = KLDivergenceLoss()
|
368 |
+
adv_loss_fn = PatchAdversarialLoss(criterion="least_squares")
|
369 |
+
|
370 |
+
with warnings.catch_warnings():
|
371 |
+
warnings.simplefilter("ignore")
|
372 |
+
perc_loss_fn = PerceptualLoss(
|
373 |
+
spatial_dims=3,
|
374 |
+
network_type="squeeze",
|
375 |
+
is_fake_3d=True,
|
376 |
+
fake_3d_ratio=0.2
|
377 |
+
).to(device)
|
378 |
+
|
379 |
+
# Optimizers
|
380 |
+
optimizer_g = torch.optim.Adam(autoencoder.parameters(), lr=lr)
|
381 |
+
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr)
|
382 |
+
|
383 |
+
# Gradient Accumulation
|
384 |
+
gradacc_g = GradientAccumulation(
|
385 |
+
actual_batch_size=max_batch_size,
|
386 |
+
expect_batch_size=batch_size,
|
387 |
+
loader_len=len(train_loader),
|
388 |
+
optimizer=optimizer_g,
|
389 |
+
grad_scaler=GradScaler()
|
390 |
+
)
|
391 |
+
|
392 |
+
gradacc_d = GradientAccumulation(
|
393 |
+
actual_batch_size=max_batch_size,
|
394 |
+
expect_batch_size=batch_size,
|
395 |
+
loader_len=len(train_loader),
|
396 |
+
optimizer=optimizer_d,
|
397 |
+
grad_scaler=GradScaler()
|
398 |
+
)
|
399 |
+
|
400 |
+
# Logging
|
401 |
+
avgloss = AverageLoss()
|
402 |
+
writer = SummaryWriter()
|
403 |
+
total_counter = 0
|
404 |
+
|
405 |
+
for epoch in range(n_epochs):
|
406 |
+
autoencoder.train()
|
407 |
+
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
|
408 |
+
progress_bar.set_description(f'Epoch {epoch + 1}/{n_epochs}')
|
409 |
+
|
410 |
+
for step, batch in progress_bar:
|
411 |
+
# Generator Training
|
412 |
+
with autocast(enabled=True):
|
413 |
+
images = batch["image"].to(device)
|
414 |
+
reconstruction, z_mu, z_sigma = autoencoder(images)
|
415 |
+
|
416 |
+
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
|
417 |
+
|
418 |
+
rec_loss = l1_loss_fn(reconstruction.float(), images.float())
|
419 |
+
kl_loss = kl_weight * kl_loss_fn(z_mu, z_sigma)
|
420 |
+
per_loss = perceptual_weight * perc_loss_fn(reconstruction.float(), images.float())
|
421 |
+
gen_loss = adv_weight * adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False)
|
422 |
+
|
423 |
+
loss_g = rec_loss + kl_loss + per_loss + gen_loss
|
424 |
+
|
425 |
+
gradacc_g.step(loss_g, step)
|
426 |
+
|
427 |
+
# Discriminator Training
|
428 |
+
with autocast(enabled=True):
|
429 |
+
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
430 |
+
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
|
431 |
+
logits_real = discriminator(images.contiguous().detach())[-1]
|
432 |
+
d_loss_real = adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True)
|
433 |
+
discriminator_loss = (d_loss_fake + d_loss_real) * 0.5
|
434 |
+
loss_d = adv_weight * discriminator_loss
|
435 |
+
|
436 |
+
gradacc_d.step(loss_d, step)
|
437 |
+
|
438 |
+
# Logging
|
439 |
+
avgloss.put('Generator/reconstruction_loss', rec_loss.item())
|
440 |
+
avgloss.put('Generator/perceptual_loss', per_loss.item())
|
441 |
+
avgloss.put('Generator/adversarial_loss', gen_loss.item())
|
442 |
+
avgloss.put('Generator/kl_regularization', kl_loss.item())
|
443 |
+
avgloss.put('Discriminator/adversarial_loss', loss_d.item())
|
444 |
+
|
445 |
+
if total_counter % 10 == 0:
|
446 |
+
step_log = total_counter // 10
|
447 |
+
avgloss.to_tensorboard(writer, step_log)
|
448 |
+
tb_display_reconstruction(
|
449 |
+
writer,
|
450 |
+
step_log,
|
451 |
+
images[0].detach().cpu(),
|
452 |
+
reconstruction[0].detach().cpu()
|
453 |
+
)
|
454 |
+
|
455 |
+
total_counter += 1
|
456 |
+
|
457 |
+
# Save the model after each epoch.
|
458 |
+
os.makedirs(output_dir, exist_ok=True)
|
459 |
+
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch + 1}.pth'))
|
460 |
+
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch + 1}.pth'))
|
461 |
+
|
462 |
+
writer.close()
|
463 |
+
print("Training completed and models saved.")
|
464 |
+
|
465 |
+
|
466 |
+
def inference(
|
467 |
+
dataset_csv: str,
|
468 |
+
aekl_ckpt: str,
|
469 |
+
output_dir: str,
|
470 |
+
device: str = ('cuda' if torch.cuda.is_available() else
|
471 |
+
'cpu'),
|
472 |
+
) -> None:
|
473 |
+
"""
|
474 |
+
Perform inference to encode images into latent space.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
dataset_csv (str): Path to the dataset CSV file.
|
478 |
+
aekl_ckpt (str): Path to the autoencoder checkpoint.
|
479 |
+
output_dir (str): Directory to save latent representations.
|
480 |
+
device (str, optional): Device to run the inference on. Defaults to 'cuda' if available.
|
481 |
+
"""
|
482 |
+
DEVICE = device
|
483 |
+
|
484 |
+
autoencoder = init_autoencoder(aekl_ckpt).to(DEVICE).eval()
|
485 |
+
|
486 |
+
transforms_fn = transforms.Compose([
|
487 |
+
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
488 |
+
transforms.LoadImageD(image_only=True, keys=['image']),
|
489 |
+
transforms.EnsureChannelFirstD(keys=['image']),
|
490 |
+
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
491 |
+
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
492 |
+
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
|
493 |
+
])
|
494 |
+
|
495 |
+
df = pd.read_csv(dataset_csv)
|
496 |
+
|
497 |
+
os.makedirs(output_dir, exist_ok=True)
|
498 |
+
|
499 |
+
with torch.no_grad():
|
500 |
+
for image_path in tqdm(df.image_path, total=len(df)):
|
501 |
+
destpath = os.path.join(
|
502 |
+
output_dir,
|
503 |
+
os.path.basename(image_path).replace('.nii.gz', '_latent.npz').replace('.nii', '_latent.npz')
|
504 |
+
)
|
505 |
+
if os.path.exists(destpath):
|
506 |
+
continue
|
507 |
+
mri_tensor = transforms_fn({'image_path': image_path})['image'].to(DEVICE)
|
508 |
+
mri_latent, _ = autoencoder.encode(mri_tensor.unsqueeze(0))
|
509 |
+
mri_latent = mri_latent.cpu().squeeze(0).numpy()
|
510 |
+
np.savez_compressed(destpath, data=mri_latent)
|
511 |
+
|
512 |
+
print("Inference completed and latent representations saved.")
|
513 |
+
|
514 |
+
|
515 |
+
def main():
|
516 |
+
"""
|
517 |
+
Main function to parse command-line arguments and execute training or inference.
|
518 |
+
"""
|
519 |
+
parser = argparse.ArgumentParser(description="BRLP Lite Training and Inference Script")
|
520 |
+
|
521 |
+
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
|
522 |
+
|
523 |
+
# Training Subparser
|
524 |
+
train_parser = subparsers.add_parser('train', help='Train the models.')
|
525 |
+
train_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
526 |
+
train_parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
|
527 |
+
train_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
|
528 |
+
train_parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
|
529 |
+
train_parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
|
530 |
+
train_parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
|
531 |
+
train_parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
|
532 |
+
train_parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
|
533 |
+
train_parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
|
534 |
+
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
535 |
+
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
536 |
+
|
537 |
+
# Inference Subparser
|
538 |
+
infer_parser = subparsers.add_parser('infererence', help='Run inference to encode images.')
|
539 |
+
infer_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
540 |
+
infer_parser.add_argument('--aekl_ckpt', type=str, required=True, help='Path to the autoencoder checkpoint.')
|
541 |
+
infer_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save latent representations.')
|
542 |
+
|
543 |
+
args = parser.parse_args()
|
544 |
+
|
545 |
+
if args.command == 'train':
|
546 |
+
train(
|
547 |
+
dataset_csv=args.dataset_csv,
|
548 |
+
cache_dir=args.cache_dir,
|
549 |
+
output_dir=args.output_dir,
|
550 |
+
aekl_ckpt=args.aekl_ckpt,
|
551 |
+
disc_ckpt=args.disc_ckpt,
|
552 |
+
num_workers=args.num_workers,
|
553 |
+
n_epochs=args.n_epochs,
|
554 |
+
max_batch_size=args.max_batch_size,
|
555 |
+
batch_size=args.batch_size,
|
556 |
+
lr=args.lr,
|
557 |
+
aug_p=args.aug_p,
|
558 |
+
)
|
559 |
+
elif args.command == 'infer':
|
560 |
+
inference(
|
561 |
+
dataset_csv=args.dataset_csv,
|
562 |
+
aekl_ckpt=args.aekl_ckpt,
|
563 |
+
output_dir=args.output_dir,
|
564 |
+
)
|
565 |
+
else:
|
566 |
+
parser.print_help()
|
567 |
+
|
568 |
+
|
569 |
+
if __name__ == '__main__':
|
570 |
+
main()
|
discriminator-ep-4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83d59c14472ce1cc582798762f7980361e0239e9f524b6b8b6861dab43fd664e
|
3 |
+
size 11098603
|
inputs_local.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# requirements.txt
|
2 |
+
|
3 |
+
# PyTorch (CUDA or CPU version). For GPU install, see PyTorch docs for the correct wheel.
|
4 |
+
torch>=1.12
|
5 |
+
|
6 |
+
# MONAI v1.2+ has the 'generative' subpackage with AutoencoderKL, PatchDiscriminator, etc.
|
7 |
+
monai>=1.2.0
|
8 |
+
monai-generative
|
9 |
+
|
10 |
+
# For perceptual losses in MONAI's generative module.
|
11 |
+
lpips
|
12 |
+
|
13 |
+
# Common Python libraries
|
14 |
+
pandas
|
15 |
+
numpy
|
16 |
+
tqdm
|
17 |
+
tensorboard
|
18 |
+
matplotlib
|