Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
7ca35bb
1
Parent(s):
515dbe1
Display gifs
Browse files- app.py +101 -57
- requirements.txt +1 -1
app.py
CHANGED
@@ -213,15 +213,18 @@ def mae(image_id, mask_ratio, progress=gr.Progress()):
|
|
213 |
batch, reconstructed_dict, masks_dict = mae_inference(
|
214 |
batch, transform, model, mask_ratio
|
215 |
)
|
216 |
-
progress(1, desc="Plotting
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
def mae_tab():
|
@@ -229,14 +232,16 @@ def mae_tab():
|
|
229 |
gr.Markdown(
|
230 |
"""
|
231 |
This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
|
232 |
-
|
233 |
-
Visualisation may take a few seconds as we download the model weights, process the data, and render the plots.
|
234 |
"""
|
235 |
)
|
236 |
with gr.Row():
|
237 |
with gr.Column(scale=5):
|
238 |
gr.Markdown("## Reconstruction")
|
239 |
-
plot = gr.
|
|
|
|
|
|
|
|
|
240 |
with gr.Column(scale=3):
|
241 |
gr.Markdown("## Data Settings")
|
242 |
image_id = gr.Slider(
|
@@ -254,6 +259,7 @@ def mae_tab():
|
|
254 |
value=0.75,
|
255 |
)
|
256 |
run_button = gr.Button("Run masked autoencoder", variant="primary")
|
|
|
257 |
run_button.click(
|
258 |
fn=mae,
|
259 |
inputs=[image_id, mask_ratio],
|
@@ -331,13 +337,27 @@ def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progre
|
|
331 |
images = images[..., ::t_step]
|
332 |
labels = segmentation_sax_inference(images, view, transform, model, progress)
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
|
343 |
def segmentation_sax_tab():
|
@@ -345,8 +365,6 @@ def segmentation_sax_tab():
|
|
345 |
gr.Markdown(
|
346 |
"""
|
347 |
This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
|
348 |
-
|
349 |
-
Visualisation may take dozens of seconds to update as we download model checkpoints, process multiple time frames sequentially, and generate the final plots.
|
350 |
"""
|
351 |
)
|
352 |
|
@@ -381,7 +399,7 @@ def segmentation_sax_tab():
|
|
381 |
maximum=10,
|
382 |
step=1,
|
383 |
label="Choose the gap between time frames",
|
384 |
-
value=
|
385 |
)
|
386 |
with gr.Column(scale=3):
|
387 |
gr.Markdown("## Model Settings")
|
@@ -401,16 +419,22 @@ def segmentation_sax_tab():
|
|
401 |
|
402 |
with gr.Row():
|
403 |
with gr.Column():
|
404 |
-
gr.
|
405 |
-
|
|
|
|
|
|
|
406 |
with gr.Column():
|
407 |
-
gr.
|
408 |
-
|
|
|
|
|
|
|
409 |
|
410 |
run_button.click(
|
411 |
fn=segmentation_sax,
|
412 |
inputs=[trained_dataset, seed, image_id, t_step],
|
413 |
-
outputs=[
|
414 |
)
|
415 |
return sax_interface
|
416 |
|
@@ -475,13 +499,17 @@ def segmentation_lax(seed, image_id, progress=gr.Progress()):
|
|
475 |
)
|
476 |
labels = segmentation_lax_inference(images, view, transform, model, progress)
|
477 |
|
478 |
-
progress(1, desc="Plotting
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
485 |
|
486 |
|
487 |
def segmentation_lax_tab():
|
@@ -489,8 +517,6 @@ def segmentation_lax_tab():
|
|
489 |
gr.Markdown(
|
490 |
"""
|
491 |
This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
|
492 |
-
|
493 |
-
Visualisation may take a few seconds to update as we download model checkpoints, process multiple time frames, and generate the final plots.
|
494 |
"""
|
495 |
)
|
496 |
|
@@ -533,16 +559,22 @@ def segmentation_lax_tab():
|
|
533 |
|
534 |
with gr.Row():
|
535 |
with gr.Column():
|
536 |
-
gr.
|
537 |
-
|
|
|
|
|
|
|
538 |
with gr.Column():
|
539 |
-
gr.
|
540 |
-
|
|
|
|
|
|
|
541 |
|
542 |
run_button.click(
|
543 |
fn=segmentation_lax,
|
544 |
inputs=[seed, image_id],
|
545 |
-
outputs=[
|
546 |
)
|
547 |
return lax_interface
|
548 |
|
@@ -651,12 +683,19 @@ def landmark(image_id, view, method, seed, progress=gr.Progress()):
|
|
651 |
else:
|
652 |
raise ValueError(f"Invalid method: {method}")
|
653 |
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
|
661 |
|
662 |
def landmark_tab():
|
@@ -664,8 +703,6 @@ def landmark_tab():
|
|
664 |
gr.Markdown(
|
665 |
"""
|
666 |
This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views.
|
667 |
-
|
668 |
-
Visualisation may take a few seconds to update as we download model checkpoints, process multiple time frames, and generate the final plots.
|
669 |
"""
|
670 |
)
|
671 |
|
@@ -679,7 +716,7 @@ def landmark_tab():
|
|
679 |
|
680 |
### Model
|
681 |
|
682 |
-
The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197)
|
683 |
There are two types of landmark localisation models:
|
684 |
|
685 |
- **Heatmap**: predicts dense probability maps of landmarks
|
@@ -726,16 +763,22 @@ def landmark_tab():
|
|
726 |
|
727 |
with gr.Row():
|
728 |
with gr.Column():
|
729 |
-
gr.
|
730 |
-
|
|
|
|
|
|
|
731 |
with gr.Column():
|
732 |
-
gr.
|
733 |
-
|
|
|
|
|
|
|
734 |
|
735 |
run_button.click(
|
736 |
fn=landmark,
|
737 |
inputs=[image_id, view, method, seed],
|
738 |
-
outputs=[
|
739 |
)
|
740 |
return landmark_interface
|
741 |
|
@@ -747,8 +790,9 @@ with gr.Blocks(
|
|
747 |
"""
|
748 |
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π«
|
749 |
|
750 |
-
The following demos showcase the capabilities of CineMA in multiple tasks
|
751 |
-
|
|
|
752 |
"""
|
753 |
)
|
754 |
|
@@ -763,4 +807,4 @@ with gr.Blocks(
|
|
763 |
segmentation_lax_tab()
|
764 |
with gr.TabItem("π Landmark Localisation in LAX 2C/4C View", id="landmark"):
|
765 |
landmark_tab()
|
766 |
-
demo.launch()
|
|
|
213 |
batch, reconstructed_dict, masks_dict = mae_inference(
|
214 |
batch, transform, model, mask_ratio
|
215 |
)
|
216 |
+
progress(1, desc="Inference finished. Plotting ...")
|
217 |
|
218 |
+
# (y, x, z) -> (x, y, z)
|
219 |
+
batch["sax"] = np.transpose(batch["sax"], (1, 0, 2))
|
220 |
+
reconstructed_dict["sax"] = np.transpose(reconstructed_dict["sax"], (1, 0, 2))
|
221 |
+
masks_dict["sax"] = np.transpose(masks_dict["sax"], (1, 0, 2))
|
222 |
+
|
223 |
+
# Plot MAE reconstruction and save to file
|
224 |
+
mae_path = cache_dir / f"mae_image{image_id}_mask{mask_ratio:.2f}.png"
|
225 |
+
plot_mae_reconstruction(batch, reconstructed_dict, masks_dict, mae_path)
|
226 |
+
|
227 |
+
return str(mae_path)
|
228 |
|
229 |
|
230 |
def mae_tab():
|
|
|
232 |
gr.Markdown(
|
233 |
"""
|
234 |
This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
|
|
|
|
|
235 |
"""
|
236 |
)
|
237 |
with gr.Row():
|
238 |
with gr.Column(scale=5):
|
239 |
gr.Markdown("## Reconstruction")
|
240 |
+
plot = gr.Image(
|
241 |
+
show_label=False,
|
242 |
+
type="filepath",
|
243 |
+
label="Masked Autoencoder Reconstruction",
|
244 |
+
)
|
245 |
with gr.Column(scale=3):
|
246 |
gr.Markdown("## Data Settings")
|
247 |
image_id = gr.Slider(
|
|
|
259 |
value=0.75,
|
260 |
)
|
261 |
run_button = gr.Button("Run masked autoencoder", variant="primary")
|
262 |
+
|
263 |
run_button.click(
|
264 |
fn=mae,
|
265 |
inputs=[image_id, mask_ratio],
|
|
|
337 |
images = images[..., ::t_step]
|
338 |
labels = segmentation_sax_inference(images, view, transform, model, progress)
|
339 |
|
340 |
+
# (y, x, z, t) -> (x, y, z, t)
|
341 |
+
images = np.transpose(images, (1, 0, 2, 3))
|
342 |
+
labels = np.transpose(labels, (1, 0, 2, 3))
|
343 |
+
|
344 |
+
progress(1, desc="Inference finished. Plotting ...")
|
345 |
+
|
346 |
+
# Create file paths for saving plots
|
347 |
+
seg_path = (
|
348 |
+
cache_dir
|
349 |
+
/ f"sax_segmentation_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.gif"
|
350 |
+
)
|
351 |
+
vol_path = (
|
352 |
+
cache_dir
|
353 |
+
/ f"sax_volume_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.png"
|
354 |
+
)
|
355 |
+
|
356 |
+
# Plot segmentations and volume changes with file paths
|
357 |
+
plot_segmentations_sax(images, labels, seg_path)
|
358 |
+
plot_volume_changes_sax(labels, t_step, vol_path)
|
359 |
+
|
360 |
+
return (str(seg_path), str(vol_path))
|
361 |
|
362 |
|
363 |
def segmentation_sax_tab():
|
|
|
365 |
gr.Markdown(
|
366 |
"""
|
367 |
This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
|
|
|
|
|
368 |
"""
|
369 |
)
|
370 |
|
|
|
399 |
maximum=10,
|
400 |
step=1,
|
401 |
label="Choose the gap between time frames",
|
402 |
+
value=3,
|
403 |
)
|
404 |
with gr.Column(scale=3):
|
405 |
gr.Markdown("## Model Settings")
|
|
|
419 |
|
420 |
with gr.Row():
|
421 |
with gr.Column():
|
422 |
+
segmentation_gif = gr.Image(
|
423 |
+
show_label=True,
|
424 |
+
type="filepath",
|
425 |
+
label="Ventricle and Myocardium Segmentation",
|
426 |
+
)
|
427 |
with gr.Column():
|
428 |
+
volume_plot = gr.Image(
|
429 |
+
show_label=True,
|
430 |
+
type="filepath",
|
431 |
+
label="Ejection Fraction Estimation",
|
432 |
+
)
|
433 |
|
434 |
run_button.click(
|
435 |
fn=segmentation_sax,
|
436 |
inputs=[trained_dataset, seed, image_id, t_step],
|
437 |
+
outputs=[segmentation_gif, volume_plot],
|
438 |
)
|
439 |
return sax_interface
|
440 |
|
|
|
499 |
)
|
500 |
labels = segmentation_lax_inference(images, view, transform, model, progress)
|
501 |
|
502 |
+
progress(1, desc="Inference finished. Plotting ...")
|
503 |
+
|
504 |
+
# Plot segmentations and save as GIF
|
505 |
+
seg_path = cache_dir / f"lax_segmentation_image{image_id}_seed{seed}.gif"
|
506 |
+
plot_segmentations_lax(images, labels, seg_path)
|
507 |
+
|
508 |
+
# Plot volume changes and save as figure
|
509 |
+
vol_path = cache_dir / f"lax_volume_image{image_id}_seed{seed}.png"
|
510 |
+
plot_volume_changes_lax(labels, vol_path)
|
511 |
+
|
512 |
+
return (str(seg_path), str(vol_path))
|
513 |
|
514 |
|
515 |
def segmentation_lax_tab():
|
|
|
517 |
gr.Markdown(
|
518 |
"""
|
519 |
This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
|
|
|
|
|
520 |
"""
|
521 |
)
|
522 |
|
|
|
559 |
|
560 |
with gr.Row():
|
561 |
with gr.Column():
|
562 |
+
segmentation_gif = gr.Image(
|
563 |
+
show_label=True,
|
564 |
+
type="filepath",
|
565 |
+
label="Ventricle and Myocardium Segmentation",
|
566 |
+
)
|
567 |
with gr.Column():
|
568 |
+
volume_plot = gr.Image(
|
569 |
+
show_label=True,
|
570 |
+
type="filepath",
|
571 |
+
label="Ejection Fraction Prediction",
|
572 |
+
)
|
573 |
|
574 |
run_button.click(
|
575 |
fn=segmentation_lax,
|
576 |
inputs=[seed, image_id],
|
577 |
+
outputs=[segmentation_gif, volume_plot],
|
578 |
)
|
579 |
return lax_interface
|
580 |
|
|
|
683 |
else:
|
684 |
raise ValueError(f"Invalid method: {method}")
|
685 |
|
686 |
+
progress(1, desc="Inference finished. Plotting ...")
|
687 |
+
|
688 |
+
# Plot landmarks in GIF
|
689 |
+
landmark_path = (
|
690 |
+
cache_dir / f"landmark_{view}_image{image_id}_{method}_seed{seed}.gif"
|
691 |
+
)
|
692 |
+
plot_landmarks(images, coords, landmark_path)
|
693 |
+
|
694 |
+
# Plot LV change in PNG
|
695 |
+
lv_path = cache_dir / f"lv_{view}_image{image_id}_{method}_seed{seed}.png"
|
696 |
+
plot_lv(coords, lv_path)
|
697 |
+
|
698 |
+
return (str(landmark_path), str(lv_path))
|
699 |
|
700 |
|
701 |
def landmark_tab():
|
|
|
703 |
gr.Markdown(
|
704 |
"""
|
705 |
This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views.
|
|
|
|
|
706 |
"""
|
707 |
)
|
708 |
|
|
|
716 |
|
717 |
### Model
|
718 |
|
719 |
+
The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197)
|
720 |
There are two types of landmark localisation models:
|
721 |
|
722 |
- **Heatmap**: predicts dense probability maps of landmarks
|
|
|
763 |
|
764 |
with gr.Row():
|
765 |
with gr.Column():
|
766 |
+
landmark_gif = gr.Image(
|
767 |
+
show_label=True,
|
768 |
+
type="filepath",
|
769 |
+
label="Landmark Localisation",
|
770 |
+
)
|
771 |
with gr.Column():
|
772 |
+
lv_plot = gr.Image(
|
773 |
+
show_label=True,
|
774 |
+
type="filepath",
|
775 |
+
label="Left Ventricle Length Estimation",
|
776 |
+
)
|
777 |
|
778 |
run_button.click(
|
779 |
fn=landmark,
|
780 |
inputs=[image_id, view, method, seed],
|
781 |
+
outputs=[landmark_gif, lv_plot],
|
782 |
)
|
783 |
return landmark_interface
|
784 |
|
|
|
790 |
"""
|
791 |
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π«
|
792 |
|
793 |
+
π The following demos showcase the capabilities of CineMA in multiple tasks.<br>
|
794 |
+
β±οΈ The examples may take 10-60 seconds to download data and model, perform inference, and render plots.<br>
|
795 |
+
π For more details, check out our [GitHub](https://github.com/mathpluscode/CineMA).
|
796 |
"""
|
797 |
)
|
798 |
|
|
|
807 |
segmentation_lax_tab()
|
808 |
with gr.TabItem("π Landmark Localisation in LAX 2C/4C View", id="landmark"):
|
809 |
landmark_tab()
|
810 |
+
demo.launch(allowed_paths=[cache_dir])
|
requirements.txt
CHANGED
@@ -17,6 +17,6 @@ scikit-learn==1.6.1
|
|
17 |
scipy==1.15.2
|
18 |
spaces==0.36.0
|
19 |
timm==1.0.15
|
20 |
-
git+https://github.com/mathpluscode/CineMA@
|
21 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
22 |
torch==2.5.1
|
|
|
17 |
scipy==1.15.2
|
18 |
spaces==0.36.0
|
19 |
timm==1.0.15
|
20 |
+
git+https://github.com/mathpluscode/CineMA@7e86ffc7ddf06ad7283915ee143ed808c0f59576#egg=cinema
|
21 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
22 |
torch==2.5.1
|