Spaces:
Sleeping
Sleeping
File size: 3,516 Bytes
01f75cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
from scipy.ndimage import binary_dilation
from utils.functions import normalize
def separate(voxel, model, device, mode):
"""
Separates the voxel data based on the specified mode and processes it using the given model.
Args:
voxel (list or numpy.ndarray): The input voxel data to be processed.
model (torch.nn.Module): The neural network model used for processing the voxel data.
device (torch.device): The device (CPU or GPU) on which the model and data are loaded.
mode (str): The mode of separation, either 'c' for coronal or 'a' for axial.
Returns:
torch.Tensor: The processed output tensor with shape (stack[0], 3, stack[1], stack[2]).
"""
if mode == "c":
# Set the stack dimensions for coronal mode
stack = (224, 192, 192)
elif mode == "a":
# Set the stack dimensions for axial mode
stack = (192, 224, 192)
# Set the model to evaluation mode
model.eval()
# Disable gradient calculation for inference
with torch.inference_mode():
# Initialize an output tensor with the specified stack dimensions
output = torch.zeros(stack[0], 3, stack[1], stack[2]).to(device)
# Iterate over each slice in the voxel data
for i, v in enumerate(voxel):
# Reshape the slice and convert it to a tensor
image = torch.tensor(v.reshape(1, 1, stack[1], stack[2]))
# Move the tensor to the specified device
image = image.to(device)
# Perform a forward pass through the model and apply softmax
x_out = torch.softmax(model(image), 1).detach()
# Store the output in the corresponding slice of the output tensor
output[i] = x_out
# Return the processed output tensor
return output
def hemisphere(voxel, hnet_c, hnet_a, device):
"""
Processes a voxel image to separate and dilate hemispheres using neural networks.
Args:
voxel (torch.Tensor): The input voxel image tensor.
hnet_c (torch.nn.Module): The neural network model for coronal separation.
hnet_a (torch.nn.Module): The neural network model for transverse separation.
device (torch.device): The device to run the neural networks on (e.g., 'cpu' or 'cuda').
Returns:
numpy.ndarray: The processed and dilated mask of the hemispheres.
"""
# Normalize the voxel data
voxel = normalize(voxel)
# Transpose the voxel data for coronal and transverse views
coronal = voxel.transpose(1, 2, 0)
transverse = voxel.transpose(2, 1, 0)
# Separate the coronal and transverse views using the respective models
out_c = separate(coronal, hnet_c, device, "c").permute(1, 3, 0, 2)
out_a = separate(transverse, hnet_a, device, "a").permute(1, 3, 2, 0)
# Combine the outputs from both views
out_e = out_c + out_a
# Get the final output by taking the argmax along the first dimension
out_e = torch.argmax(out_e, 0).cpu().numpy()
# Clear the CUDA cache
torch.cuda.empty_cache()
# Perform binary dilation on the mask for class 1
dilated_mask_1 = binary_dilation(out_e == 1, iterations=5).astype("int16")
dilated_mask_1[out_e == 2] = 2
# Perform binary dilation on the mask for class 2
dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=5).astype("int16") * 2
dilated_mask_2[dilated_mask_1 == 1] = 1
# Return the final dilated mask
return dilated_mask_2
|