OpenMAP-T1 / src /parcellation.py
西牧慧
add: first commit
01f75cf
raw
history blame
5.88 kB
import os
import tempfile
from functools import partial
import gradio as gr
import nibabel as nib
import numpy as np
import torch
from tqdm import tqdm as std_tqdm
tqdm = partial(std_tqdm, dynamic_ncols=True)
# 必要なモジュールのインポート
from utils.cropping import cropping
from utils.functions import reimburse_conform
from utils.hemisphere import hemisphere
from utils.load_model import load_model
from utils.make_csv import make_csv
from utils.make_level import create_parcellated_images
from utils.parcellation import parcellation
from utils.postprocessing import postprocessing
from utils.preprocessing import preprocessing
from utils.stripping import stripping
# モデルは起動時に一度読み込む
MODELS = {}
def load_all_models(device):
model_folder = "model/" # 学習済みモデルは model/ に配置している前提
try:
# load_model 内部が argparse.Namespace 等を前提の場合、必要な属性が使われるなら適宜オブジェクトを用意してください
# ここでは opt を使わないように必要最低限の修正例を示しています。
models = load_model(model_folder=model_folder, opt=None, device=device)
return models # (cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a)
except Exception as e:
print("Error during model loading:", e)
return None
def run_inference(input_file, only_face_cropping, only_skull_stripping):
# 一時ディレクトリの作成
tmp_input_dir = tempfile.mkdtemp()
tmp_output_dir = tempfile.mkdtemp()
# アップロードされたファイルを一時フォルダに保存
basename = os.path.splitext(os.path.basename(input_file.name))[0]
input_path = os.path.join(tmp_input_dir, f"{basename}.nii")
with open(input_path, "wb") as f:
f.write(input_file.read())
# Gradio用のオプションオブジェクト(オリジナルコードの argparse.Namespace 相当)
class Options:
pass
opt = Options()
opt.i = tmp_input_dir # 入力ディレクトリを指定
opt.o = tmp_output_dir # 出力ディレクトリを指定
# ここでは学習済みモデルフォルダを固定で "model/" に設定
opt.m = "model/"
opt.only_face_cropping = only_face_cropping
opt.only_skull_stripping = only_skull_stripping
# デバイス選択
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# 初回のみモデルを読み込み(キャッシュ)
global MODELS
if not MODELS:
MODELS = load_all_models(device)
if MODELS is None:
return "モデルの読み込みに失敗しました。"
cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = MODELS
# --- 以下、元のparcellation.py の処理フローに準じた処理 ---
# 1. 入力画像の読み込み・キャノニカル化・squeeze
odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_path)))
nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine)
os.makedirs(os.path.join(tmp_output_dir, "original"), exist_ok=True)
nib.save(nii, os.path.join(tmp_output_dir, f"original/{basename}.nii"))
# 2. 前処理
odata, data = preprocessing(input_path, tmp_output_dir, basename)
# 3. クロッピング
cropped = cropping(tmp_output_dir, basename, odata, data, cnet, device)
if only_face_cropping:
return os.path.join(tmp_output_dir, f"{basename}_cropped.nii")
# 4. スキルストリッピング
stripped, shift = stripping(tmp_output_dir, basename, cropped, odata, data, ssnet, device)
if only_skull_stripping:
return os.path.join(tmp_output_dir, f"{basename}_stripped.nii")
# 5. パーセレーション
parcellated = parcellation(stripped, pnet_c, pnet_s, pnet_a, device)
# 6. 両半球に分離
separated = hemisphere(stripped, hnet_c, hnet_a, device)
# 7. 後処理
output = postprocessing(parcellated, separated, shift, device)
# 8. CSV作成(体積情報等)
df = make_csv(output, tmp_output_dir, basename)
# 9. パーセル結果のNIfTI作成と保存
nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine)
header = odata.header
nii_out = nib.processing.conform(
nii_out,
out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]),
voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]),
order=0,
)
out_parcellated_dir = os.path.join(tmp_output_dir, "parcellated")
os.makedirs(out_parcellated_dir, exist_ok=True)
out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii")
nib.save(nii_out, out_filename)
create_parcellated_images(output, tmp_output_dir, basename, odata, data)
# 必要に応じて、一時ファイルの削除等の処理を追加してください
# 最終結果のNIfTIファイルのパスを返す
return out_filename
# Gradioインターフェースの作成(モデルフォルダの入力は不要)
iface = gr.Interface(
fn=run_inference,
inputs=[
gr.File(label="Input NIfTI File (.nii or .nii.gz)"),
gr.Checkbox(label="Only Face Cropping", value=False),
gr.Checkbox(label="Only Skull Stripping", value=False),
],
outputs=gr.File(label="Output Parcellated NIfTI File"),
title="OpenMAP-T1 Inference",
description="学習済みモデルは model/ に配置されています。\nアップロードされたMRI画像に対してOpenMAP-T1の処理を行い、パーセル結果を返します。",
)
if __name__ == "__main__":
iface.launch()