西牧慧 commited on
Commit
dcbe128
·
1 Parent(s): fe59750

update: parcellation

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: OpenMAP T1
3
  emoji: 📚
4
  colorFrom: gray
5
  colorTo: red
 
1
  ---
2
+ title: OpenMAP-T1
3
  emoji: 📚
4
  colorFrom: gray
5
  colorTo: red
requirements.txt CHANGED
@@ -4,9 +4,12 @@ anyio==4.9.0
4
  certifi==2025.1.31
5
  charset-normalizer==3.4.1
6
  click==8.1.8
 
 
7
  fastapi==0.115.12
8
  ffmpy==0.5.0
9
  filelock==3.18.0
 
10
  fsspec==2025.3.2
11
  gradio==5.25.0
12
  gradio_client==1.8.0
@@ -17,8 +20,10 @@ httpx==0.28.1
17
  huggingface-hub==0.30.2
18
  idna==3.10
19
  Jinja2==3.1.6
 
20
  markdown-it-py==3.0.0
21
  MarkupSafe==3.0.2
 
22
  mdurl==0.1.2
23
  mpmath==1.3.0
24
  networkx==3.4.2
@@ -32,6 +37,7 @@ pydantic==2.11.3
32
  pydantic_core==2.33.1
33
  pydub==0.25.1
34
  Pygments==2.19.1
 
35
  python-dateutil==2.9.0.post0
36
  python-multipart==0.0.20
37
  pytz==2025.2
 
4
  certifi==2025.1.31
5
  charset-normalizer==3.4.1
6
  click==8.1.8
7
+ contourpy==1.3.1
8
+ cycler==0.12.1
9
  fastapi==0.115.12
10
  ffmpy==0.5.0
11
  filelock==3.18.0
12
+ fonttools==4.57.0
13
  fsspec==2025.3.2
14
  gradio==5.25.0
15
  gradio_client==1.8.0
 
20
  huggingface-hub==0.30.2
21
  idna==3.10
22
  Jinja2==3.1.6
23
+ kiwisolver==1.4.8
24
  markdown-it-py==3.0.0
25
  MarkupSafe==3.0.2
26
+ matplotlib==3.10.1
27
  mdurl==0.1.2
28
  mpmath==1.3.0
29
  networkx==3.4.2
 
37
  pydantic_core==2.33.1
38
  pydub==0.25.1
39
  Pygments==2.19.1
40
+ pyparsing==3.2.3
41
  python-dateutil==2.9.0.post0
42
  python-multipart==0.0.20
43
  pytz==2025.2
src/parcellation.py CHANGED
@@ -1,18 +1,22 @@
1
  import os
2
- import tempfile
 
 
 
3
  from functools import partial
4
 
5
  import gradio as gr
 
6
  import nibabel as nib
7
  import numpy as np
8
  import torch
 
9
  from tqdm import tqdm as std_tqdm
10
 
11
  tqdm = partial(std_tqdm, dynamic_ncols=True)
12
 
13
- # 必要なモジュールのインポート
14
  from utils.cropping import cropping
15
- from utils.functions import reimburse_conform
16
  from utils.hemisphere import hemisphere
17
  from utils.load_model import load_model
18
  from utils.make_csv import make_csv
@@ -22,46 +26,66 @@ from utils.postprocessing import postprocessing
22
  from utils.preprocessing import preprocessing
23
  from utils.stripping import stripping
24
 
25
- # モデルは起動時に一度読み込む
26
- MODELS = {}
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def load_all_models(device):
30
- model_folder = "model/" # 学習済みモデルは model/ に配置している前提
31
- try:
32
- # load_model 内部が argparse.Namespace 等を前提の場合、必要な属性が使われるなら適宜オブジェクトを用意してください
33
- # ここでは opt を使わないように必要最低限の修正例を示しています。
34
- models = load_model(model_folder=model_folder, opt=None, device=device)
35
- return models # (cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a)
36
- except Exception as e:
37
- print("Error during model loading:", e)
38
- return None
39
 
40
 
41
  def run_inference(input_file, only_face_cropping, only_skull_stripping):
42
- # 一時ディレクトリの作成
43
- tmp_input_dir = tempfile.mkdtemp()
44
- tmp_output_dir = tempfile.mkdtemp()
45
 
46
- # アップロードされたファイルを一時フォルダに保存
47
  basename = os.path.splitext(os.path.basename(input_file.name))[0]
48
- input_path = os.path.join(tmp_input_dir, f"{basename}.nii")
49
- with open(input_path, "wb") as f:
50
- f.write(input_file.read())
51
 
52
- # Gradio用のオプションオブジェクト(オリジナルコードの argparse.Namespace 相当)
53
  class Options:
54
  pass
55
 
56
  opt = Options()
57
- opt.i = tmp_input_dir # 入力ディレクトリを指定
58
- opt.o = tmp_output_dir # 出力ディレクトリを指定
59
- # ここでは学習済みモデルフォルダを固定で "model/" に設定
60
- opt.m = "model/"
61
  opt.only_face_cropping = only_face_cropping
62
  opt.only_skull_stripping = only_skull_stripping
63
 
64
- # デバイス選択
65
  if torch.cuda.is_available():
66
  device = torch.device("cuda")
67
  elif torch.backends.mps.is_available():
@@ -70,74 +94,98 @@ def run_inference(input_file, only_face_cropping, only_skull_stripping):
70
  device = torch.device("cpu")
71
  print(f"Using device: {device}")
72
 
73
- # 初回の���モデルを読み込み(キャッシュ)
74
- global MODELS
75
- if not MODELS:
76
- MODELS = load_all_models(device)
77
- if MODELS is None:
78
- return "モデルの読み込みに失敗しました。"
79
- cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = MODELS
80
-
81
- # --- 以下、元のparcellation.py の処理フローに準じた処理 ---
82
- # 1. 入力画像の読み込み・キャノニカル化・squeeze
83
- odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_path)))
84
  nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine)
85
- os.makedirs(os.path.join(tmp_output_dir, "original"), exist_ok=True)
86
- nib.save(nii, os.path.join(tmp_output_dir, f"original/{basename}.nii"))
 
87
 
88
- # 2. 前処理
89
- odata, data = preprocessing(input_path, tmp_output_dir, basename)
90
 
91
- # 3. クロッピング
92
- cropped = cropping(tmp_output_dir, basename, odata, data, cnet, device)
93
  if only_face_cropping:
94
- return os.path.join(tmp_output_dir, f"{basename}_cropped.nii")
95
-
96
- # 4. スキルストリッピング
97
- stripped, shift = stripping(tmp_output_dir, basename, cropped, odata, data, ssnet, device)
98
- if only_skull_stripping:
99
- return os.path.join(tmp_output_dir, f"{basename}_stripped.nii")
100
-
101
- # 5. パーセレーション
102
- parcellated = parcellation(stripped, pnet_c, pnet_s, pnet_a, device)
103
- # 6. 両半球に分離
104
- separated = hemisphere(stripped, hnet_c, hnet_a, device)
105
- # 7. 後処理
106
- output = postprocessing(parcellated, separated, shift, device)
107
- # 8. CSV作成(体積情報等)
108
- df = make_csv(output, tmp_output_dir, basename)
109
- # 9. パーセル結果のNIfTI作成と保存
110
- nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
111
- header = odata.header
112
- nii_out = nib.processing.conform(
113
- nii_out,
114
- out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]),
115
- voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]),
116
- order=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
- out_parcellated_dir = os.path.join(tmp_output_dir, "parcellated")
119
- os.makedirs(out_parcellated_dir, exist_ok=True)
120
- out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii")
121
- nib.save(nii_out, out_filename)
122
- create_parcellated_images(output, tmp_output_dir, basename, odata, data)
123
 
124
- # 必要に応じて、一時ファイルの削除等の処理を追加してください
 
 
125
 
126
- # 最終結果のNIfTIファイルのパスを返す
127
- return out_filename
128
 
129
 
130
- # Gradioインターフェースの作成(モデルフォルダの入力は不要)
131
  iface = gr.Interface(
132
  fn=run_inference,
133
  inputs=[
134
- gr.Image(label="Input NIfTI File (.nii or .nii.gz)"),
135
  gr.Checkbox(label="Only Face Cropping", value=False),
136
  gr.Checkbox(label="Only Skull Stripping", value=False),
137
  ],
138
- outputs=gr.Image(label="Output Parcellated NIfTI File"),
 
 
 
 
139
  title="OpenMAP-T1 Inference",
140
- description="学習済みモデルは model/ に配置されています。\nアップロードされたMRI画像に対してOpenMAP-T1の処理を行い、パーセル結果を返します。",
 
 
 
141
  )
142
 
143
  if __name__ == "__main__":
 
1
  import os
2
+ import random
3
+ import shutil
4
+ import string
5
+ import zipfile
6
  from functools import partial
7
 
8
  import gradio as gr
9
+ import matplotlib.pyplot as plt
10
  import nibabel as nib
11
  import numpy as np
12
  import torch
13
+ from PIL import Image
14
  from tqdm import tqdm as std_tqdm
15
 
16
  tqdm = partial(std_tqdm, dynamic_ncols=True)
17
 
18
+ # Import required modules from our project
19
  from utils.cropping import cropping
 
20
  from utils.hemisphere import hemisphere
21
  from utils.load_model import load_model
22
  from utils.make_csv import make_csv
 
26
  from utils.preprocessing import preprocessing
27
  from utils.stripping import stripping
28
 
 
 
29
 
30
+ def nii_to_image(voxel_path, label_path, output_dir, basename):
31
+ """
32
+ Converts two NIfTI files into 2D images for visualization.
33
+ The voxel (input MRI) is shown as a grayscale image and the label (segmentation)
34
+ is shown using a default color map.
35
+ A middle slice is chosen by default.
36
+ """
37
+ # Load the NIfTI volumes and squeeze to remove extra dimensions
38
+ vdata = nib.squeeze_image(nib.as_closest_canonical(nib.load(voxel_path)))
39
+ ldata = nib.squeeze_image(nib.as_closest_canonical(nib.load(label_path)))
40
+ voxel = vdata.get_fdata().astype("float32")
41
+ label = ldata.get_fdata().astype("int16")
42
+
43
+ # Choose the middle slice along the first dimension and rotate for display
44
+ slice_index = voxel.shape[0] // 2
45
+ slice_voxel = np.rot90(voxel[slice_index, :, :])
46
+ slice_label = np.rot90(label[slice_index, :, :])
47
+
48
+ # Plot and save the input MRI image
49
+ plt.figure(figsize=(5, 5))
50
+ plt.imshow(slice_voxel, cmap="gray")
51
+ plt.title("Input Image")
52
+ plt.axis("off")
53
+ input_png_path = os.path.join(os.path.dirname(output_dir), f"{basename}_input.png")
54
+ plt.savefig(input_png_path, format="png", bbox_inches="tight", pad_inches=0)
55
+
56
+ # Plot and save the parcellation (segmentation) map image
57
+ plt.figure(figsize=(5, 5))
58
+ plt.imshow(slice_label)
59
+ plt.title("Parcellation Result")
60
+ plt.axis("off")
61
+ parcellation_png_path = os.path.join(
62
+ os.path.dirname(output_dir), f"{basename}_parcellation.png"
63
+ )
64
+ plt.savefig(parcellation_png_path, format="png", bbox_inches="tight", pad_inches=0)
65
 
66
+ return input_png_path, parcellation_png_path
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  def run_inference(input_file, only_face_cropping, only_skull_stripping):
70
+ # Generate a random 10-character string to create a unique temporary directory
71
+ random_string = "".join(random.choices(string.ascii_letters + string.digits, k=10))
 
72
 
73
+ # Extract the base filename from the uploaded file (handle .nii and .nii.gz)
74
  basename = os.path.splitext(os.path.basename(input_file.name))[0]
75
+ if basename.endswith(".nii"):
76
+ basename = os.path.splitext(basename)[0]
 
77
 
78
+ # Create an Options object (similar to argparse.Namespace)
79
  class Options:
80
  pass
81
 
82
  opt = Options()
83
+ # Set the output directory uniquely with the random string and base filename
84
+ opt.o = f"output/{random_string}/{basename}"
 
 
85
  opt.only_face_cropping = only_face_cropping
86
  opt.only_skull_stripping = only_skull_stripping
87
 
88
+ # Device selection: prefer CUDA if available, otherwise MPS or CPU
89
  if torch.cuda.is_available():
90
  device = torch.device("cuda")
91
  elif torch.backends.mps.is_available():
 
94
  device = torch.device("cpu")
95
  print(f"Using device: {device}")
96
 
97
+ # Load the pre-trained models from the fixed "model/" folder
98
+ # cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
99
+ cnet, ssnet, pnet_a, hnet_c, hnet_a = load_model("model/", device=device)
100
+
101
+ # --- Processing Flow (based on the original parcellation.py) ---
102
+ # 1. Load the input image, convert to canonical orientation, and remove extra dimensions
103
+ odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_file.name)))
 
 
 
 
104
  nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine)
105
+ os.makedirs(os.path.join(opt.o, "original"), exist_ok=True)
106
+ original_nii_path = os.path.join(opt.o, f"original/{basename}.nii")
107
+ nib.save(nii, original_nii_path)
108
 
109
+ # 2. Preprocess the image
110
+ odata, data = preprocessing(input_file.name, opt.o, basename)
111
 
112
+ # 3. Cropping
113
+ cropped, out_filename = cropping(opt.o, basename, odata, data, cnet, device)
114
  if only_face_cropping:
115
+ pass
116
+
117
+ else:
118
+ # 4. Skull stripping
119
+ stripped, shift, out_filename = stripping(
120
+ opt.o, basename, cropped, odata, data, ssnet, device
121
+ )
122
+ if only_skull_stripping:
123
+ pass
124
+ else:
125
+ # 5. Parcellation
126
+ parcellated = parcellation(stripped, pnet_a, pnet_a, pnet_a, device)
127
+ # 6. Separate into hemispheres
128
+ separated = hemisphere(stripped, hnet_c, hnet_a, device)
129
+ # 7. Postprocessing
130
+ output = postprocessing(parcellated, separated, shift, device)
131
+ # 8. Create CSV with volume information, etc.
132
+ df = make_csv(output, opt.o, basename)
133
+ # 9. Create and save the parcellation result NIfTI file
134
+ nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
135
+ header = odata.header
136
+ nii_out = nib.processing.conform(
137
+ nii_out,
138
+ out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]),
139
+ voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]),
140
+ order=0,
141
+ )
142
+ out_parcellated_dir = os.path.join(opt.o, "parcellated")
143
+ os.makedirs(out_parcellated_dir, exist_ok=True)
144
+ out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii")
145
+ nib.save(nii_out, out_filename)
146
+ create_parcellated_images(output, opt.o, basename, odata, data)
147
+
148
+ # Zip the entire output directory into a ZIP file
149
+ zip_path = os.path.join(os.path.dirname(opt.o), f"{basename}_results.zip")
150
+ with zipfile.ZipFile(zip_path, "w") as zipf:
151
+ for root, _, files in os.walk(opt.o):
152
+ for file in files:
153
+ file_path = os.path.join(root, file)
154
+ # Adjust the path within the zip archive
155
+ arcname = os.path.relpath(file_path, start=opt.o)
156
+ zipf.write(file_path, arcname)
157
+
158
+ # Convert the NIfTI files into visualization images (PNG)
159
+ input_png_path, parcellation_png_path = nii_to_image(
160
+ input_file.name, out_filename, opt.o, basename
161
  )
 
 
 
 
 
162
 
163
+ # *** Cleanup: Remove the temporary output directory ***
164
+ # Note: This is performed before returning. It is not possible to execute code after the return statement.
165
+ shutil.rmtree(opt.o)
166
 
167
+ # Return the ZIP file path and the two visualization images
168
+ return zip_path, Image.open(input_png_path), Image.open(parcellation_png_path)
169
 
170
 
171
+ # Create the Gradio interface (the model folder input is not needed)
172
  iface = gr.Interface(
173
  fn=run_inference,
174
  inputs=[
175
+ gr.File(label="Input NIfTI File (.nii or .nii.gz)"),
176
  gr.Checkbox(label="Only Face Cropping", value=False),
177
  gr.Checkbox(label="Only Skull Stripping", value=False),
178
  ],
179
+ outputs=[
180
+ gr.File(label="Output Results ZIP File"),
181
+ gr.Image(label="MRI Image (Original)"),
182
+ gr.Image(label="Parcellation Map (Type1_Level5)"),
183
+ ],
184
  title="OpenMAP-T1 Inference",
185
+ description=(
186
+ "The uploaded MRI image will be processed using OpenMAP-T1, and the parcellation "
187
+ "results will be returned as a ZIP file along with visualization images."
188
+ ),
189
  )
190
 
191
  if __name__ == "__main__":
src/utils/cropping.py CHANGED
@@ -70,6 +70,6 @@ def cropping(output_dir, basename, odata, data, cnet, device):
70
  out_e = closing(out_e)
71
  cropped = data.get_fdata().astype("float32") * out_e
72
 
73
- reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
74
 
75
- return cropped
 
70
  out_e = closing(out_e)
71
  cropped = data.get_fdata().astype("float32") * out_e
72
 
73
+ out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
74
 
75
+ return cropped, out_filename
src/utils/functions.py CHANGED
@@ -12,6 +12,7 @@ def normalize(voxel):
12
  voxel = (voxel * 2) - 1
13
  return voxel.astype("float32")
14
 
 
15
  def reimburse_conform(output_dir, basename, suffix, odata, data, output):
16
  nii = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
17
  header = odata.header
@@ -23,8 +24,8 @@ def reimburse_conform(output_dir, basename, suffix, odata, data, output):
23
  )
24
  os.makedirs(os.path.join(output_dir, f"{suffix}"), exist_ok=True)
25
  nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}_mask.nii"))
26
-
27
  result = odata.get_fdata().astype("float32") * nii.get_fdata().astype("int16")
28
  nii = nib.Nifti1Image(result.astype(np.float32), affine=odata.affine)
29
  nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}.nii"))
30
- return
 
12
  voxel = (voxel * 2) - 1
13
  return voxel.astype("float32")
14
 
15
+
16
  def reimburse_conform(output_dir, basename, suffix, odata, data, output):
17
  nii = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
18
  header = odata.header
 
24
  )
25
  os.makedirs(os.path.join(output_dir, f"{suffix}"), exist_ok=True)
26
  nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}_mask.nii"))
27
+
28
  result = odata.get_fdata().astype("float32") * nii.get_fdata().astype("int16")
29
  nii = nib.Nifti1Image(result.astype(np.float32), affine=odata.affine)
30
  nib.save(nii, os.path.join(output_dir, f"{suffix}/{basename}_{suffix}.nii"))
31
+ return os.path.join(output_dir, f"{suffix}/{basename}_{suffix}_mask.nii")
src/utils/load_model.py CHANGED
@@ -4,11 +4,8 @@ import torch
4
 
5
  from utils.network import UNet
6
 
7
- CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
8
- PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "..", ".."))
9
 
10
-
11
- def load_model(opt, device):
12
  """
13
  This function loads multiple pre-trained models and sets them to evaluation mode.
14
  The models loaded are:
@@ -27,10 +24,6 @@ def load_model(opt, device):
27
  Returns:
28
  tuple: A tuple containing all the loaded models.
29
  """
30
- if os.path.isabs(opt.m):
31
- model_dir = opt.m
32
- else:
33
- model_dir = os.path.join(PROJECT_ROOT, opt.m)
34
 
35
  # Load CNet model
36
  cnet = UNet(1, 1)
@@ -47,20 +40,20 @@ def load_model(opt, device):
47
  ssnet.eval()
48
 
49
  # Load PNet coronal model
50
- pnet_c = UNet(3, 142)
51
- pnet_c.load_state_dict(
52
- torch.load(os.path.join(model_dir, "PNet", "coronal.pth"), weights_only=True)
53
- )
54
- pnet_c.to(device)
55
- pnet_c.eval()
56
 
57
  # Load PNet sagittal model
58
- pnet_s = UNet(3, 142)
59
- pnet_s.load_state_dict(
60
- torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True)
61
- )
62
- pnet_s.to(device)
63
- pnet_s.eval()
64
 
65
  # Load PNet axial model
66
  pnet_a = UNet(3, 142)
@@ -87,4 +80,5 @@ def load_model(opt, device):
87
  hnet_a.eval()
88
 
89
  # Return all loaded models
90
- return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a
 
 
4
 
5
  from utils.network import UNet
6
 
 
 
7
 
8
+ def load_model(model_dir, device):
 
9
  """
10
  This function loads multiple pre-trained models and sets them to evaluation mode.
11
  The models loaded are:
 
24
  Returns:
25
  tuple: A tuple containing all the loaded models.
26
  """
 
 
 
 
27
 
28
  # Load CNet model
29
  cnet = UNet(1, 1)
 
40
  ssnet.eval()
41
 
42
  # Load PNet coronal model
43
+ # pnet_c = UNet(3, 142)
44
+ # pnet_c.load_state_dict(
45
+ # torch.load(os.path.join(model_dir, "PNet", "coronal.pth"), weights_only=True)
46
+ # )
47
+ # pnet_c.to(device)
48
+ # pnet_c.eval()
49
 
50
  # Load PNet sagittal model
51
+ # pnet_s = UNet(3, 142)
52
+ # pnet_s.load_state_dict(
53
+ # torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True)
54
+ # )
55
+ # pnet_s.to(device)
56
+ # pnet_s.eval()
57
 
58
  # Load PNet axial model
59
  pnet_a = UNet(3, 142)
 
80
  hnet_a.eval()
81
 
82
  # Return all loaded models
83
+ # return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a
84
+ return cnet, ssnet, pnet_a, hnet_c, hnet_a
src/utils/parcellation.py CHANGED
@@ -74,24 +74,24 @@ def parcellation(voxel, pnet_c, pnet_s, pnet_a, device):
74
  sagittal = voxel
75
  axial = voxel.transpose(2, 1, 0)
76
 
77
- # Perform parcellation for the coronal view
78
- out_c = parcellate(coronal, pnet_c, device, "c").permute(1, 3, 0, 2)
79
- torch.cuda.empty_cache()
80
 
81
- # Perform parcellation for the sagittal view
82
- out_s = parcellate(sagittal, pnet_s, device, "s").permute(1, 0, 2, 3)
83
- torch.cuda.empty_cache()
84
 
85
- # Combine the results from coronal and sagittal views
86
- out_e = out_c + out_s
87
- del out_c, out_s
88
 
89
  # Perform parcellation for the axial view
90
  out_a = parcellate(axial, pnet_a, device, "a").permute(1, 3, 2, 0)
91
  torch.cuda.empty_cache()
92
 
93
  # Combine the results from all views
94
- out_e = out_e + out_a
95
  del out_a
96
 
97
  # Get the final parcellated output by taking the argmax
 
74
  sagittal = voxel
75
  axial = voxel.transpose(2, 1, 0)
76
 
77
+ # # Perform parcellation for the coronal view
78
+ # out_c = parcellate(coronal, pnet_c, device, "c").permute(1, 3, 0, 2)
79
+ # torch.cuda.empty_cache()
80
 
81
+ # # Perform parcellation for the sagittal view
82
+ # out_s = parcellate(sagittal, pnet_s, device, "s").permute(1, 0, 2, 3)
83
+ # torch.cuda.empty_cache()
84
 
85
+ # # Combine the results from coronal and sagittal views
86
+ # out_e = out_c + out_s
87
+ # del out_c, out_s
88
 
89
  # Perform parcellation for the axial view
90
  out_a = parcellate(axial, pnet_a, device, "a").permute(1, 3, 2, 0)
91
  torch.cuda.empty_cache()
92
 
93
  # Combine the results from all views
94
+ out_e = out_a # out_e + out_a
95
  del out_a
96
 
97
  # Get the final parcellated output by taking the argmax
src/utils/stripping.py CHANGED
@@ -82,7 +82,7 @@ def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
82
  # Multiply the original data by the thresholded output to get the stripped brain image
83
  stripped = data.get_fdata().astype("float32") * out_e
84
 
85
- reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
86
 
87
  # Calculate the center of mass of the stripped brain image
88
  x, y, z = map(int, ndimage.center_of_mass(out_e))
@@ -99,4 +99,4 @@ def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
99
  stripped = stripped[32:-32, 16:-16, 32:-32]
100
 
101
  # Return the stripped brain image and the shifts applied
102
- return stripped, (xd, yd, zd)
 
82
  # Multiply the original data by the thresholded output to get the stripped brain image
83
  stripped = data.get_fdata().astype("float32") * out_e
84
 
85
+ out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
86
 
87
  # Calculate the center of mass of the stripped brain image
88
  x, y, z = map(int, ndimage.center_of_mass(out_e))
 
99
  stripped = stripped[32:-32, 16:-16, 32:-32]
100
 
101
  # Return the stripped brain image and the shifts applied
102
+ return stripped, (xd, yd, zd), out_filename