mathpluscode commited on
Commit
7ca35bb
Β·
1 Parent(s): 515dbe1

Display gifs

Browse files
Files changed (2) hide show
  1. app.py +101 -57
  2. requirements.txt +1 -1
app.py CHANGED
@@ -213,15 +213,18 @@ def mae(image_id, mask_ratio, progress=gr.Progress()):
213
  batch, reconstructed_dict, masks_dict = mae_inference(
214
  batch, transform, model, mask_ratio
215
  )
216
- progress(1, desc="Plotting results...")
217
 
218
- fig = plot_mae_reconstruction(
219
- batch,
220
- reconstructed_dict,
221
- masks_dict,
222
- )
223
- plt.close(fig)
224
- return fig
 
 
 
225
 
226
 
227
  def mae_tab():
@@ -229,14 +232,16 @@ def mae_tab():
229
  gr.Markdown(
230
  """
231
  This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
232
-
233
- Visualisation may take a few seconds as we download the model weights, process the data, and render the plots.
234
  """
235
  )
236
  with gr.Row():
237
  with gr.Column(scale=5):
238
  gr.Markdown("## Reconstruction")
239
- plot = gr.Plot(show_label=False)
 
 
 
 
240
  with gr.Column(scale=3):
241
  gr.Markdown("## Data Settings")
242
  image_id = gr.Slider(
@@ -254,6 +259,7 @@ def mae_tab():
254
  value=0.75,
255
  )
256
  run_button = gr.Button("Run masked autoencoder", variant="primary")
 
257
  run_button.click(
258
  fn=mae,
259
  inputs=[image_id, mask_ratio],
@@ -331,13 +337,27 @@ def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progre
331
  images = images[..., ::t_step]
332
  labels = segmentation_sax_inference(images, view, transform, model, progress)
333
 
334
- progress(1, desc="Plotting results...")
335
- fig1 = plot_segmentations_sax(images, labels, t_step)
336
- fig2 = plot_volume_changes_sax(labels, t_step)
337
- result = (fig1, fig2)
338
- plt.close(fig1)
339
- plt.close(fig2)
340
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
 
343
  def segmentation_sax_tab():
@@ -345,8 +365,6 @@ def segmentation_sax_tab():
345
  gr.Markdown(
346
  """
347
  This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
348
-
349
- Visualisation may take dozens of seconds to update as we download model checkpoints, process multiple time frames sequentially, and generate the final plots.
350
  """
351
  )
352
 
@@ -381,7 +399,7 @@ def segmentation_sax_tab():
381
  maximum=10,
382
  step=1,
383
  label="Choose the gap between time frames",
384
- value=2,
385
  )
386
  with gr.Column(scale=3):
387
  gr.Markdown("## Model Settings")
@@ -401,16 +419,22 @@ def segmentation_sax_tab():
401
 
402
  with gr.Row():
403
  with gr.Column():
404
- gr.Markdown("## Ventricle and Myocardium Segmentation")
405
- segmentation_plot = gr.Plot(show_label=False)
 
 
 
406
  with gr.Column():
407
- gr.Markdown("## Volume Estimation")
408
- volume_plot = gr.Plot(show_label=False)
 
 
 
409
 
410
  run_button.click(
411
  fn=segmentation_sax,
412
  inputs=[trained_dataset, seed, image_id, t_step],
413
- outputs=[segmentation_plot, volume_plot],
414
  )
415
  return sax_interface
416
 
@@ -475,13 +499,17 @@ def segmentation_lax(seed, image_id, progress=gr.Progress()):
475
  )
476
  labels = segmentation_lax_inference(images, view, transform, model, progress)
477
 
478
- progress(1, desc="Plotting results...")
479
- fig1 = plot_segmentations_lax(images, labels)
480
- fig2 = plot_volume_changes_lax(labels)
481
- result = (fig1, fig2)
482
- plt.close(fig1)
483
- plt.close(fig2)
484
- return result
 
 
 
 
485
 
486
 
487
  def segmentation_lax_tab():
@@ -489,8 +517,6 @@ def segmentation_lax_tab():
489
  gr.Markdown(
490
  """
491
  This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
492
-
493
- Visualisation may take a few seconds to update as we download model checkpoints, process multiple time frames, and generate the final plots.
494
  """
495
  )
496
 
@@ -533,16 +559,22 @@ def segmentation_lax_tab():
533
 
534
  with gr.Row():
535
  with gr.Column():
536
- gr.Markdown("## Ventricle and Myocardium Segmentation")
537
- segmentation_plot = gr.Plot(show_label=False)
 
 
 
538
  with gr.Column():
539
- gr.Markdown("## Ejection Fraction Prediction")
540
- volume_plot = gr.Plot(show_label=False)
 
 
 
541
 
542
  run_button.click(
543
  fn=segmentation_lax,
544
  inputs=[seed, image_id],
545
- outputs=[segmentation_plot, volume_plot],
546
  )
547
  return lax_interface
548
 
@@ -651,12 +683,19 @@ def landmark(image_id, view, method, seed, progress=gr.Progress()):
651
  else:
652
  raise ValueError(f"Invalid method: {method}")
653
 
654
- landmark_fig = plot_landmarks(images, coords)
655
- lv_fig = plot_lv(coords)
656
- result = (landmark_fig, lv_fig)
657
- plt.close(landmark_fig)
658
- plt.close(lv_fig)
659
- return result
 
 
 
 
 
 
 
660
 
661
 
662
  def landmark_tab():
@@ -664,8 +703,6 @@ def landmark_tab():
664
  gr.Markdown(
665
  """
666
  This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views.
667
-
668
- Visualisation may take a few seconds to update as we download model checkpoints, process multiple time frames, and generate the final plots.
669
  """
670
  )
671
 
@@ -679,7 +716,7 @@ def landmark_tab():
679
 
680
  ### Model
681
 
682
- The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197).
683
  There are two types of landmark localisation models:
684
 
685
  - **Heatmap**: predicts dense probability maps of landmarks
@@ -726,16 +763,22 @@ def landmark_tab():
726
 
727
  with gr.Row():
728
  with gr.Column():
729
- gr.Markdown("## Landmark Localisation")
730
- landmark_plot = gr.Plot(show_label=False)
 
 
 
731
  with gr.Column():
732
- gr.Markdown("## Left Ventricle Length Estimation")
733
- lv_plot = gr.Plot(show_label=False)
 
 
 
734
 
735
  run_button.click(
736
  fn=landmark,
737
  inputs=[image_id, view, method, seed],
738
- outputs=[landmark_plot, lv_plot],
739
  )
740
  return landmark_interface
741
 
@@ -747,8 +790,9 @@ with gr.Blocks(
747
  """
748
  # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
749
 
750
- The following demos showcase the capabilities of CineMA in multiple tasks.
751
- For more details, check out our [GitHub](https://github.com/mathpluscode/CineMA).
 
752
  """
753
  )
754
 
@@ -763,4 +807,4 @@ with gr.Blocks(
763
  segmentation_lax_tab()
764
  with gr.TabItem("πŸ“ Landmark Localisation in LAX 2C/4C View", id="landmark"):
765
  landmark_tab()
766
- demo.launch()
 
213
  batch, reconstructed_dict, masks_dict = mae_inference(
214
  batch, transform, model, mask_ratio
215
  )
216
+ progress(1, desc="Inference finished. Plotting ...")
217
 
218
+ # (y, x, z) -> (x, y, z)
219
+ batch["sax"] = np.transpose(batch["sax"], (1, 0, 2))
220
+ reconstructed_dict["sax"] = np.transpose(reconstructed_dict["sax"], (1, 0, 2))
221
+ masks_dict["sax"] = np.transpose(masks_dict["sax"], (1, 0, 2))
222
+
223
+ # Plot MAE reconstruction and save to file
224
+ mae_path = cache_dir / f"mae_image{image_id}_mask{mask_ratio:.2f}.png"
225
+ plot_mae_reconstruction(batch, reconstructed_dict, masks_dict, mae_path)
226
+
227
+ return str(mae_path)
228
 
229
 
230
  def mae_tab():
 
232
  gr.Markdown(
233
  """
234
  This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
 
 
235
  """
236
  )
237
  with gr.Row():
238
  with gr.Column(scale=5):
239
  gr.Markdown("## Reconstruction")
240
+ plot = gr.Image(
241
+ show_label=False,
242
+ type="filepath",
243
+ label="Masked Autoencoder Reconstruction",
244
+ )
245
  with gr.Column(scale=3):
246
  gr.Markdown("## Data Settings")
247
  image_id = gr.Slider(
 
259
  value=0.75,
260
  )
261
  run_button = gr.Button("Run masked autoencoder", variant="primary")
262
+
263
  run_button.click(
264
  fn=mae,
265
  inputs=[image_id, mask_ratio],
 
337
  images = images[..., ::t_step]
338
  labels = segmentation_sax_inference(images, view, transform, model, progress)
339
 
340
+ # (y, x, z, t) -> (x, y, z, t)
341
+ images = np.transpose(images, (1, 0, 2, 3))
342
+ labels = np.transpose(labels, (1, 0, 2, 3))
343
+
344
+ progress(1, desc="Inference finished. Plotting ...")
345
+
346
+ # Create file paths for saving plots
347
+ seg_path = (
348
+ cache_dir
349
+ / f"sax_segmentation_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.gif"
350
+ )
351
+ vol_path = (
352
+ cache_dir
353
+ / f"sax_volume_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.png"
354
+ )
355
+
356
+ # Plot segmentations and volume changes with file paths
357
+ plot_segmentations_sax(images, labels, seg_path)
358
+ plot_volume_changes_sax(labels, t_step, vol_path)
359
+
360
+ return (str(seg_path), str(vol_path))
361
 
362
 
363
  def segmentation_sax_tab():
 
365
  gr.Markdown(
366
  """
367
  This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
 
 
368
  """
369
  )
370
 
 
399
  maximum=10,
400
  step=1,
401
  label="Choose the gap between time frames",
402
+ value=3,
403
  )
404
  with gr.Column(scale=3):
405
  gr.Markdown("## Model Settings")
 
419
 
420
  with gr.Row():
421
  with gr.Column():
422
+ segmentation_gif = gr.Image(
423
+ show_label=True,
424
+ type="filepath",
425
+ label="Ventricle and Myocardium Segmentation",
426
+ )
427
  with gr.Column():
428
+ volume_plot = gr.Image(
429
+ show_label=True,
430
+ type="filepath",
431
+ label="Ejection Fraction Estimation",
432
+ )
433
 
434
  run_button.click(
435
  fn=segmentation_sax,
436
  inputs=[trained_dataset, seed, image_id, t_step],
437
+ outputs=[segmentation_gif, volume_plot],
438
  )
439
  return sax_interface
440
 
 
499
  )
500
  labels = segmentation_lax_inference(images, view, transform, model, progress)
501
 
502
+ progress(1, desc="Inference finished. Plotting ...")
503
+
504
+ # Plot segmentations and save as GIF
505
+ seg_path = cache_dir / f"lax_segmentation_image{image_id}_seed{seed}.gif"
506
+ plot_segmentations_lax(images, labels, seg_path)
507
+
508
+ # Plot volume changes and save as figure
509
+ vol_path = cache_dir / f"lax_volume_image{image_id}_seed{seed}.png"
510
+ plot_volume_changes_lax(labels, vol_path)
511
+
512
+ return (str(seg_path), str(vol_path))
513
 
514
 
515
  def segmentation_lax_tab():
 
517
  gr.Markdown(
518
  """
519
  This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
 
 
520
  """
521
  )
522
 
 
559
 
560
  with gr.Row():
561
  with gr.Column():
562
+ segmentation_gif = gr.Image(
563
+ show_label=True,
564
+ type="filepath",
565
+ label="Ventricle and Myocardium Segmentation",
566
+ )
567
  with gr.Column():
568
+ volume_plot = gr.Image(
569
+ show_label=True,
570
+ type="filepath",
571
+ label="Ejection Fraction Prediction",
572
+ )
573
 
574
  run_button.click(
575
  fn=segmentation_lax,
576
  inputs=[seed, image_id],
577
+ outputs=[segmentation_gif, volume_plot],
578
  )
579
  return lax_interface
580
 
 
683
  else:
684
  raise ValueError(f"Invalid method: {method}")
685
 
686
+ progress(1, desc="Inference finished. Plotting ...")
687
+
688
+ # Plot landmarks in GIF
689
+ landmark_path = (
690
+ cache_dir / f"landmark_{view}_image{image_id}_{method}_seed{seed}.gif"
691
+ )
692
+ plot_landmarks(images, coords, landmark_path)
693
+
694
+ # Plot LV change in PNG
695
+ lv_path = cache_dir / f"lv_{view}_image{image_id}_{method}_seed{seed}.png"
696
+ plot_lv(coords, lv_path)
697
+
698
+ return (str(landmark_path), str(lv_path))
699
 
700
 
701
  def landmark_tab():
 
703
  gr.Markdown(
704
  """
705
  This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views.
 
 
706
  """
707
  )
708
 
 
716
 
717
  ### Model
718
 
719
+ The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197)
720
  There are two types of landmark localisation models:
721
 
722
  - **Heatmap**: predicts dense probability maps of landmarks
 
763
 
764
  with gr.Row():
765
  with gr.Column():
766
+ landmark_gif = gr.Image(
767
+ show_label=True,
768
+ type="filepath",
769
+ label="Landmark Localisation",
770
+ )
771
  with gr.Column():
772
+ lv_plot = gr.Image(
773
+ show_label=True,
774
+ type="filepath",
775
+ label="Left Ventricle Length Estimation",
776
+ )
777
 
778
  run_button.click(
779
  fn=landmark,
780
  inputs=[image_id, view, method, seed],
781
+ outputs=[landmark_gif, lv_plot],
782
  )
783
  return landmark_interface
784
 
 
790
  """
791
  # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
792
 
793
+ πŸš€ The following demos showcase the capabilities of CineMA in multiple tasks.<br>
794
+ ⏱️ The examples may take 10-60 seconds to download data and model, perform inference, and render plots.<br>
795
+ πŸ”— For more details, check out our [GitHub](https://github.com/mathpluscode/CineMA).
796
  """
797
  )
798
 
 
807
  segmentation_lax_tab()
808
  with gr.TabItem("πŸ“ Landmark Localisation in LAX 2C/4C View", id="landmark"):
809
  landmark_tab()
810
+ demo.launch(allowed_paths=[cache_dir])
requirements.txt CHANGED
@@ -17,6 +17,6 @@ scikit-learn==1.6.1
17
  scipy==1.15.2
18
  spaces==0.36.0
19
  timm==1.0.15
20
- git+https://github.com/mathpluscode/CineMA@af1958f51e475d3d6658132c6a680d4fc4a10cac#egg=cinema
21
  --extra-index-url https://download.pytorch.org/whl/cu113
22
  torch==2.5.1
 
17
  scipy==1.15.2
18
  spaces==0.36.0
19
  timm==1.0.15
20
+ git+https://github.com/mathpluscode/CineMA@7e86ffc7ddf06ad7283915ee143ed808c0f59576#egg=cinema
21
  --extra-index-url https://download.pytorch.org/whl/cu113
22
  torch==2.5.1