English
medical
brain-data
mri
jesseab commited on
Commit
bef8312
·
1 Parent(s): 46a0841

Code changes

Browse files
Files changed (3) hide show
  1. README.md +45 -44
  2. inference_brain2vec.py +6 -7
  3. train_brain2vec.py +31 -42
README.md CHANGED
@@ -13,23 +13,29 @@ pretty_name: 3D Brain Structure MRI Autoencoder
13
 
14
  ## 🧠 Model Summary
15
  # brain2vec
16
- An autoencoder model for brain structure T1 MRIs based on [Brain Latent Progression](https://github.com/LemuelPuglisi/BrLP/tree/main). The autoencoder takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image. The loss functions for training the autoencoder are:
17
  - [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
18
  - [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
19
  - [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
20
  - [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
21
 
22
 
23
-
24
  # Training data
25
  [Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
26
 
 
27
  # Example usage
28
  ```
29
  # get brain2vec model repository
30
  git clone https://huggingface.co/radiata-ai/brain2vec
31
  cd brain2vec
32
 
 
 
 
 
 
 
33
  # set up virtual environemt
34
  python3 -m venv venv_brain2vec
35
  source venv_brain2vec/bin/activate
@@ -38,54 +44,54 @@ source venv_brain2vec/bin/activate
38
  pip install -r requirements.txt
39
 
40
  # create the csv file inputs.csv listing the scan paths and other info
41
- # this script loads the radiata-ai/brain-structure dataset
42
  python create_csv.py
43
 
44
  mkdir ae_cache
45
  mkdir ae_output
46
 
47
- # install git lfs to pull large model weights
48
- sudo apt-get update
49
- sudo apt install git-lfs
50
- git lfs install
51
- git lfs pull
52
-
53
  # train the model
54
- nohup python brain2vec.py train \
55
- --dataset_csv /home/ubuntu/brain2vec/inputs.csv \
56
  --cache_dir ./ae_cache \
57
  --output_dir ./ae_output \
58
  --n_epochs 10 \
59
  > train_log.txt 2>&1 &
60
 
61
  # model inference
 
 
 
 
 
 
 
 
62
  python inference_brain2vec.py \
63
  --checkpoint_path /path/to/model.pth \
64
  --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
65
- --output_dir ./vae_inference_outputs \
66
- --embeddings_filename pca_output/pca_embeddings_2.npy \
67
- --save_recons
68
  ```
69
 
70
  # Methods
71
  Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
72
 
73
- The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1]. Images are flattened to 614400-length 1D vectors.
 
 
74
 
75
- 10 epochs
76
- max_batch_size: int = 2,
77
- batch_size: int = 16,
78
- lr: float = 1e-4,
79
 
80
  # References
81
- Puglisi
82
- Pinaya
 
83
 
84
  # Citation
85
  ```
86
- @misc{Radiata-Brain2Vec,
87
  author = {Jesse Brown and Clayton Young},
88
- title = {brain2vec_PCA: A VAE Model for Brain Structure T1 MRIs},
89
  year = {2025},
90
  url = {https://huggingface.co/radiata-ai/brain2vec},
91
  note = {Version 1.0},
@@ -93,25 +99,20 @@ Pinaya
93
  }
94
  ```
95
 
 
96
  # License
97
- MIT License
98
-
99
- Copyright (c) 2025
100
-
101
- Permission is hereby granted, free of charge, to any person obtaining a copy
102
- of this software and associated documentation files (the "Software"), to deal
103
- in the Software without restriction, including without limitation the rights
104
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
105
- copies of the Software, and to permit persons to whom the Software is
106
- furnished to do so, subject to the following conditions:
107
-
108
- The above copyright notice and this permission notice shall be included in all
109
- copies or substantial portions of the Software.
110
-
111
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
112
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
113
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
114
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
115
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
116
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
117
- SOFTWARE.
 
13
 
14
  ## 🧠 Model Summary
15
  # brain2vec
16
+ An autoencoder model for brain structure T1 MRIs (forked from [Brain Latent Progression](https://github.com/LemuelPuglisi/BrLP/tree/main)). The autoencoder takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image. The loss functions for training the autoencoder are:
17
  - [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
18
  - [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
19
  - [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
20
  - [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
21
 
22
 
 
23
  # Training data
24
  [Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
25
 
26
+
27
  # Example usage
28
  ```
29
  # get brain2vec model repository
30
  git clone https://huggingface.co/radiata-ai/brain2vec
31
  cd brain2vec
32
 
33
+ # pull pre-trained model weights
34
+ sudo apt-get update
35
+ sudo apt install git-lfs
36
+ git lfs install
37
+ git lfs pull
38
+
39
  # set up virtual environemt
40
  python3 -m venv venv_brain2vec
41
  source venv_brain2vec/bin/activate
 
44
  pip install -r requirements.txt
45
 
46
  # create the csv file inputs.csv listing the scan paths and other info
47
+ # this script loads the radiata-ai/brain-structure dataset from Hugging Face
48
  python create_csv.py
49
 
50
  mkdir ae_cache
51
  mkdir ae_output
52
 
 
 
 
 
 
 
53
  # train the model
54
+ nohup python train_brain2vec.py \
55
+ --dataset_csv inputs.csv \
56
  --cache_dir ./ae_cache \
57
  --output_dir ./ae_output \
58
  --n_epochs 10 \
59
  > train_log.txt 2>&1 &
60
 
61
  # model inference
62
+ # for a set of scans in inputs.csv
63
+ python inference_brain2vec.py \
64
+ --checkpoint_path /path/to/model.pth \
65
+ --csv_input inputs.csv \
66
+ --output_dir ./ae_output \
67
+ --embeddings_filename ae_embeddings_all.npy
68
+
69
+ # or for individual scans
70
  python inference_brain2vec.py \
71
  --checkpoint_path /path/to/model.pth \
72
  --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
73
+ --output_dir ./ae_output \
74
+ --embeddings_filename ae_embeddings_2.npy
 
75
  ```
76
 
77
  # Methods
78
  Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
79
 
80
+ The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
81
+
82
+ The model was trained with an effective batch size=16, 10 epochs, learning rate=1e-4 (see references 1 and 2).
83
 
 
 
 
 
84
 
85
  # References
86
+ 1. Puglisi L, Alexander DC, Ravì D. Enhancing Spatiotemporal Disease Progression Models via Latent Diffusion and Prior Knowledge [Internet]. arXiv; 2024. Available from: http://arxiv.org/abs/2405.03328
87
+ 2. Pinaya WHL, Tudosiu PD, Dafflon J, Costa PF da, Fernandez V, Nachev P, et al. Brain Imaging Generation with Latent Diffusion Models [Internet]. arXiv; 2022. Available from: http://arxiv.org/abs/2209.07162
88
+
89
 
90
  # Citation
91
  ```
92
+ @misc{Radiata-Brain2vec,
93
  author = {Jesse Brown and Clayton Young},
94
+ title = {Brain2vec: An Autoencoder Model for Brain Structure T1 MRIs},
95
  year = {2025},
96
  url = {https://huggingface.co/radiata-ai/brain2vec},
97
  note = {Version 1.0},
 
99
  }
100
  ```
101
 
102
+
103
  # License
104
+ ### Apache License 2.0
105
+
106
+ Copyright 2025 Jesse Brown
107
+
108
+ Licensed under the Apache License, Version 2.0 (the "License");
109
+ you may not use this file except in compliance with the License.
110
+ You may obtain a copy of the License at:
111
+
112
+ [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
113
+
114
+ Unless required by applicable law or agreed to in writing, software
115
+ distributed under the License is distributed on an "AS IS" BASIS,
116
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
117
+ See the License for the specific language governing permissions and
118
+ limitations under the License.
 
 
 
 
 
 
inference_brain2vec.py CHANGED
@@ -143,10 +143,6 @@ def main() -> None:
143
  "--output_dir", type=str, default="./vae_inference_outputs",
144
  help="Directory to save reconstructions and latent parameters."
145
  )
146
- parser.add_argument(
147
- "--device", type=str, default="cpu",
148
- help="Device to run inference on ('cpu', 'cuda', etc.)."
149
- )
150
  # Two ways to supply images: multiple file paths or a CSV
151
  parser.add_argument(
152
  "--input_images", type=str, nargs="*",
@@ -172,10 +168,13 @@ def main() -> None:
172
 
173
  os.makedirs(args.output_dir, exist_ok=True)
174
 
175
- # Load the model
 
 
 
176
  model = Brain2vec.from_pretrained(
177
  checkpoint_path=args.checkpoint_path,
178
- device=args.device
179
  )
180
 
181
  # Gather image paths
@@ -199,7 +198,7 @@ def main() -> None:
199
  raise FileNotFoundError(f"Image not found: {img_path}")
200
 
201
  print(f"[INFO] Processing image {i}: {img_path}")
202
- img_tensor = preprocess_mri(img_path, device=args.device)
203
 
204
  with torch.no_grad():
205
  recon, z_mu, z_sigma = model.forward(img_tensor)
 
143
  "--output_dir", type=str, default="./vae_inference_outputs",
144
  help="Directory to save reconstructions and latent parameters."
145
  )
 
 
 
 
146
  # Two ways to supply images: multiple file paths or a CSV
147
  parser.add_argument(
148
  "--input_images", type=str, nargs="*",
 
168
 
169
  os.makedirs(args.output_dir, exist_ok=True)
170
 
171
+ # After parsing args, add:
172
+ device = "cuda" if torch.cuda.is_available() else "cpu"
173
+
174
+ # Then pass that device to the model:
175
  model = Brain2vec.from_pretrained(
176
  checkpoint_path=args.checkpoint_path,
177
+ device=device
178
  )
179
 
180
  # Gather image paths
 
198
  raise FileNotFoundError(f"Image not found: {img_path}")
199
 
200
  print(f"[INFO] Processing image {i}: {img_path}")
201
+ img_tensor = preprocess_mri(img_path, device=device)
202
 
203
  with torch.no_grad():
204
  recon, z_mu, z_sigma = model.forward(img_tensor)
train_brain2vec.py CHANGED
@@ -9,10 +9,10 @@ a perceptual loss, and KL divergence regularization for robust latent
9
  representations.
10
 
11
  Example usage:
12
- python train_brain2vec.py train \
13
- --dataset_csv /path/to/dataset.csv \
14
- --cache_dir /path/to/cache \
15
- --output_dir /path/to/output_dir \
16
  --n_epochs 10
17
  """
18
 
@@ -487,50 +487,39 @@ def train(
487
 
488
  def main():
489
  """
490
- Main function to parse command-line arguments and execute training.
491
  """
 
 
492
  parser = argparse.ArgumentParser(description="brain2vec Training Script")
493
 
494
- subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-commands: train or infer')
495
-
496
- # Training Subparser
497
- train_parser = subparsers.add_parser('train', help='Train the models.')
498
- train_parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
499
- train_parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
500
- train_parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
501
- train_parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
502
- train_parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
503
- train_parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
504
- train_parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
505
- train_parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
506
- train_parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
507
- train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
508
- train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
509
 
510
  args = parser.parse_args()
511
 
512
- if args.command == 'train':
513
- train(
514
- dataset_csv=args.dataset_csv,
515
- cache_dir=args.cache_dir,
516
- output_dir=args.output_dir,
517
- aekl_ckpt=args.aekl_ckpt,
518
- disc_ckpt=args.disc_ckpt,
519
- num_workers=args.num_workers,
520
- n_epochs=args.n_epochs,
521
- max_batch_size=args.max_batch_size,
522
- batch_size=args.batch_size,
523
- lr=args.lr,
524
- aug_p=args.aug_p,
525
- )
526
- elif args.command == 'infer':
527
- inference(
528
- dataset_csv=args.dataset_csv,
529
- aekl_ckpt=args.aekl_ckpt,
530
- output_dir=args.output_dir,
531
- )
532
- else:
533
- parser.print_help()
534
 
535
 
536
  if __name__ == '__main__':
 
9
  representations.
10
 
11
  Example usage:
12
+ python train_brain2vec.py \
13
+ --dataset_csv inputs.csv \
14
+ --cache_dir ./ae_cache \
15
+ --output_dir ./ae_output \
16
  --n_epochs 10
17
  """
18
 
 
487
 
488
  def main():
489
  """
490
+ Main function to parse command-line arguments and run train().
491
  """
492
+ import argparse
493
+
494
  parser = argparse.ArgumentParser(description="brain2vec Training Script")
495
 
496
+ parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
497
+ parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
498
+ parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
499
+ parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
500
+ parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
501
+ parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
502
+ parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
503
+ parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
504
+ parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
505
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
506
+ parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
 
 
 
 
507
 
508
  args = parser.parse_args()
509
 
510
+ train(
511
+ dataset_csv=args.dataset_csv,
512
+ cache_dir=args.cache_dir,
513
+ output_dir=args.output_dir,
514
+ aekl_ckpt=args.aekl_ckpt,
515
+ disc_ckpt=args.disc_ckpt,
516
+ num_workers=args.num_workers,
517
+ n_epochs=args.n_epochs,
518
+ max_batch_size=args.max_batch_size,
519
+ batch_size=args.batch_size,
520
+ lr=args.lr,
521
+ aug_p=args.aug_p,
522
+ )
 
 
 
 
 
 
 
 
 
523
 
524
 
525
  if __name__ == '__main__':