Code updates
Browse files- .DS_Store +0 -0
- __pycache__/brain2vec.cpython-310.pyc +0 -0
- model.py → inference_brain2vec.py +140 -11
- requirements.txt +8 -4
- brain2vec.py → train_brain2vec.py +21 -114
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
__pycache__/brain2vec.cpython-310.pyc
DELETED
Binary file (18.8 kB)
|
|
model.py → inference_brain2vec.py
RENAMED
@@ -1,9 +1,34 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
7 |
from monai.transforms import (
|
8 |
Compose,
|
9 |
CopyItemsD,
|
@@ -14,12 +39,12 @@ from monai.transforms import (
|
|
14 |
ScaleIntensityD,
|
15 |
)
|
16 |
from generative.networks.nets import AutoencoderKL
|
|
|
|
|
17 |
|
18 |
-
# Constants for your typical config
|
19 |
RESOLUTION = 2
|
20 |
INPUT_SHAPE_AE = (80, 96, 80)
|
21 |
|
22 |
-
# Define the exact transform pipeline for input MRI
|
23 |
transforms_fn = Compose([
|
24 |
CopyItemsD(keys={'image_path'}, names=['image']),
|
25 |
LoadImageD(image_only=True, keys=['image']),
|
@@ -29,15 +54,23 @@ transforms_fn = Compose([
|
|
29 |
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
30 |
])
|
31 |
|
|
|
32 |
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
33 |
"""
|
34 |
Preprocess an MRI using MONAI transforms to produce
|
35 |
-
a 5D tensor (batch=1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
"""
|
37 |
data_dict = {"image_path": image_path}
|
38 |
output_dict = transforms_fn(data_dict)
|
39 |
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
40 |
-
image_tensor = image_tensor.unsqueeze(0) # => (
|
41 |
return image_tensor.to(device)
|
42 |
|
43 |
|
@@ -63,11 +96,11 @@ class Brain2vec(AutoencoderKL):
|
|
63 |
Otherwise, return an uninitialized model.
|
64 |
|
65 |
Args:
|
66 |
-
checkpoint_path (Optional[str]):
|
67 |
device (str): "cpu", "cuda", "mps", etc.
|
68 |
|
69 |
Returns:
|
70 |
-
nn.Module:
|
71 |
"""
|
72 |
model = Brain2vec(
|
73 |
spatial_dims=3,
|
@@ -90,5 +123,101 @@ class Brain2vec(AutoencoderKL):
|
|
90 |
model.load_state_dict(state_dict)
|
91 |
|
92 |
model.to(device)
|
93 |
-
model.eval()
|
94 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
inference_brain2vec.py
|
5 |
+
|
6 |
+
Loads a pretrained Brain2vec VAE (AutoencoderKL) model and performs inference
|
7 |
+
on one or more MRI images, generating reconstructions and latent parameters
|
8 |
+
(z_mu, z_sigma).
|
9 |
|
10 |
+
Example usage:
|
11 |
+
|
12 |
+
# 1) Multiple file paths
|
13 |
+
python inference_brain2vec.py \
|
14 |
+
--checkpoint_path /path/to/autoencoder_checkpoint.pth \
|
15 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
16 |
+
--output_dir ./vae_inference_outputs \
|
17 |
+
--device cuda
|
18 |
+
|
19 |
+
# 2) Use a CSV containing image paths
|
20 |
+
python inference_brain2vec.py \
|
21 |
+
--checkpoint_path /path/to/autoencoder_checkpoint.pth \
|
22 |
+
--csv_input /path/to/images.csv \
|
23 |
+
--output_dir ./vae_inference_outputs
|
24 |
+
"""
|
25 |
+
|
26 |
+
import os
|
27 |
+
import argparse
|
28 |
+
import numpy as np
|
29 |
import torch
|
30 |
import torch.nn as nn
|
31 |
+
from typing import Optional
|
32 |
from monai.transforms import (
|
33 |
Compose,
|
34 |
CopyItemsD,
|
|
|
39 |
ScaleIntensityD,
|
40 |
)
|
41 |
from generative.networks.nets import AutoencoderKL
|
42 |
+
import pandas as pd
|
43 |
+
|
44 |
|
|
|
45 |
RESOLUTION = 2
|
46 |
INPUT_SHAPE_AE = (80, 96, 80)
|
47 |
|
|
|
48 |
transforms_fn = Compose([
|
49 |
CopyItemsD(keys={'image_path'}, names=['image']),
|
50 |
LoadImageD(image_only=True, keys=['image']),
|
|
|
54 |
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
55 |
])
|
56 |
|
57 |
+
|
58 |
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
59 |
"""
|
60 |
Preprocess an MRI using MONAI transforms to produce
|
61 |
+
a 5D tensor (batch=1, channel=1, D, H, W) for inference.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
image_path (str): Path to the MRI (e.g. .nii.gz).
|
65 |
+
device (str): Device to place the tensor on.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: Shape (1, 1, D, H, W).
|
69 |
"""
|
70 |
data_dict = {"image_path": image_path}
|
71 |
output_dict = transforms_fn(data_dict)
|
72 |
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
73 |
+
image_tensor = image_tensor.unsqueeze(0) # => (1, 1, D, H, W)
|
74 |
return image_tensor.to(device)
|
75 |
|
76 |
|
|
|
96 |
Otherwise, return an uninitialized model.
|
97 |
|
98 |
Args:
|
99 |
+
checkpoint_path (Optional[str]): Path to a .pth checkpoint file.
|
100 |
device (str): "cpu", "cuda", "mps", etc.
|
101 |
|
102 |
Returns:
|
103 |
+
nn.Module: The loaded Brain2vec model on the chosen device.
|
104 |
"""
|
105 |
model = Brain2vec(
|
106 |
spatial_dims=3,
|
|
|
123 |
model.load_state_dict(state_dict)
|
124 |
|
125 |
model.to(device)
|
126 |
+
model.eval()
|
127 |
+
return model
|
128 |
+
|
129 |
+
|
130 |
+
def main() -> None:
|
131 |
+
"""
|
132 |
+
Main function to parse command-line arguments and run inference
|
133 |
+
with a pretrained Brain2vec model.
|
134 |
+
"""
|
135 |
+
parser = argparse.ArgumentParser(
|
136 |
+
description="Inference script for a Brain2vec (VAE) model."
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--checkpoint_path", type=str, required=True,
|
140 |
+
help="Path to the .pth checkpoint of the pretrained Brain2vec model."
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--output_dir", type=str, default="./vae_inference_outputs",
|
144 |
+
help="Directory to save reconstructions and latent parameters."
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--device", type=str, default="cpu",
|
148 |
+
help="Device to run inference on ('cpu', 'cuda', etc.)."
|
149 |
+
)
|
150 |
+
# Two ways to supply images: multiple file paths or a CSV
|
151 |
+
parser.add_argument(
|
152 |
+
"--input_images", type=str, nargs="*",
|
153 |
+
help="One or more MRI file paths (e.g. .nii.gz)."
|
154 |
+
)
|
155 |
+
parser.add_argument(
|
156 |
+
"--csv_input", type=str,
|
157 |
+
help="Path to a CSV file with an 'image_path' column."
|
158 |
+
)
|
159 |
+
args = parser.parse_args()
|
160 |
+
|
161 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
162 |
+
|
163 |
+
# Load the model
|
164 |
+
model = Brain2vec.from_pretrained(
|
165 |
+
checkpoint_path=args.checkpoint_path,
|
166 |
+
device=args.device
|
167 |
+
)
|
168 |
+
|
169 |
+
# Gather image paths
|
170 |
+
if args.csv_input:
|
171 |
+
df = pd.read_csv(args.csv_input)
|
172 |
+
if "image_path" not in df.columns:
|
173 |
+
raise ValueError("CSV must contain a column named 'image_path'.")
|
174 |
+
image_paths = df["image_path"].tolist()
|
175 |
+
else:
|
176 |
+
if not args.input_images:
|
177 |
+
raise ValueError("Must provide either --csv_input or --input_images.")
|
178 |
+
image_paths = args.input_images
|
179 |
+
|
180 |
+
# Lists for stacking latent parameters later
|
181 |
+
all_z_mu = []
|
182 |
+
all_z_sigma = []
|
183 |
+
|
184 |
+
# Inference on each image
|
185 |
+
for i, img_path in enumerate(image_paths):
|
186 |
+
if not os.path.exists(img_path):
|
187 |
+
raise FileNotFoundError(f"Image not found: {img_path}")
|
188 |
+
|
189 |
+
print(f"[INFO] Processing image {i}: {img_path}")
|
190 |
+
img_tensor = preprocess_mri(img_path, device=args.device)
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
recon, z_mu, z_sigma = model.forward(img_tensor)
|
194 |
+
|
195 |
+
# Convert to NumPy
|
196 |
+
recon_np = recon.detach().cpu().numpy() # shape: (1, 1, D, H, W)
|
197 |
+
z_mu_np = z_mu.detach().cpu().numpy() # shape: (1, latent_channels, ...)
|
198 |
+
z_sigma_np = z_sigma.detach().cpu().numpy()
|
199 |
+
|
200 |
+
# Save each reconstruction (per image) as .npy
|
201 |
+
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
202 |
+
np.save(recon_path, recon_np)
|
203 |
+
print(f"[INFO] Saved reconstruction to {recon_path}")
|
204 |
+
|
205 |
+
# Store latent parameters for optional combined saving
|
206 |
+
all_z_mu.append(z_mu_np)
|
207 |
+
all_z_sigma.append(z_sigma_np)
|
208 |
+
|
209 |
+
# Combine latent parameters from all images and save
|
210 |
+
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
211 |
+
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
212 |
+
|
213 |
+
mu_path = os.path.join(args.output_dir, "all_z_mu.npy")
|
214 |
+
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
215 |
+
np.save(mu_path, stacked_mu)
|
216 |
+
np.save(sigma_path, stacked_sigma)
|
217 |
+
|
218 |
+
print(f"[INFO] Saved z_mu of shape {stacked_mu.shape} to {mu_path}")
|
219 |
+
print(f"[INFO] Saved z_sigma of shape {stacked_sigma.shape} to {sigma_path}")
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
main()
|
requirements.txt
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
# requirements.txt
|
2 |
|
3 |
-
# PyTorch (CUDA or CPU version).
|
4 |
torch>=1.12
|
5 |
|
6 |
-
# MONAI
|
7 |
-
monai-weekly
|
8 |
monai-generative
|
9 |
|
|
|
|
|
|
|
|
|
10 |
# For perceptual losses in MONAI's generative module.
|
11 |
lpips
|
12 |
|
@@ -17,4 +20,5 @@ nibabel
|
|
17 |
tqdm
|
18 |
tensorboard
|
19 |
matplotlib
|
20 |
-
datasets
|
|
|
|
1 |
# requirements.txt
|
2 |
|
3 |
+
# PyTorch (CUDA or CPU version).
|
4 |
torch>=1.12
|
5 |
|
6 |
+
# Install MONAI Generative first
|
|
|
7 |
monai-generative
|
8 |
|
9 |
+
# Now force reinstall MONAI Weekly so its (newer) MONAI version takes precedence
|
10 |
+
--force-reinstall
|
11 |
+
monai-weekly
|
12 |
+
|
13 |
# For perceptual losses in MONAI's generative module.
|
14 |
lpips
|
15 |
|
|
|
20 |
tqdm
|
21 |
tensorboard
|
22 |
matplotlib
|
23 |
+
datasets
|
24 |
+
scikit-learn
|
brain2vec.py → train_brain2vec.py
RENAMED
@@ -1,35 +1,20 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
-
# SOFTWARE.
|
22 |
-
|
23 |
-
# Forked from: https://github.com/LemuelPuglisi/BrLP
|
24 |
-
|
25 |
-
# @inproceedings{puglisi2024enhancing,
|
26 |
-
# title={Enhancing spatiotemporal disease progression models via latent diffusion and prior knowledge},
|
27 |
-
# author={Puglisi, Lemuel and Alexander, Daniel C and Rav{\`\i}, Daniele},
|
28 |
-
# booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
|
29 |
-
# pages={173--183},
|
30 |
-
# year={2024},
|
31 |
-
# organization={Springer}
|
32 |
-
# }
|
33 |
|
34 |
import os
|
35 |
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
@@ -37,7 +22,6 @@ from typing import Optional, Union
|
|
37 |
import pandas as pd
|
38 |
import argparse
|
39 |
import numpy as np
|
40 |
-
|
41 |
import warnings
|
42 |
import torch
|
43 |
import torch.nn as nn
|
@@ -47,7 +31,6 @@ from torch.nn import L1Loss
|
|
47 |
from torch.utils.data import DataLoader
|
48 |
from torch.amp import autocast
|
49 |
from torch.amp import GradScaler
|
50 |
-
|
51 |
from generative.networks.nets import (
|
52 |
AutoencoderKL,
|
53 |
PatchDiscriminator,
|
@@ -65,13 +48,11 @@ torch.serialization.add_safe_globals([_reconstruct])
|
|
65 |
torch.serialization.add_safe_globals([MetaTensor])
|
66 |
torch.serialization.add_safe_globals([ndarray])
|
67 |
torch.serialization.add_safe_globals([dtype])
|
68 |
-
|
69 |
from tqdm import tqdm
|
70 |
import matplotlib.pyplot as plt
|
71 |
-
|
72 |
from torch.utils.tensorboard import SummaryWriter
|
73 |
|
74 |
-
#
|
75 |
RESOLUTION = 2
|
76 |
|
77 |
# shape of the MNI152 (1mm^3) template
|
@@ -101,10 +82,7 @@ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
|
|
101 |
"""
|
102 |
if checkpoints_path is not None:
|
103 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
104 |
-
# Using context manager to allow MetaTensor
|
105 |
-
#with torch.serialization.safe_globals([MetaTensor]):
|
106 |
network.load_state_dict(torch.load(checkpoints_path))
|
107 |
-
#network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
|
108 |
return network
|
109 |
|
110 |
|
@@ -140,7 +118,7 @@ def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Modul
|
|
140 |
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
141 |
|
142 |
Returns:
|
143 |
-
nn.Module: the
|
144 |
"""
|
145 |
patch_discriminator = PatchDiscriminator(spatial_dims=3,
|
146 |
num_layers_d=3,
|
@@ -387,22 +365,6 @@ def train(
|
|
387 |
train_df = dataset_df[dataset_df.split == 'train']
|
388 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
389 |
|
390 |
-
print(f"[DEBUG] Using cache_dir={cache_dir}")
|
391 |
-
print(f"[DEBUG] trainset length={len(trainset)}")
|
392 |
-
|
393 |
-
try:
|
394 |
-
sample_debug = trainset[0] # Force a transform on the first record
|
395 |
-
print("[DEBUG] Successfully loaded sample 0 from trainset.")
|
396 |
-
except Exception as e:
|
397 |
-
print("[DEBUG] Error loading sample 0:", e)
|
398 |
-
|
399 |
-
import glob
|
400 |
-
|
401 |
-
hashfiles = glob.glob(os.path.join(cache_dir, "*.pt"))
|
402 |
-
print(f"[DEBUG] Found {len(hashfiles)} cached .pt files in {cache_dir}")
|
403 |
-
if hashfiles:
|
404 |
-
print("[DEBUG] Example cache file:", hashfiles[0])
|
405 |
-
|
406 |
train_loader = DataLoader(
|
407 |
dataset=trainset,
|
408 |
num_workers=num_workers,
|
@@ -523,60 +485,11 @@ def train(
|
|
523 |
print("Training completed and models saved.")
|
524 |
|
525 |
|
526 |
-
def inference(
|
527 |
-
dataset_csv: str,
|
528 |
-
aekl_ckpt: str,
|
529 |
-
output_dir: str,
|
530 |
-
device: str = ('cuda' if torch.cuda.is_available() else
|
531 |
-
'cpu'),
|
532 |
-
) -> None:
|
533 |
-
"""
|
534 |
-
Perform inference to encode images into latent space.
|
535 |
-
|
536 |
-
Args:
|
537 |
-
dataset_csv (str): Path to the dataset CSV file.
|
538 |
-
aekl_ckpt (str): Path to the autoencoder checkpoint.
|
539 |
-
output_dir (str): Directory to save latent representations.
|
540 |
-
device (str, optional): Device to run the inference on. Defaults to 'cuda' if available.
|
541 |
-
"""
|
542 |
-
DEVICE = device
|
543 |
-
|
544 |
-
autoencoder = init_autoencoder(aekl_ckpt).to(DEVICE).eval()
|
545 |
-
|
546 |
-
transforms_fn = transforms.Compose([
|
547 |
-
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
548 |
-
transforms.LoadImageD(image_only=True, keys=['image']),
|
549 |
-
transforms.EnsureChannelFirstD(keys=['image']),
|
550 |
-
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
551 |
-
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
552 |
-
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
|
553 |
-
])
|
554 |
-
|
555 |
-
df = pd.read_csv(dataset_csv)
|
556 |
-
|
557 |
-
os.makedirs(output_dir, exist_ok=True)
|
558 |
-
|
559 |
-
with torch.no_grad():
|
560 |
-
for image_path in tqdm(df.image_path, total=len(df)):
|
561 |
-
destpath = os.path.join(
|
562 |
-
output_dir,
|
563 |
-
os.path.basename(image_path).replace('.nii.gz', '_embeddings.npz').replace('.nii', '_embeddings.npz')
|
564 |
-
)
|
565 |
-
if os.path.exists(destpath):
|
566 |
-
continue
|
567 |
-
mri_tensor = transforms_fn({'image_path': image_path})['image'].to(DEVICE)
|
568 |
-
mri_latent, _ = autoencoder.encode(mri_tensor.unsqueeze(0))
|
569 |
-
mri_latent = mri_latent.cpu().squeeze(0).numpy()
|
570 |
-
np.savez_compressed(destpath, data=mri_latent)
|
571 |
-
|
572 |
-
print("Inference completed and latent representations saved.")
|
573 |
-
|
574 |
-
|
575 |
def main():
|
576 |
"""
|
577 |
-
Main function to parse command-line arguments and execute training
|
578 |
"""
|
579 |
-
parser = argparse.ArgumentParser(description="brain2vec Training
|
580 |
|
581 |
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
|
582 |
|
@@ -594,12 +507,6 @@ def main():
|
|
594 |
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
595 |
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
596 |
|
597 |
-
# Inference Subparser
|
598 |
-
infer_parser = subparsers.add_parser('inference', help='Run inference to encode images.')
|
599 |
-
infer_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
600 |
-
infer_parser.add_argument('--aekl_ckpt', type=str, required=True, help='Path to the autoencoder checkpoint.')
|
601 |
-
infer_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save latent representations.')
|
602 |
-
|
603 |
args = parser.parse_args()
|
604 |
|
605 |
if args.command == 'train':
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
train_brain2vec.py
|
5 |
+
|
6 |
+
Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements
|
7 |
+
autoencoder training with adversarial loss (via a patch discriminator),
|
8 |
+
a perceptual loss, and KL divergence regularization for robust latent
|
9 |
+
representations.
|
10 |
+
|
11 |
+
Example usage:
|
12 |
+
python train_brain2vec.py train \
|
13 |
+
--dataset_csv /path/to/dataset.csv \
|
14 |
+
--cache_dir /path/to/cache \
|
15 |
+
--output_dir /path/to/output_dir \
|
16 |
+
--n_epochs 10
|
17 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
import os
|
20 |
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
|
|
22 |
import pandas as pd
|
23 |
import argparse
|
24 |
import numpy as np
|
|
|
25 |
import warnings
|
26 |
import torch
|
27 |
import torch.nn as nn
|
|
|
31 |
from torch.utils.data import DataLoader
|
32 |
from torch.amp import autocast
|
33 |
from torch.amp import GradScaler
|
|
|
34 |
from generative.networks.nets import (
|
35 |
AutoencoderKL,
|
36 |
PatchDiscriminator,
|
|
|
48 |
torch.serialization.add_safe_globals([MetaTensor])
|
49 |
torch.serialization.add_safe_globals([ndarray])
|
50 |
torch.serialization.add_safe_globals([dtype])
|
|
|
51 |
from tqdm import tqdm
|
52 |
import matplotlib.pyplot as plt
|
|
|
53 |
from torch.utils.tensorboard import SummaryWriter
|
54 |
|
55 |
+
# voxel resolution
|
56 |
RESOLUTION = 2
|
57 |
|
58 |
# shape of the MNI152 (1mm^3) template
|
|
|
82 |
"""
|
83 |
if checkpoints_path is not None:
|
84 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
|
|
|
|
85 |
network.load_state_dict(torch.load(checkpoints_path))
|
|
|
86 |
return network
|
87 |
|
88 |
|
|
|
118 |
checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
|
119 |
|
120 |
Returns:
|
121 |
+
nn.Module: the patch discriminator
|
122 |
"""
|
123 |
patch_discriminator = PatchDiscriminator(spatial_dims=3,
|
124 |
num_layers_d=3,
|
|
|
365 |
train_df = dataset_df[dataset_df.split == 'train']
|
366 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
train_loader = DataLoader(
|
369 |
dataset=trainset,
|
370 |
num_workers=num_workers,
|
|
|
485 |
print("Training completed and models saved.")
|
486 |
|
487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
def main():
|
489 |
"""
|
490 |
+
Main function to parse command-line arguments and execute training.
|
491 |
"""
|
492 |
+
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
493 |
|
494 |
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
|
495 |
|
|
|
507 |
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
508 |
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
args = parser.parse_args()
|
511 |
|
512 |
if args.command == 'train':
|