English
medical
brain-data
mri
jesseab commited on
Commit
ac3730a
·
1 Parent(s): 3ae8863

Code updates

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
__pycache__/brain2vec.cpython-310.pyc DELETED
Binary file (18.8 kB)
 
model.py → inference_brain2vec.py RENAMED
@@ -1,9 +1,34 @@
1
- # model.py
2
- import os
3
- from typing import Optional
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import torch.nn as nn
 
7
  from monai.transforms import (
8
  Compose,
9
  CopyItemsD,
@@ -14,12 +39,12 @@ from monai.transforms import (
14
  ScaleIntensityD,
15
  )
16
  from generative.networks.nets import AutoencoderKL
 
 
17
 
18
- # Constants for your typical config
19
  RESOLUTION = 2
20
  INPUT_SHAPE_AE = (80, 96, 80)
21
 
22
- # Define the exact transform pipeline for input MRI
23
  transforms_fn = Compose([
24
  CopyItemsD(keys={'image_path'}, names=['image']),
25
  LoadImageD(image_only=True, keys=['image']),
@@ -29,15 +54,23 @@ transforms_fn = Compose([
29
  ScaleIntensityD(minv=0, maxv=1, keys=['image']),
30
  ])
31
 
 
32
  def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
33
  """
34
  Preprocess an MRI using MONAI transforms to produce
35
- a 5D tensor (batch=1, channels=1, D, H, W) for inference.
 
 
 
 
 
 
 
36
  """
37
  data_dict = {"image_path": image_path}
38
  output_dict = transforms_fn(data_dict)
39
  image_tensor = output_dict["image"] # shape: (1, D, H, W)
40
- image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
41
  return image_tensor.to(device)
42
 
43
 
@@ -63,11 +96,11 @@ class Brain2vec(AutoencoderKL):
63
  Otherwise, return an uninitialized model.
64
 
65
  Args:
66
- checkpoint_path (Optional[str]): path to a .pth checkpoint
67
  device (str): "cpu", "cuda", "mps", etc.
68
 
69
  Returns:
70
- nn.Module: the loaded Brain2vec model on the chosen device
71
  """
72
  model = Brain2vec(
73
  spatial_dims=3,
@@ -90,5 +123,101 @@ class Brain2vec(AutoencoderKL):
90
  model.load_state_dict(state_dict)
91
 
92
  model.to(device)
93
- model.eval() # ready for inference
94
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ inference_brain2vec.py
5
+
6
+ Loads a pretrained Brain2vec VAE (AutoencoderKL) model and performs inference
7
+ on one or more MRI images, generating reconstructions and latent parameters
8
+ (z_mu, z_sigma).
9
 
10
+ Example usage:
11
+
12
+ # 1) Multiple file paths
13
+ python inference_brain2vec.py \
14
+ --checkpoint_path /path/to/autoencoder_checkpoint.pth \
15
+ --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
16
+ --output_dir ./vae_inference_outputs \
17
+ --device cuda
18
+
19
+ # 2) Use a CSV containing image paths
20
+ python inference_brain2vec.py \
21
+ --checkpoint_path /path/to/autoencoder_checkpoint.pth \
22
+ --csv_input /path/to/images.csv \
23
+ --output_dir ./vae_inference_outputs
24
+ """
25
+
26
+ import os
27
+ import argparse
28
+ import numpy as np
29
  import torch
30
  import torch.nn as nn
31
+ from typing import Optional
32
  from monai.transforms import (
33
  Compose,
34
  CopyItemsD,
 
39
  ScaleIntensityD,
40
  )
41
  from generative.networks.nets import AutoencoderKL
42
+ import pandas as pd
43
+
44
 
 
45
  RESOLUTION = 2
46
  INPUT_SHAPE_AE = (80, 96, 80)
47
 
 
48
  transforms_fn = Compose([
49
  CopyItemsD(keys={'image_path'}, names=['image']),
50
  LoadImageD(image_only=True, keys=['image']),
 
54
  ScaleIntensityD(minv=0, maxv=1, keys=['image']),
55
  ])
56
 
57
+
58
  def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
59
  """
60
  Preprocess an MRI using MONAI transforms to produce
61
+ a 5D tensor (batch=1, channel=1, D, H, W) for inference.
62
+
63
+ Args:
64
+ image_path (str): Path to the MRI (e.g. .nii.gz).
65
+ device (str): Device to place the tensor on.
66
+
67
+ Returns:
68
+ torch.Tensor: Shape (1, 1, D, H, W).
69
  """
70
  data_dict = {"image_path": image_path}
71
  output_dict = transforms_fn(data_dict)
72
  image_tensor = output_dict["image"] # shape: (1, D, H, W)
73
+ image_tensor = image_tensor.unsqueeze(0) # => (1, 1, D, H, W)
74
  return image_tensor.to(device)
75
 
76
 
 
96
  Otherwise, return an uninitialized model.
97
 
98
  Args:
99
+ checkpoint_path (Optional[str]): Path to a .pth checkpoint file.
100
  device (str): "cpu", "cuda", "mps", etc.
101
 
102
  Returns:
103
+ nn.Module: The loaded Brain2vec model on the chosen device.
104
  """
105
  model = Brain2vec(
106
  spatial_dims=3,
 
123
  model.load_state_dict(state_dict)
124
 
125
  model.to(device)
126
+ model.eval()
127
+ return model
128
+
129
+
130
+ def main() -> None:
131
+ """
132
+ Main function to parse command-line arguments and run inference
133
+ with a pretrained Brain2vec model.
134
+ """
135
+ parser = argparse.ArgumentParser(
136
+ description="Inference script for a Brain2vec (VAE) model."
137
+ )
138
+ parser.add_argument(
139
+ "--checkpoint_path", type=str, required=True,
140
+ help="Path to the .pth checkpoint of the pretrained Brain2vec model."
141
+ )
142
+ parser.add_argument(
143
+ "--output_dir", type=str, default="./vae_inference_outputs",
144
+ help="Directory to save reconstructions and latent parameters."
145
+ )
146
+ parser.add_argument(
147
+ "--device", type=str, default="cpu",
148
+ help="Device to run inference on ('cpu', 'cuda', etc.)."
149
+ )
150
+ # Two ways to supply images: multiple file paths or a CSV
151
+ parser.add_argument(
152
+ "--input_images", type=str, nargs="*",
153
+ help="One or more MRI file paths (e.g. .nii.gz)."
154
+ )
155
+ parser.add_argument(
156
+ "--csv_input", type=str,
157
+ help="Path to a CSV file with an 'image_path' column."
158
+ )
159
+ args = parser.parse_args()
160
+
161
+ os.makedirs(args.output_dir, exist_ok=True)
162
+
163
+ # Load the model
164
+ model = Brain2vec.from_pretrained(
165
+ checkpoint_path=args.checkpoint_path,
166
+ device=args.device
167
+ )
168
+
169
+ # Gather image paths
170
+ if args.csv_input:
171
+ df = pd.read_csv(args.csv_input)
172
+ if "image_path" not in df.columns:
173
+ raise ValueError("CSV must contain a column named 'image_path'.")
174
+ image_paths = df["image_path"].tolist()
175
+ else:
176
+ if not args.input_images:
177
+ raise ValueError("Must provide either --csv_input or --input_images.")
178
+ image_paths = args.input_images
179
+
180
+ # Lists for stacking latent parameters later
181
+ all_z_mu = []
182
+ all_z_sigma = []
183
+
184
+ # Inference on each image
185
+ for i, img_path in enumerate(image_paths):
186
+ if not os.path.exists(img_path):
187
+ raise FileNotFoundError(f"Image not found: {img_path}")
188
+
189
+ print(f"[INFO] Processing image {i}: {img_path}")
190
+ img_tensor = preprocess_mri(img_path, device=args.device)
191
+
192
+ with torch.no_grad():
193
+ recon, z_mu, z_sigma = model.forward(img_tensor)
194
+
195
+ # Convert to NumPy
196
+ recon_np = recon.detach().cpu().numpy() # shape: (1, 1, D, H, W)
197
+ z_mu_np = z_mu.detach().cpu().numpy() # shape: (1, latent_channels, ...)
198
+ z_sigma_np = z_sigma.detach().cpu().numpy()
199
+
200
+ # Save each reconstruction (per image) as .npy
201
+ recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
202
+ np.save(recon_path, recon_np)
203
+ print(f"[INFO] Saved reconstruction to {recon_path}")
204
+
205
+ # Store latent parameters for optional combined saving
206
+ all_z_mu.append(z_mu_np)
207
+ all_z_sigma.append(z_sigma_np)
208
+
209
+ # Combine latent parameters from all images and save
210
+ stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
211
+ stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
212
+
213
+ mu_path = os.path.join(args.output_dir, "all_z_mu.npy")
214
+ sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
215
+ np.save(mu_path, stacked_mu)
216
+ np.save(sigma_path, stacked_sigma)
217
+
218
+ print(f"[INFO] Saved z_mu of shape {stacked_mu.shape} to {mu_path}")
219
+ print(f"[INFO] Saved z_sigma of shape {stacked_sigma.shape} to {sigma_path}")
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()
requirements.txt CHANGED
@@ -1,12 +1,15 @@
1
  # requirements.txt
2
 
3
- # PyTorch (CUDA or CPU version). For GPU install, see PyTorch docs for the correct wheel.
4
  torch>=1.12
5
 
6
- # MONAI v1.2+ has the 'generative' subpackage with AutoencoderKL, PatchDiscriminator, etc.
7
- monai-weekly
8
  monai-generative
9
 
 
 
 
 
10
  # For perceptual losses in MONAI's generative module.
11
  lpips
12
 
@@ -17,4 +20,5 @@ nibabel
17
  tqdm
18
  tensorboard
19
  matplotlib
20
- datasets
 
 
1
  # requirements.txt
2
 
3
+ # PyTorch (CUDA or CPU version).
4
  torch>=1.12
5
 
6
+ # Install MONAI Generative first
 
7
  monai-generative
8
 
9
+ # Now force reinstall MONAI Weekly so its (newer) MONAI version takes precedence
10
+ --force-reinstall
11
+ monai-weekly
12
+
13
  # For perceptual losses in MONAI's generative module.
14
  lpips
15
 
 
20
  tqdm
21
  tensorboard
22
  matplotlib
23
+ datasets
24
+ scikit-learn
brain2vec.py → train_brain2vec.py RENAMED
@@ -1,35 +1,20 @@
1
- # MIT License
2
-
3
- # Copyright (c) 2025
4
-
5
- # Permission is hereby granted, free of charge, to any person obtaining a copy
6
- # of this software and associated documentation files (the "Software"), to deal
7
- # in the Software without restriction, including without limitation the rights
8
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
-
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
-
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
-
23
- # Forked from: https://github.com/LemuelPuglisi/BrLP
24
-
25
- # @inproceedings{puglisi2024enhancing,
26
- # title={Enhancing spatiotemporal disease progression models via latent diffusion and prior knowledge},
27
- # author={Puglisi, Lemuel and Alexander, Daniel C and Rav{\`\i}, Daniele},
28
- # booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
29
- # pages={173--183},
30
- # year={2024},
31
- # organization={Springer}
32
- # }
33
 
34
  import os
35
  os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
@@ -37,7 +22,6 @@ from typing import Optional, Union
37
  import pandas as pd
38
  import argparse
39
  import numpy as np
40
-
41
  import warnings
42
  import torch
43
  import torch.nn as nn
@@ -47,7 +31,6 @@ from torch.nn import L1Loss
47
  from torch.utils.data import DataLoader
48
  from torch.amp import autocast
49
  from torch.amp import GradScaler
50
-
51
  from generative.networks.nets import (
52
  AutoencoderKL,
53
  PatchDiscriminator,
@@ -65,13 +48,11 @@ torch.serialization.add_safe_globals([_reconstruct])
65
  torch.serialization.add_safe_globals([MetaTensor])
66
  torch.serialization.add_safe_globals([ndarray])
67
  torch.serialization.add_safe_globals([dtype])
68
-
69
  from tqdm import tqdm
70
  import matplotlib.pyplot as plt
71
-
72
  from torch.utils.tensorboard import SummaryWriter
73
 
74
- # choosen resolution
75
  RESOLUTION = 2
76
 
77
  # shape of the MNI152 (1mm^3) template
@@ -101,10 +82,7 @@ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
101
  """
102
  if checkpoints_path is not None:
103
  assert os.path.exists(checkpoints_path), 'Invalid path'
104
- # Using context manager to allow MetaTensor
105
- #with torch.serialization.safe_globals([MetaTensor]):
106
  network.load_state_dict(torch.load(checkpoints_path))
107
- #network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
108
  return network
109
 
110
 
@@ -140,7 +118,7 @@ def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Modul
140
  checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
141
 
142
  Returns:
143
- nn.Module: the parch discriminator
144
  """
145
  patch_discriminator = PatchDiscriminator(spatial_dims=3,
146
  num_layers_d=3,
@@ -387,22 +365,6 @@ def train(
387
  train_df = dataset_df[dataset_df.split == 'train']
388
  trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
389
 
390
- print(f"[DEBUG] Using cache_dir={cache_dir}")
391
- print(f"[DEBUG] trainset length={len(trainset)}")
392
-
393
- try:
394
- sample_debug = trainset[0] # Force a transform on the first record
395
- print("[DEBUG] Successfully loaded sample 0 from trainset.")
396
- except Exception as e:
397
- print("[DEBUG] Error loading sample 0:", e)
398
-
399
- import glob
400
-
401
- hashfiles = glob.glob(os.path.join(cache_dir, "*.pt"))
402
- print(f"[DEBUG] Found {len(hashfiles)} cached .pt files in {cache_dir}")
403
- if hashfiles:
404
- print("[DEBUG] Example cache file:", hashfiles[0])
405
-
406
  train_loader = DataLoader(
407
  dataset=trainset,
408
  num_workers=num_workers,
@@ -523,60 +485,11 @@ def train(
523
  print("Training completed and models saved.")
524
 
525
 
526
- def inference(
527
- dataset_csv: str,
528
- aekl_ckpt: str,
529
- output_dir: str,
530
- device: str = ('cuda' if torch.cuda.is_available() else
531
- 'cpu'),
532
- ) -> None:
533
- """
534
- Perform inference to encode images into latent space.
535
-
536
- Args:
537
- dataset_csv (str): Path to the dataset CSV file.
538
- aekl_ckpt (str): Path to the autoencoder checkpoint.
539
- output_dir (str): Directory to save latent representations.
540
- device (str, optional): Device to run the inference on. Defaults to 'cuda' if available.
541
- """
542
- DEVICE = device
543
-
544
- autoencoder = init_autoencoder(aekl_ckpt).to(DEVICE).eval()
545
-
546
- transforms_fn = transforms.Compose([
547
- transforms.CopyItemsD(keys={'image_path'}, names=['image']),
548
- transforms.LoadImageD(image_only=True, keys=['image']),
549
- transforms.EnsureChannelFirstD(keys=['image']),
550
- transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
551
- transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
552
- transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image'])
553
- ])
554
-
555
- df = pd.read_csv(dataset_csv)
556
-
557
- os.makedirs(output_dir, exist_ok=True)
558
-
559
- with torch.no_grad():
560
- for image_path in tqdm(df.image_path, total=len(df)):
561
- destpath = os.path.join(
562
- output_dir,
563
- os.path.basename(image_path).replace('.nii.gz', '_embeddings.npz').replace('.nii', '_embeddings.npz')
564
- )
565
- if os.path.exists(destpath):
566
- continue
567
- mri_tensor = transforms_fn({'image_path': image_path})['image'].to(DEVICE)
568
- mri_latent, _ = autoencoder.encode(mri_tensor.unsqueeze(0))
569
- mri_latent = mri_latent.cpu().squeeze(0).numpy()
570
- np.savez_compressed(destpath, data=mri_latent)
571
-
572
- print("Inference completed and latent representations saved.")
573
-
574
-
575
  def main():
576
  """
577
- Main function to parse command-line arguments and execute training or inference.
578
  """
579
- parser = argparse.ArgumentParser(description="brain2vec Training and Inference Script")
580
 
581
  subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
582
 
@@ -594,12 +507,6 @@ def main():
594
  train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
595
  train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
596
 
597
- # Inference Subparser
598
- infer_parser = subparsers.add_parser('inference', help='Run inference to encode images.')
599
- infer_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
600
- infer_parser.add_argument('--aekl_ckpt', type=str, required=True, help='Path to the autoencoder checkpoint.')
601
- infer_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save latent representations.')
602
-
603
  args = parser.parse_args()
604
 
605
  if args.command == 'train':
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ train_brain2vec.py
5
+
6
+ Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements
7
+ autoencoder training with adversarial loss (via a patch discriminator),
8
+ a perceptual loss, and KL divergence regularization for robust latent
9
+ representations.
10
+
11
+ Example usage:
12
+ python train_brain2vec.py train \
13
+ --dataset_csv /path/to/dataset.csv \
14
+ --cache_dir /path/to/cache \
15
+ --output_dir /path/to/output_dir \
16
+ --n_epochs 10
17
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  import os
20
  os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
 
22
  import pandas as pd
23
  import argparse
24
  import numpy as np
 
25
  import warnings
26
  import torch
27
  import torch.nn as nn
 
31
  from torch.utils.data import DataLoader
32
  from torch.amp import autocast
33
  from torch.amp import GradScaler
 
34
  from generative.networks.nets import (
35
  AutoencoderKL,
36
  PatchDiscriminator,
 
48
  torch.serialization.add_safe_globals([MetaTensor])
49
  torch.serialization.add_safe_globals([ndarray])
50
  torch.serialization.add_safe_globals([dtype])
 
51
  from tqdm import tqdm
52
  import matplotlib.pyplot as plt
 
53
  from torch.utils.tensorboard import SummaryWriter
54
 
55
+ # voxel resolution
56
  RESOLUTION = 2
57
 
58
  # shape of the MNI152 (1mm^3) template
 
82
  """
83
  if checkpoints_path is not None:
84
  assert os.path.exists(checkpoints_path), 'Invalid path'
 
 
85
  network.load_state_dict(torch.load(checkpoints_path))
 
86
  return network
87
 
88
 
 
118
  checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None.
119
 
120
  Returns:
121
+ nn.Module: the patch discriminator
122
  """
123
  patch_discriminator = PatchDiscriminator(spatial_dims=3,
124
  num_layers_d=3,
 
365
  train_df = dataset_df[dataset_df.split == 'train']
366
  trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  train_loader = DataLoader(
369
  dataset=trainset,
370
  num_workers=num_workers,
 
485
  print("Training completed and models saved.")
486
 
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  def main():
489
  """
490
+ Main function to parse command-line arguments and execute training.
491
  """
492
+ parser = argparse.ArgumentParser(description="brain2vec Training Script")
493
 
494
  subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
495
 
 
507
  train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
508
  train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
509
 
 
 
 
 
 
 
510
  args = parser.parse_args()
511
 
512
  if args.command == 'train':