Spaces:
Sleeping
Sleeping
import torch | |
from scipy.ndimage import binary_dilation | |
from utils.functions import normalize | |
def separate(voxel, model, device, mode): | |
""" | |
Separates the voxel data based on the specified mode and processes it using the given model. | |
Args: | |
voxel (list or numpy.ndarray): The input voxel data to be processed. | |
model (torch.nn.Module): The neural network model used for processing the voxel data. | |
device (torch.device): The device (CPU or GPU) on which the model and data are loaded. | |
mode (str): The mode of separation, either 'c' for coronal or 'a' for axial. | |
Returns: | |
torch.Tensor: The processed output tensor with shape (stack[0], 3, stack[1], stack[2]). | |
""" | |
if mode == "c": | |
# Set the stack dimensions for coronal mode | |
stack = (224, 192, 192) | |
elif mode == "a": | |
# Set the stack dimensions for axial mode | |
stack = (192, 224, 192) | |
# Set the model to evaluation mode | |
model.eval() | |
# Disable gradient calculation for inference | |
with torch.inference_mode(): | |
# Initialize an output tensor with the specified stack dimensions | |
output = torch.zeros(stack[0], 3, stack[1], stack[2]).to(device) | |
# Iterate over each slice in the voxel data | |
for i, v in enumerate(voxel): | |
# Reshape the slice and convert it to a tensor | |
image = torch.tensor(v.reshape(1, 1, stack[1], stack[2])) | |
# Move the tensor to the specified device | |
image = image.to(device) | |
# Perform a forward pass through the model and apply softmax | |
x_out = torch.softmax(model(image), 1).detach() | |
# Store the output in the corresponding slice of the output tensor | |
output[i] = x_out | |
# Return the processed output tensor | |
return output | |
def hemisphere(voxel, hnet_c, hnet_a, device): | |
""" | |
Processes a voxel image to separate and dilate hemispheres using neural networks. | |
Args: | |
voxel (torch.Tensor): The input voxel image tensor. | |
hnet_c (torch.nn.Module): The neural network model for coronal separation. | |
hnet_a (torch.nn.Module): The neural network model for transverse separation. | |
device (torch.device): The device to run the neural networks on (e.g., 'cpu' or 'cuda'). | |
Returns: | |
numpy.ndarray: The processed and dilated mask of the hemispheres. | |
""" | |
# Normalize the voxel data | |
voxel = normalize(voxel) | |
# Transpose the voxel data for coronal and transverse views | |
coronal = voxel.transpose(1, 2, 0) | |
transverse = voxel.transpose(2, 1, 0) | |
# Separate the coronal and transverse views using the respective models | |
out_c = separate(coronal, hnet_c, device, "c").permute(1, 3, 0, 2) | |
out_a = separate(transverse, hnet_a, device, "a").permute(1, 3, 2, 0) | |
# Combine the outputs from both views | |
out_e = out_c + out_a | |
# Get the final output by taking the argmax along the first dimension | |
out_e = torch.argmax(out_e, 0).cpu().numpy() | |
# Clear the CUDA cache | |
torch.cuda.empty_cache() | |
# Perform binary dilation on the mask for class 1 | |
dilated_mask_1 = binary_dilation(out_e == 1, iterations=5).astype("int16") | |
dilated_mask_1[out_e == 2] = 2 | |
# Perform binary dilation on the mask for class 2 | |
dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=5).astype("int16") * 2 | |
dilated_mask_2[dilated_mask_1 == 1] = 1 | |
# Return the final dilated mask | |
return dilated_mask_2 | |