OpenMAP-T1 / src /utils /parcellation.py
西牧慧
update: parcellation
dcbe128
import numpy as np
import torch
from utils.functions import normalize
def parcellate(voxel, model, device, mode):
"""
Parcellates a given voxel volume using a specified model and mode.
Args:
voxel (numpy.ndarray): The input voxel volume to be parcellated.
model (torch.nn.Module): The neural network model used for parcellation.
device (torch.device): The device (CPU or GPU) on which the model is run.
mode (str): The mode of parcellation. Can be 'c', 's', or 'a', which determines the stack dimensions.
Returns:
torch.Tensor: The parcellated voxel volume.
"""
if mode == "c":
stack = (224, 192, 192)
elif mode == "s":
stack = (192, 224, 192)
elif mode == "a":
stack = (192, 224, 192)
# Set the model to evaluation mode
model.eval()
# Pad the voxel volume to handle edge cases
voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
# Disable gradient calculation for inference
with torch.inference_mode():
# Initialize an empty tensor to store the parcellation results
box = torch.zeros(stack[0], 142, stack[1], stack[2])
# Iterate over each slice in the stack dimension
for i in range(1, stack[0] + 1):
# Stack three consecutive slices to form the input image
image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
image = torch.tensor(image.reshape(1, 3, stack[1], stack[2]))
image = image.to(device)
# Perform the forward pass through the model and apply softmax
x_out = torch.softmax(model(image), 1).detach().cpu()
# Store the output in the corresponding slice of the box tensor
box[i - 1] = x_out
# Reshape the box tensor to the desired output shape
return box.reshape(stack[0], 142, stack[1], stack[2])
def parcellation(voxel, pnet_c, pnet_s, pnet_a, device):
"""
Perform parcellation on the given voxel data using provided neural networks for coronal, sagittal, and axial views.
Args:
voxel (torch.Tensor): The input 3D voxel data to be parcellated.
pnet_c (torch.nn.Module): The neural network model for coronal view parcellation.
pnet_s (torch.nn.Module): The neural network model for sagittal view parcellation.
pnet_a (torch.nn.Module): The neural network model for axial view parcellation.
device (torch.device): The device (CPU or GPU) to perform computations on.
Returns:
numpy.ndarray: The parcellated output as a numpy array.
"""
# Normalize the voxel data
voxel = normalize(voxel)
# Prepare the voxel data for different views
coronal = voxel.transpose(1, 2, 0)
sagittal = voxel
axial = voxel.transpose(2, 1, 0)
# # Perform parcellation for the coronal view
# out_c = parcellate(coronal, pnet_c, device, "c").permute(1, 3, 0, 2)
# torch.cuda.empty_cache()
# # Perform parcellation for the sagittal view
# out_s = parcellate(sagittal, pnet_s, device, "s").permute(1, 0, 2, 3)
# torch.cuda.empty_cache()
# # Combine the results from coronal and sagittal views
# out_e = out_c + out_s
# del out_c, out_s
# Perform parcellation for the axial view
out_a = parcellate(axial, pnet_a, device, "a").permute(1, 3, 2, 0)
torch.cuda.empty_cache()
# Combine the results from all views
out_e = out_a # out_e + out_a
del out_a
# Get the final parcellated output by taking the argmax
parcellated = torch.argmax(out_e, 0).numpy()
return parcellated