File size: 3,516 Bytes
01f75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from scipy.ndimage import binary_dilation

from utils.functions import normalize


def separate(voxel, model, device, mode):
    """
    Separates the voxel data based on the specified mode and processes it using the given model.

    Args:
        voxel (list or numpy.ndarray): The input voxel data to be processed.
        model (torch.nn.Module): The neural network model used for processing the voxel data.
        device (torch.device): The device (CPU or GPU) on which the model and data are loaded.
        mode (str): The mode of separation, either 'c' for coronal or 'a' for axial.

    Returns:
        torch.Tensor: The processed output tensor with shape (stack[0], 3, stack[1], stack[2]).
    """
    if mode == "c":
        # Set the stack dimensions for coronal mode
        stack = (224, 192, 192)
    elif mode == "a":
        # Set the stack dimensions for axial mode
        stack = (192, 224, 192)

    # Set the model to evaluation mode
    model.eval()

    # Disable gradient calculation for inference
    with torch.inference_mode():
        # Initialize an output tensor with the specified stack dimensions
        output = torch.zeros(stack[0], 3, stack[1], stack[2]).to(device)

        # Iterate over each slice in the voxel data
        for i, v in enumerate(voxel):
            # Reshape the slice and convert it to a tensor
            image = torch.tensor(v.reshape(1, 1, stack[1], stack[2]))
            # Move the tensor to the specified device
            image = image.to(device)
            # Perform a forward pass through the model and apply softmax
            x_out = torch.softmax(model(image), 1).detach()
            # Store the output in the corresponding slice of the output tensor
            output[i] = x_out

        # Return the processed output tensor
        return output


def hemisphere(voxel, hnet_c, hnet_a, device):
    """
    Processes a voxel image to separate and dilate hemispheres using neural networks.

    Args:
        voxel (torch.Tensor): The input voxel image tensor.
        hnet_c (torch.nn.Module): The neural network model for coronal separation.
        hnet_a (torch.nn.Module): The neural network model for transverse separation.
        device (torch.device): The device to run the neural networks on (e.g., 'cpu' or 'cuda').

    Returns:
        numpy.ndarray: The processed and dilated mask of the hemispheres.
    """
    # Normalize the voxel data
    voxel = normalize(voxel)

    # Transpose the voxel data for coronal and transverse views
    coronal = voxel.transpose(1, 2, 0)
    transverse = voxel.transpose(2, 1, 0)

    # Separate the coronal and transverse views using the respective models
    out_c = separate(coronal, hnet_c, device, "c").permute(1, 3, 0, 2)
    out_a = separate(transverse, hnet_a, device, "a").permute(1, 3, 2, 0)

    # Combine the outputs from both views
    out_e = out_c + out_a

    # Get the final output by taking the argmax along the first dimension
    out_e = torch.argmax(out_e, 0).cpu().numpy()

    # Clear the CUDA cache
    torch.cuda.empty_cache()

    # Perform binary dilation on the mask for class 1
    dilated_mask_1 = binary_dilation(out_e == 1, iterations=5).astype("int16")
    dilated_mask_1[out_e == 2] = 2

    # Perform binary dilation on the mask for class 2
    dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=5).astype("int16") * 2
    dilated_mask_2[dilated_mask_1 == 1] = 1

    # Return the final dilated mask
    return dilated_mask_2