File size: 2,559 Bytes
01f75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbe128
01f75cf
dcbe128
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
from scipy.ndimage import binary_closing

from utils.functions import normalize, reimburse_conform


def crop(voxel, model, device):
    """
    Crops the given voxel data using the provided model and device.

    Args:
        voxel (numpy.ndarray): The input voxel data to be cropped, expected to be of shape (N, 256, 256).
        model (torch.nn.Module): The PyTorch model used for cropping.
        device (torch.device): The device (CPU or GPU) on which the computation will be performed.

    Returns:
        torch.Tensor: The cropped output tensor of shape (256, 256, 256).
    """
    model.eval()
    with torch.inference_mode():
        output = torch.zeros(256, 256, 256).to(device)
        for i, v in enumerate(voxel):
            image = v.reshape(1, 1, 256, 256)
            image = torch.tensor(image).to(device)
            x_out = torch.sigmoid(model(image)).detach()
            output[i] = x_out
        return output.reshape(256, 256, 256)


def closing(voxel):
    """
    Perform a binary closing operation on a 3D voxel array.

    This function applies a binary closing operation using a 3x3x3 structuring element
    and performs the operation for a specified number of iterations.

    Parameters:
    voxel (numpy.ndarray): A 3D numpy array representing the voxel data to be processed.

    Returns:
    numpy.ndarray: The voxel data after the binary closing operation.
    """
    selem = np.ones((3, 3, 3), dtype="bool")
    voxel = binary_closing(voxel, structure=selem, iterations=3)
    return voxel


def cropping(output_dir, basename, odata, data, cnet, device):
    """
    Crops the input medical imaging data using a neural network model.

    Args:
        data (nibabel.Nifti1Image): The input medical imaging data in NIfTI format.
        cnet (torch.nn.Module): The neural network model used for cropping.
        device (torch.device): The device (CPU or GPU) on which the model is run.

    Returns:
        numpy.ndarray: The cropped medical imaging data.
    """
    voxel = data.get_fdata().astype("float32")
    voxel = normalize(voxel)

    coronal = voxel.transpose(1, 2, 0)
    sagittal = voxel
    out_c = crop(coronal, cnet, device).permute(2, 0, 1)
    out_s = crop(sagittal, cnet, device)
    out_e = ((out_c + out_s) / 2) > 0.5
    out_e = out_e.cpu().numpy()
    out_e = closing(out_e)
    cropped = data.get_fdata().astype("float32") * out_e

    out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)

    return cropped, out_filename