Tournesol-Saturday's picture
Upload folder using huggingface_hub
4c1d50f verified
import os
import time
import h5py
import numpy as np
import gradio as gr
import plotly.graph_objects as go
from railnet_model import RailNetSystem
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model = RailNetSystem.from_pretrained(".").cuda()
model.load_weights(".")
def wait_for_stable_file(file_path, timeout=5, check_interval=0.2):
start_time = time.time()
last_size = -1
while time.time() - start_time < timeout:
current_size = os.path.getsize(file_path)
if current_size == last_size:
return True
last_size = current_size
time.sleep(check_interval)
return False
def process_cbct_file(h5_file, save_dir="./output"):
if not wait_for_stable_file(h5_file.name):
raise RuntimeError("File upload has not been completed or is unstable, please try again.")
try:
with h5py.File(h5_file.name, "r") as f:
if "image" not in f or "label" not in f:
raise KeyError("The file is missing ‘image’ or ‘label’ value")
image = f["image"][:]
label = f["label"][:]
except Exception as e:
raise RuntimeError(f"Failed to read the .h5 file: {str(e)}")
name = os.path.basename(h5_file.name).replace(".h5", "")
pred, dice, jc, hd, asd = model(image, label, save_dir, name)
return pred, f"Dice: {dice:.4f}, Jaccard: {jc:.4f}, 95HD: {hd:.2f}, ASD: {asd:.2f}"
def render_plotly_volume(pred, x_eye=1.25, y_eye=1.25, z_eye=1.25):
downsample_factor = 2
pred_ds = pred[::downsample_factor, ::downsample_factor, ::downsample_factor]
fig = go.Figure(data=go.Volume(
x=np.repeat(np.arange(pred_ds.shape[0]), pred_ds.shape[1] * pred_ds.shape[2]),
y=np.tile(np.repeat(np.arange(pred_ds.shape[1]), pred_ds.shape[2]), pred_ds.shape[0]),
z=np.tile(np.arange(pred_ds.shape[2]), pred_ds.shape[0] * pred_ds.shape[1]),
value=pred_ds.flatten(),
isomin=0.5,
isomax=1.0,
opacity=0.1,
surface_count=1,
colorscale=[[0, 'rgb(255, 0, 0)'], [1, 'rgb(255, 0, 0)']],
showscale=False
))
fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
camera=dict(eye=dict(x=x_eye, y=y_eye, z=z_eye))
),
margin=dict(l=0, r=0, b=0, t=0)
)
return fig
def clear_all():
return None, "", None
with gr.Blocks() as demo:
gr.Markdown("<div style='text-align: center; font-size: 28px; font-weight: bold;'>🦷 Demo of RailNet: A CBCT Tooth Segmentation System</div>")
gr.Markdown("<div style='text-align: center; font-size: 20px'>✅ Steps: Upload a CBCT example file (.h5) → Automatic inference and metrics display → View 3D segmentation result (Mouse drag and scroll wheel zooming)</div>")
gr.Markdown("<div style='height: 20px;'></div>")
gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📂 Step 1: Upload the .h5 example file containing both ‘image’ and ‘label’ values</div>")
file_input = gr.File()
with gr.Row():
clear_btn = gr.Button("清除", variant="secondary")
submit_btn = gr.Button("提交", variant="primary")
gr.Markdown("<div style='height: 20px;'></div>")
gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>📊 Step 2: Metrics (Dice, Jaccard, 95HD, ASD)</div>")
result_text = gr.Textbox()
hidden_pred = gr.State(value=None)
gr.Markdown("<div style='height: 20px;'></div>")
gr.Markdown("<div style='font-size: 20px; font-weight: bold;'>👁️ Step 3: 3D Visualisation</div>")
plot_output = gr.Plot()
def handle_upload(h5_file):
pred, metrics = process_cbct_file(h5_file)
fig = render_plotly_volume(pred)
return metrics, pred, fig
submit_btn.click(
fn=handle_upload,
inputs=[file_input],
outputs=[result_text, hidden_pred, plot_output]
)
def update_view(pred, x_eye, y_eye, z_eye):
if pred is None:
return gr.update()
return render_plotly_volume(pred, x_eye, y_eye, z_eye)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[file_input, result_text, plot_output]
)
demo.launch()