import os import tempfile from functools import partial import gradio as gr import nibabel as nib import numpy as np import torch from tqdm import tqdm as std_tqdm tqdm = partial(std_tqdm, dynamic_ncols=True) # 必要なモジュールのインポート from utils.cropping import cropping from utils.functions import reimburse_conform from utils.hemisphere import hemisphere from utils.load_model import load_model from utils.make_csv import make_csv from utils.make_level import create_parcellated_images from utils.parcellation import parcellation from utils.postprocessing import postprocessing from utils.preprocessing import preprocessing from utils.stripping import stripping # モデルは起動時に一度読み込む MODELS = {} def load_all_models(device): model_folder = "model/" # 学習済みモデルは model/ に配置している前提 try: # load_model 内部が argparse.Namespace 等を前提の場合、必要な属性が使われるなら適宜オブジェクトを用意してください # ここでは opt を使わないように必要最低限の修正例を示しています。 models = load_model(model_folder=model_folder, opt=None, device=device) return models # (cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a) except Exception as e: print("Error during model loading:", e) return None def run_inference(input_file, only_face_cropping, only_skull_stripping): # 一時ディレクトリの作成 tmp_input_dir = tempfile.mkdtemp() tmp_output_dir = tempfile.mkdtemp() # アップロードされたファイルを一時フォルダに保存 basename = os.path.splitext(os.path.basename(input_file.name))[0] input_path = os.path.join(tmp_input_dir, f"{basename}.nii") with open(input_path, "wb") as f: f.write(input_file.read()) # Gradio用のオプションオブジェクト(オリジナルコードの argparse.Namespace 相当) class Options: pass opt = Options() opt.i = tmp_input_dir # 入力ディレクトリを指定 opt.o = tmp_output_dir # 出力ディレクトリを指定 # ここでは学習済みモデルフォルダを固定で "model/" に設定 opt.m = "model/" opt.only_face_cropping = only_face_cropping opt.only_skull_stripping = only_skull_stripping # デバイス選択 if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print(f"Using device: {device}") # 初回のみモデルを読み込み(キャッシュ) global MODELS if not MODELS: MODELS = load_all_models(device) if MODELS is None: return "モデルの読み込みに失敗しました。" cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = MODELS # --- 以下、元のparcellation.py の処理フローに準じた処理 --- # 1. 入力画像の読み込み・キャノニカル化・squeeze odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_path))) nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine) os.makedirs(os.path.join(tmp_output_dir, "original"), exist_ok=True) nib.save(nii, os.path.join(tmp_output_dir, f"original/{basename}.nii")) # 2. 前処理 odata, data = preprocessing(input_path, tmp_output_dir, basename) # 3. クロッピング cropped = cropping(tmp_output_dir, basename, odata, data, cnet, device) if only_face_cropping: return os.path.join(tmp_output_dir, f"{basename}_cropped.nii") # 4. スキルストリッピング stripped, shift = stripping(tmp_output_dir, basename, cropped, odata, data, ssnet, device) if only_skull_stripping: return os.path.join(tmp_output_dir, f"{basename}_stripped.nii") # 5. パーセレーション parcellated = parcellation(stripped, pnet_c, pnet_s, pnet_a, device) # 6. 両半球に分離 separated = hemisphere(stripped, hnet_c, hnet_a, device) # 7. 後処理 output = postprocessing(parcellated, separated, shift, device) # 8. CSV作成(体積情報等) df = make_csv(output, tmp_output_dir, basename) # 9. パーセル結果のNIfTI作成と保存 nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine) header = odata.header nii_out = nib.processing.conform( nii_out, out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]), voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]), order=0, ) out_parcellated_dir = os.path.join(tmp_output_dir, "parcellated") os.makedirs(out_parcellated_dir, exist_ok=True) out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii") nib.save(nii_out, out_filename) create_parcellated_images(output, tmp_output_dir, basename, odata, data) # 必要に応じて、一時ファイルの削除等の処理を追加してください # 最終結果のNIfTIファイルのパスを返す return out_filename # Gradioインターフェースの作成(モデルフォルダの入力は不要) iface = gr.Interface( fn=run_inference, inputs=[ gr.File(label="Input NIfTI File (.nii or .nii.gz)"), gr.Checkbox(label="Only Face Cropping", value=False), gr.Checkbox(label="Only Skull Stripping", value=False), ], outputs=gr.File(label="Output Parcellated NIfTI File"), title="OpenMAP-T1 Inference", description="学習済みモデルは model/ に配置されています。\nアップロードされたMRI画像に対してOpenMAP-T1の処理を行い、パーセル結果を返します。", ) if __name__ == "__main__": iface.launch()