Code changes
Browse files- README.md +45 -44
- inference_brain2vec.py +6 -7
- train_brain2vec.py +31 -42
README.md
CHANGED
@@ -13,23 +13,29 @@ pretty_name: 3D Brain Structure MRI Autoencoder
|
|
13 |
|
14 |
## 🧠 Model Summary
|
15 |
# brain2vec
|
16 |
-
An autoencoder model for brain structure T1 MRIs
|
17 |
- [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
|
18 |
- [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
|
19 |
- [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
|
20 |
- [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
|
21 |
|
22 |
|
23 |
-
|
24 |
# Training data
|
25 |
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
26 |
|
|
|
27 |
# Example usage
|
28 |
```
|
29 |
# get brain2vec model repository
|
30 |
git clone https://huggingface.co/radiata-ai/brain2vec
|
31 |
cd brain2vec
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
# set up virtual environemt
|
34 |
python3 -m venv venv_brain2vec
|
35 |
source venv_brain2vec/bin/activate
|
@@ -38,54 +44,54 @@ source venv_brain2vec/bin/activate
|
|
38 |
pip install -r requirements.txt
|
39 |
|
40 |
# create the csv file inputs.csv listing the scan paths and other info
|
41 |
-
# this script loads the radiata-ai/brain-structure dataset
|
42 |
python create_csv.py
|
43 |
|
44 |
mkdir ae_cache
|
45 |
mkdir ae_output
|
46 |
|
47 |
-
# install git lfs to pull large model weights
|
48 |
-
sudo apt-get update
|
49 |
-
sudo apt install git-lfs
|
50 |
-
git lfs install
|
51 |
-
git lfs pull
|
52 |
-
|
53 |
# train the model
|
54 |
-
nohup python
|
55 |
-
--dataset_csv
|
56 |
--cache_dir ./ae_cache \
|
57 |
--output_dir ./ae_output \
|
58 |
--n_epochs 10 \
|
59 |
> train_log.txt 2>&1 &
|
60 |
|
61 |
# model inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
python inference_brain2vec.py \
|
63 |
--checkpoint_path /path/to/model.pth \
|
64 |
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
65 |
-
--output_dir ./
|
66 |
-
--embeddings_filename
|
67 |
-
--save_recons
|
68 |
```
|
69 |
|
70 |
# Methods
|
71 |
Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
|
72 |
|
73 |
-
The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
|
|
|
|
|
74 |
|
75 |
-
10 epochs
|
76 |
-
max_batch_size: int = 2,
|
77 |
-
batch_size: int = 16,
|
78 |
-
lr: float = 1e-4,
|
79 |
|
80 |
# References
|
81 |
-
Puglisi
|
82 |
-
Pinaya
|
|
|
83 |
|
84 |
# Citation
|
85 |
```
|
86 |
-
@misc{Radiata-
|
87 |
author = {Jesse Brown and Clayton Young},
|
88 |
-
title = {
|
89 |
year = {2025},
|
90 |
url = {https://huggingface.co/radiata-ai/brain2vec},
|
91 |
note = {Version 1.0},
|
@@ -93,25 +99,20 @@ Pinaya
|
|
93 |
}
|
94 |
```
|
95 |
|
|
|
96 |
# License
|
97 |
-
|
98 |
-
|
99 |
-
Copyright
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
113 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
114 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
115 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
116 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
117 |
-
SOFTWARE.
|
|
|
13 |
|
14 |
## 🧠 Model Summary
|
15 |
# brain2vec
|
16 |
+
An autoencoder model for brain structure T1 MRIs (forked from [Brain Latent Progression](https://github.com/LemuelPuglisi/BrLP/tree/main)). The autoencoder takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image. The loss functions for training the autoencoder are:
|
17 |
- [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
|
18 |
- [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
|
19 |
- [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
|
20 |
- [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
|
21 |
|
22 |
|
|
|
23 |
# Training data
|
24 |
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
25 |
|
26 |
+
|
27 |
# Example usage
|
28 |
```
|
29 |
# get brain2vec model repository
|
30 |
git clone https://huggingface.co/radiata-ai/brain2vec
|
31 |
cd brain2vec
|
32 |
|
33 |
+
# pull pre-trained model weights
|
34 |
+
sudo apt-get update
|
35 |
+
sudo apt install git-lfs
|
36 |
+
git lfs install
|
37 |
+
git lfs pull
|
38 |
+
|
39 |
# set up virtual environemt
|
40 |
python3 -m venv venv_brain2vec
|
41 |
source venv_brain2vec/bin/activate
|
|
|
44 |
pip install -r requirements.txt
|
45 |
|
46 |
# create the csv file inputs.csv listing the scan paths and other info
|
47 |
+
# this script loads the radiata-ai/brain-structure dataset from Hugging Face
|
48 |
python create_csv.py
|
49 |
|
50 |
mkdir ae_cache
|
51 |
mkdir ae_output
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
# train the model
|
54 |
+
nohup python train_brain2vec.py \
|
55 |
+
--dataset_csv inputs.csv \
|
56 |
--cache_dir ./ae_cache \
|
57 |
--output_dir ./ae_output \
|
58 |
--n_epochs 10 \
|
59 |
> train_log.txt 2>&1 &
|
60 |
|
61 |
# model inference
|
62 |
+
# for a set of scans in inputs.csv
|
63 |
+
python inference_brain2vec.py \
|
64 |
+
--checkpoint_path /path/to/model.pth \
|
65 |
+
--csv_input inputs.csv \
|
66 |
+
--output_dir ./ae_output \
|
67 |
+
--embeddings_filename ae_embeddings_all.npy
|
68 |
+
|
69 |
+
# or for individual scans
|
70 |
python inference_brain2vec.py \
|
71 |
--checkpoint_path /path/to/model.pth \
|
72 |
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
73 |
+
--output_dir ./ae_output \
|
74 |
+
--embeddings_filename ae_embeddings_2.npy
|
|
|
75 |
```
|
76 |
|
77 |
# Methods
|
78 |
Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
|
79 |
|
80 |
+
The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
|
81 |
+
|
82 |
+
The model was trained with an effective batch size=16, 10 epochs, learning rate=1e-4 (see references 1 and 2).
|
83 |
|
|
|
|
|
|
|
|
|
84 |
|
85 |
# References
|
86 |
+
1. Puglisi L, Alexander DC, Ravì D. Enhancing Spatiotemporal Disease Progression Models via Latent Diffusion and Prior Knowledge [Internet]. arXiv; 2024. Available from: http://arxiv.org/abs/2405.03328
|
87 |
+
2. Pinaya WHL, Tudosiu PD, Dafflon J, Costa PF da, Fernandez V, Nachev P, et al. Brain Imaging Generation with Latent Diffusion Models [Internet]. arXiv; 2022. Available from: http://arxiv.org/abs/2209.07162
|
88 |
+
|
89 |
|
90 |
# Citation
|
91 |
```
|
92 |
+
@misc{Radiata-Brain2vec,
|
93 |
author = {Jesse Brown and Clayton Young},
|
94 |
+
title = {Brain2vec: An Autoencoder Model for Brain Structure T1 MRIs},
|
95 |
year = {2025},
|
96 |
url = {https://huggingface.co/radiata-ai/brain2vec},
|
97 |
note = {Version 1.0},
|
|
|
99 |
}
|
100 |
```
|
101 |
|
102 |
+
|
103 |
# License
|
104 |
+
### Apache License 2.0
|
105 |
+
|
106 |
+
Copyright 2025 Jesse Brown
|
107 |
+
|
108 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
109 |
+
you may not use this file except in compliance with the License.
|
110 |
+
You may obtain a copy of the License at:
|
111 |
+
|
112 |
+
[http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
|
113 |
+
|
114 |
+
Unless required by applicable law or agreed to in writing, software
|
115 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
116 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
117 |
+
See the License for the specific language governing permissions and
|
118 |
+
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
inference_brain2vec.py
CHANGED
@@ -143,10 +143,6 @@ def main() -> None:
|
|
143 |
"--output_dir", type=str, default="./vae_inference_outputs",
|
144 |
help="Directory to save reconstructions and latent parameters."
|
145 |
)
|
146 |
-
parser.add_argument(
|
147 |
-
"--device", type=str, default="cpu",
|
148 |
-
help="Device to run inference on ('cpu', 'cuda', etc.)."
|
149 |
-
)
|
150 |
# Two ways to supply images: multiple file paths or a CSV
|
151 |
parser.add_argument(
|
152 |
"--input_images", type=str, nargs="*",
|
@@ -172,10 +168,13 @@ def main() -> None:
|
|
172 |
|
173 |
os.makedirs(args.output_dir, exist_ok=True)
|
174 |
|
175 |
-
#
|
|
|
|
|
|
|
176 |
model = Brain2vec.from_pretrained(
|
177 |
checkpoint_path=args.checkpoint_path,
|
178 |
-
device=
|
179 |
)
|
180 |
|
181 |
# Gather image paths
|
@@ -199,7 +198,7 @@ def main() -> None:
|
|
199 |
raise FileNotFoundError(f"Image not found: {img_path}")
|
200 |
|
201 |
print(f"[INFO] Processing image {i}: {img_path}")
|
202 |
-
img_tensor = preprocess_mri(img_path, device=
|
203 |
|
204 |
with torch.no_grad():
|
205 |
recon, z_mu, z_sigma = model.forward(img_tensor)
|
|
|
143 |
"--output_dir", type=str, default="./vae_inference_outputs",
|
144 |
help="Directory to save reconstructions and latent parameters."
|
145 |
)
|
|
|
|
|
|
|
|
|
146 |
# Two ways to supply images: multiple file paths or a CSV
|
147 |
parser.add_argument(
|
148 |
"--input_images", type=str, nargs="*",
|
|
|
168 |
|
169 |
os.makedirs(args.output_dir, exist_ok=True)
|
170 |
|
171 |
+
# After parsing args, add:
|
172 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
173 |
+
|
174 |
+
# Then pass that device to the model:
|
175 |
model = Brain2vec.from_pretrained(
|
176 |
checkpoint_path=args.checkpoint_path,
|
177 |
+
device=device
|
178 |
)
|
179 |
|
180 |
# Gather image paths
|
|
|
198 |
raise FileNotFoundError(f"Image not found: {img_path}")
|
199 |
|
200 |
print(f"[INFO] Processing image {i}: {img_path}")
|
201 |
+
img_tensor = preprocess_mri(img_path, device=device)
|
202 |
|
203 |
with torch.no_grad():
|
204 |
recon, z_mu, z_sigma = model.forward(img_tensor)
|
train_brain2vec.py
CHANGED
@@ -9,10 +9,10 @@ a perceptual loss, and KL divergence regularization for robust latent
|
|
9 |
representations.
|
10 |
|
11 |
Example usage:
|
12 |
-
python train_brain2vec.py
|
13 |
-
--dataset_csv
|
14 |
-
--cache_dir
|
15 |
-
--output_dir
|
16 |
--n_epochs 10
|
17 |
"""
|
18 |
|
@@ -487,50 +487,39 @@ def train(
|
|
487 |
|
488 |
def main():
|
489 |
"""
|
490 |
-
Main function to parse command-line arguments and
|
491 |
"""
|
|
|
|
|
492 |
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
493 |
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
train_parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
|
506 |
-
train_parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
|
507 |
-
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
508 |
-
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
509 |
|
510 |
args = parser.parse_args()
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
)
|
526 |
-
elif args.command == 'infer':
|
527 |
-
inference(
|
528 |
-
dataset_csv=args.dataset_csv,
|
529 |
-
aekl_ckpt=args.aekl_ckpt,
|
530 |
-
output_dir=args.output_dir,
|
531 |
-
)
|
532 |
-
else:
|
533 |
-
parser.print_help()
|
534 |
|
535 |
|
536 |
if __name__ == '__main__':
|
|
|
9 |
representations.
|
10 |
|
11 |
Example usage:
|
12 |
+
python train_brain2vec.py \
|
13 |
+
--dataset_csv inputs.csv \
|
14 |
+
--cache_dir ./ae_cache \
|
15 |
+
--output_dir ./ae_output \
|
16 |
--n_epochs 10
|
17 |
"""
|
18 |
|
|
|
487 |
|
488 |
def main():
|
489 |
"""
|
490 |
+
Main function to parse command-line arguments and run train().
|
491 |
"""
|
492 |
+
import argparse
|
493 |
+
|
494 |
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
495 |
|
496 |
+
parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
497 |
+
parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
|
498 |
+
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
|
499 |
+
parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
|
500 |
+
parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
|
501 |
+
parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
|
502 |
+
parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
|
503 |
+
parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
|
504 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
|
505 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
506 |
+
parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
|
|
|
|
|
|
|
|
507 |
|
508 |
args = parser.parse_args()
|
509 |
|
510 |
+
train(
|
511 |
+
dataset_csv=args.dataset_csv,
|
512 |
+
cache_dir=args.cache_dir,
|
513 |
+
output_dir=args.output_dir,
|
514 |
+
aekl_ckpt=args.aekl_ckpt,
|
515 |
+
disc_ckpt=args.disc_ckpt,
|
516 |
+
num_workers=args.num_workers,
|
517 |
+
n_epochs=args.n_epochs,
|
518 |
+
max_batch_size=args.max_batch_size,
|
519 |
+
batch_size=args.batch_size,
|
520 |
+
lr=args.lr,
|
521 |
+
aug_p=args.aug_p,
|
522 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
|
525 |
if __name__ == '__main__':
|