Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from scipy.ndimage import binary_closing | |
from utils.functions import normalize, reimburse_conform | |
def crop(voxel, model, device): | |
""" | |
Crops the given voxel data using the provided model and device. | |
Args: | |
voxel (numpy.ndarray): The input voxel data to be cropped, expected to be of shape (N, 256, 256). | |
model (torch.nn.Module): The PyTorch model used for cropping. | |
device (torch.device): The device (CPU or GPU) on which the computation will be performed. | |
Returns: | |
torch.Tensor: The cropped output tensor of shape (256, 256, 256). | |
""" | |
model.eval() | |
with torch.inference_mode(): | |
output = torch.zeros(256, 256, 256).to(device) | |
for i, v in enumerate(voxel): | |
image = v.reshape(1, 1, 256, 256) | |
image = torch.tensor(image).to(device) | |
x_out = torch.sigmoid(model(image)).detach() | |
output[i] = x_out | |
return output.reshape(256, 256, 256) | |
def closing(voxel): | |
""" | |
Perform a binary closing operation on a 3D voxel array. | |
This function applies a binary closing operation using a 3x3x3 structuring element | |
and performs the operation for a specified number of iterations. | |
Parameters: | |
voxel (numpy.ndarray): A 3D numpy array representing the voxel data to be processed. | |
Returns: | |
numpy.ndarray: The voxel data after the binary closing operation. | |
""" | |
selem = np.ones((3, 3, 3), dtype="bool") | |
voxel = binary_closing(voxel, structure=selem, iterations=3) | |
return voxel | |
def cropping(output_dir, basename, odata, data, cnet, device): | |
""" | |
Crops the input medical imaging data using a neural network model. | |
Args: | |
data (nibabel.Nifti1Image): The input medical imaging data in NIfTI format. | |
cnet (torch.nn.Module): The neural network model used for cropping. | |
device (torch.device): The device (CPU or GPU) on which the model is run. | |
Returns: | |
numpy.ndarray: The cropped medical imaging data. | |
""" | |
voxel = data.get_fdata().astype("float32") | |
voxel = normalize(voxel) | |
coronal = voxel.transpose(1, 2, 0) | |
sagittal = voxel | |
out_c = crop(coronal, cnet, device).permute(2, 0, 1) | |
out_s = crop(sagittal, cnet, device) | |
out_e = ((out_c + out_s) / 2) > 0.5 | |
out_e = out_e.cpu().numpy() | |
out_e = closing(out_e) | |
cropped = data.get_fdata().astype("float32") * out_e | |
out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e) | |
return cropped, out_filename | |