Tournesol-Saturday commited on
Commit
24c0f42
·
verified ·
1 Parent(s): d746c58

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -183
app.py DELETED
@@ -1,183 +0,0 @@
1
- import os
2
- import time
3
- import h5py
4
- import numpy as np
5
- import gradio as gr
6
- import plotly.graph_objects as go
7
- from railnet_model import RailNetSystem
8
-
9
- from huggingface_hub import hf_hub_download
10
-
11
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
12
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
-
14
- # model = RailNetSystem.from_pretrained(".").cuda()
15
-
16
- model = RailNetSystem.from_pretrained("Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image").cuda()
17
-
18
- model.load_weights(from_hub=True, repo_id="Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image")
19
-
20
-
21
- # def wait_for_stable_file(file_path, timeout=5, check_interval=0.2):
22
- # start_time = time.time()
23
- # last_size = -1
24
- # while time.time() - start_time < timeout:
25
- # current_size = os.path.getsize(file_path)
26
- # if current_size == last_size:
27
- # return True
28
- # last_size = current_size
29
- # time.sleep(check_interval)
30
- # return False
31
-
32
- # def process_cbct_file(h5_file, save_dir="./output"):
33
- # if not wait_for_stable_file(h5_file.name):
34
- # raise RuntimeError("File upload has not been completed or is unstable, please try again.")
35
-
36
- # try:
37
- # with h5py.File(h5_file.name, "r") as f:
38
- # if "image" not in f or "label" not in f:
39
- # raise KeyError("The file is missing ‘image’ or ‘label’ value")
40
- # image = f["image"][:]
41
- # label = f["label"][:]
42
- # except Exception as e:
43
- # raise RuntimeError(f"Failed to read the .h5 file: {str(e)}")
44
-
45
- # name = os.path.basename(h5_file.name).replace(".h5", "")
46
- # pred, dice, jc, hd, asd = model(image, label, save_dir, name)
47
-
48
- # img_path = os.path.join(save_dir, f"{name}_img.nii.gz")
49
- # pred_path = os.path.join(save_dir, f"{name}_pred.nii.gz")
50
-
51
- # return pred, f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}", img_path, pred_path
52
-
53
- def render_plotly_volume(pred, x_eye=1.25, y_eye=1.25, z_eye=1.25):
54
- downsample_factor = 2
55
- pred_ds = pred[::downsample_factor, ::downsample_factor, ::downsample_factor]
56
-
57
- fig = go.Figure(data=go.Volume(
58
- x=np.repeat(np.arange(pred_ds.shape[0]), pred_ds.shape[1] * pred_ds.shape[2]),
59
- y=np.tile(np.repeat(np.arange(pred_ds.shape[1]), pred_ds.shape[2]), pred_ds.shape[0]),
60
- z=np.tile(np.arange(pred_ds.shape[2]), pred_ds.shape[0] * pred_ds.shape[1]),
61
- value=pred_ds.flatten(),
62
- isomin=0.5,
63
- isomax=1.0,
64
- opacity=0.1,
65
- surface_count=1,
66
- colorscale=[[0, 'rgb(255, 0, 0)'], [1, 'rgb(255, 0, 0)']],
67
- showscale=False
68
- ))
69
-
70
- fig.update_layout(
71
- scene=dict(
72
- xaxis=dict(visible=False),
73
- yaxis=dict(visible=False),
74
- zaxis=dict(visible=False),
75
- camera=dict(eye=dict(x=x_eye, y=y_eye, z=z_eye))
76
- ),
77
- margin=dict(l=0, r=0, b=0, t=0)
78
- )
79
- return fig
80
-
81
-
82
- def handle_example(filename):
83
- repo_id = "Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image"
84
- h5_path = hf_hub_download(repo_id=repo_id, filename=f"example_input_file/{filename}")
85
-
86
- with h5py.File(h5_path, "r") as f:
87
- image = f["image"][:]
88
- label = f["label"][:]
89
-
90
- name = filename.replace(".h5", "")
91
- pred, dice, jc, hd, asd = model(image, label, "./output", name)
92
-
93
- fig = render_plotly_volume(pred)
94
-
95
- img_path = f"./output/{name}_img.nii.gz"
96
- pred_path = f"./output/{name}_pred.nii.gz"
97
-
98
- metrics = f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}"
99
-
100
- return metrics, pred, fig, img_path, pred_path
101
-
102
-
103
- def clear_all():
104
- return "", None, None, None, None
105
-
106
- with gr.Blocks() as demo:
107
- gr.HTML("<div style='text-align: center; font-size: 22px; font-weight: bold;'>🦷 Demo of RailNet: A CBCT Tooth Segmentation System</div>")
108
- gr.HTML("<div style='text-align: center; font-size: 15px'>✅ Steps: Select a CBCT example file (.h5) → Automatic inference and metrics display → View 3D segmentation result (Mouse drag and scroll wheel zooming)</div>")
109
-
110
- # gr.HTML("<div style='font-size: 15px; font-weight: bold;'>📂 Step 1: Upload the .h5 example file containing both ‘image’ and ‘label’ values</div>")
111
- gr.HTML("""
112
- <style>
113
- .code-style {
114
- font-family: monospace;
115
- background-color: #2f363d;
116
- color: #ffffff;
117
- padding: 2px 6px;
118
- border-radius: 4px;
119
- font-size: 90%;
120
- }
121
- </style>
122
-
123
- <div style='font-size: 15px; font-weight: bold;'>
124
- 📂 Step 1: Select a <span class='code-style'>.h5</span> example file from the <span class='code-style'>example_input_file</span> folder in our
125
- <a href='https://huggingface.co/Tournesol-Saturday/railNet-tooth-segmentation-in-CBCT-image' target='_blank' style='text-decoration: none; color: #1f6feb; font-weight: bold;'>
126
- Hugging Face model
127
- </a> repository.
128
- </div>
129
- """)
130
-
131
- # file_input = gr.File()
132
-
133
-
134
- example_files = ["CBCT_01.h5", "CBCT_02.h5", "CBCT_03.h5", "CBCT_04.h5"]
135
- dropdown = gr.Dropdown(choices=example_files, label="Example File", value=example_files[0])
136
-
137
-
138
- with gr.Row():
139
- clear_btn = gr.Button("清除", variant="secondary")
140
- submit_btn = gr.Button("提交", variant="primary")
141
-
142
- gr.HTML("<div style='font-size: 15px; font-weight: bold;'>📊 Step 2: Metrics (Dice, Jaccard, 95HD, ASD)</div>")
143
- result_text = gr.Textbox()
144
- hidden_pred = gr.State(value=None)
145
-
146
- gr.HTML("<div style='font-size: 15px; font-weight: bold;'>👁️ Step 3: 3D Visualisation</div>")
147
- plot_output = gr.Plot()
148
-
149
- hidden_img_file = gr.File(visible=False)
150
- hidden_pred_file = gr.File(visible=False)
151
-
152
- gr.HTML("<div style='font-size: 15px; font-weight: bold;'>⬇️ Step 4: Download <span class='code-style'>NIfTI</span> files for accurate 1:1 visualization using <span class='code-style'>ITK-SNAP</span> software</div>")
153
- with gr.Row():
154
- download_img_btn = gr.Button("Download Original Image")
155
- download_pred_btn = gr.Button("Download Segmentation Result")
156
-
157
- # def handle_upload(h5_file):
158
- # pred, metrics, img_path, pred_path = process_cbct_file(h5_file)
159
- # fig = render_plotly_volume(pred)
160
- # return metrics, pred, fig, img_path, pred_path
161
-
162
- submit_btn.click(
163
- fn=handle_example,
164
- inputs=[dropdown],
165
- outputs=[result_text, hidden_pred, plot_output, hidden_img_file, hidden_pred_file]
166
- )
167
-
168
- def update_view(pred, x_eye, y_eye, z_eye):
169
- if pred is None:
170
- return gr.update()
171
- return render_plotly_volume(pred, x_eye, y_eye, z_eye)
172
-
173
- clear_btn.click(
174
- fn=clear_all,
175
- inputs=[],
176
- outputs=[result_text, hidden_pred, plot_output, hidden_img_file, hidden_pred_file]
177
- )
178
-
179
- download_img_btn.click(fn=lambda f: f, inputs=[hidden_img_file], outputs=[hidden_img_file])
180
- download_pred_btn.click(fn=lambda f: f, inputs=[hidden_pred_file], outputs=[hidden_pred_file])
181
-
182
- demo.launch()
183
-