Spaces:
Sleeping
Sleeping
import os | |
import random | |
import shutil | |
import string | |
import zipfile | |
from functools import partial | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import nibabel as nib | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm import tqdm as std_tqdm | |
tqdm = partial(std_tqdm, dynamic_ncols=True) | |
# Import required modules from our project | |
from utils.cropping import cropping | |
from utils.hemisphere import hemisphere | |
from utils.load_model import load_model | |
from utils.make_csv import make_csv | |
from utils.make_level import create_parcellated_images | |
from utils.parcellation import parcellation | |
from utils.postprocessing import postprocessing | |
from utils.preprocessing import preprocessing | |
from utils.stripping import stripping | |
def nii_to_image(voxel_path, label_path, output_dir, basename): | |
""" | |
Converts two NIfTI files into 2D images for visualization. | |
The voxel (input MRI) is shown as a grayscale image and the label (segmentation) | |
is shown using a default color map. | |
A middle slice is chosen by default. | |
""" | |
# Load the NIfTI volumes and squeeze to remove extra dimensions | |
vdata = nib.squeeze_image(nib.as_closest_canonical(nib.load(voxel_path))) | |
ldata = nib.squeeze_image(nib.as_closest_canonical(nib.load(label_path))) | |
voxel = vdata.get_fdata().astype("float32") | |
label = ldata.get_fdata().astype("int16") | |
# Choose the middle slice along the first dimension and rotate for display | |
slice_index = voxel.shape[0] // 2 | |
slice_voxel = np.rot90(voxel[slice_index, :, :]) | |
slice_label = np.rot90(label[slice_index, :, :]) | |
# Plot and save the input MRI image | |
plt.figure(figsize=(5, 5)) | |
plt.imshow(slice_voxel, cmap="gray") | |
plt.title("Input Image") | |
plt.axis("off") | |
input_png_path = os.path.join(os.path.dirname(output_dir), f"{basename}_input.png") | |
plt.savefig(input_png_path, format="png", bbox_inches="tight", pad_inches=0) | |
# Plot and save the parcellation (segmentation) map image | |
plt.figure(figsize=(5, 5)) | |
plt.imshow(slice_label) | |
plt.title("Parcellation Result") | |
plt.axis("off") | |
parcellation_png_path = os.path.join( | |
os.path.dirname(output_dir), f"{basename}_parcellation.png" | |
) | |
plt.savefig(parcellation_png_path, format="png", bbox_inches="tight", pad_inches=0) | |
return input_png_path, parcellation_png_path | |
def run_inference(input_file, only_face_cropping, only_skull_stripping): | |
# Generate a random 10-character string to create a unique temporary directory | |
random_string = "".join(random.choices(string.ascii_letters + string.digits, k=10)) | |
# Extract the base filename from the uploaded file (handle .nii and .nii.gz) | |
basename = os.path.splitext(os.path.basename(input_file.name))[0] | |
if basename.endswith(".nii"): | |
basename = os.path.splitext(basename)[0] | |
# Create an Options object (similar to argparse.Namespace) | |
class Options: | |
pass | |
opt = Options() | |
# Set the output directory uniquely with the random string and base filename | |
opt.o = f"output/{random_string}/{basename}" | |
opt.only_face_cropping = only_face_cropping | |
opt.only_skull_stripping = only_skull_stripping | |
# Device selection: prefer CUDA if available, otherwise MPS or CPU | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print(f"Using device: {device}") | |
# Load the pre-trained models from the fixed "model/" folder | |
# cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a = load_model("model/", device=device) | |
cnet, ssnet, pnet_a, hnet_c, hnet_a = load_model("model/", device=device) | |
# --- Processing Flow (based on the original parcellation.py) --- | |
# 1. Load the input image, convert to canonical orientation, and remove extra dimensions | |
odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_file.name))) | |
nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine) | |
os.makedirs(os.path.join(opt.o, "original"), exist_ok=True) | |
original_nii_path = os.path.join(opt.o, f"original/{basename}.nii") | |
nib.save(nii, original_nii_path) | |
# 2. Preprocess the image | |
odata, data = preprocessing(input_file.name, opt.o, basename) | |
# 3. Cropping | |
cropped, out_filename = cropping(opt.o, basename, odata, data, cnet, device) | |
if only_face_cropping: | |
pass | |
else: | |
# 4. Skull stripping | |
stripped, shift, out_filename = stripping( | |
opt.o, basename, cropped, odata, data, ssnet, device | |
) | |
if only_skull_stripping: | |
pass | |
else: | |
# 5. Parcellation | |
parcellated = parcellation(stripped, pnet_a, pnet_a, pnet_a, device) | |
# 6. Separate into hemispheres | |
separated = hemisphere(stripped, hnet_c, hnet_a, device) | |
# 7. Postprocessing | |
output = postprocessing(parcellated, separated, shift, device) | |
# 8. Create CSV with volume information, etc. | |
df = make_csv(output, opt.o, basename) | |
# 9. Create and save the parcellation result NIfTI file | |
nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine) | |
header = odata.header | |
nii_out = nib.processing.conform( | |
nii_out, | |
out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]), | |
voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]), | |
order=0, | |
) | |
out_parcellated_dir = os.path.join(opt.o, "parcellated") | |
os.makedirs(out_parcellated_dir, exist_ok=True) | |
out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii") | |
nib.save(nii_out, out_filename) | |
create_parcellated_images(output, opt.o, basename, odata, data) | |
# Zip the entire output directory into a ZIP file | |
zip_path = os.path.join(os.path.dirname(opt.o), f"{basename}_results.zip") | |
with zipfile.ZipFile(zip_path, "w") as zipf: | |
for root, _, files in os.walk(opt.o): | |
for file in files: | |
file_path = os.path.join(root, file) | |
# Adjust the path within the zip archive | |
arcname = os.path.relpath(file_path, start=opt.o) | |
zipf.write(file_path, arcname) | |
# Convert the NIfTI files into visualization images (PNG) | |
input_png_path, parcellation_png_path = nii_to_image( | |
input_file.name, out_filename, opt.o, basename | |
) | |
# *** Cleanup: Remove the temporary output directory *** | |
# Note: This is performed before returning. It is not possible to execute code after the return statement. | |
shutil.rmtree(opt.o) | |
# Return the ZIP file path and the two visualization images | |
return zip_path, Image.open(input_png_path), Image.open(parcellation_png_path) | |
# Create the Gradio interface (the model folder input is not needed) | |
iface = gr.Interface( | |
fn=run_inference, | |
inputs=[ | |
gr.File(label="Input NIfTI File (.nii or .nii.gz)"), | |
gr.Checkbox(label="Only Face Cropping", value=False), | |
gr.Checkbox(label="Only Skull Stripping", value=False), | |
], | |
outputs=[ | |
gr.File(label="Output Results ZIP File"), | |
gr.Image(label="MRI Image (Original)"), | |
gr.Image(label="Parcellation Map (Type1_Level5)"), | |
], | |
title="OpenMAP-T1 Inference", | |
description=( | |
"The uploaded MRI image will be processed using OpenMAP-T1, and the parcellation " | |
"results will be returned as a ZIP file along with visualization images." | |
), | |
) | |
if __name__ == "__main__": | |
iface.launch() | |