OpenMAP-T1 / src /utils /cropping.py
西牧慧
update: parcellation
dcbe128
import numpy as np
import torch
from scipy.ndimage import binary_closing
from utils.functions import normalize, reimburse_conform
def crop(voxel, model, device):
"""
Crops the given voxel data using the provided model and device.
Args:
voxel (numpy.ndarray): The input voxel data to be cropped, expected to be of shape (N, 256, 256).
model (torch.nn.Module): The PyTorch model used for cropping.
device (torch.device): The device (CPU or GPU) on which the computation will be performed.
Returns:
torch.Tensor: The cropped output tensor of shape (256, 256, 256).
"""
model.eval()
with torch.inference_mode():
output = torch.zeros(256, 256, 256).to(device)
for i, v in enumerate(voxel):
image = v.reshape(1, 1, 256, 256)
image = torch.tensor(image).to(device)
x_out = torch.sigmoid(model(image)).detach()
output[i] = x_out
return output.reshape(256, 256, 256)
def closing(voxel):
"""
Perform a binary closing operation on a 3D voxel array.
This function applies a binary closing operation using a 3x3x3 structuring element
and performs the operation for a specified number of iterations.
Parameters:
voxel (numpy.ndarray): A 3D numpy array representing the voxel data to be processed.
Returns:
numpy.ndarray: The voxel data after the binary closing operation.
"""
selem = np.ones((3, 3, 3), dtype="bool")
voxel = binary_closing(voxel, structure=selem, iterations=3)
return voxel
def cropping(output_dir, basename, odata, data, cnet, device):
"""
Crops the input medical imaging data using a neural network model.
Args:
data (nibabel.Nifti1Image): The input medical imaging data in NIfTI format.
cnet (torch.nn.Module): The neural network model used for cropping.
device (torch.device): The device (CPU or GPU) on which the model is run.
Returns:
numpy.ndarray: The cropped medical imaging data.
"""
voxel = data.get_fdata().astype("float32")
voxel = normalize(voxel)
coronal = voxel.transpose(1, 2, 0)
sagittal = voxel
out_c = crop(coronal, cnet, device).permute(2, 0, 1)
out_s = crop(sagittal, cnet, device)
out_e = ((out_c + out_s) / 2) > 0.5
out_e = out_e.cpu().numpy()
out_e = closing(out_e)
cropped = data.get_fdata().astype("float32") * out_e
out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)
return cropped, out_filename