Tournesol-Saturday's picture
Upload folder using huggingface_hub
4c1d50f verified
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import nibabel as nib
from skimage import morphology
import math
from scipy import ndimage
from medpy import metric
import h5py
from tqdm import tqdm
class ConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
super(ConvBlock, self).__init__()
ops = []
for i in range(n_stages):
if i == 0:
input_channel = n_filters_in
else:
input_channel = n_filters_out
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class DownsamplingConvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(DownsamplingConvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class UpsamplingDeconvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(UpsamplingDeconvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class Upsampling(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(Upsampling, self).__init__()
ops = []
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False))
ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class ConnectNet(nn.Module):
def __init__(self, in_channels, out_channels, input_size):
super(ConnectNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2, stride=2),
nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool3d(kernel_size=2, stride=2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
class VNet(nn.Module):
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
super(VNet, self).__init__()
self.has_dropout = has_dropout
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
self.__init_weight()
def encoder(self, input):
x1 = self.block_one(input)
x1_dw = self.block_one_dw(x1)
x2 = self.block_two(x1_dw)
x2_dw = self.block_two_dw(x2)
x3 = self.block_three(x2_dw)
x3_dw = self.block_three_dw(x3)
x4 = self.block_four(x3_dw)
x4_dw = self.block_four_dw(x4)
x5 = self.block_five(x4_dw)
if self.has_dropout:
x5 = self.dropout(x5)
res = [x1, x2, x3, x4, x5]
return res
def decoder(self, features):
x1 = features[0]
x2 = features[1]
x3 = features[2]
x4 = features[3]
x5 = features[4]
x5_up = self.block_five_up(x5)
x5_up = x5_up + x4
x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3
x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2
x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
x8_up = x8_up + x1
x9 = self.block_nine(x8_up)
if self.has_dropout:
x9 = self.dropout(x9)
out = self.out_conv(x9)
return out
def forward(self, input, turnoff_drop=False):
if turnoff_drop:
has_dropout = self.has_dropout
self.has_dropout = False
features = self.encoder(input)
out = self.decoder(features)
if turnoff_drop:
self.has_dropout = has_dropout
return out
def __init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class VNet_roi(nn.Module):
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
super(VNet_roi, self).__init__()
self.has_dropout = has_dropout
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
# self.__init_weight()
def encoder(self, input):
x1 = self.block_one(input)
x1_dw = self.block_one_dw(x1)
x2 = self.block_two(x1_dw)
x2_dw = self.block_two_dw(x2)
x3 = self.block_three(x2_dw)
x3_dw = self.block_three_dw(x3)
x4 = self.block_four(x3_dw)
x4_dw = self.block_four_dw(x4)
x5 = self.block_five(x4_dw)
# x5 = F.dropout3d(x5, p=0.5, training=True)
if self.has_dropout:
x5 = self.dropout(x5)
res = [x1, x2, x3, x4, x5]
return res
def decoder(self, features):
x1 = features[0]
x2 = features[1]
x3 = features[2]
x4 = features[3]
x5 = features[4]
x5_up = self.block_five_up(x5)
x5_up = x5_up + x4
x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3
x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2
x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
x8_up = x8_up + x1
x9 = self.block_nine(x8_up)
# x9 = F.dropout3d(x9, p=0.5, training=True)
if self.has_dropout:
x9 = self.dropout(x9)
out = self.out_conv(x9)
return out
def forward(self, input, turnoff_drop=False):
if turnoff_drop:
has_dropout = self.has_dropout
self.has_dropout = False
features = self.encoder(input)
out = self.decoder(features)
if turnoff_drop:
self.has_dropout = has_dropout
return out
class ResVNet(nn.Module):
def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False):
super(ResVNet, self).__init__()
self.resencoder = resnet34()
self.has_dropout = has_dropout
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
if has_dropout:
self.dropout = nn.Dropout3d(p=0.5)
self.branchs = nn.ModuleList()
for i in range(1):
if has_dropout:
seq = nn.Sequential(
ConvBlock(1, n_filters, n_filters, normalization=normalization),
nn.Dropout3d(p=0.5),
nn.Conv3d(n_filters, n_classes, 1, padding=0)
)
else:
seq = nn.Sequential(
ConvBlock(1, n_filters, n_filters, normalization=normalization),
nn.Conv3d(n_filters, n_classes, 1, padding=0)
)
self.branchs.append(seq)
def encoder(self, input):
x1 = self.block_one(input)
x1_dw = self.block_one_dw(x1)
x2 = self.block_two(x1_dw)
x2_dw = self.block_two_dw(x2)
x3 = self.block_three(x2_dw)
x3_dw = self.block_three_dw(x3)
x4 = self.block_four(x3_dw)
x4_dw = self.block_four_dw(x4)
x5 = self.block_five(x4_dw)
if self.has_dropout:
x5 = self.dropout(x5)
res = [x1, x2, x3, x4, x5]
return res
def decoder(self, features):
x1 = features[0]
x2 = features[1]
x3 = features[2]
x4 = features[3]
x5 = features[4]
x5_up = self.block_five_up(x5)
x5_up = x5_up + x4
x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3
x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2
x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
x8_up = x8_up + x1
x9 = self.block_nine(x8_up)
out = self.out_conv(x9)
return out
def forward(self, input, turnoff_drop=False):
if turnoff_drop:
has_dropout = self.has_dropout
self.has_dropout = False
features = self.resencoder(input)
out = self.decoder(features)
if turnoff_drop:
self.has_dropout = has_dropout
return out
__all__ = ['ResNet', 'resnet34']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
return nn.Sequential(
conv3x3(in_planes, out_planes, stride),
nn.InstanceNorm3d(out_planes),
nn.ReLU()
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=-1):
super(BasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.InstanceNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.InstanceNorm3d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(Bottleneck, self).__init__()
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False)
self.bn1 = nn.InstanceNorm3d(width)
self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation,
padding=dilation, groups=groups, bias=False)
self.bn2 = nn.InstanceNorm3d(width)
self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.InstanceNorm3d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, in_channel=1, width=1,
groups=1, width_per_group=64,
mid_dim=1024, low_dim=128,
avg_down=False, deep_stem=False,
head_type='mlp_head', layer4_dilation=1):
super(ResNet, self).__init__()
self.avg_down = avg_down
self.inplanes = 16 * width
self.base = int(16 * width)
self.groups = groups
self.base_width = width_per_group
mid_dim = self.base * 8 * block.expansion
if deep_stem:
self.conv1 = nn.Sequential(
conv3x3_bn_relu(in_channel, 32, stride=2),
conv3x3_bn_relu(32, 32, stride=1),
conv3x3(32, 64, stride=1)
)
else:
self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
self.bn1 = nn.InstanceNorm3d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2)
self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2)
self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2)
if layer4_dilation == 1:
self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2)
elif layer4_dilation == 2:
self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2)
else:
raise NotImplementedError
self.avgpool = nn.AvgPool3d(7, stride=1)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if self.avg_down:
downsample = nn.Sequential(
nn.AvgPool3d(kernel_size=stride, stride=stride),
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False),
nn.InstanceNorm3d(planes * block.expansion),
)
else:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.InstanceNorm3d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
#c2 = self.maxpool(x)
c2 = self.layer1(x)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return [x,c2,c3,c4,c5]
def resnet34(**kwargs):
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def label_rescale(image_label, w_ori, h_ori, z_ori, flag):
w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori)
# resize label map (int)
if flag == 'trilinear':
teeth_ids = np.unique(image_label)
image_label_ori = np.zeros((w_ori, h_ori, z_ori))
image_label = torch.from_numpy(image_label).cuda(0)
for label_id in range(len(teeth_ids)):
image_label_bn = (image_label == teeth_ids[label_id]).float()
image_label_bn = image_label_bn[None, None, :, :, :]
image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori),
mode='trilinear', align_corners=False)
image_label_bn = image_label_bn[0, 0, :, :, :]
image_label_bn = image_label_bn.cpu().data.numpy()
image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id]
image_label = image_label_ori
if flag == 'nearest':
image_label = torch.from_numpy(image_label).cuda(0)
image_label = image_label[None, None, :, :, :].float()
image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest')
image_label = image_label[0, 0, :, :, :].cpu().data.numpy()
return image_label
def img_crop(image_bbox):
if image_bbox.sum() > 0:
x_min = np.nonzero(image_bbox)[0].min() - 8
x_max = np.nonzero(image_bbox)[0].max() + 8
y_min = np.nonzero(image_bbox)[1].min() - 16
y_max = np.nonzero(image_bbox)[1].max() + 16
z_min = np.nonzero(image_bbox)[2].min() - 16
z_max = np.nonzero(image_bbox)[2].max() + 16
if x_min < 0:
x_min = 0
if y_min < 0:
y_min = 0
if z_min < 0:
z_min = 0
if x_max > image_bbox.shape[0]:
x_max = image_bbox.shape[0]
if y_max > image_bbox.shape[1]:
y_max = image_bbox.shape[1]
if z_max > image_bbox.shape[2]:
z_max = image_bbox.shape[2]
if (x_max - x_min) % 16 != 0:
x_max -= (x_max - x_min) % 16
if (y_max - y_min) % 16 != 0:
y_max -= (y_max - y_min) % 16
if (z_max - z_min) % 16 != 0:
z_max -= (z_max - z_min) % 16
if image_bbox.sum() == 0:
x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[
2]
return x_min, x_max, y_min, y_max, z_min, z_max
def roi_extraction(image, net_roi, ids):
w, h, d = image.shape
# roi binary segmentation parameters, the input spacing is 0.4 mm
print('---run the roi binary segmentation.')
stride_xy = 32
stride_z = 16
patch_size_roi_stage = (112, 112, 80)
label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z,
patch_size_roi_stage) # (400,400,200)
print(label_roi.shape, np.max(label_roi))
label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400)
label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float)
label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5))
label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype(
float)
label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5))
# crop image
x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi)
if x_min == -1: # non-foreground label
whole_label = np.zeros((w, h, d))
return whole_label
image = image[x_min:x_max, y_min:y_max, z_min:z_max]
print("image shape(after roi): ", image.shape)
return image, x_min, x_max, y_min, y_max, z_min, z_max
def roi_detection(net, image, stride_xy, stride_z, patch_size):
w, h, d = image.shape # (400,400,200)
# if the size of image is less than patch_size, then padding it
add_pad = False
if w < patch_size[0]:
w_pad = patch_size[0] - w
add_pad = True
else:
w_pad = 0
if h < patch_size[1]:
h_pad = patch_size[1] - h
add_pad = True
else:
h_pad = 0
if d < patch_size[2]:
d_pad = patch_size[2] - d
add_pad = True
else:
d_pad = 0
wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
if add_pad:
image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant',
constant_values=0)
ww, hh, dd = image.shape
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2
score_map = np.zeros((2,) + image.shape).astype(np.float32)
cnt = np.zeros(image.shape).astype(np.float32)
count = 0
for x in range(0, sx):
xs = min(stride_xy * x, ww - patch_size[0])
for y in range(0, sy):
ys = min(stride_xy * y, hh - patch_size[1])
for z in range(0, sz):
zs = min(stride_z * z, dd - patch_size[2])
test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1],
zs:zs + patch_size[2]]
test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(
np.float32)
test_patch = torch.from_numpy(test_patch).cuda(0)
with torch.no_grad():
y1 = net(test_patch) # (1,2,256,256,160)
y = F.softmax(y1, dim=1) # (1,2,256,256,160)
y = y.cpu().data.numpy()
y = y[0, :, :, :, :] # (2,256,256,160)
score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
= score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1],
zs:zs + patch_size[2]] + y # (2,400,400,200)
cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
= cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200)
count = count + 1
score_map = score_map / np.expand_dims(cnt, axis=0)
label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1
if add_pad:
label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
return label_map
def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1):
w, h, d = image.shape
# if the size of image is less than patch_size, then padding it
add_pad = False
if w < patch_size[0]:
w_pad = patch_size[0]-w
add_pad = True
else:
w_pad = 0
if h < patch_size[1]:
h_pad = patch_size[1]-h
add_pad = True
else:
h_pad = 0
if d < patch_size[2]:
d_pad = patch_size[2]-d
add_pad = True
else:
d_pad = 0
wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
if add_pad:
image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
ww,hh,dd = image.shape
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
cnt = np.zeros(image.shape).astype(np.float32)
for x in range(0, sx):
xs = min(stride_xy*x, ww-patch_size[0])
for y in range(0, sy):
ys = min(stride_xy * y,hh-patch_size[1])
for z in range(0, sz):
zs = min(stride_z * z, dd-patch_size[2])
test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
test_patch = torch.from_numpy(test_patch).cuda()
for model in model_array:
output = model(test_patch)
y_temp = F.softmax(output, dim=1)
y_temp = y_temp.cpu().data.numpy()
y += y_temp[0,:,:,:,:]
y /= len(model_array)
score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
= score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
= cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
score_map = score_map/np.expand_dims(cnt,axis=0)
label_map = np.argmax(score_map, axis = 0)
if add_pad:
label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
return label_map, score_map
def calculate_metric_percase(pred, gt):
dice = metric.binary.dc(pred, gt)
jc = metric.binary.jc(pred, gt)
hd = metric.binary.hd95(pred, gt)
asd = metric.binary.asd(pred, gt)
return dice, jc, hd, asd
class RailNetSystem(nn.Module, PyTorchModelHubMixin):
def __init__(self, n_channels: int, n_classes: int, normalization: str):
super().__init__()
self.num_classes = 2
self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda()
self.model_array = []
for i in range(4):
if i < 2:
model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
else:
model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
self.model_array.append(model)
def load_weights(self, weight_dir="."):
self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cuda", weights_only=True))
self.net_roi.eval()
model_files = [
"rail_0_iter_7995_best.pth",
"rail_1_iter_7995_best.pth",
"rail_2_iter_7995_best.pth",
"rail_3_iter_7995_best.pth",
]
for i, file in enumerate(model_files):
self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cuda", weights_only=True))
self.model_array[i].eval()
def forward(self, image, label, save_path="./output", name="case"):
if not os.path.exists(save_path):
os.makedirs(save_path)
nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz"))
w, h, d = image.shape
image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name)
prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes)
prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float)
new_prediction = np.zeros((w, h, d))
new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:])
nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz"))
return new_prediction, dice, jc, hd, asd