Tournesol-Saturday commited on
Commit
cbe57e4
·
verified ·
1 Parent(s): b78f01b

Upload railnet_model.py

Browse files
Files changed (1) hide show
  1. railnet_model.py +975 -0
railnet_model.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ import numpy as np
11
+ import nibabel as nib
12
+ from skimage import morphology
13
+
14
+ import math
15
+ from scipy import ndimage
16
+ from medpy import metric
17
+
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ class ConvBlock(nn.Module):
22
+ def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
23
+ super(ConvBlock, self).__init__()
24
+
25
+ ops = []
26
+ for i in range(n_stages):
27
+ if i == 0:
28
+ input_channel = n_filters_in
29
+ else:
30
+ input_channel = n_filters_out
31
+
32
+ ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
33
+ if normalization == 'batchnorm':
34
+ ops.append(nn.BatchNorm3d(n_filters_out))
35
+ elif normalization == 'groupnorm':
36
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
37
+ elif normalization == 'instancenorm':
38
+ ops.append(nn.InstanceNorm3d(n_filters_out))
39
+ elif normalization != 'none':
40
+ assert False
41
+ ops.append(nn.ReLU(inplace=True))
42
+
43
+ self.conv = nn.Sequential(*ops)
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class DownsamplingConvBlock(nn.Module):
51
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
52
+ super(DownsamplingConvBlock, self).__init__()
53
+
54
+ ops = []
55
+ if normalization != 'none':
56
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
57
+ if normalization == 'batchnorm':
58
+ ops.append(nn.BatchNorm3d(n_filters_out))
59
+ elif normalization == 'groupnorm':
60
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
61
+ elif normalization == 'instancenorm':
62
+ ops.append(nn.InstanceNorm3d(n_filters_out))
63
+ else:
64
+ assert False
65
+ else:
66
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
67
+
68
+ ops.append(nn.ReLU(inplace=True))
69
+
70
+ self.conv = nn.Sequential(*ops)
71
+
72
+ def forward(self, x):
73
+ x = self.conv(x)
74
+ return x
75
+
76
+
77
+ class UpsamplingDeconvBlock(nn.Module):
78
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
79
+ super(UpsamplingDeconvBlock, self).__init__()
80
+
81
+ ops = []
82
+ if normalization != 'none':
83
+ ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
84
+ if normalization == 'batchnorm':
85
+ ops.append(nn.BatchNorm3d(n_filters_out))
86
+ elif normalization == 'groupnorm':
87
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
88
+ elif normalization == 'instancenorm':
89
+ ops.append(nn.InstanceNorm3d(n_filters_out))
90
+ else:
91
+ assert False
92
+ else:
93
+ ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
94
+
95
+ ops.append(nn.ReLU(inplace=True))
96
+
97
+ self.conv = nn.Sequential(*ops)
98
+
99
+ def forward(self, x):
100
+ x = self.conv(x)
101
+ return x
102
+
103
+
104
+ class Upsampling(nn.Module):
105
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
106
+ super(Upsampling, self).__init__()
107
+
108
+ ops = []
109
+ ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False))
110
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
111
+ if normalization == 'batchnorm':
112
+ ops.append(nn.BatchNorm3d(n_filters_out))
113
+ elif normalization == 'groupnorm':
114
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
115
+ elif normalization == 'instancenorm':
116
+ ops.append(nn.InstanceNorm3d(n_filters_out))
117
+ elif normalization != 'none':
118
+ assert False
119
+ ops.append(nn.ReLU(inplace=True))
120
+
121
+ self.conv = nn.Sequential(*ops)
122
+
123
+ def forward(self, x):
124
+ x = self.conv(x)
125
+ return x
126
+
127
+
128
+ class ConnectNet(nn.Module):
129
+ def __init__(self, in_channels, out_channels, input_size):
130
+ super(ConnectNet, self).__init__()
131
+ self.encoder = nn.Sequential(
132
+ nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1),
133
+ nn.ReLU(),
134
+ nn.MaxPool3d(kernel_size=2, stride=2),
135
+ nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
136
+ nn.ReLU(),
137
+ nn.MaxPool3d(kernel_size=2, stride=2)
138
+ )
139
+
140
+ self.decoder = nn.Sequential(
141
+ nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2),
142
+ nn.ReLU(),
143
+ nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2),
144
+ nn.Sigmoid()
145
+ )
146
+
147
+ def forward(self, x):
148
+ encoded = self.encoder(x)
149
+ decoded = self.decoder(encoded)
150
+ return decoded
151
+
152
+
153
+ class VNet(nn.Module):
154
+ def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
155
+ super(VNet, self).__init__()
156
+ self.has_dropout = has_dropout
157
+
158
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
159
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
160
+
161
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
162
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
163
+
164
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
165
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
166
+
167
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
168
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
169
+
170
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
171
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
172
+
173
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
174
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
175
+
176
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
177
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
178
+
179
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
180
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
181
+
182
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
183
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
184
+
185
+ self.dropout = nn.Dropout3d(p=0.5, inplace=False)
186
+
187
+ self.__init_weight()
188
+
189
+ def encoder(self, input):
190
+ x1 = self.block_one(input)
191
+ x1_dw = self.block_one_dw(x1)
192
+
193
+ x2 = self.block_two(x1_dw)
194
+ x2_dw = self.block_two_dw(x2)
195
+
196
+ x3 = self.block_three(x2_dw)
197
+ x3_dw = self.block_three_dw(x3)
198
+
199
+ x4 = self.block_four(x3_dw)
200
+ x4_dw = self.block_four_dw(x4)
201
+
202
+ x5 = self.block_five(x4_dw)
203
+ if self.has_dropout:
204
+ x5 = self.dropout(x5)
205
+
206
+ res = [x1, x2, x3, x4, x5]
207
+
208
+ return res
209
+
210
+ def decoder(self, features):
211
+ x1 = features[0]
212
+ x2 = features[1]
213
+ x3 = features[2]
214
+ x4 = features[3]
215
+ x5 = features[4]
216
+
217
+ x5_up = self.block_five_up(x5)
218
+ x5_up = x5_up + x4
219
+
220
+ x6 = self.block_six(x5_up)
221
+ x6_up = self.block_six_up(x6)
222
+ x6_up = x6_up + x3
223
+
224
+ x7 = self.block_seven(x6_up)
225
+ x7_up = self.block_seven_up(x7)
226
+ x7_up = x7_up + x2
227
+
228
+ x8 = self.block_eight(x7_up)
229
+ x8_up = self.block_eight_up(x8)
230
+ x8_up = x8_up + x1
231
+ x9 = self.block_nine(x8_up)
232
+ if self.has_dropout:
233
+ x9 = self.dropout(x9)
234
+ out = self.out_conv(x9)
235
+ return out
236
+
237
+ def forward(self, input, turnoff_drop=False):
238
+ if turnoff_drop:
239
+ has_dropout = self.has_dropout
240
+ self.has_dropout = False
241
+ features = self.encoder(input)
242
+ out = self.decoder(features)
243
+ if turnoff_drop:
244
+ self.has_dropout = has_dropout
245
+ return out
246
+
247
+ def __init_weight(self):
248
+ for m in self.modules():
249
+ if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
250
+ torch.nn.init.kaiming_normal_(m.weight)
251
+ elif isinstance(m, nn.BatchNorm3d):
252
+ m.weight.data.fill_(1)
253
+ m.bias.data.zero_()
254
+
255
+
256
+ class VNet_roi(nn.Module):
257
+ def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
258
+ super(VNet_roi, self).__init__()
259
+ self.has_dropout = has_dropout
260
+
261
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
262
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
263
+
264
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
265
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
266
+
267
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
268
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
269
+
270
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
271
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
272
+
273
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
274
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
275
+
276
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
277
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
278
+
279
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
280
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
281
+
282
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
283
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
284
+
285
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
286
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
287
+
288
+ self.dropout = nn.Dropout3d(p=0.5, inplace=False)
289
+ # self.__init_weight()
290
+
291
+ def encoder(self, input):
292
+ x1 = self.block_one(input)
293
+ x1_dw = self.block_one_dw(x1)
294
+
295
+ x2 = self.block_two(x1_dw)
296
+ x2_dw = self.block_two_dw(x2)
297
+
298
+ x3 = self.block_three(x2_dw)
299
+ x3_dw = self.block_three_dw(x3)
300
+
301
+ x4 = self.block_four(x3_dw)
302
+ x4_dw = self.block_four_dw(x4)
303
+
304
+ x5 = self.block_five(x4_dw)
305
+ # x5 = F.dropout3d(x5, p=0.5, training=True)
306
+ if self.has_dropout:
307
+ x5 = self.dropout(x5)
308
+
309
+ res = [x1, x2, x3, x4, x5]
310
+
311
+ return res
312
+
313
+ def decoder(self, features):
314
+ x1 = features[0]
315
+ x2 = features[1]
316
+ x3 = features[2]
317
+ x4 = features[3]
318
+ x5 = features[4]
319
+
320
+ x5_up = self.block_five_up(x5)
321
+ x5_up = x5_up + x4
322
+
323
+ x6 = self.block_six(x5_up)
324
+ x6_up = self.block_six_up(x6)
325
+ x6_up = x6_up + x3
326
+
327
+ x7 = self.block_seven(x6_up)
328
+ x7_up = self.block_seven_up(x7)
329
+ x7_up = x7_up + x2
330
+
331
+ x8 = self.block_eight(x7_up)
332
+ x8_up = self.block_eight_up(x8)
333
+ x8_up = x8_up + x1
334
+ x9 = self.block_nine(x8_up)
335
+ # x9 = F.dropout3d(x9, p=0.5, training=True)
336
+ if self.has_dropout:
337
+ x9 = self.dropout(x9)
338
+ out = self.out_conv(x9)
339
+ return out
340
+
341
+
342
+ def forward(self, input, turnoff_drop=False):
343
+ if turnoff_drop:
344
+ has_dropout = self.has_dropout
345
+ self.has_dropout = False
346
+ features = self.encoder(input)
347
+ out = self.decoder(features)
348
+ if turnoff_drop:
349
+ self.has_dropout = has_dropout
350
+ return out
351
+
352
+
353
+ class ResVNet(nn.Module):
354
+ def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False):
355
+ super(ResVNet, self).__init__()
356
+ self.resencoder = resnet34()
357
+ self.has_dropout = has_dropout
358
+
359
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
360
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
361
+
362
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
363
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
364
+
365
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
366
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
367
+
368
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
369
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
370
+
371
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
372
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
373
+
374
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
375
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
376
+
377
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
378
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
379
+
380
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
381
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
382
+
383
+
384
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
385
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
386
+
387
+
388
+ if has_dropout:
389
+ self.dropout = nn.Dropout3d(p=0.5)
390
+ self.branchs = nn.ModuleList()
391
+ for i in range(1):
392
+ if has_dropout:
393
+ seq = nn.Sequential(
394
+ ConvBlock(1, n_filters, n_filters, normalization=normalization),
395
+ nn.Dropout3d(p=0.5),
396
+ nn.Conv3d(n_filters, n_classes, 1, padding=0)
397
+ )
398
+ else:
399
+ seq = nn.Sequential(
400
+ ConvBlock(1, n_filters, n_filters, normalization=normalization),
401
+ nn.Conv3d(n_filters, n_classes, 1, padding=0)
402
+ )
403
+ self.branchs.append(seq)
404
+
405
+ def encoder(self, input):
406
+ x1 = self.block_one(input)
407
+ x1_dw = self.block_one_dw(x1)
408
+
409
+ x2 = self.block_two(x1_dw)
410
+ x2_dw = self.block_two_dw(x2)
411
+
412
+ x3 = self.block_three(x2_dw)
413
+ x3_dw = self.block_three_dw(x3)
414
+
415
+ x4 = self.block_four(x3_dw)
416
+ x4_dw = self.block_four_dw(x4)
417
+
418
+ x5 = self.block_five(x4_dw)
419
+
420
+ if self.has_dropout:
421
+ x5 = self.dropout(x5)
422
+
423
+ res = [x1, x2, x3, x4, x5]
424
+
425
+ return res
426
+
427
+ def decoder(self, features):
428
+ x1 = features[0]
429
+ x2 = features[1]
430
+ x3 = features[2]
431
+ x4 = features[3]
432
+ x5 = features[4]
433
+
434
+ x5_up = self.block_five_up(x5)
435
+ x5_up = x5_up + x4
436
+
437
+ x6 = self.block_six(x5_up)
438
+ x6_up = self.block_six_up(x6)
439
+ x6_up = x6_up + x3
440
+
441
+ x7 = self.block_seven(x6_up)
442
+ x7_up = self.block_seven_up(x7)
443
+ x7_up = x7_up + x2
444
+
445
+ x8 = self.block_eight(x7_up)
446
+ x8_up = self.block_eight_up(x8)
447
+ x8_up = x8_up + x1
448
+
449
+
450
+ x9 = self.block_nine(x8_up)
451
+
452
+ out = self.out_conv(x9)
453
+
454
+
455
+ return out
456
+
457
+ def forward(self, input, turnoff_drop=False):
458
+ if turnoff_drop:
459
+ has_dropout = self.has_dropout
460
+ self.has_dropout = False
461
+ features = self.resencoder(input)
462
+ out = self.decoder(features)
463
+ if turnoff_drop:
464
+ self.has_dropout = has_dropout
465
+ return out
466
+
467
+
468
+ __all__ = ['ResNet', 'resnet34']
469
+
470
+
471
+ def conv3x3(in_planes, out_planes, stride=1):
472
+ """3x3 convolution with padding"""
473
+ return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
474
+
475
+
476
+ def conv3x3_bn_relu(in_planes, out_planes, stride=1):
477
+ return nn.Sequential(
478
+ conv3x3(in_planes, out_planes, stride),
479
+ nn.InstanceNorm3d(out_planes),
480
+ nn.ReLU()
481
+ )
482
+
483
+
484
+ class BasicBlock(nn.Module):
485
+ expansion = 1
486
+
487
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
488
+ groups=1, base_width=64, dilation=-1):
489
+ super(BasicBlock, self).__init__()
490
+ if groups != 1 or base_width != 64:
491
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
492
+ self.conv1 = conv3x3(inplanes, planes, stride)
493
+ self.bn1 = nn.InstanceNorm3d(planes)
494
+ self.relu = nn.ReLU(inplace=True)
495
+ self.conv2 = conv3x3(planes, planes)
496
+ self.bn2 = nn.InstanceNorm3d(planes)
497
+ self.downsample = downsample
498
+ self.stride = stride
499
+
500
+ def forward(self, x):
501
+ residual = x
502
+
503
+ out = self.conv1(x)
504
+ out = self.bn1(out)
505
+ out = self.relu(out)
506
+
507
+ out = self.conv2(out)
508
+ out = self.bn2(out)
509
+
510
+ if self.downsample is not None:
511
+ residual = self.downsample(x)
512
+
513
+ out += residual
514
+ out = self.relu(out)
515
+
516
+ return out
517
+
518
+
519
+ class Bottleneck(nn.Module):
520
+ expansion = 4
521
+
522
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
523
+ groups=1, base_width=64, dilation=1):
524
+ super(Bottleneck, self).__init__()
525
+ width = int(planes * (base_width / 64.)) * groups
526
+ self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False)
527
+ self.bn1 = nn.InstanceNorm3d(width)
528
+ self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation,
529
+ padding=dilation, groups=groups, bias=False)
530
+ self.bn2 = nn.InstanceNorm3d(width)
531
+ self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False)
532
+ self.bn3 = nn.InstanceNorm3d(planes * self.expansion)
533
+ self.relu = nn.ReLU(inplace=True)
534
+ self.downsample = downsample
535
+ self.stride = stride
536
+
537
+ def forward(self, x):
538
+ residual = x
539
+
540
+ out = self.conv1(x)
541
+ out = self.bn1(out)
542
+ out = self.relu(out)
543
+
544
+ out = self.conv2(out)
545
+ out = self.bn2(out)
546
+ out = self.relu(out)
547
+
548
+ out = self.conv3(out)
549
+ out = self.bn3(out)
550
+
551
+ if self.downsample is not None:
552
+ residual = self.downsample(x)
553
+
554
+ out += residual
555
+ out = self.relu(out)
556
+
557
+ return out
558
+
559
+
560
+ class ResNet(nn.Module):
561
+
562
+ def __init__(self, block, layers, in_channel=1, width=1,
563
+ groups=1, width_per_group=64,
564
+ mid_dim=1024, low_dim=128,
565
+ avg_down=False, deep_stem=False,
566
+ head_type='mlp_head', layer4_dilation=1):
567
+ super(ResNet, self).__init__()
568
+ self.avg_down = avg_down
569
+ self.inplanes = 16 * width
570
+ self.base = int(16 * width)
571
+ self.groups = groups
572
+ self.base_width = width_per_group
573
+
574
+ mid_dim = self.base * 8 * block.expansion
575
+
576
+ if deep_stem:
577
+ self.conv1 = nn.Sequential(
578
+ conv3x3_bn_relu(in_channel, 32, stride=2),
579
+ conv3x3_bn_relu(32, 32, stride=1),
580
+ conv3x3(32, 64, stride=1)
581
+ )
582
+ else:
583
+ self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
584
+
585
+ self.bn1 = nn.InstanceNorm3d(self.inplanes)
586
+ self.relu = nn.ReLU(inplace=True)
587
+
588
+ self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
589
+ self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2)
590
+ self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2)
591
+ self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2)
592
+ if layer4_dilation == 1:
593
+ self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2)
594
+ elif layer4_dilation == 2:
595
+ self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2)
596
+ else:
597
+ raise NotImplementedError
598
+ self.avgpool = nn.AvgPool3d(7, stride=1)
599
+
600
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
601
+ downsample = None
602
+ if stride != 1 or self.inplanes != planes * block.expansion:
603
+ if self.avg_down:
604
+ downsample = nn.Sequential(
605
+ nn.AvgPool3d(kernel_size=stride, stride=stride),
606
+ nn.Conv3d(self.inplanes, planes * block.expansion,
607
+ kernel_size=1, stride=1, bias=False),
608
+ nn.InstanceNorm3d(planes * block.expansion),
609
+ )
610
+ else:
611
+ downsample = nn.Sequential(
612
+ nn.Conv3d(self.inplanes, planes * block.expansion,
613
+ kernel_size=1, stride=stride, bias=False),
614
+ nn.InstanceNorm3d(planes * block.expansion),
615
+ )
616
+
617
+ layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
618
+ self.inplanes = planes * block.expansion
619
+ for _ in range(1, blocks):
620
+ layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
621
+
622
+ return nn.Sequential(*layers)
623
+
624
+ def forward(self, x):
625
+ x = self.conv1(x)
626
+ x = self.bn1(x)
627
+ x = self.relu(x)
628
+ #c2 = self.maxpool(x)
629
+ c2 = self.layer1(x)
630
+ c3 = self.layer2(c2)
631
+ c4 = self.layer3(c3)
632
+ c5 = self.layer4(c4)
633
+
634
+
635
+ return [x,c2,c3,c4,c5]
636
+
637
+
638
+ def resnet34(**kwargs):
639
+ return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
640
+
641
+
642
+ def label_rescale(image_label, w_ori, h_ori, z_ori, flag):
643
+ w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori)
644
+ # resize label map (int)
645
+ if flag == 'trilinear':
646
+ teeth_ids = np.unique(image_label)
647
+ image_label_ori = np.zeros((w_ori, h_ori, z_ori))
648
+
649
+
650
+ image_label = torch.from_numpy(image_label).cuda(0)
651
+
652
+
653
+ # image_label = torch.from_numpy(image_label).to("cpu")
654
+ for label_id in range(len(teeth_ids)):
655
+ image_label_bn = (image_label == teeth_ids[label_id]).float()
656
+ image_label_bn = image_label_bn[None, None, :, :, :]
657
+ image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori),
658
+ mode='trilinear', align_corners=False)
659
+ image_label_bn = image_label_bn[0, 0, :, :, :]
660
+ image_label_bn = image_label_bn.cpu().data.numpy()
661
+ image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id]
662
+ image_label = image_label_ori
663
+
664
+ if flag == 'nearest':
665
+
666
+
667
+ image_label = torch.from_numpy(image_label).cuda(0)
668
+
669
+
670
+ # image_label = torch.from_numpy(image_label).to("cpu")
671
+ image_label = image_label[None, None, :, :, :].float()
672
+ image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest')
673
+ image_label = image_label[0, 0, :, :, :].cpu().data.numpy()
674
+ return image_label
675
+
676
+
677
+ def img_crop(image_bbox):
678
+ if image_bbox.sum() > 0:
679
+
680
+ x_min = np.nonzero(image_bbox)[0].min() - 8
681
+ x_max = np.nonzero(image_bbox)[0].max() + 8
682
+
683
+ y_min = np.nonzero(image_bbox)[1].min() - 16
684
+ y_max = np.nonzero(image_bbox)[1].max() + 16
685
+
686
+ z_min = np.nonzero(image_bbox)[2].min() - 16
687
+ z_max = np.nonzero(image_bbox)[2].max() + 16
688
+
689
+ if x_min < 0:
690
+ x_min = 0
691
+ if y_min < 0:
692
+ y_min = 0
693
+ if z_min < 0:
694
+ z_min = 0
695
+ if x_max > image_bbox.shape[0]:
696
+ x_max = image_bbox.shape[0]
697
+ if y_max > image_bbox.shape[1]:
698
+ y_max = image_bbox.shape[1]
699
+ if z_max > image_bbox.shape[2]:
700
+ z_max = image_bbox.shape[2]
701
+
702
+ if (x_max - x_min) % 16 != 0:
703
+ x_max -= (x_max - x_min) % 16
704
+ if (y_max - y_min) % 16 != 0:
705
+ y_max -= (y_max - y_min) % 16
706
+ if (z_max - z_min) % 16 != 0:
707
+ z_max -= (z_max - z_min) % 16
708
+
709
+ if image_bbox.sum() == 0:
710
+ x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[
711
+ 2]
712
+ return x_min, x_max, y_min, y_max, z_min, z_max
713
+
714
+
715
+ def roi_extraction(image, net_roi, ids):
716
+ w, h, d = image.shape
717
+ # roi binary segmentation parameters, the input spacing is 0.4 mm
718
+ print('---run the roi binary segmentation.')
719
+
720
+ stride_xy = 32
721
+ stride_z = 16
722
+ patch_size_roi_stage = (112, 112, 80)
723
+
724
+ label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z,
725
+ patch_size_roi_stage) # (400,400,200)
726
+ print(label_roi.shape, np.max(label_roi))
727
+ label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400)
728
+
729
+ label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float)
730
+
731
+ label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5))
732
+
733
+ label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype(
734
+ float)
735
+
736
+ label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5))
737
+
738
+ # crop image
739
+ x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi)
740
+ if x_min == -1: # non-foreground label
741
+ whole_label = np.zeros((w, h, d))
742
+ return whole_label
743
+ image = image[x_min:x_max, y_min:y_max, z_min:z_max]
744
+ print("image shape(after roi): ", image.shape)
745
+
746
+ return image, x_min, x_max, y_min, y_max, z_min, z_max
747
+
748
+
749
+ def roi_detection(net, image, stride_xy, stride_z, patch_size):
750
+ w, h, d = image.shape # (400,400,200)
751
+
752
+ # if the size of image is less than patch_size, then padding it
753
+ add_pad = False
754
+ if w < patch_size[0]:
755
+ w_pad = patch_size[0] - w
756
+ add_pad = True
757
+ else:
758
+ w_pad = 0
759
+ if h < patch_size[1]:
760
+ h_pad = patch_size[1] - h
761
+ add_pad = True
762
+ else:
763
+ h_pad = 0
764
+ if d < patch_size[2]:
765
+ d_pad = patch_size[2] - d
766
+ add_pad = True
767
+ else:
768
+ d_pad = 0
769
+ wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
770
+ hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
771
+ dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
772
+ if add_pad:
773
+ image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant',
774
+ constant_values=0)
775
+ ww, hh, dd = image.shape
776
+
777
+ sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2
778
+ sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2
779
+ sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2
780
+ score_map = np.zeros((2,) + image.shape).astype(np.float32)
781
+ cnt = np.zeros(image.shape).astype(np.float32)
782
+ count = 0
783
+ for x in range(0, sx):
784
+ xs = min(stride_xy * x, ww - patch_size[0])
785
+ for y in range(0, sy):
786
+ ys = min(stride_xy * y, hh - patch_size[1])
787
+ for z in range(0, sz):
788
+ zs = min(stride_z * z, dd - patch_size[2])
789
+ test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1],
790
+ zs:zs + patch_size[2]]
791
+ test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(
792
+ np.float32)
793
+
794
+
795
+ test_patch = torch.from_numpy(test_patch).cuda(0)
796
+
797
+
798
+ # test_patch = torch.from_numpy(test_patch).to("cpu")
799
+ with torch.no_grad():
800
+ y1 = net(test_patch) # (1,2,256,256,160)
801
+ y = F.softmax(y1, dim=1) # (1,2,256,256,160)
802
+ y = y.cpu().data.numpy()
803
+ y = y[0, :, :, :, :] # (2,256,256,160)
804
+ score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
805
+ = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1],
806
+ zs:zs + patch_size[2]] + y # (2,400,400,200)
807
+ cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
808
+ = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200)
809
+ count = count + 1
810
+ score_map = score_map / np.expand_dims(cnt, axis=0)
811
+
812
+ label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1
813
+ if add_pad:
814
+ label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
815
+ score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
816
+ return label_map
817
+
818
+
819
+ def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1):
820
+ w, h, d = image.shape
821
+
822
+ # if the size of image is less than patch_size, then padding it
823
+ add_pad = False
824
+ if w < patch_size[0]:
825
+ w_pad = patch_size[0]-w
826
+ add_pad = True
827
+ else:
828
+ w_pad = 0
829
+ if h < patch_size[1]:
830
+ h_pad = patch_size[1]-h
831
+ add_pad = True
832
+ else:
833
+ h_pad = 0
834
+ if d < patch_size[2]:
835
+ d_pad = patch_size[2]-d
836
+ add_pad = True
837
+ else:
838
+ d_pad = 0
839
+ wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
840
+ hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
841
+ dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
842
+ if add_pad:
843
+ image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
844
+
845
+ ww,hh,dd = image.shape
846
+
847
+ sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
848
+ sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
849
+ sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
850
+ score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
851
+ cnt = np.zeros(image.shape).astype(np.float32)
852
+
853
+ for x in range(0, sx):
854
+ xs = min(stride_xy*x, ww-patch_size[0])
855
+ for y in range(0, sy):
856
+ ys = min(stride_xy * y,hh-patch_size[1])
857
+ for z in range(0, sz):
858
+ zs = min(stride_z * z, dd-patch_size[2])
859
+ test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
860
+ test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
861
+
862
+
863
+ test_patch = torch.from_numpy(test_patch).cuda()
864
+
865
+
866
+ # test_patch = torch.from_numpy(test_patch).to("cpu")
867
+ for model in model_array:
868
+ output = model(test_patch)
869
+ y_temp = F.softmax(output, dim=1)
870
+ y_temp = y_temp.cpu().data.numpy()
871
+ y += y_temp[0,:,:,:,:]
872
+ y /= len(model_array)
873
+ score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
874
+ = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
875
+ cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
876
+ = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
877
+ score_map = score_map/np.expand_dims(cnt,axis=0)
878
+
879
+ label_map = np.argmax(score_map, axis = 0)
880
+ if add_pad:
881
+ label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
882
+ score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
883
+ return label_map, score_map
884
+
885
+ def calculate_metric_percase(pred, gt):
886
+ dice = metric.binary.dc(pred, gt)
887
+ jc = metric.binary.jc(pred, gt)
888
+ hd = metric.binary.hd95(pred, gt)
889
+ asd = metric.binary.asd(pred, gt)
890
+
891
+ return dice, jc, hd, asd
892
+
893
+
894
+ class RailNetSystem(nn.Module, PyTorchModelHubMixin):
895
+ def __init__(self, n_channels: int, n_classes: int, normalization: str):
896
+ super().__init__()
897
+
898
+ self.num_classes = 2
899
+
900
+
901
+ self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda()
902
+
903
+
904
+ # self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).to("cpu")
905
+
906
+ self.model_array = []
907
+ for i in range(4):
908
+ if i < 2:
909
+
910
+
911
+ model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
912
+
913
+
914
+ # model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).to("cpu")
915
+ else:
916
+
917
+
918
+ model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
919
+
920
+
921
+ # model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).to("cpu")
922
+ self.model_array.append(model)
923
+
924
+ def load_weights(self, weight_dir=".", from_hub=False, repo_id=None):
925
+ def load(file_name):
926
+ if from_hub:
927
+ return hf_hub_download(repo_id=repo_id, filename=f"model weights/{file_name}")
928
+ else:
929
+ return os.path.join(weight_dir, "model weights", file_name)
930
+
931
+
932
+ # self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cuda", weights_only=True))
933
+
934
+
935
+ # self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cpu", weights_only=True))
936
+ self.net_roi.load_state_dict(torch.load(load("roi_best_model.pth"), map_location="cuda", weights_only=True))
937
+ self.net_roi.eval()
938
+
939
+ model_files = [
940
+ "rail_0_iter_7995_best.pth",
941
+ "rail_1_iter_7995_best.pth",
942
+ "rail_2_iter_7995_best.pth",
943
+ "rail_3_iter_7995_best.pth",
944
+ ]
945
+ for i, file in enumerate(model_files):
946
+
947
+
948
+ # self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cuda", weights_only=True))
949
+
950
+
951
+ # self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cpu", weights_only=True))
952
+ self.model_array[i].load_state_dict(torch.load(load(file), map_location="cuda", weights_only=True))
953
+ self.model_array[i].eval()
954
+
955
+ def forward(self, image, label, save_path="./output", name="case"):
956
+ if not os.path.exists(save_path):
957
+ os.makedirs(save_path)
958
+ nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz"))
959
+
960
+ w, h, d = image.shape
961
+
962
+ image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name)
963
+
964
+ prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes)
965
+
966
+ prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float)
967
+
968
+ new_prediction = np.zeros((w, h, d))
969
+ new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
970
+
971
+ dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:])
972
+
973
+ nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz"))
974
+
975
+ return new_prediction, dice, jc, hd, asd