OpenMAP-T1 / src /utils /stripping.py
西牧慧
update: parcellation
dcbe128
import numpy as np
import torch
from scipy import ndimage
from utils.functions import normalize, reimburse_conform
def strip(voxel, model, device):
"""
Applies a given model to a 3D voxel array and returns the processed output.
Args:
voxel (numpy.ndarray): A 3D numpy array of shape (256, 256, 256) representing the input voxel data.
model (torch.nn.Module): A PyTorch model to be used for processing the voxel data.
device (torch.device): The device (CPU or GPU) on which the model and data should be loaded.
Returns:
torch.Tensor: A 3D tensor of shape (256, 256, 256) containing the processed output.
"""
# Set the model to evaluation mode
model.eval()
# Disable gradient calculation for inference
with torch.inference_mode():
# Initialize an empty tensor to store the output
output = torch.zeros(256, 256, 256).to(device)
# Iterate over each slice in the voxel data
for i, v in enumerate(voxel):
# Reshape the slice to match the model's input dimensions
image = v.reshape(1, 1, 256, 256)
# Convert the numpy array to a PyTorch tensor and move it to the specified device
image = torch.tensor(image).to(device)
# Apply the model to the input image and apply the sigmoid activation function
x_out = torch.sigmoid(model(image)).detach()
# Store the output in the corresponding slice of the output tensor
output[i] = x_out
# Reshape the output tensor to the original voxel dimensions and return it
return output.reshape(256, 256, 256)
def stripping(output_dir, basename, voxel, odata, data, ssnet, device):
"""
Perform brain stripping on a given voxel using a specified neural network.
This function normalizes the input voxel, applies brain stripping in three anatomical planes
(coronal, sagittal, and axial), and combines the results to produce a final stripped brain image.
The stripped image is then centered and cropped.
Args:
voxel (numpy.ndarray): The input 3D voxel data to be stripped.
data (nibabel.Nifti1Image): The original neuroimaging data.
ssnet (torch.nn.Module): The neural network model used for brain stripping.
device (torch.device): The device on which the neural network model is loaded (e.g., CPU or GPU).
Returns:
tuple: A tuple containing:
- stripped (numpy.ndarray): The stripped and processed brain image.
- (xd, yd, zd) (tuple of int): The shifts applied to center the brain image in the x, y, and z directions.
"""
# Normalize the input voxel data
voxel = normalize(voxel)
# Prepare the voxel data in three anatomical planes: coronal, sagittal, and axial
coronal = voxel.transpose(1, 2, 0)
sagittal = voxel
axial = voxel.transpose(2, 1, 0)
# Apply the brain stripping model to each plane
out_c = strip(coronal, ssnet, device).permute(2, 0, 1)
out_s = strip(sagittal, ssnet, device)
out_a = strip(axial, ssnet, device).permute(2, 1, 0)
# Combine the results from the three planes and threshold the output
out_e = ((out_c + out_s + out_a) / 3) > 0.5
out_e = out_e.cpu().numpy()
# Multiply the original data by the thresholded output to get the stripped brain image
stripped = data.get_fdata().astype("float32") * out_e
out_filename = reimburse_conform(output_dir, basename, "stripped", odata, data, out_e)
# Calculate the center of mass of the stripped brain image
x, y, z = map(int, ndimage.center_of_mass(out_e))
# Calculate the shifts needed to center the brain image
xd = 128 - x
yd = 120 - y
zd = 128 - z
# Apply the shifts to center the brain image
stripped = np.roll(stripped, (xd, yd, zd), axis=(0, 1, 2))
# Crop the centered brain image
stripped = stripped[32:-32, 16:-16, 32:-32]
# Return the stripped brain image and the shifts applied
return stripped, (xd, yd, zd), out_filename