File size: 4,441 Bytes
4c1d50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import time
import h5py
import numpy as np
import gradio as gr
import plotly.graph_objects as go
from railnet_model import RailNetSystem

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

model = RailNetSystem.from_pretrained(".").cuda()
model.load_weights(".")

def wait_for_stable_file(file_path, timeout=5, check_interval=0.2):
    start_time = time.time()
    last_size = -1
    while time.time() - start_time < timeout:
        current_size = os.path.getsize(file_path)
        if current_size == last_size:
            return True
        last_size = current_size
        time.sleep(check_interval)
    return False

def process_cbct_file(h5_file, save_dir="./output"):
    if not wait_for_stable_file(h5_file.name):
        raise RuntimeError("File upload has not been completed or is unstable, please try again.")

    try:
        with h5py.File(h5_file.name, "r") as f:
            if "image" not in f or "label" not in f:
                raise KeyError("The file is missing ‘image’ or ‘label’ value")
            image = f["image"][:]
            label = f["label"][:]
    except Exception as e:
        raise RuntimeError(f"Failed to read the .h5 file: {str(e)}")

    name = os.path.basename(h5_file.name).replace(".h5", "")
    pred, dice, jc, hd, asd = model(image, label, save_dir, name)
    return pred, f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}"

def render_plotly_volume(pred, x_eye=1.25, y_eye=1.25, z_eye=1.25):
    downsample_factor = 2
    pred_ds = pred[::downsample_factor, ::downsample_factor, ::downsample_factor]

    fig = go.Figure(data=go.Volume(
        x=np.repeat(np.arange(pred_ds.shape[0]), pred_ds.shape[1] * pred_ds.shape[2]),
        y=np.tile(np.repeat(np.arange(pred_ds.shape[1]), pred_ds.shape[2]), pred_ds.shape[0]),
        z=np.tile(np.arange(pred_ds.shape[2]), pred_ds.shape[0] * pred_ds.shape[1]),
        value=pred_ds.flatten(),
        isomin=0.5,
        isomax=1.0,
        opacity=0.1,
        surface_count=1,
        colorscale=[[0, 'rgb(255, 0, 0)'], [1, 'rgb(255, 0, 0)']],
        showscale=False
    ))

    fig.update_layout(
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            camera=dict(eye=dict(x=x_eye, y=y_eye, z=z_eye))
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )
    return fig

def clear_all():
    return None, "", None

with gr.Blocks() as demo:
    gr.Markdown("<div style='text-align: center; font-size: 28px; font-weight: bold;'>🦷 Demo of RailNet: A CBCT Tooth Segmentation System</div>")
    gr.Markdown("<div style='text-align: center; font-size: 20px'>✅ Steps: Upload a CBCT example file (.h5)  →  Automatic inference and metrics display  →  View 3D segmentation result (Mouse drag and scroll wheel zooming)</div>")

    gr.Markdown("<div style='height: 20px;'></div>")
    gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📂 Step 1: Upload the .h5 example file containing both ‘image’ and ‘label’ values</div>")
    file_input = gr.File()
    with gr.Row():
        clear_btn = gr.Button("清除", variant="secondary")
        submit_btn = gr.Button("提交", variant="primary")

    gr.Markdown("<div style='height: 20px;'></div>")
    gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📊 Step 2: Metrics (Dice, Jaccard, 95HD, ASD)</div>")
    result_text = gr.Textbox()
    hidden_pred = gr.State(value=None) 

    gr.Markdown("<div style='height: 20px;'></div>")
    gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>👁️ Step 3: 3D Visualisation</div>")
    plot_output = gr.Plot()

    def handle_upload(h5_file):
        pred, metrics = process_cbct_file(h5_file)
        fig = render_plotly_volume(pred)
        return metrics, pred, fig

    submit_btn.click(
        fn=handle_upload,
        inputs=[file_input],
        outputs=[result_text, hidden_pred, plot_output]
    )

    def update_view(pred, x_eye, y_eye, z_eye):
        if pred is None:
            return gr.update()
        return render_plotly_volume(pred, x_eye, y_eye, z_eye)

    clear_btn.click(
        fn=clear_all,
        inputs=[],
        outputs=[file_input, result_text, plot_output]
    )

demo.launch()