Tournesol-Saturday commited on
Commit
4c1d50f
·
verified ·
1 Parent(s): 86f6c59

Upload folder using huggingface_hub

Browse files
example_input_file/CBCT_01.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19680f9f0b220ee20b55731f386276d3c00496ae861ed3ecedc759d6ec2b8a42
3
+ size 716802048
example_input_file/CBCT_02.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed421541ed92ab7c203aa8edbfde743d139fef000264a8279353a08087f5735f
3
+ size 716802048
example_input_file/CBCT_03.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4beb84cb31420c30aeb992345192b8bc8e798542f502b4d7d48c6b1060bda8ae
3
+ size 716802048
example_input_file/CBCT_04.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecfd0c239381a72acecf48de3f309c4a79267f27cfbce00b4f126a8b35fc7a81
3
+ size 839682048
gradio_app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import h5py
4
+ import numpy as np
5
+ import gradio as gr
6
+ import plotly.graph_objects as go
7
+ from railnet_model import RailNetSystem
8
+
9
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11
+
12
+ model = RailNetSystem.from_pretrained(".").cuda()
13
+ model.load_weights(".")
14
+
15
+ def wait_for_stable_file(file_path, timeout=5, check_interval=0.2):
16
+ start_time = time.time()
17
+ last_size = -1
18
+ while time.time() - start_time < timeout:
19
+ current_size = os.path.getsize(file_path)
20
+ if current_size == last_size:
21
+ return True
22
+ last_size = current_size
23
+ time.sleep(check_interval)
24
+ return False
25
+
26
+ def process_cbct_file(h5_file, save_dir="./output"):
27
+ if not wait_for_stable_file(h5_file.name):
28
+ raise RuntimeError("File upload has not been completed or is unstable, please try again.")
29
+
30
+ try:
31
+ with h5py.File(h5_file.name, "r") as f:
32
+ if "image" not in f or "label" not in f:
33
+ raise KeyError("The file is missing ‘image’ or ‘label’ value")
34
+ image = f["image"][:]
35
+ label = f["label"][:]
36
+ except Exception as e:
37
+ raise RuntimeError(f"Failed to read the .h5 file: {str(e)}")
38
+
39
+ name = os.path.basename(h5_file.name).replace(".h5", "")
40
+ pred, dice, jc, hd, asd = model(image, label, save_dir, name)
41
+ return pred, f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}"
42
+
43
+ def render_plotly_volume(pred, x_eye=1.25, y_eye=1.25, z_eye=1.25):
44
+ downsample_factor = 2
45
+ pred_ds = pred[::downsample_factor, ::downsample_factor, ::downsample_factor]
46
+
47
+ fig = go.Figure(data=go.Volume(
48
+ x=np.repeat(np.arange(pred_ds.shape[0]), pred_ds.shape[1] * pred_ds.shape[2]),
49
+ y=np.tile(np.repeat(np.arange(pred_ds.shape[1]), pred_ds.shape[2]), pred_ds.shape[0]),
50
+ z=np.tile(np.arange(pred_ds.shape[2]), pred_ds.shape[0] * pred_ds.shape[1]),
51
+ value=pred_ds.flatten(),
52
+ isomin=0.5,
53
+ isomax=1.0,
54
+ opacity=0.1,
55
+ surface_count=1,
56
+ colorscale=[[0, 'rgb(255, 0, 0)'], [1, 'rgb(255, 0, 0)']],
57
+ showscale=False
58
+ ))
59
+
60
+ fig.update_layout(
61
+ scene=dict(
62
+ xaxis=dict(visible=False),
63
+ yaxis=dict(visible=False),
64
+ zaxis=dict(visible=False),
65
+ camera=dict(eye=dict(x=x_eye, y=y_eye, z=z_eye))
66
+ ),
67
+ margin=dict(l=0, r=0, b=0, t=0)
68
+ )
69
+ return fig
70
+
71
+ def clear_all():
72
+ return None, "", None
73
+
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("<div style='text-align: center; font-size: 28px; font-weight: bold;'>🦷 Demo of RailNet: A CBCT Tooth Segmentation System</div>")
76
+ gr.Markdown("<div style='text-align: center; font-size: 20px'>✅ Steps: Upload a CBCT example file (.h5) → Automatic inference and metrics display → View 3D segmentation result (Mouse drag and scroll wheel zooming)</div>")
77
+
78
+ gr.Markdown("<div style='height: 20px;'></div>")
79
+ gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📂 Step 1: Upload the .h5 example file containing both ‘image’ and ‘label’ values</div>")
80
+ file_input = gr.File()
81
+ with gr.Row():
82
+ clear_btn = gr.Button("清除", variant="secondary")
83
+ submit_btn = gr.Button("提交", variant="primary")
84
+
85
+ gr.Markdown("<div style='height: 20px;'></div>")
86
+ gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📊 Step 2: Metrics (Dice, Jaccard, 95HD, ASD)</div>")
87
+ result_text = gr.Textbox()
88
+ hidden_pred = gr.State(value=None)
89
+
90
+ gr.Markdown("<div style='height: 20px;'></div>")
91
+ gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>👁️ Step 3: 3D Visualisation</div>")
92
+ plot_output = gr.Plot()
93
+
94
+ def handle_upload(h5_file):
95
+ pred, metrics = process_cbct_file(h5_file)
96
+ fig = render_plotly_volume(pred)
97
+ return metrics, pred, fig
98
+
99
+ submit_btn.click(
100
+ fn=handle_upload,
101
+ inputs=[file_input],
102
+ outputs=[result_text, hidden_pred, plot_output]
103
+ )
104
+
105
+ def update_view(pred, x_eye, y_eye, z_eye):
106
+ if pred is None:
107
+ return gr.update()
108
+ return render_plotly_volume(pred, x_eye, y_eye, z_eye)
109
+
110
+ clear_btn.click(
111
+ fn=clear_all,
112
+ inputs=[],
113
+ outputs=[file_input, result_text, plot_output]
114
+ )
115
+
116
+ demo.launch()
117
+
model weights/rail_0_iter_7995_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9085f5b56d7598efa06a17c09fe4a922feca7039faea329f3b5d3dba4e7479d3
3
+ size 37890229
model weights/rail_1_iter_7995_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66a4b9ba3314ff6421db1d2ddebdccc6aecd76477c599decbe2da82eeede2857
3
+ size 37890229
model weights/rail_2_iter_7995_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef122d5b0770a92f81c49597daf9205255698f947f11fb39d371cfb1b9fb4359
3
+ size 101339695
model weights/rail_3_iter_7995_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01e9210704169266994ebb5f468b323b262c9cbce968b35dc9a74665ae0bf85d
3
+ size 101339695
model weights/roi_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:629e02002b680e44bffab99bc5d469502045a24d0347e561f471753738a3ae57
3
+ size 37889184
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:76d4d74def12d6df1e6e1a4ff1fc92818eba6485bd11129eed89f9f97e018c94
3
  size 37839024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20c9bfd0d3bdbea2411f8c92139f7774f658e4a5d72d2277adac7e2ef368a38a
3
  size 37839024
railnet_model.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ import numpy as np
11
+ import nibabel as nib
12
+ from skimage import morphology
13
+
14
+ import math
15
+ from scipy import ndimage
16
+ from medpy import metric
17
+ import h5py
18
+ from tqdm import tqdm
19
+
20
+
21
+ class ConvBlock(nn.Module):
22
+ def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
23
+ super(ConvBlock, self).__init__()
24
+
25
+ ops = []
26
+ for i in range(n_stages):
27
+ if i == 0:
28
+ input_channel = n_filters_in
29
+ else:
30
+ input_channel = n_filters_out
31
+
32
+ ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
33
+ if normalization == 'batchnorm':
34
+ ops.append(nn.BatchNorm3d(n_filters_out))
35
+ elif normalization == 'groupnorm':
36
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
37
+ elif normalization == 'instancenorm':
38
+ ops.append(nn.InstanceNorm3d(n_filters_out))
39
+ elif normalization != 'none':
40
+ assert False
41
+ ops.append(nn.ReLU(inplace=True))
42
+
43
+ self.conv = nn.Sequential(*ops)
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class DownsamplingConvBlock(nn.Module):
51
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
52
+ super(DownsamplingConvBlock, self).__init__()
53
+
54
+ ops = []
55
+ if normalization != 'none':
56
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
57
+ if normalization == 'batchnorm':
58
+ ops.append(nn.BatchNorm3d(n_filters_out))
59
+ elif normalization == 'groupnorm':
60
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
61
+ elif normalization == 'instancenorm':
62
+ ops.append(nn.InstanceNorm3d(n_filters_out))
63
+ else:
64
+ assert False
65
+ else:
66
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
67
+
68
+ ops.append(nn.ReLU(inplace=True))
69
+
70
+ self.conv = nn.Sequential(*ops)
71
+
72
+ def forward(self, x):
73
+ x = self.conv(x)
74
+ return x
75
+
76
+
77
+ class UpsamplingDeconvBlock(nn.Module):
78
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
79
+ super(UpsamplingDeconvBlock, self).__init__()
80
+
81
+ ops = []
82
+ if normalization != 'none':
83
+ ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
84
+ if normalization == 'batchnorm':
85
+ ops.append(nn.BatchNorm3d(n_filters_out))
86
+ elif normalization == 'groupnorm':
87
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
88
+ elif normalization == 'instancenorm':
89
+ ops.append(nn.InstanceNorm3d(n_filters_out))
90
+ else:
91
+ assert False
92
+ else:
93
+ ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
94
+
95
+ ops.append(nn.ReLU(inplace=True))
96
+
97
+ self.conv = nn.Sequential(*ops)
98
+
99
+ def forward(self, x):
100
+ x = self.conv(x)
101
+ return x
102
+
103
+
104
+ class Upsampling(nn.Module):
105
+ def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
106
+ super(Upsampling, self).__init__()
107
+
108
+ ops = []
109
+ ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False))
110
+ ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
111
+ if normalization == 'batchnorm':
112
+ ops.append(nn.BatchNorm3d(n_filters_out))
113
+ elif normalization == 'groupnorm':
114
+ ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
115
+ elif normalization == 'instancenorm':
116
+ ops.append(nn.InstanceNorm3d(n_filters_out))
117
+ elif normalization != 'none':
118
+ assert False
119
+ ops.append(nn.ReLU(inplace=True))
120
+
121
+ self.conv = nn.Sequential(*ops)
122
+
123
+ def forward(self, x):
124
+ x = self.conv(x)
125
+ return x
126
+
127
+
128
+ class ConnectNet(nn.Module):
129
+ def __init__(self, in_channels, out_channels, input_size):
130
+ super(ConnectNet, self).__init__()
131
+ self.encoder = nn.Sequential(
132
+ nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1),
133
+ nn.ReLU(),
134
+ nn.MaxPool3d(kernel_size=2, stride=2),
135
+ nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
136
+ nn.ReLU(),
137
+ nn.MaxPool3d(kernel_size=2, stride=2)
138
+ )
139
+
140
+ self.decoder = nn.Sequential(
141
+ nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2),
142
+ nn.ReLU(),
143
+ nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2),
144
+ nn.Sigmoid()
145
+ )
146
+
147
+ def forward(self, x):
148
+ encoded = self.encoder(x)
149
+ decoded = self.decoder(encoded)
150
+ return decoded
151
+
152
+
153
+ class VNet(nn.Module):
154
+ def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
155
+ super(VNet, self).__init__()
156
+ self.has_dropout = has_dropout
157
+
158
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
159
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
160
+
161
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
162
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
163
+
164
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
165
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
166
+
167
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
168
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
169
+
170
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
171
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
172
+
173
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
174
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
175
+
176
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
177
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
178
+
179
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
180
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
181
+
182
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
183
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
184
+
185
+ self.dropout = nn.Dropout3d(p=0.5, inplace=False)
186
+
187
+ self.__init_weight()
188
+
189
+ def encoder(self, input):
190
+ x1 = self.block_one(input)
191
+ x1_dw = self.block_one_dw(x1)
192
+
193
+ x2 = self.block_two(x1_dw)
194
+ x2_dw = self.block_two_dw(x2)
195
+
196
+ x3 = self.block_three(x2_dw)
197
+ x3_dw = self.block_three_dw(x3)
198
+
199
+ x4 = self.block_four(x3_dw)
200
+ x4_dw = self.block_four_dw(x4)
201
+
202
+ x5 = self.block_five(x4_dw)
203
+ if self.has_dropout:
204
+ x5 = self.dropout(x5)
205
+
206
+ res = [x1, x2, x3, x4, x5]
207
+
208
+ return res
209
+
210
+ def decoder(self, features):
211
+ x1 = features[0]
212
+ x2 = features[1]
213
+ x3 = features[2]
214
+ x4 = features[3]
215
+ x5 = features[4]
216
+
217
+ x5_up = self.block_five_up(x5)
218
+ x5_up = x5_up + x4
219
+
220
+ x6 = self.block_six(x5_up)
221
+ x6_up = self.block_six_up(x6)
222
+ x6_up = x6_up + x3
223
+
224
+ x7 = self.block_seven(x6_up)
225
+ x7_up = self.block_seven_up(x7)
226
+ x7_up = x7_up + x2
227
+
228
+ x8 = self.block_eight(x7_up)
229
+ x8_up = self.block_eight_up(x8)
230
+ x8_up = x8_up + x1
231
+ x9 = self.block_nine(x8_up)
232
+ if self.has_dropout:
233
+ x9 = self.dropout(x9)
234
+ out = self.out_conv(x9)
235
+ return out
236
+
237
+ def forward(self, input, turnoff_drop=False):
238
+ if turnoff_drop:
239
+ has_dropout = self.has_dropout
240
+ self.has_dropout = False
241
+ features = self.encoder(input)
242
+ out = self.decoder(features)
243
+ if turnoff_drop:
244
+ self.has_dropout = has_dropout
245
+ return out
246
+
247
+ def __init_weight(self):
248
+ for m in self.modules():
249
+ if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
250
+ torch.nn.init.kaiming_normal_(m.weight)
251
+ elif isinstance(m, nn.BatchNorm3d):
252
+ m.weight.data.fill_(1)
253
+ m.bias.data.zero_()
254
+
255
+
256
+ class VNet_roi(nn.Module):
257
+ def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
258
+ super(VNet_roi, self).__init__()
259
+ self.has_dropout = has_dropout
260
+
261
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
262
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
263
+
264
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
265
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
266
+
267
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
268
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
269
+
270
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
271
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
272
+
273
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
274
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
275
+
276
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
277
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
278
+
279
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
280
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
281
+
282
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
283
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
284
+
285
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
286
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
287
+
288
+ self.dropout = nn.Dropout3d(p=0.5, inplace=False)
289
+ # self.__init_weight()
290
+
291
+ def encoder(self, input):
292
+ x1 = self.block_one(input)
293
+ x1_dw = self.block_one_dw(x1)
294
+
295
+ x2 = self.block_two(x1_dw)
296
+ x2_dw = self.block_two_dw(x2)
297
+
298
+ x3 = self.block_three(x2_dw)
299
+ x3_dw = self.block_three_dw(x3)
300
+
301
+ x4 = self.block_four(x3_dw)
302
+ x4_dw = self.block_four_dw(x4)
303
+
304
+ x5 = self.block_five(x4_dw)
305
+ # x5 = F.dropout3d(x5, p=0.5, training=True)
306
+ if self.has_dropout:
307
+ x5 = self.dropout(x5)
308
+
309
+ res = [x1, x2, x3, x4, x5]
310
+
311
+ return res
312
+
313
+ def decoder(self, features):
314
+ x1 = features[0]
315
+ x2 = features[1]
316
+ x3 = features[2]
317
+ x4 = features[3]
318
+ x5 = features[4]
319
+
320
+ x5_up = self.block_five_up(x5)
321
+ x5_up = x5_up + x4
322
+
323
+ x6 = self.block_six(x5_up)
324
+ x6_up = self.block_six_up(x6)
325
+ x6_up = x6_up + x3
326
+
327
+ x7 = self.block_seven(x6_up)
328
+ x7_up = self.block_seven_up(x7)
329
+ x7_up = x7_up + x2
330
+
331
+ x8 = self.block_eight(x7_up)
332
+ x8_up = self.block_eight_up(x8)
333
+ x8_up = x8_up + x1
334
+ x9 = self.block_nine(x8_up)
335
+ # x9 = F.dropout3d(x9, p=0.5, training=True)
336
+ if self.has_dropout:
337
+ x9 = self.dropout(x9)
338
+ out = self.out_conv(x9)
339
+ return out
340
+
341
+
342
+ def forward(self, input, turnoff_drop=False):
343
+ if turnoff_drop:
344
+ has_dropout = self.has_dropout
345
+ self.has_dropout = False
346
+ features = self.encoder(input)
347
+ out = self.decoder(features)
348
+ if turnoff_drop:
349
+ self.has_dropout = has_dropout
350
+ return out
351
+
352
+
353
+ class ResVNet(nn.Module):
354
+ def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False):
355
+ super(ResVNet, self).__init__()
356
+ self.resencoder = resnet34()
357
+ self.has_dropout = has_dropout
358
+
359
+ self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
360
+ self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
361
+
362
+ self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
363
+ self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
364
+
365
+ self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
366
+ self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
367
+
368
+ self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
369
+ self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
370
+
371
+ self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
372
+ self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
373
+
374
+ self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
375
+ self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
376
+
377
+ self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
378
+ self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
379
+
380
+ self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
381
+ self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
382
+
383
+
384
+ self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
385
+ self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
386
+
387
+
388
+ if has_dropout:
389
+ self.dropout = nn.Dropout3d(p=0.5)
390
+ self.branchs = nn.ModuleList()
391
+ for i in range(1):
392
+ if has_dropout:
393
+ seq = nn.Sequential(
394
+ ConvBlock(1, n_filters, n_filters, normalization=normalization),
395
+ nn.Dropout3d(p=0.5),
396
+ nn.Conv3d(n_filters, n_classes, 1, padding=0)
397
+ )
398
+ else:
399
+ seq = nn.Sequential(
400
+ ConvBlock(1, n_filters, n_filters, normalization=normalization),
401
+ nn.Conv3d(n_filters, n_classes, 1, padding=0)
402
+ )
403
+ self.branchs.append(seq)
404
+
405
+ def encoder(self, input):
406
+ x1 = self.block_one(input)
407
+ x1_dw = self.block_one_dw(x1)
408
+
409
+ x2 = self.block_two(x1_dw)
410
+ x2_dw = self.block_two_dw(x2)
411
+
412
+ x3 = self.block_three(x2_dw)
413
+ x3_dw = self.block_three_dw(x3)
414
+
415
+ x4 = self.block_four(x3_dw)
416
+ x4_dw = self.block_four_dw(x4)
417
+
418
+ x5 = self.block_five(x4_dw)
419
+
420
+ if self.has_dropout:
421
+ x5 = self.dropout(x5)
422
+
423
+ res = [x1, x2, x3, x4, x5]
424
+
425
+ return res
426
+
427
+ def decoder(self, features):
428
+ x1 = features[0]
429
+ x2 = features[1]
430
+ x3 = features[2]
431
+ x4 = features[3]
432
+ x5 = features[4]
433
+
434
+ x5_up = self.block_five_up(x5)
435
+ x5_up = x5_up + x4
436
+
437
+ x6 = self.block_six(x5_up)
438
+ x6_up = self.block_six_up(x6)
439
+ x6_up = x6_up + x3
440
+
441
+ x7 = self.block_seven(x6_up)
442
+ x7_up = self.block_seven_up(x7)
443
+ x7_up = x7_up + x2
444
+
445
+ x8 = self.block_eight(x7_up)
446
+ x8_up = self.block_eight_up(x8)
447
+ x8_up = x8_up + x1
448
+
449
+
450
+ x9 = self.block_nine(x8_up)
451
+
452
+ out = self.out_conv(x9)
453
+
454
+
455
+ return out
456
+
457
+ def forward(self, input, turnoff_drop=False):
458
+ if turnoff_drop:
459
+ has_dropout = self.has_dropout
460
+ self.has_dropout = False
461
+ features = self.resencoder(input)
462
+ out = self.decoder(features)
463
+ if turnoff_drop:
464
+ self.has_dropout = has_dropout
465
+ return out
466
+
467
+
468
+ __all__ = ['ResNet', 'resnet34']
469
+
470
+
471
+ def conv3x3(in_planes, out_planes, stride=1):
472
+ """3x3 convolution with padding"""
473
+ return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
474
+
475
+
476
+ def conv3x3_bn_relu(in_planes, out_planes, stride=1):
477
+ return nn.Sequential(
478
+ conv3x3(in_planes, out_planes, stride),
479
+ nn.InstanceNorm3d(out_planes),
480
+ nn.ReLU()
481
+ )
482
+
483
+
484
+ class BasicBlock(nn.Module):
485
+ expansion = 1
486
+
487
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
488
+ groups=1, base_width=64, dilation=-1):
489
+ super(BasicBlock, self).__init__()
490
+ if groups != 1 or base_width != 64:
491
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
492
+ self.conv1 = conv3x3(inplanes, planes, stride)
493
+ self.bn1 = nn.InstanceNorm3d(planes)
494
+ self.relu = nn.ReLU(inplace=True)
495
+ self.conv2 = conv3x3(planes, planes)
496
+ self.bn2 = nn.InstanceNorm3d(planes)
497
+ self.downsample = downsample
498
+ self.stride = stride
499
+
500
+ def forward(self, x):
501
+ residual = x
502
+
503
+ out = self.conv1(x)
504
+ out = self.bn1(out)
505
+ out = self.relu(out)
506
+
507
+ out = self.conv2(out)
508
+ out = self.bn2(out)
509
+
510
+ if self.downsample is not None:
511
+ residual = self.downsample(x)
512
+
513
+ out += residual
514
+ out = self.relu(out)
515
+
516
+ return out
517
+
518
+
519
+ class Bottleneck(nn.Module):
520
+ expansion = 4
521
+
522
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
523
+ groups=1, base_width=64, dilation=1):
524
+ super(Bottleneck, self).__init__()
525
+ width = int(planes * (base_width / 64.)) * groups
526
+ self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False)
527
+ self.bn1 = nn.InstanceNorm3d(width)
528
+ self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation,
529
+ padding=dilation, groups=groups, bias=False)
530
+ self.bn2 = nn.InstanceNorm3d(width)
531
+ self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False)
532
+ self.bn3 = nn.InstanceNorm3d(planes * self.expansion)
533
+ self.relu = nn.ReLU(inplace=True)
534
+ self.downsample = downsample
535
+ self.stride = stride
536
+
537
+ def forward(self, x):
538
+ residual = x
539
+
540
+ out = self.conv1(x)
541
+ out = self.bn1(out)
542
+ out = self.relu(out)
543
+
544
+ out = self.conv2(out)
545
+ out = self.bn2(out)
546
+ out = self.relu(out)
547
+
548
+ out = self.conv3(out)
549
+ out = self.bn3(out)
550
+
551
+ if self.downsample is not None:
552
+ residual = self.downsample(x)
553
+
554
+ out += residual
555
+ out = self.relu(out)
556
+
557
+ return out
558
+
559
+
560
+ class ResNet(nn.Module):
561
+
562
+ def __init__(self, block, layers, in_channel=1, width=1,
563
+ groups=1, width_per_group=64,
564
+ mid_dim=1024, low_dim=128,
565
+ avg_down=False, deep_stem=False,
566
+ head_type='mlp_head', layer4_dilation=1):
567
+ super(ResNet, self).__init__()
568
+ self.avg_down = avg_down
569
+ self.inplanes = 16 * width
570
+ self.base = int(16 * width)
571
+ self.groups = groups
572
+ self.base_width = width_per_group
573
+
574
+ mid_dim = self.base * 8 * block.expansion
575
+
576
+ if deep_stem:
577
+ self.conv1 = nn.Sequential(
578
+ conv3x3_bn_relu(in_channel, 32, stride=2),
579
+ conv3x3_bn_relu(32, 32, stride=1),
580
+ conv3x3(32, 64, stride=1)
581
+ )
582
+ else:
583
+ self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
584
+
585
+ self.bn1 = nn.InstanceNorm3d(self.inplanes)
586
+ self.relu = nn.ReLU(inplace=True)
587
+
588
+ self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
589
+ self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2)
590
+ self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2)
591
+ self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2)
592
+ if layer4_dilation == 1:
593
+ self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2)
594
+ elif layer4_dilation == 2:
595
+ self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2)
596
+ else:
597
+ raise NotImplementedError
598
+ self.avgpool = nn.AvgPool3d(7, stride=1)
599
+
600
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
601
+ downsample = None
602
+ if stride != 1 or self.inplanes != planes * block.expansion:
603
+ if self.avg_down:
604
+ downsample = nn.Sequential(
605
+ nn.AvgPool3d(kernel_size=stride, stride=stride),
606
+ nn.Conv3d(self.inplanes, planes * block.expansion,
607
+ kernel_size=1, stride=1, bias=False),
608
+ nn.InstanceNorm3d(planes * block.expansion),
609
+ )
610
+ else:
611
+ downsample = nn.Sequential(
612
+ nn.Conv3d(self.inplanes, planes * block.expansion,
613
+ kernel_size=1, stride=stride, bias=False),
614
+ nn.InstanceNorm3d(planes * block.expansion),
615
+ )
616
+
617
+ layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
618
+ self.inplanes = planes * block.expansion
619
+ for _ in range(1, blocks):
620
+ layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
621
+
622
+ return nn.Sequential(*layers)
623
+
624
+ def forward(self, x):
625
+ x = self.conv1(x)
626
+ x = self.bn1(x)
627
+ x = self.relu(x)
628
+ #c2 = self.maxpool(x)
629
+ c2 = self.layer1(x)
630
+ c3 = self.layer2(c2)
631
+ c4 = self.layer3(c3)
632
+ c5 = self.layer4(c4)
633
+
634
+
635
+ return [x,c2,c3,c4,c5]
636
+
637
+
638
+ def resnet34(**kwargs):
639
+ return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
640
+
641
+
642
+ def label_rescale(image_label, w_ori, h_ori, z_ori, flag):
643
+ w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori)
644
+ # resize label map (int)
645
+ if flag == 'trilinear':
646
+ teeth_ids = np.unique(image_label)
647
+ image_label_ori = np.zeros((w_ori, h_ori, z_ori))
648
+ image_label = torch.from_numpy(image_label).cuda(0)
649
+ for label_id in range(len(teeth_ids)):
650
+ image_label_bn = (image_label == teeth_ids[label_id]).float()
651
+ image_label_bn = image_label_bn[None, None, :, :, :]
652
+ image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori),
653
+ mode='trilinear', align_corners=False)
654
+ image_label_bn = image_label_bn[0, 0, :, :, :]
655
+ image_label_bn = image_label_bn.cpu().data.numpy()
656
+ image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id]
657
+ image_label = image_label_ori
658
+
659
+ if flag == 'nearest':
660
+ image_label = torch.from_numpy(image_label).cuda(0)
661
+ image_label = image_label[None, None, :, :, :].float()
662
+ image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest')
663
+ image_label = image_label[0, 0, :, :, :].cpu().data.numpy()
664
+ return image_label
665
+
666
+
667
+ def img_crop(image_bbox):
668
+ if image_bbox.sum() > 0:
669
+
670
+ x_min = np.nonzero(image_bbox)[0].min() - 8
671
+ x_max = np.nonzero(image_bbox)[0].max() + 8
672
+
673
+ y_min = np.nonzero(image_bbox)[1].min() - 16
674
+ y_max = np.nonzero(image_bbox)[1].max() + 16
675
+
676
+ z_min = np.nonzero(image_bbox)[2].min() - 16
677
+ z_max = np.nonzero(image_bbox)[2].max() + 16
678
+
679
+ if x_min < 0:
680
+ x_min = 0
681
+ if y_min < 0:
682
+ y_min = 0
683
+ if z_min < 0:
684
+ z_min = 0
685
+ if x_max > image_bbox.shape[0]:
686
+ x_max = image_bbox.shape[0]
687
+ if y_max > image_bbox.shape[1]:
688
+ y_max = image_bbox.shape[1]
689
+ if z_max > image_bbox.shape[2]:
690
+ z_max = image_bbox.shape[2]
691
+
692
+ if (x_max - x_min) % 16 != 0:
693
+ x_max -= (x_max - x_min) % 16
694
+ if (y_max - y_min) % 16 != 0:
695
+ y_max -= (y_max - y_min) % 16
696
+ if (z_max - z_min) % 16 != 0:
697
+ z_max -= (z_max - z_min) % 16
698
+
699
+ if image_bbox.sum() == 0:
700
+ x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[
701
+ 2]
702
+ return x_min, x_max, y_min, y_max, z_min, z_max
703
+
704
+
705
+ def roi_extraction(image, net_roi, ids):
706
+ w, h, d = image.shape
707
+ # roi binary segmentation parameters, the input spacing is 0.4 mm
708
+ print('---run the roi binary segmentation.')
709
+
710
+ stride_xy = 32
711
+ stride_z = 16
712
+ patch_size_roi_stage = (112, 112, 80)
713
+
714
+ label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z,
715
+ patch_size_roi_stage) # (400,400,200)
716
+ print(label_roi.shape, np.max(label_roi))
717
+ label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400)
718
+
719
+ label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float)
720
+
721
+ label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5))
722
+
723
+ label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype(
724
+ float)
725
+
726
+ label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5))
727
+
728
+ # crop image
729
+ x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi)
730
+ if x_min == -1: # non-foreground label
731
+ whole_label = np.zeros((w, h, d))
732
+ return whole_label
733
+ image = image[x_min:x_max, y_min:y_max, z_min:z_max]
734
+ print("image shape(after roi): ", image.shape)
735
+
736
+ return image, x_min, x_max, y_min, y_max, z_min, z_max
737
+
738
+
739
+ def roi_detection(net, image, stride_xy, stride_z, patch_size):
740
+ w, h, d = image.shape # (400,400,200)
741
+
742
+ # if the size of image is less than patch_size, then padding it
743
+ add_pad = False
744
+ if w < patch_size[0]:
745
+ w_pad = patch_size[0] - w
746
+ add_pad = True
747
+ else:
748
+ w_pad = 0
749
+ if h < patch_size[1]:
750
+ h_pad = patch_size[1] - h
751
+ add_pad = True
752
+ else:
753
+ h_pad = 0
754
+ if d < patch_size[2]:
755
+ d_pad = patch_size[2] - d
756
+ add_pad = True
757
+ else:
758
+ d_pad = 0
759
+ wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
760
+ hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
761
+ dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
762
+ if add_pad:
763
+ image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant',
764
+ constant_values=0)
765
+ ww, hh, dd = image.shape
766
+
767
+ sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2
768
+ sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2
769
+ sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2
770
+ score_map = np.zeros((2,) + image.shape).astype(np.float32)
771
+ cnt = np.zeros(image.shape).astype(np.float32)
772
+ count = 0
773
+ for x in range(0, sx):
774
+ xs = min(stride_xy * x, ww - patch_size[0])
775
+ for y in range(0, sy):
776
+ ys = min(stride_xy * y, hh - patch_size[1])
777
+ for z in range(0, sz):
778
+ zs = min(stride_z * z, dd - patch_size[2])
779
+ test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1],
780
+ zs:zs + patch_size[2]]
781
+ test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(
782
+ np.float32)
783
+ test_patch = torch.from_numpy(test_patch).cuda(0)
784
+ with torch.no_grad():
785
+ y1 = net(test_patch) # (1,2,256,256,160)
786
+ y = F.softmax(y1, dim=1) # (1,2,256,256,160)
787
+ y = y.cpu().data.numpy()
788
+ y = y[0, :, :, :, :] # (2,256,256,160)
789
+ score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
790
+ = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1],
791
+ zs:zs + patch_size[2]] + y # (2,400,400,200)
792
+ cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
793
+ = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200)
794
+ count = count + 1
795
+ score_map = score_map / np.expand_dims(cnt, axis=0)
796
+
797
+ label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1
798
+ if add_pad:
799
+ label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
800
+ score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
801
+ return label_map
802
+
803
+
804
+ def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1):
805
+ w, h, d = image.shape
806
+
807
+ # if the size of image is less than patch_size, then padding it
808
+ add_pad = False
809
+ if w < patch_size[0]:
810
+ w_pad = patch_size[0]-w
811
+ add_pad = True
812
+ else:
813
+ w_pad = 0
814
+ if h < patch_size[1]:
815
+ h_pad = patch_size[1]-h
816
+ add_pad = True
817
+ else:
818
+ h_pad = 0
819
+ if d < patch_size[2]:
820
+ d_pad = patch_size[2]-d
821
+ add_pad = True
822
+ else:
823
+ d_pad = 0
824
+ wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
825
+ hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
826
+ dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
827
+ if add_pad:
828
+ image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
829
+
830
+ ww,hh,dd = image.shape
831
+
832
+ sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
833
+ sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
834
+ sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
835
+ score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
836
+ cnt = np.zeros(image.shape).astype(np.float32)
837
+
838
+ for x in range(0, sx):
839
+ xs = min(stride_xy*x, ww-patch_size[0])
840
+ for y in range(0, sy):
841
+ ys = min(stride_xy * y,hh-patch_size[1])
842
+ for z in range(0, sz):
843
+ zs = min(stride_z * z, dd-patch_size[2])
844
+ test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
845
+ test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
846
+ test_patch = torch.from_numpy(test_patch).cuda()
847
+ for model in model_array:
848
+ output = model(test_patch)
849
+ y_temp = F.softmax(output, dim=1)
850
+ y_temp = y_temp.cpu().data.numpy()
851
+ y += y_temp[0,:,:,:,:]
852
+ y /= len(model_array)
853
+ score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
854
+ = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
855
+ cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
856
+ = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
857
+ score_map = score_map/np.expand_dims(cnt,axis=0)
858
+
859
+ label_map = np.argmax(score_map, axis = 0)
860
+ if add_pad:
861
+ label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
862
+ score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
863
+ return label_map, score_map
864
+
865
+ def calculate_metric_percase(pred, gt):
866
+ dice = metric.binary.dc(pred, gt)
867
+ jc = metric.binary.jc(pred, gt)
868
+ hd = metric.binary.hd95(pred, gt)
869
+ asd = metric.binary.asd(pred, gt)
870
+
871
+ return dice, jc, hd, asd
872
+
873
+
874
+ class RailNetSystem(nn.Module, PyTorchModelHubMixin):
875
+ def __init__(self, n_channels: int, n_classes: int, normalization: str):
876
+ super().__init__()
877
+
878
+ self.num_classes = 2
879
+
880
+ self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda()
881
+
882
+ self.model_array = []
883
+ for i in range(4):
884
+ if i < 2:
885
+ model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
886
+ else:
887
+ model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
888
+ self.model_array.append(model)
889
+
890
+ def load_weights(self, weight_dir="."):
891
+ self.net_roi.load_state_dict(torch.load(os.path.join(weight_dir, "model weights", "roi_best_model.pth"), map_location="cuda", weights_only=True))
892
+ self.net_roi.eval()
893
+
894
+ model_files = [
895
+ "rail_0_iter_7995_best.pth",
896
+ "rail_1_iter_7995_best.pth",
897
+ "rail_2_iter_7995_best.pth",
898
+ "rail_3_iter_7995_best.pth",
899
+ ]
900
+ for i, file in enumerate(model_files):
901
+ self.model_array[i].load_state_dict(torch.load(os.path.join(weight_dir, "model weights", file), map_location="cuda", weights_only=True))
902
+ self.model_array[i].eval()
903
+
904
+ def forward(self, image, label, save_path="./output", name="case"):
905
+ if not os.path.exists(save_path):
906
+ os.makedirs(save_path)
907
+ nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz"))
908
+
909
+ w, h, d = image.shape
910
+
911
+ image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name)
912
+
913
+ prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes)
914
+
915
+ prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float)
916
+
917
+ new_prediction = np.zeros((w, h, d))
918
+ new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
919
+
920
+ dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:])
921
+
922
+ nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz"))
923
+
924
+ return new_prediction, dice, jc, hd, asd