Spaces:
Sleeping
Sleeping
西牧慧
commited on
Commit
·
dcbe128
1
Parent(s):
fe59750
update: parcellation
Browse files- README.md +1 -1
- requirements.txt +6 -0
- src/parcellation.py +128 -80
- src/utils/cropping.py +2 -2
- src/utils/functions.py +3 -2
- src/utils/load_model.py +15 -21
- src/utils/parcellation.py +10 -10
- src/utils/stripping.py +2 -2
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: OpenMAP
|
3 |
emoji: 📚
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: OpenMAP-T1
|
3 |
emoji: 📚
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
requirements.txt
CHANGED
@@ -4,9 +4,12 @@ anyio==4.9.0
|
|
4 |
certifi==2025.1.31
|
5 |
charset-normalizer==3.4.1
|
6 |
click==8.1.8
|
|
|
|
|
7 |
fastapi==0.115.12
|
8 |
ffmpy==0.5.0
|
9 |
filelock==3.18.0
|
|
|
10 |
fsspec==2025.3.2
|
11 |
gradio==5.25.0
|
12 |
gradio_client==1.8.0
|
@@ -17,8 +20,10 @@ httpx==0.28.1
|
|
17 |
huggingface-hub==0.30.2
|
18 |
idna==3.10
|
19 |
Jinja2==3.1.6
|
|
|
20 |
markdown-it-py==3.0.0
|
21 |
MarkupSafe==3.0.2
|
|
|
22 |
mdurl==0.1.2
|
23 |
mpmath==1.3.0
|
24 |
networkx==3.4.2
|
@@ -32,6 +37,7 @@ pydantic==2.11.3
|
|
32 |
pydantic_core==2.33.1
|
33 |
pydub==0.25.1
|
34 |
Pygments==2.19.1
|
|
|
35 |
python-dateutil==2.9.0.post0
|
36 |
python-multipart==0.0.20
|
37 |
pytz==2025.2
|
|
|
4 |
certifi==2025.1.31
|
5 |
charset-normalizer==3.4.1
|
6 |
click==8.1.8
|
7 |
+
contourpy==1.3.1
|
8 |
+
cycler==0.12.1
|
9 |
fastapi==0.115.12
|
10 |
ffmpy==0.5.0
|
11 |
filelock==3.18.0
|
12 |
+
fonttools==4.57.0
|
13 |
fsspec==2025.3.2
|
14 |
gradio==5.25.0
|
15 |
gradio_client==1.8.0
|
|
|
20 |
huggingface-hub==0.30.2
|
21 |
idna==3.10
|
22 |
Jinja2==3.1.6
|
23 |
+
kiwisolver==1.4.8
|
24 |
markdown-it-py==3.0.0
|
25 |
MarkupSafe==3.0.2
|
26 |
+
matplotlib==3.10.1
|
27 |
mdurl==0.1.2
|
28 |
mpmath==1.3.0
|
29 |
networkx==3.4.2
|
|
|
37 |
pydantic_core==2.33.1
|
38 |
pydub==0.25.1
|
39 |
Pygments==2.19.1
|
40 |
+
pyparsing==3.2.3
|
41 |
python-dateutil==2.9.0.post0
|
42 |
python-multipart==0.0.20
|
43 |
pytz==2025.2
|
src/parcellation.py
CHANGED
@@ -1,18 +1,22 @@
|
|
1 |
import os
|
2 |
-
import
|
|
|
|
|
|
|
3 |
from functools import partial
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
import nibabel as nib
|
7 |
import numpy as np
|
8 |
import torch
|
|
|
9 |
from tqdm import tqdm as std_tqdm
|
10 |
|
11 |
tqdm = partial(std_tqdm, dynamic_ncols=True)
|
12 |
|
13 |
-
#
|
14 |
from utils.cropping import cropping
|
15 |
-
from utils.functions import reimburse_conform
|
16 |
from utils.hemisphere import hemisphere
|
17 |
from utils.load_model import load_model
|
18 |
from utils.make_csv import make_csv
|
@@ -22,46 +26,66 @@ from utils.postprocessing import postprocessing
|
|
22 |
from utils.preprocessing import preprocessing
|
23 |
from utils.stripping import stripping
|
24 |
|
25 |
-
# モデルは起動時に一度読み込む
|
26 |
-
MODELS = {}
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
model_folder = "model/" # 学習済みモデルは model/ に配置している前提
|
31 |
-
try:
|
32 |
-
# load_model 内部が argparse.Namespace 等を前提の場合、必要な属性が使われるなら適宜オブジェクトを用意してください
|
33 |
-
# ここでは opt を使わないように必要最低限の修正例を示しています。
|
34 |
-
models = load_model(model_folder=model_folder, opt=None, device=device)
|
35 |
-
return models # (cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a)
|
36 |
-
except Exception as e:
|
37 |
-
print("Error during model loading:", e)
|
38 |
-
return None
|
39 |
|
40 |
|
41 |
def run_inference(input_file, only_face_cropping, only_skull_stripping):
|
42 |
-
#
|
43 |
-
|
44 |
-
tmp_output_dir = tempfile.mkdtemp()
|
45 |
|
46 |
-
#
|
47 |
basename = os.path.splitext(os.path.basename(input_file.name))[0]
|
48 |
-
|
49 |
-
|
50 |
-
f.write(input_file.read())
|
51 |
|
52 |
-
#
|
53 |
class Options:
|
54 |
pass
|
55 |
|
56 |
opt = Options()
|
57 |
-
|
58 |
-
opt.o =
|
59 |
-
# ここでは学習済みモデルフォルダを固定で "model/" に設定
|
60 |
-
opt.m = "model/"
|
61 |
opt.only_face_cropping = only_face_cropping
|
62 |
opt.only_skull_stripping = only_skull_stripping
|
63 |
|
64 |
-
#
|
65 |
if torch.cuda.is_available():
|
66 |
device = torch.device("cuda")
|
67 |
elif torch.backends.mps.is_available():
|
@@ -70,74 +94,98 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
|
|
70 |
device = torch.device("cpu")
|
71 |
print(f"Using device: {device}")
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
# --- 以下、元のparcellation.py の処理フローに準じた処理 ---
|
82 |
-
# 1. 入力画像の読み込み・キャノニカル化・squeeze
|
83 |
-
odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_path)))
|
84 |
nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine)
|
85 |
-
os.makedirs(os.path.join(
|
86 |
-
|
|
|
87 |
|
88 |
-
# 2.
|
89 |
-
odata, data = preprocessing(
|
90 |
|
91 |
-
# 3.
|
92 |
-
cropped = cropping(
|
93 |
if only_face_cropping:
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
)
|
118 |
-
out_parcellated_dir = os.path.join(tmp_output_dir, "parcellated")
|
119 |
-
os.makedirs(out_parcellated_dir, exist_ok=True)
|
120 |
-
out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii")
|
121 |
-
nib.save(nii_out, out_filename)
|
122 |
-
create_parcellated_images(output, tmp_output_dir, basename, odata, data)
|
123 |
|
124 |
-
#
|
|
|
|
|
125 |
|
126 |
-
#
|
127 |
-
return
|
128 |
|
129 |
|
130 |
-
# Gradio
|
131 |
iface = gr.Interface(
|
132 |
fn=run_inference,
|
133 |
inputs=[
|
134 |
-
gr.
|
135 |
gr.Checkbox(label="Only Face Cropping", value=False),
|
136 |
gr.Checkbox(label="Only Skull Stripping", value=False),
|
137 |
],
|
138 |
-
outputs=
|
|
|
|
|
|
|
|
|
139 |
title="OpenMAP-T1 Inference",
|
140 |
-
description=
|
|
|
|
|
|
|
141 |
)
|
142 |
|
143 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
import string
|
5 |
+
import zipfile
|
6 |
from functools import partial
|
7 |
|
8 |
import gradio as gr
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
import nibabel as nib
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
+
from PIL import Image
|
14 |
from tqdm import tqdm as std_tqdm
|
15 |
|
16 |
tqdm = partial(std_tqdm, dynamic_ncols=True)
|
17 |
|
18 |
+
# Import required modules from our project
|
19 |
from utils.cropping import cropping
|
|
|
20 |
from utils.hemisphere import hemisphere
|
21 |
from utils.load_model import load_model
|
22 |
from utils.make_csv import make_csv
|
|
|
26 |
from utils.preprocessing import preprocessing
|
27 |
from utils.stripping import stripping
|
28 |
|
|
|
|
|
29 |
|
30 |
+
def nii_to_image(voxel_path, label_path, output_dir, basename):
|
31 |
+
"""
|
32 |
+
Converts two NIfTI files into 2D images for visualization.
|
33 |
+
The voxel (input MRI) is shown as a grayscale image and the label (segmentation)
|
34 |
+
is shown using a default color map.
|
35 |
+
A middle slice is chosen by default.
|
36 |
+
"""
|
37 |
+
# Load the NIfTI volumes and squeeze to remove extra dimensions
|
38 |
+
vdata = nib.squeeze_image(nib.as_closest_canonical(nib.load(voxel_path)))
|
39 |
+
ldata = nib.squeeze_image(nib.as_closest_canonical(nib.load(label_path)))
|
40 |
+
voxel = vdata.get_fdata().astype("float32")
|
41 |
+
label = ldata.get_fdata().astype("int16")
|
42 |
+
|
43 |
+
# Choose the middle slice along the first dimension and rotate for display
|
44 |
+
slice_index = voxel.shape[0] // 2
|
45 |
+
slice_voxel = np.rot90(voxel[slice_index, :, :])
|
46 |
+
slice_label = np.rot90(label[slice_index, :, :])
|
47 |
+
|
48 |
+
# Plot and save the input MRI image
|
49 |
+
plt.figure(figsize=(5, 5))
|
50 |
+
plt.imshow(slice_voxel, cmap="gray")
|
51 |
+
plt.title("Input Image")
|
52 |
+
plt.axis("off")
|
53 |
+
input_png_path = os.path.join(os.path.dirname(output_dir), f"{basename}_input.png")
|
54 |
+
plt.savefig(input_png_path, format="png", bbox_inches="tight", pad_inches=0)
|
55 |
+
|
56 |
+
# Plot and save the parcellation (segmentation) map image
|
57 |
+
plt.figure(figsize=(5, 5))
|
58 |
+
plt.imshow(slice_label)
|
59 |
+
plt.title("Parcellation Result")
|
60 |
+
plt.axis("off")
|
61 |
+
parcellation_png_path = os.path.join(
|
62 |
+
os.path.dirname(output_dir), f"{basename}_parcellation.png"
|
63 |
+
)
|
64 |
+
plt.savefig(parcellation_png_path, format="png", bbox_inches="tight", pad_inches=0)
|
65 |
|
66 |
+
return input_png_path, parcellation_png_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
def run_inference(input_file, only_face_cropping, only_skull_stripping):
|
70 |
+
# Generate a random 10-character string to create a unique temporary directory
|
71 |
+
random_string = "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
|
|
72 |
|
73 |
+
# Extract the base filename from the uploaded file (handle .nii and .nii.gz)
|
74 |
basename = os.path.splitext(os.path.basename(input_file.name))[0]
|
75 |
+
if basename.endswith(".nii"):
|
76 |
+
basename = os.path.splitext(basename)[0]
|
|
|
77 |
|
78 |
+
# Create an Options object (similar to argparse.Namespace)
|
79 |
class Options:
|
80 |
pass
|
81 |
|
82 |
opt = Options()
|
83 |
+
# Set the output directory uniquely with the random string and base filename
|
84 |
+
opt.o = f"output/{random_string}/{basename}"
|
|
|
|
|
85 |
opt.only_face_cropping = only_face_cropping
|
86 |
opt.only_skull_stripping = only_skull_stripping
|
87 |
|
88 |
+
# Device selection: prefer CUDA if available, otherwise MPS or CPU
|
89 |
if torch.cuda.is_available():
|
90 |
device = torch.device("cuda")
|
91 |
elif torch.backends.mps.is_available():
|
|
|
94 |
device = torch.device("cpu")
|
95 |
print(f"Using device: {device}")
|
96 |
|
97 |
+
# Load the pre-trained models from the fixed "model/" folder
|
98 |
+
# cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
|
99 |
+
cnet, ssnet, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
|
100 |
+
|
101 |
+
# --- Processing Flow (based on the original parcellation.py) ---
|
102 |
+
# 1. Load the input image, convert to canonical orientation, and remove extra dimensions
|
103 |
+
odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_file.name)))
|
|
|
|
|
|
|
|
|
104 |
nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine)
|
105 |
+
os.makedirs(os.path.join(opt.o, "original"), exist_ok=True)
|
106 |
+
original_nii_path = os.path.join(opt.o, f"original/{basename}.nii")
|
107 |
+
nib.save(nii, original_nii_path)
|
108 |
|
109 |
+
# 2. Preprocess the image
|
110 |
+
odata, data = preprocessing(input_file.name, opt.o, basename)
|
111 |
|
112 |
+
# 3. Cropping
|
113 |
+
cropped, out_filename = cropping(opt.o, basename, odata, data, cnet, device)
|
114 |
if only_face_cropping:
|
115 |
+
pass
|
116 |
+
|
117 |
+
else:
|
118 |
+
# 4. Skull stripping
|
119 |
+
stripped, shift, out_filename = stripping(
|
120 |
+
opt.o, basename, cropped, odata, data, ssnet, device
|
121 |
+
)
|
122 |
+
if only_skull_stripping:
|
123 |
+
pass
|
124 |
+
else:
|
125 |
+
# 5. Parcellation
|
126 |
+
parcellated = parcellation(stripped, pnet_a, pnet_a, pnet_a, device)
|
127 |
+
# 6. Separate into hemispheres
|
128 |
+
separated = hemisphere(stripped, hnet_c, hnet_a, device)
|
129 |
+
# 7. Postprocessing
|
130 |
+
output = postprocessing(parcellated, separated, shift, device)
|
131 |
+
# 8. Create CSV with volume information, etc.
|
132 |
+
df = make_csv(output, opt.o, basename)
|
133 |
+
# 9. Create and save the parcellation result NIfTI file
|
134 |
+
nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
|
135 |
+
header = odata.header
|
136 |
+
nii_out = nib.processing.conform(
|
137 |
+
nii_out,
|
138 |
+
out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]),
|
139 |
+
voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]),
|
140 |
+
order=0,
|
141 |
+
)
|
142 |
+
out_parcellated_dir = os.path.join(opt.o, "parcellated")
|
143 |
+
os.makedirs(out_parcellated_dir, exist_ok=True)
|
144 |
+
out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii")
|
145 |
+
nib.save(nii_out, out_filename)
|
146 |
+
create_parcellated_images(output, opt.o, basename, odata, data)
|
147 |
+
|
148 |
+
# Zip the entire output directory into a ZIP file
|
149 |
+
zip_path = os.path.join(os.path.dirname(opt.o), f"{basename}_results.zip")
|
150 |
+
with zipfile.ZipFile(zip_path, "w") as zipf:
|
151 |
+
for root, _, files in os.walk(opt.o):
|
152 |
+
for file in files:
|
153 |
+
file_path = os.path.join(root, file)
|
154 |
+
# Adjust the path within the zip archive
|
155 |
+
arcname = os.path.relpath(file_path, start=opt.o)
|
156 |
+
zipf.write(file_path, arcname)
|
157 |
+
|
158 |
+
# Convert the NIfTI files into visualization images (PNG)
|
159 |
+
input_png_path, parcellation_png_path = nii_to_image(
|
160 |
+
input_file.name, out_filename, opt.o, basename
|
161 |
)
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
+
# *** Cleanup: Remove the temporary output directory ***
|
164 |
+
# Note: This is performed before returning. It is not possible to execute code after the return statement.
|
165 |
+
shutil.rmtree(opt.o)
|
166 |
|
167 |
+
# Return the ZIP file path and the two visualization images
|
168 |
+
return zip_path, Image.open(input_png_path), Image.open(parcellation_png_path)
|
169 |
|
170 |
|
171 |
+
# Create the Gradio interface (the model folder input is not needed)
|
172 |
iface = gr.Interface(
|
173 |
fn=run_inference,
|
174 |
inputs=[
|
175 |
+
gr.File(label="Input NIfTI File (.nii or .nii.gz)"),
|
176 |
gr.Checkbox(label="Only Face Cropping", value=False),
|
177 |
gr.Checkbox(label="Only Skull Stripping", value=False),
|
178 |
],
|
179 |
+
outputs=[
|
180 |
+
gr.File(label="Output Results ZIP File"),
|
181 |
+
gr.Image(label="MRI Image (Original)"),
|
182 |
+
gr.Image(label="Parcellation Map (Type1_Level5)"),
|
183 |
+
],
|
184 |
title="OpenMAP-T1 Inference",
|
185 |
+
description=(
|
186 |
+
"The uploaded MRI image will be processed using OpenMAP-T1, and the parcellation "
|
187 |
+
"results will be returned as a ZIP file along with visualization images."
|
188 |
+
),
|
189 |
)
|
190 |
|
191 |
if __name__ == "__main__":
|
src/utils/cropping.py
CHANGED
@@ -70,6 +70,6 @@ def cropping(output_dir, basename, odata, data, cnet, device):
|
|
70 |
out_e = closing(out_e)
|
71 |
cropped = data.get_fdata().astype("float32") * out_e
|
72 |
|
73 |
-
reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
|
74 |
|
75 |
-
return cropped
|
|
|
70 |
out_e = closing(out_e)
|
71 |
cropped = data.get_fdata().astype("float32") * out_e
|
72 |
|
73 |
+
out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
|
74 |
|
75 |
+
return cropped, out_filename
|
src/utils/functions.py
CHANGED
@@ -12,6 +12,7 @@ def normalize(voxel):
|
|
12 |
voxel = (voxel * 2) - 1
|
13 |
return voxel.astype("float32")
|
14 |
|
|
|
15 |
def reimburse_conform(output_dir, basename, suffix, odata, data, output):
|
16 |
nii = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
|
17 |
header = odata.header
|
@@ -23,8 +24,8 @@ def reimburse_conform(output_dir, basename, suffix, odata, data, output):
|
|
23 |
)
|
24 |
os.makedirs(os.path.join(output_dir, f"{suffix}"), exist_ok=True)
|
25 |
nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}_mask.nii"))
|
26 |
-
|
27 |
result = odata.get_fdata().astype("float32") * nii.get_fdata().astype("int16")
|
28 |
nii = nib.Nifti1Image(result.astype(np.float32), affine=odata.affine)
|
29 |
nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}.nii"))
|
30 |
-
return
|
|
|
12 |
voxel = (voxel * 2) - 1
|
13 |
return voxel.astype("float32")
|
14 |
|
15 |
+
|
16 |
def reimburse_conform(output_dir, basename, suffix, odata, data, output):
|
17 |
nii = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
|
18 |
header = odata.header
|
|
|
24 |
)
|
25 |
os.makedirs(os.path.join(output_dir, f"{suffix}"), exist_ok=True)
|
26 |
nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}_mask.nii"))
|
27 |
+
|
28 |
result = odata.get_fdata().astype("float32") * nii.get_fdata().astype("int16")
|
29 |
nii = nib.Nifti1Image(result.astype(np.float32), affine=odata.affine)
|
30 |
nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}.nii"))
|
31 |
+
return os.path.join(output_dir, f"{suffix}/{basename}_{suffix}_mask.nii")
|
src/utils/load_model.py
CHANGED
@@ -4,11 +4,8 @@ import torch
|
|
4 |
|
5 |
from utils.network import UNet
|
6 |
|
7 |
-
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
8 |
-
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "..", ".."))
|
9 |
|
10 |
-
|
11 |
-
def load_model(opt, device):
|
12 |
"""
|
13 |
This function loads multiple pre-trained models and sets them to evaluation mode.
|
14 |
The models loaded are:
|
@@ -27,10 +24,6 @@ def load_model(opt, device):
|
|
27 |
Returns:
|
28 |
tuple: A tuple containing all the loaded models.
|
29 |
"""
|
30 |
-
if os.path.isabs(opt.m):
|
31 |
-
model_dir = opt.m
|
32 |
-
else:
|
33 |
-
model_dir = os.path.join(PROJECT_ROOT, opt.m)
|
34 |
|
35 |
# Load CNet model
|
36 |
cnet = UNet(1, 1)
|
@@ -47,20 +40,20 @@ def load_model(opt, device):
|
|
47 |
ssnet.eval()
|
48 |
|
49 |
# Load PNet coronal model
|
50 |
-
pnet_c = UNet(3, 142)
|
51 |
-
pnet_c.load_state_dict(
|
52 |
-
|
53 |
-
)
|
54 |
-
pnet_c.to(device)
|
55 |
-
pnet_c.eval()
|
56 |
|
57 |
# Load PNet sagittal model
|
58 |
-
pnet_s = UNet(3, 142)
|
59 |
-
pnet_s.load_state_dict(
|
60 |
-
|
61 |
-
)
|
62 |
-
pnet_s.to(device)
|
63 |
-
pnet_s.eval()
|
64 |
|
65 |
# Load PNet axial model
|
66 |
pnet_a = UNet(3, 142)
|
@@ -87,4 +80,5 @@ def load_model(opt, device):
|
|
87 |
hnet_a.eval()
|
88 |
|
89 |
# Return all loaded models
|
90 |
-
return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a
|
|
|
|
4 |
|
5 |
from utils.network import UNet
|
6 |
|
|
|
|
|
7 |
|
8 |
+
def load_model(model_dir, device):
|
|
|
9 |
"""
|
10 |
This function loads multiple pre-trained models and sets them to evaluation mode.
|
11 |
The models loaded are:
|
|
|
24 |
Returns:
|
25 |
tuple: A tuple containing all the loaded models.
|
26 |
"""
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Load CNet model
|
29 |
cnet = UNet(1, 1)
|
|
|
40 |
ssnet.eval()
|
41 |
|
42 |
# Load PNet coronal model
|
43 |
+
# pnet_c = UNet(3, 142)
|
44 |
+
# pnet_c.load_state_dict(
|
45 |
+
# torch.load(os.path.join(model_dir, "PNet", "coronal.pth"), weights_only=True)
|
46 |
+
# )
|
47 |
+
# pnet_c.to(device)
|
48 |
+
# pnet_c.eval()
|
49 |
|
50 |
# Load PNet sagittal model
|
51 |
+
# pnet_s = UNet(3, 142)
|
52 |
+
# pnet_s.load_state_dict(
|
53 |
+
# torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True)
|
54 |
+
# )
|
55 |
+
# pnet_s.to(device)
|
56 |
+
# pnet_s.eval()
|
57 |
|
58 |
# Load PNet axial model
|
59 |
pnet_a = UNet(3, 142)
|
|
|
80 |
hnet_a.eval()
|
81 |
|
82 |
# Return all loaded models
|
83 |
+
# return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a
|
84 |
+
return cnet, ssnet, pnet_a, hnet_c, hnet_a
|
src/utils/parcellation.py
CHANGED
@@ -74,24 +74,24 @@ def parcellation(voxel, pnet_c, pnet_s, pnet_a, device):
|
|
74 |
sagittal = voxel
|
75 |
axial = voxel.transpose(2, 1, 0)
|
76 |
|
77 |
-
# Perform parcellation for the coronal view
|
78 |
-
out_c = parcellate(coronal, pnet_c, device, "c").permute(1, 3, 0, 2)
|
79 |
-
torch.cuda.empty_cache()
|
80 |
|
81 |
-
# Perform parcellation for the sagittal view
|
82 |
-
out_s = parcellate(sagittal, pnet_s, device, "s").permute(1, 0, 2, 3)
|
83 |
-
torch.cuda.empty_cache()
|
84 |
|
85 |
-
# Combine the results from coronal and sagittal views
|
86 |
-
out_e = out_c + out_s
|
87 |
-
del out_c, out_s
|
88 |
|
89 |
# Perform parcellation for the axial view
|
90 |
out_a = parcellate(axial, pnet_a, device, "a").permute(1, 3, 2, 0)
|
91 |
torch.cuda.empty_cache()
|
92 |
|
93 |
# Combine the results from all views
|
94 |
-
out_e = out_e + out_a
|
95 |
del out_a
|
96 |
|
97 |
# Get the final parcellated output by taking the argmax
|
|
|
74 |
sagittal = voxel
|
75 |
axial = voxel.transpose(2, 1, 0)
|
76 |
|
77 |
+
# # Perform parcellation for the coronal view
|
78 |
+
# out_c = parcellate(coronal, pnet_c, device, "c").permute(1, 3, 0, 2)
|
79 |
+
# torch.cuda.empty_cache()
|
80 |
|
81 |
+
# # Perform parcellation for the sagittal view
|
82 |
+
# out_s = parcellate(sagittal, pnet_s, device, "s").permute(1, 0, 2, 3)
|
83 |
+
# torch.cuda.empty_cache()
|
84 |
|
85 |
+
# # Combine the results from coronal and sagittal views
|
86 |
+
# out_e = out_c + out_s
|
87 |
+
# del out_c, out_s
|
88 |
|
89 |
# Perform parcellation for the axial view
|
90 |
out_a = parcellate(axial, pnet_a, device, "a").permute(1, 3, 2, 0)
|
91 |
torch.cuda.empty_cache()
|
92 |
|
93 |
# Combine the results from all views
|
94 |
+
out_e = out_a # out_e + out_a
|
95 |
del out_a
|
96 |
|
97 |
# Get the final parcellated output by taking the argmax
|
src/utils/stripping.py
CHANGED
@@ -82,7 +82,7 @@ def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
|
|
82 |
# Multiply the original data by the thresholded output to get the stripped brain image
|
83 |
stripped = data.get_fdata().astype("float32") * out_e
|
84 |
|
85 |
-
reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
|
86 |
|
87 |
# Calculate the center of mass of the stripped brain image
|
88 |
x, y, z = map(int, ndimage.center_of_mass(out_e))
|
@@ -99,4 +99,4 @@ def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
|
|
99 |
stripped = stripped[32:-32, 16:-16, 32:-32]
|
100 |
|
101 |
# Return the stripped brain image and the shifts applied
|
102 |
-
return stripped, (xd, yd, zd)
|
|
|
82 |
# Multiply the original data by the thresholded output to get the stripped brain image
|
83 |
stripped = data.get_fdata().astype("float32") * out_e
|
84 |
|
85 |
+
out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
|
86 |
|
87 |
# Calculate the center of mass of the stripped brain image
|
88 |
x, y, z = map(int, ndimage.center_of_mass(out_e))
|
|
|
99 |
stripped = stripped[32:-32, 16:-16, 32:-32]
|
100 |
|
101 |
# Return the stripped brain image and the shifts applied
|
102 |
+
return stripped, (xd, yd, zd), out_filename
|