import os os.environ['KMP_DUPLICATE_LIB_OK']='True' os.environ["CUDA_VISIBLE_DEVICES"] = "0" import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin import numpy as np import nibabel as nib from skimage import morphology import math from scipy import ndimage from medpy import metric import h5py from tqdm import tqdm class ConvBlock(nn.Module): def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): super(ConvBlock, self).__init__() ops = [] for i in range(n_stages): if i == 0: input_channel = n_filters_in else: input_channel = n_filters_out ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) elif normalization != 'none': assert False ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class DownsamplingConvBlock(nn.Module): def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): super(DownsamplingConvBlock, self).__init__() ops = [] if normalization != 'none': ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) else: assert False else: ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class UpsamplingDeconvBlock(nn.Module): def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): super(UpsamplingDeconvBlock, self).__init__() ops = [] if normalization != 'none': ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) else: assert False else: ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class Upsampling(nn.Module): def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): super(Upsampling, self).__init__() ops = [] ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False)) ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) elif normalization != 'none': assert False ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class ConnectNet(nn.Module): def __init__(self, in_channels, out_channels, input_size): super(ConnectNet, self).__init__() self.encoder = nn.Sequential( nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2), nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2) ) self.decoder = nn.Sequential( nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2), nn.ReLU(), nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2), nn.Sigmoid() ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded class VNet(nn.Module): def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): super(VNet, self).__init__() self.has_dropout = has_dropout self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) self.dropout = nn.Dropout3d(p=0.5, inplace=False) self.__init_weight() def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) if self.has_dropout: x5 = self.dropout(x5) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) if self.has_dropout: x9 = self.dropout(x9) out = self.out_conv(x9) return out def forward(self, input, turnoff_drop=False): if turnoff_drop: has_dropout = self.has_dropout self.has_dropout = False features = self.encoder(input) out = self.decoder(features) if turnoff_drop: self.has_dropout = has_dropout return out def __init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() class VNet_roi(nn.Module): def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): super(VNet_roi, self).__init__() self.has_dropout = has_dropout self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) self.dropout = nn.Dropout3d(p=0.5, inplace=False) # self.__init_weight() def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) # x5 = F.dropout3d(x5, p=0.5, training=True) if self.has_dropout: x5 = self.dropout(x5) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) # x9 = F.dropout3d(x9, p=0.5, training=True) if self.has_dropout: x9 = self.dropout(x9) out = self.out_conv(x9) return out def forward(self, input, turnoff_drop=False): if turnoff_drop: has_dropout = self.has_dropout self.has_dropout = False features = self.encoder(input) out = self.decoder(features) if turnoff_drop: self.has_dropout = has_dropout return out class ResVNet(nn.Module): def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False): super(ResVNet, self).__init__() self.resencoder = resnet34() self.has_dropout = has_dropout self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) if has_dropout: self.dropout = nn.Dropout3d(p=0.5) self.branchs = nn.ModuleList() for i in range(1): if has_dropout: seq = nn.Sequential( ConvBlock(1, n_filters, n_filters, normalization=normalization), nn.Dropout3d(p=0.5), nn.Conv3d(n_filters, n_classes, 1, padding=0) ) else: seq = nn.Sequential( ConvBlock(1, n_filters, n_filters, normalization=normalization), nn.Conv3d(n_filters, n_classes, 1, padding=0) ) self.branchs.append(seq) def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) if self.has_dropout: x5 = self.dropout(x5) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) out = self.out_conv(x9) return out def forward(self, input, turnoff_drop=False): if turnoff_drop: has_dropout = self.has_dropout self.has_dropout = False features = self.resencoder(input) out = self.decoder(features) if turnoff_drop: self.has_dropout = has_dropout return out __all__ = ['ResNet', 'resnet34'] def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv3x3_bn_relu(in_planes, out_planes, stride=1): return nn.Sequential( conv3x3(in_planes, out_planes, stride), nn.InstanceNorm3d(out_planes), nn.ReLU() ) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=-1): super(BasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.InstanceNorm3d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.InstanceNorm3d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1): super(Bottleneck, self).__init__() width = int(planes * (base_width / 64.)) * groups self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False) self.bn1 = nn.InstanceNorm3d(width) self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, groups=groups, bias=False) self.bn2 = nn.InstanceNorm3d(width) self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.InstanceNorm3d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, in_channel=1, width=1, groups=1, width_per_group=64, mid_dim=1024, low_dim=128, avg_down=False, deep_stem=False, head_type='mlp_head', layer4_dilation=1): super(ResNet, self).__init__() self.avg_down = avg_down self.inplanes = 16 * width self.base = int(16 * width) self.groups = groups self.base_width = width_per_group mid_dim = self.base * 8 * block.expansion if deep_stem: self.conv1 = nn.Sequential( conv3x3_bn_relu(in_channel, 32, stride=2), conv3x3_bn_relu(32, 32, stride=1), conv3x3(32, 64, stride=1) ) else: self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False) self.bn1 = nn.InstanceNorm3d(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2) self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2) self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2) if layer4_dilation == 1: self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2) elif layer4_dilation == 2: self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2) else: raise NotImplementedError self.avgpool = nn.AvgPool3d(7, stride=1) def _make_layer(self, block, planes, blocks, stride=1, dilation=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: if self.avg_down: downsample = nn.Sequential( nn.AvgPool3d(kernel_size=stride, stride=stride), nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False), nn.InstanceNorm3d(planes * block.expansion), ) else: downsample = nn.Sequential( nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.InstanceNorm3d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) #c2 = self.maxpool(x) c2 = self.layer1(x) c3 = self.layer2(c2) c4 = self.layer3(c3) c5 = self.layer4(c4) return [x,c2,c3,c4,c5] def resnet34(**kwargs): return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) def label_rescale(image_label, w_ori, h_ori, z_ori, flag): w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori) # resize label map (int) if flag == 'trilinear': teeth_ids = np.unique(image_label) image_label_ori = np.zeros((w_ori, h_ori, z_ori)) image_label = torch.from_numpy(image_label).cuda(0) for label_id in range(len(teeth_ids)): image_label_bn = (image_label == teeth_ids[label_id]).float() image_label_bn = image_label_bn[None, None, :, :, :] image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori), mode='trilinear', align_corners=False) image_label_bn = image_label_bn[0, 0, :, :, :] image_label_bn = image_label_bn.cpu().data.numpy() image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id] image_label = image_label_ori if flag == 'nearest': image_label = torch.from_numpy(image_label).cuda(0) image_label = image_label[None, None, :, :, :].float() image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest') image_label = image_label[0, 0, :, :, :].cpu().data.numpy() return image_label def img_crop(image_bbox): if image_bbox.sum() > 0: x_min = np.nonzero(image_bbox)[0].min() - 8 x_max = np.nonzero(image_bbox)[0].max() + 8 y_min = np.nonzero(image_bbox)[1].min() - 16 y_max = np.nonzero(image_bbox)[1].max() + 16 z_min = np.nonzero(image_bbox)[2].min() - 16 z_max = np.nonzero(image_bbox)[2].max() + 16 if x_min < 0: x_min = 0 if y_min < 0: y_min = 0 if z_min < 0: z_min = 0 if x_max > image_bbox.shape[0]: x_max = image_bbox.shape[0] if y_max > image_bbox.shape[1]: y_max = image_bbox.shape[1] if z_max > image_bbox.shape[2]: z_max = image_bbox.shape[2] if (x_max - x_min) % 16 != 0: x_max -= (x_max - x_min) % 16 if (y_max - y_min) % 16 != 0: y_max -= (y_max - y_min) % 16 if (z_max - z_min) % 16 != 0: z_max -= (z_max - z_min) % 16 if image_bbox.sum() == 0: x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[ 2] return x_min, x_max, y_min, y_max, z_min, z_max def roi_extraction(image, net_roi, ids): w, h, d = image.shape # roi binary segmentation parameters, the input spacing is 0.4 mm print('---run the roi binary segmentation.') stride_xy = 32 stride_z = 16 patch_size_roi_stage = (112, 112, 80) label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z, patch_size_roi_stage) # (400,400,200) print(label_roi.shape, np.max(label_roi)) label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400) label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float) label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5)) label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype( float) label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5)) # crop image x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi) if x_min == -1: # non-foreground label whole_label = np.zeros((w, h, d)) return whole_label image = image[x_min:x_max, y_min:y_max, z_min:z_max] print("image shape(after roi): ", image.shape) return image, x_min, x_max, y_min, y_max, z_min, z_max def roi_detection(net, image, stride_xy, stride_z, patch_size): w, h, d = image.shape # (400,400,200) # if the size of image is less than patch_size, then padding it add_pad = False if w < patch_size[0]: w_pad = patch_size[0] - w add_pad = True else: w_pad = 0 if h < patch_size[1]: h_pad = patch_size[1] - h add_pad = True else: h_pad = 0 if d < patch_size[2]: d_pad = patch_size[2] - d add_pad = True else: d_pad = 0 wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 if add_pad: image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) ww, hh, dd = image.shape sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2 sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2 sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2 score_map = np.zeros((2,) + image.shape).astype(np.float32) cnt = np.zeros(image.shape).astype(np.float32) count = 0 for x in range(0, sx): xs = min(stride_xy * x, ww - patch_size[0]) for y in range(0, sy): ys = min(stride_xy * y, hh - patch_size[1]) for z in range(0, sz): zs = min(stride_z * z, dd - patch_size[2]) test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype( np.float32) test_patch = torch.from_numpy(test_patch).cuda(0) with torch.no_grad(): y1 = net(test_patch) # (1,2,256,256,160) y = F.softmax(y1, dim=1) # (1,2,256,256,160) y = y.cpu().data.numpy() y = y[0, :, :, :, :] # (2,256,256,160) score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y # (2,400,400,200) cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200) count = count + 1 score_map = score_map / np.expand_dims(cnt, axis=0) label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1 if add_pad: label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] return label_map def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1): w, h, d = image.shape # if the size of image is less than patch_size, then padding it add_pad = False if w < patch_size[0]: w_pad = patch_size[0]-w add_pad = True else: w_pad = 0 if h < patch_size[1]: h_pad = patch_size[1]-h add_pad = True else: h_pad = 0 if d < patch_size[2]: d_pad = patch_size[2]-d add_pad = True else: d_pad = 0 wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 if add_pad: image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) ww,hh,dd = image.shape sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) cnt = np.zeros(image.shape).astype(np.float32) for x in range(0, sx): xs = min(stride_xy*x, ww-patch_size[0]) for y in range(0, sy): ys = min(stride_xy * y,hh-patch_size[1]) for z in range(0, sz): zs = min(stride_z * z, dd-patch_size[2]) test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) test_patch = torch.from_numpy(test_patch).cuda() for model in model_array: output = model(test_patch) y_temp = F.softmax(output, dim=1) y_temp = y_temp.cpu().data.numpy() y += y_temp[0,:,:,:,:] y /= len(model_array) score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 score_map = score_map/np.expand_dims(cnt,axis=0) label_map = np.argmax(score_map, axis = 0) if add_pad: label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] return label_map, score_map def calculate_metric_percase(pred, gt): dice = metric.binary.dc(pred, gt) jc = metric.binary.jc(pred, gt) hd = metric.binary.hd95(pred, gt) asd = metric.binary.asd(pred, gt) return dice, jc, hd, asd class RailNetSystem(nn.Module, PyTorchModelHubMixin): def __init__(self, n_channels: int, n_classes: int, normalization: str): super().__init__() self.num_classes = 2 self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda() self.model_array = [] for i in range(4): if i < 2: model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda() else: model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda() self.model_array.append(model) def load_weights(self, weight_dir="."): self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cuda", weights_only=True)) self.net_roi.eval() model_files = [ "rail_0_iter_7995_best.pth", "rail_1_iter_7995_best.pth", "rail_2_iter_7995_best.pth", "rail_3_iter_7995_best.pth", ] for i, file in enumerate(model_files): self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cuda", weights_only=True)) self.model_array[i].eval() def forward(self, image, label, save_path="./output", name="case"): if not os.path.exists(save_path): os.makedirs(save_path) nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz")) w, h, d = image.shape image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name) prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes) prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float) new_prediction = np.zeros((w, h, d)) new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:]) nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz")) return new_prediction, dice, jc, hd, asd