mathpluscode commited on
Commit
b257e4f
Β·
1 Parent(s): 7f9c492

Add landmark tab

Browse files
Files changed (2) hide show
  1. app.py +218 -27
  2. requirements.txt +1 -1
app.py CHANGED
@@ -6,8 +6,13 @@ import requests
6
  import SimpleITK as sitk # noqa: N813
7
  import spaces
8
  import torch
9
- from cinema import CineMA, ConvUNetR
10
  from cinema.examples.cine_cmr import plot_cmr_views
 
 
 
 
 
11
  from cinema.examples.inference.mae import plot_mae_reconstruction, reconstruct_images
12
  from cinema.examples.inference.segmentation_lax_4c import (
13
  plot_segmentations as plot_segmentations_lax,
@@ -63,7 +68,7 @@ def cmr_tab():
63
  with gr.Blocks() as cmr_interface:
64
  gr.Markdown(
65
  """
66
- This page illustrates the spatial orientation of short-axis (SAX) and long-axis (LAX) views in 3D. Use the control panels on the right to select specific images and slices.
67
  """
68
  )
69
  with gr.Row():
@@ -222,7 +227,7 @@ def mae_tab():
222
  with gr.Blocks() as mae_interface:
223
  gr.Markdown(
224
  """
225
- This page illustrates the masking and reconstruction process of the masked autoencoder. The model was trained with mask ratio 0.75 over 74,000 studies.
226
  """
227
  )
228
  with gr.Row():
@@ -334,8 +339,7 @@ def segmentation_sax_tab():
334
  with gr.Blocks() as sax_interface:
335
  gr.Markdown(
336
  """
337
- This page demonstrates the segmentation of cardiac structures in the Short-Axis (SAX) view.
338
- Please adjust the settings on the right panels and click the button to run the inference.
339
  """
340
  )
341
 
@@ -345,17 +349,16 @@ def segmentation_sax_tab():
345
  ## Description
346
  ### Data
347
 
348
- The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
349
- Image 101 - 150 are from the test set.
350
 
351
  ### Model
352
 
353
- The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2.
354
 
355
- ### Visualization
356
 
357
- The left figure shows the segmentation of ventricles and myocardium every n time steps across all SAX slices.
358
- The right figure plots the ventricle and mycoardium volumes across all inference time frames.
359
  """)
360
  with gr.Column(scale=3):
361
  gr.Markdown("## Data Settings")
@@ -363,7 +366,7 @@ def segmentation_sax_tab():
363
  minimum=101,
364
  maximum=150,
365
  step=1,
366
- label="Choose an ACDC image, ID is between 101 and 150",
367
  value=101,
368
  )
369
  t_step = gr.Slider(
@@ -374,10 +377,10 @@ def segmentation_sax_tab():
374
  value=2,
375
  )
376
  with gr.Column(scale=3):
377
- gr.Markdown("## Model Setting")
378
  trained_dataset = gr.Dropdown(
379
  choices=["ACDC", "M&MS", "M&MS2"],
380
- label="Choose which dataset the segmentation model was finetuned on",
381
  value="ACDC",
382
  )
383
  seed = gr.Slider(
@@ -394,7 +397,7 @@ def segmentation_sax_tab():
394
  gr.Markdown("## Ventricle and Myocardium Segmentation")
395
  segmentation_plot = gr.Plot(show_label=False)
396
  with gr.Column():
397
- gr.Markdown("## Ejection Fraction Prediction")
398
  volume_plot = gr.Plot(show_label=False)
399
 
400
  run_button.click(
@@ -476,8 +479,7 @@ def segmentation_lax_tab():
476
  with gr.Blocks() as lax_interface:
477
  gr.Markdown(
478
  """
479
- This page demonstrates the segmentation of cardiac structures in the Long-Axis (LAX) view.
480
- Please adjust the settings on the right panels and click the button to run the inference.
481
  """
482
  )
483
 
@@ -487,16 +489,16 @@ def segmentation_lax_tab():
487
  ## Description
488
  ### Data
489
 
490
- There are four example samples. All images have been resampled to 1 mm Γ— 1 mm and centre-cropped.
491
 
492
  ### Model
493
 
494
- The available models are finetuned on [M&Ms2](https://www.ub.edu/mnms-2/). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2.
495
 
496
- ### Visualization
497
 
498
  The left figure shows the segmentation of ventricles and myocardium across all time frames.
499
- The right figure plots the ventricle and mycoardium volumes across all inference time frames.
500
  """)
501
  with gr.Column(scale=3):
502
  gr.Markdown("## Data Settings")
@@ -505,10 +507,10 @@ def segmentation_lax_tab():
505
  maximum=4,
506
  step=1,
507
  label="Choose an image, ID is between 1 and 4",
508
- value=4,
509
  )
510
  with gr.Column(scale=3):
511
- gr.Markdown("## Model Setting")
512
  seed = gr.Slider(
513
  minimum=0,
514
  maximum=2,
@@ -516,7 +518,7 @@ def segmentation_lax_tab():
516
  label="Choose which seed the finetuning used",
517
  value=0,
518
  )
519
- run_button = gr.Button("Run LAX segmentation inference", variant="primary")
520
 
521
  with gr.Row():
522
  with gr.Column():
@@ -534,6 +536,194 @@ def segmentation_lax_tab():
534
  return lax_interface
535
 
536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  with gr.Blocks(
538
  theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
539
  ) as demo:
@@ -542,7 +732,7 @@ with gr.Blocks(
542
  # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
543
 
544
  This demo showcases the capabilities of CineMA in multiple tasks.
545
- For more details, checkout our [GitHub](https://github.com/mathpluscode/CineMA).
546
  """
547
  )
548
 
@@ -553,7 +743,8 @@ with gr.Blocks(
553
  mae_tab()
554
  with gr.TabItem("Segmentation in SAX View"):
555
  segmentation_sax_tab()
556
- with gr.TabItem("Segmentation in LAX View"):
557
  segmentation_lax_tab()
558
-
 
559
  demo.launch()
 
6
  import SimpleITK as sitk # noqa: N813
7
  import spaces
8
  import torch
9
+ from cinema import CineMA, ConvUNetR, ConvViT, heatmap_soft_argmax
10
  from cinema.examples.cine_cmr import plot_cmr_views
11
+ from cinema.examples.inference.landmark_heatmap import (
12
+ plot_heatmaps,
13
+ plot_landmarks,
14
+ plot_lv,
15
+ )
16
  from cinema.examples.inference.mae import plot_mae_reconstruction, reconstruct_images
17
  from cinema.examples.inference.segmentation_lax_4c import (
18
  plot_segmentations as plot_segmentations_lax,
 
68
  with gr.Blocks() as cmr_interface:
69
  gr.Markdown(
70
  """
71
+ This page illustrates the spatial orientation of short-axis (SAX) and long-axis (LAX) views in 3D.
72
  """
73
  )
74
  with gr.Row():
 
227
  with gr.Blocks() as mae_interface:
228
  gr.Markdown(
229
  """
230
+ This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
231
  """
232
  )
233
  with gr.Row():
 
339
  with gr.Blocks() as sax_interface:
340
  gr.Markdown(
341
  """
342
+ This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
 
343
  """
344
  )
345
 
 
349
  ## Description
350
  ### Data
351
 
352
+ Images 101–150 are from the test set of [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/).
 
353
 
354
  ### Model
355
 
356
+ The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are three models finetuned with seeds: 0, 1, 2.
357
 
358
+ ### Visualisation
359
 
360
+ The left figure shows the segmentation of ventricles and myocardium at every n time step across all SAX slices.
361
+ The right figure shows the volumes across all time frames and estimates the ejection fraction (EF) for the left ventricle (LV) and right ventricle (RV).
362
  """)
363
  with gr.Column(scale=3):
364
  gr.Markdown("## Data Settings")
 
366
  minimum=101,
367
  maximum=150,
368
  step=1,
369
+ label="Choose an image, ID is between 101 and 150",
370
  value=101,
371
  )
372
  t_step = gr.Slider(
 
377
  value=2,
378
  )
379
  with gr.Column(scale=3):
380
+ gr.Markdown("## Model Settings")
381
  trained_dataset = gr.Dropdown(
382
  choices=["ACDC", "M&MS", "M&MS2"],
383
+ label="Choose which dataset the model was finetuned on",
384
  value="ACDC",
385
  )
386
  seed = gr.Slider(
 
397
  gr.Markdown("## Ventricle and Myocardium Segmentation")
398
  segmentation_plot = gr.Plot(show_label=False)
399
  with gr.Column():
400
+ gr.Markdown("## Volume Estimation")
401
  volume_plot = gr.Plot(show_label=False)
402
 
403
  run_button.click(
 
479
  with gr.Blocks() as lax_interface:
480
  gr.Markdown(
481
  """
482
+ This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
 
483
  """
484
  )
485
 
 
489
  ## Description
490
  ### Data
491
 
492
+ There are four example images from the UK Biobank.
493
 
494
  ### Model
495
 
496
+ The available models are finetuned on [M&Ms2](https://www.ub.edu/mnms-2/). There are three models finetuned with seeds: 0, 1, 2.
497
 
498
+ ### Visualisation
499
 
500
  The left figure shows the segmentation of ventricles and myocardium across all time frames.
501
+ The right figure shows the volumes across all time frames and estimates the ejection fraction (EF).
502
  """)
503
  with gr.Column(scale=3):
504
  gr.Markdown("## Data Settings")
 
507
  maximum=4,
508
  step=1,
509
  label="Choose an image, ID is between 1 and 4",
510
+ value=1,
511
  )
512
  with gr.Column(scale=3):
513
+ gr.Markdown("## Model Settings")
514
  seed = gr.Slider(
515
  minimum=0,
516
  maximum=2,
 
518
  label="Choose which seed the finetuning used",
519
  value=0,
520
  )
521
+ run_button = gr.Button("Run LAX 4C segmentation inference", variant="primary")
522
 
523
  with gr.Row():
524
  with gr.Column():
 
536
  return lax_interface
537
 
538
 
539
+ @spaces.GPU
540
+ def landmark_heatmap_inference(
541
+ images: torch.Tensor,
542
+ view: str,
543
+ transform: Compose,
544
+ model: ConvUNetR,
545
+ progress=gr.Progress(),
546
+ ) -> tuple[np.ndarray, np.ndarray]:
547
+ model.to(device)
548
+
549
+ n_frames = images.shape[-1]
550
+ probs_list = []
551
+ coords_list = []
552
+ for t in tqdm(range(n_frames), total=n_frames):
553
+ progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
554
+ batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
555
+ batch = {
556
+ k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()
557
+ }
558
+ with (
559
+ torch.no_grad(),
560
+ torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
561
+ ):
562
+ logits = model(batch)[view] # (1, 3, x, y)
563
+ probs = torch.sigmoid(logits) # (1, 3, width, height)
564
+ probs_list.append(probs[0].detach().to(torch.float32).cpu().numpy())
565
+ coords = heatmap_soft_argmax(probs)[0].detach().to(torch.float32).cpu().numpy()
566
+ coords = [int(x) for x in coords]
567
+ coords_list.append(coords)
568
+ probs = np.stack(probs_list, axis=-1) # (3, x, y, t)
569
+ coords = np.stack(coords_list, axis=-1) # (6, t)
570
+ return probs, coords
571
+
572
+
573
+ @spaces.GPU
574
+ def landmark_coordinate_inference(
575
+ images: torch.Tensor,
576
+ view: str,
577
+ transform: Compose,
578
+ model: ConvViT,
579
+ progress=gr.Progress(),
580
+ ) -> np.ndarray:
581
+ model.to(device)
582
+
583
+ w, h, _, n_frames = images.shape
584
+ coords_list = []
585
+ for t in tqdm(range(n_frames), total=n_frames):
586
+ progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
587
+ batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
588
+ batch = {
589
+ k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()
590
+ }
591
+ with (
592
+ torch.no_grad(),
593
+ torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
594
+ ):
595
+ coords = model(batch)[0].detach().to(torch.float32).cpu().numpy() # (6,)
596
+ coords *= np.array([w, h, w, h, w, h])
597
+ coords = [int(x) for x in coords]
598
+ coords_list.append(coords)
599
+ coords = np.stack(coords_list, axis=-1) # (6, t)
600
+ return coords
601
+
602
+
603
+ def landmark(image_id, view, method, seed, progress=gr.Progress()):
604
+ view = "lax_2c" if view == "LAX 2C" else "lax_4c"
605
+ method = method.lower()
606
+
607
+ # Download and load model
608
+ progress(0, desc="Downloading model...")
609
+ if method == "heatmap":
610
+ model = ConvUNetR.from_finetuned(
611
+ repo_id="mathpluscode/CineMA",
612
+ model_filename=f"finetuned/landmark_{method}/{view}/{view}_{seed}.safetensors",
613
+ config_filename=f"finetuned/landmark_{method}/{view}/config.yaml",
614
+ cache_dir=cache_dir,
615
+ )
616
+ elif method == "coordinate":
617
+ model = ConvViT.from_finetuned(
618
+ repo_id="mathpluscode/CineMA",
619
+ model_filename=f"finetuned/landmark_{method}/{view}/{view}_{seed}.safetensors",
620
+ config_filename=f"finetuned/landmark_{method}/{view}/config.yaml",
621
+ cache_dir=cache_dir,
622
+ )
623
+ else:
624
+ raise ValueError(f"Invalid method: {method}")
625
+ model.eval()
626
+
627
+ # Inference
628
+ progress(0, desc="Downloading data...")
629
+ transform = ScaleIntensityd(keys=view)
630
+ images = np.transpose(
631
+ sitk.GetArrayFromImage(
632
+ load_nifti_from_github(f"ukb/{image_id}/{image_id}_{view}.nii.gz")
633
+ )
634
+ )
635
+
636
+ if method == "heatmap":
637
+ _, coords = landmark_heatmap_inference(images, view, transform, model, progress)
638
+ elif method == "coordinate":
639
+ coords = landmark_coordinate_inference(images, view, transform, model, progress)
640
+ else:
641
+ raise ValueError(f"Invalid method: {method}")
642
+
643
+ landmark_fig = plot_landmarks(images, coords)
644
+ lv_fig = plot_lv(coords)
645
+ return landmark_fig, lv_fig
646
+
647
+
648
+ def landmark_tab():
649
+ with gr.Blocks() as landmark_interface:
650
+ gr.Markdown(
651
+ """
652
+ This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views
653
+ """
654
+ )
655
+
656
+ with gr.Row():
657
+ with gr.Column(scale=4):
658
+ gr.Markdown("""
659
+ ## Description
660
+ ### Data
661
+
662
+ There are four example images from the UK Biobank.
663
+
664
+ ### Model
665
+
666
+ The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197).
667
+ There are two types of landmark localisation models:
668
+
669
+ - **Heatmap**: predicts dense probability maps of landmarks
670
+ - **Coordinate**: predicts landmark coordinates directly
671
+
672
+ For each type, there are three models finetuned with seeds: 0, 1, 2.
673
+
674
+ ### Visualisation
675
+
676
+ The left figure shows the landmark positions across all time frames.
677
+ The right figure shows the length of the left ventricle across all time frames and the estimates of two metrics:
678
+ - Mitral annular plane systolic excursion (MAPSE)
679
+ - Global longitudinal shortening (GLS)
680
+ """)
681
+ with gr.Column(scale=3):
682
+ gr.Markdown("## Data Settings")
683
+ image_id = gr.Slider(
684
+ minimum=1,
685
+ maximum=4,
686
+ step=1,
687
+ label="Choose an image, ID is between 1 and 4",
688
+ value=1,
689
+ )
690
+ view = gr.Dropdown(
691
+ choices=["LAX 2C", "LAX 4C"],
692
+ label="Choose which view to localise the landmarks",
693
+ value="LAX 2C",
694
+ )
695
+ with gr.Column(scale=3):
696
+ gr.Markdown("## Model Settings")
697
+ method = gr.Dropdown(
698
+ choices=["Heatmap", "Coordinate"],
699
+ label="Choose which method to use",
700
+ value="Heatmap",
701
+ )
702
+ seed = gr.Slider(
703
+ minimum=0,
704
+ maximum=2,
705
+ step=1,
706
+ label="Choose which seed the finetuning used",
707
+ value=0,
708
+ )
709
+ run_button = gr.Button("Run landmark localisation inference", variant="primary")
710
+
711
+ with gr.Row():
712
+ with gr.Column():
713
+ gr.Markdown("## Landmark Localisation")
714
+ landmark_plot = gr.Plot(show_label=False)
715
+ with gr.Column():
716
+ gr.Markdown("## Left Ventricle Length Estimation")
717
+ lv_plot = gr.Plot(show_label=False)
718
+
719
+ run_button.click(
720
+ fn=landmark,
721
+ inputs=[image_id, view, method, seed],
722
+ outputs=[landmark_plot, lv_plot],
723
+ )
724
+ return landmark_interface
725
+
726
+
727
  with gr.Blocks(
728
  theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
729
  ) as demo:
 
732
  # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
733
 
734
  This demo showcases the capabilities of CineMA in multiple tasks.
735
+ For more details, check out our [GitHub](https://github.com/mathpluscode/CineMA).
736
  """
737
  )
738
 
 
743
  mae_tab()
744
  with gr.TabItem("Segmentation in SAX View"):
745
  segmentation_sax_tab()
746
+ with gr.TabItem("Segmentation in LAX 4C View"):
747
  segmentation_lax_tab()
748
+ with gr.TabItem("Landmark Localisation in LAX 2C/4C View"):
749
+ landmark_tab()
750
  demo.launch()
requirements.txt CHANGED
@@ -17,6 +17,6 @@ scikit-learn==1.6.1
17
  scipy==1.15.2
18
  spaces==0.36.0
19
  timm==1.0.15
20
- git+https://github.com/mathpluscode/CineMA@3ace4d79ee037f95e8767b35c7bc97d511f8b9c1#egg=cinema
21
  --extra-index-url https://download.pytorch.org/whl/cu113
22
  torch==2.5.1
 
17
  scipy==1.15.2
18
  spaces==0.36.0
19
  timm==1.0.15
20
+ git+https://github.com/mathpluscode/CineMA@edc0774baed3c2429a2a4ffaa36a3910f2780b2b#egg=cinema
21
  --extra-index-url https://download.pytorch.org/whl/cu113
22
  torch==2.5.1