Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
from functools import partial | |
import gradio as gr | |
import nibabel as nib | |
import numpy as np | |
import torch | |
from tqdm import tqdm as std_tqdm | |
tqdm = partial(std_tqdm, dynamic_ncols=True) | |
# 必要なモジュールのインポート | |
from utils.cropping import cropping | |
from utils.functions import reimburse_conform | |
from utils.hemisphere import hemisphere | |
from utils.load_model import load_model | |
from utils.make_csv import make_csv | |
from utils.make_level import create_parcellated_images | |
from utils.parcellation import parcellation | |
from utils.postprocessing import postprocessing | |
from utils.preprocessing import preprocessing | |
from utils.stripping import stripping | |
# モデルは起動時に一度読み込む | |
MODELS = {} | |
def load_all_models(device): | |
model_folder = "model/" # 学習済みモデルは model/ に配置している前提 | |
try: | |
# load_model 内部が argparse.Namespace 等を前提の場合、必要な属性が使われるなら適宜オブジェクトを用意してください | |
# ここでは opt を使わないように必要最低限の修正例を示しています。 | |
models = load_model(model_folder=model_folder, opt=None, device=device) | |
return models # (cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a) | |
except Exception as e: | |
print("Error during model loading:", e) | |
return None | |
def run_inference(input_file, only_face_cropping, only_skull_stripping): | |
# 一時ディレクトリの作成 | |
tmp_input_dir = tempfile.mkdtemp() | |
tmp_output_dir = tempfile.mkdtemp() | |
# アップロードされたファイルを一時フォルダに保存 | |
basename = os.path.splitext(os.path.basename(input_file.name))[0] | |
input_path = os.path.join(tmp_input_dir, f"{basename}.nii") | |
with open(input_path, "wb") as f: | |
f.write(input_file.read()) | |
# Gradio用のオプションオブジェクト(オリジナルコードの argparse.Namespace 相当) | |
class Options: | |
pass | |
opt = Options() | |
opt.i = tmp_input_dir # 入力ディレクトリを指定 | |
opt.o = tmp_output_dir # 出力ディレクトリを指定 | |
# ここでは学習済みモデルフォルダを固定で "model/" に設定 | |
opt.m = "model/" | |
opt.only_face_cropping = only_face_cropping | |
opt.only_skull_stripping = only_skull_stripping | |
# デバイス選択 | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print(f"Using device: {device}") | |
# 初回のみモデルを読み込み(キャッシュ) | |
global MODELS | |
if not MODELS: | |
MODELS = load_all_models(device) | |
if MODELS is None: | |
return "モデルの読み込みに失敗しました。" | |
cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = MODELS | |
# --- 以下、元のparcellation.py の処理フローに準じた処理 --- | |
# 1. 入力画像の読み込み・キャノニカル化・squeeze | |
odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_path))) | |
nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine) | |
os.makedirs(os.path.join(tmp_output_dir, "original"), exist_ok=True) | |
nib.save(nii, os.path.join(tmp_output_dir, f"original/{basename}.nii")) | |
# 2. 前処理 | |
odata, data = preprocessing(input_path, tmp_output_dir, basename) | |
# 3. クロッピング | |
cropped = cropping(tmp_output_dir, basename, odata, data, cnet, device) | |
if only_face_cropping: | |
return os.path.join(tmp_output_dir, f"{basename}_cropped.nii") | |
# 4. スキルストリッピング | |
stripped, shift = stripping(tmp_output_dir, basename, cropped, odata, data, ssnet, device) | |
if only_skull_stripping: | |
return os.path.join(tmp_output_dir, f"{basename}_stripped.nii") | |
# 5. パーセレーション | |
parcellated = parcellation(stripped, pnet_c, pnet_s, pnet_a, device) | |
# 6. 両半球に分離 | |
separated = hemisphere(stripped, hnet_c, hnet_a, device) | |
# 7. 後処理 | |
output = postprocessing(parcellated, separated, shift, device) | |
# 8. CSV作成(体積情報等) | |
df = make_csv(output, tmp_output_dir, basename) | |
# 9. パーセル結果のNIfTI作成と保存 | |
nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine) | |
header = odata.header | |
nii_out = nib.processing.conform( | |
nii_out, | |
out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]), | |
voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]), | |
order=0, | |
) | |
out_parcellated_dir = os.path.join(tmp_output_dir, "parcellated") | |
os.makedirs(out_parcellated_dir, exist_ok=True) | |
out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii") | |
nib.save(nii_out, out_filename) | |
create_parcellated_images(output, tmp_output_dir, basename, odata, data) | |
# 必要に応じて、一時ファイルの削除等の処理を追加してください | |
# 最終結果のNIfTIファイルのパスを返す | |
return out_filename | |
# Gradioインターフェースの作成(モデルフォルダの入力は不要) | |
iface = gr.Interface( | |
fn=run_inference, | |
inputs=[ | |
gr.File(label="Input NIfTI File (.nii or .nii.gz)"), | |
gr.Checkbox(label="Only Face Cropping", value=False), | |
gr.Checkbox(label="Only Skull Stripping", value=False), | |
], | |
outputs=gr.File(label="Output Parcellated NIfTI File"), | |
title="OpenMAP-T1 Inference", | |
description="学習済みモデルは model/ に配置されています。\nアップロードされたMRI画像に対してOpenMAP-T1の処理を行い、パーセル結果を返します。", | |
) | |
if __name__ == "__main__": | |
iface.launch() | |