mathpluscode commited on
Commit
7f9c492
Β·
1 Parent(s): c69a91e

Add MAE tab

Browse files
Files changed (4) hide show
  1. .pre-commit-config.yaml +1 -1
  2. README.md +6 -6
  3. app.py +157 -24
  4. requirements.txt +1 -1
.pre-commit-config.yaml CHANGED
@@ -27,7 +27,7 @@ repos:
27
  hooks:
28
  # run the linter
29
  - id: ruff
30
- args: [--fix]
31
  # run the formatter
32
  - id: ruff-format
33
  - repo: https://github.com/pre-commit/mirrors-prettier
 
27
  hooks:
28
  # run the linter
29
  - id: ruff
30
+ args: ["--fix", "--select=I"]
31
  # run the formatter
32
  - id: ruff-format
33
  - repo: https://github.com/pre-commit/mirrors-prettier
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  title: CineMA
3
  tags:
4
- - medical
5
- - cardiac
6
- - MRI
7
- - foundation model
8
- - MAE
9
  emoji: πŸš€
10
  colorFrom: red
11
  colorTo: purple
@@ -23,4 +23,4 @@ thumbnail: >-
23
  # CineMA: A Foundation Model for Cine Cardiac MRI
24
 
25
  This is a demo of CineMA, a foundation model for cine cardiac MRI. For more details, checkout our
26
- [GitHub](https://github.com/mathpluscode/CineMA).
 
1
  ---
2
  title: CineMA
3
  tags:
4
+ - medical
5
+ - cardiac
6
+ - MRI
7
+ - foundation model
8
+ - MAE
9
  emoji: πŸš€
10
  colorFrom: red
11
  colorTo: purple
 
23
  # CineMA: A Foundation Model for Cine Cardiac MRI
24
 
25
  This is a demo of CineMA, a foundation model for cine cardiac MRI. For more details, checkout our
26
+ [GitHub](https://github.com/mathpluscode/CineMA).
app.py CHANGED
@@ -1,24 +1,32 @@
1
- import numpy as np
 
2
  import gradio as gr
3
- from huggingface_hub import hf_hub_download
 
4
  import SimpleITK as sitk # noqa: N813
 
5
  import torch
6
- from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
7
- from cinema import ConvUNetR
8
- from pathlib import Path
9
- from cinema.examples.inference.segmentation_sax import (
10
- plot_segmentations as plot_segmentations_sax,
11
- plot_volume_changes as plot_volume_changes_sax,
12
- )
13
  from cinema.examples.inference.segmentation_lax_4c import (
14
  plot_segmentations as plot_segmentations_lax,
 
 
15
  plot_volume_changes as plot_volume_changes_lax,
 
 
16
  post_process as post_process_lax_segmentation,
17
  )
18
- from cinema.examples.cine_cmr import plot_cmr_views
 
 
 
 
 
 
 
19
  from tqdm import tqdm
20
- import spaces
21
- import requests
22
 
23
  # cache directories
24
  cache_dir = Path("/tmp/.cinema")
@@ -52,18 +60,17 @@ def load_nifti_from_github(name: str) -> sitk.Image:
52
 
53
 
54
  def cmr_tab():
55
- with gr.Blocks() as sax_interface:
56
  gr.Markdown(
57
  """
58
- This page demonstrates the geometry of SAX and LAX views in 3D spaces.
59
- Please adjust the settings on the right panels to select images and slices.
60
  """
61
  )
62
  with gr.Row():
63
- with gr.Column(scale=3):
64
  gr.Markdown("## Views")
65
  cmr_plot = gr.Plot(show_label=False)
66
- with gr.Column(scale=1):
67
  gr.Markdown("## Data Settings")
68
  image_id = gr.Slider(
69
  minimum=1,
@@ -125,7 +132,127 @@ def cmr_tab():
125
  outputs=[cmr_plot],
126
  )
127
 
128
- return sax_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
 
131
  @spaces.GPU
@@ -152,7 +279,7 @@ def segmentation_sax_inference(
152
  ):
153
  logits = model(batch)[view]
154
  labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
155
- labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy()
156
  return labels
157
 
158
 
@@ -181,6 +308,7 @@ def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progre
181
  config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
182
  cache_dir=cache_dir,
183
  )
 
184
 
185
  # Inference
186
  progress(0, desc="Downloading data...")
@@ -218,7 +346,7 @@ def segmentation_sax_tab():
218
  ### Data
219
 
220
  The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
221
- Image 1 - 100 are from the training set, and image 101 - 150 are from the test set.
222
 
223
  ### Model
224
 
@@ -232,11 +360,11 @@ def segmentation_sax_tab():
232
  with gr.Column(scale=3):
233
  gr.Markdown("## Data Settings")
234
  image_id = gr.Slider(
235
- minimum=1,
236
  maximum=150,
237
  step=1,
238
- label="Choose an ACDC image, ID is between 1 and 150",
239
- value=150,
240
  )
241
  t_step = gr.Slider(
242
  minimum=1,
@@ -299,7 +427,9 @@ def segmentation_lax_inference(
299
  torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
300
  ):
301
  logits = model(batch)[view] # (1, 4, x, y)
302
- labels = torch.argmax(logits, dim=1)[0].detach().cpu().numpy() # (x, y)
 
 
303
 
304
  # the model seems to hallucinate an additional right ventricle and myocardium sometimes
305
  # find the connected component that is closest to left ventricle
@@ -322,6 +452,7 @@ def segmentation_lax(seed, image_id, progress=gr.Progress()):
322
  config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
323
  cache_dir=cache_dir,
324
  )
 
325
 
326
  # Inference
327
  progress(0, desc="Downloading data...")
@@ -418,6 +549,8 @@ with gr.Blocks(
418
  with gr.Tabs() as tabs:
419
  with gr.TabItem("Cine CMR Views"):
420
  cmr_tab()
 
 
421
  with gr.TabItem("Segmentation in SAX View"):
422
  segmentation_sax_tab()
423
  with gr.TabItem("Segmentation in LAX View"):
 
1
+ from pathlib import Path
2
+
3
  import gradio as gr
4
+ import numpy as np
5
+ import requests
6
  import SimpleITK as sitk # noqa: N813
7
+ import spaces
8
  import torch
9
+ from cinema import CineMA, ConvUNetR
10
+ from cinema.examples.cine_cmr import plot_cmr_views
11
+ from cinema.examples.inference.mae import plot_mae_reconstruction, reconstruct_images
 
 
 
 
12
  from cinema.examples.inference.segmentation_lax_4c import (
13
  plot_segmentations as plot_segmentations_lax,
14
+ )
15
+ from cinema.examples.inference.segmentation_lax_4c import (
16
  plot_volume_changes as plot_volume_changes_lax,
17
+ )
18
+ from cinema.examples.inference.segmentation_lax_4c import (
19
  post_process as post_process_lax_segmentation,
20
  )
21
+ from cinema.examples.inference.segmentation_sax import (
22
+ plot_segmentations as plot_segmentations_sax,
23
+ )
24
+ from cinema.examples.inference.segmentation_sax import (
25
+ plot_volume_changes as plot_volume_changes_sax,
26
+ )
27
+ from huggingface_hub import hf_hub_download
28
+ from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
29
  from tqdm import tqdm
 
 
30
 
31
  # cache directories
32
  cache_dir = Path("/tmp/.cinema")
 
60
 
61
 
62
  def cmr_tab():
63
+ with gr.Blocks() as cmr_interface:
64
  gr.Markdown(
65
  """
66
+ This page illustrates the spatial orientation of short-axis (SAX) and long-axis (LAX) views in 3D. Use the control panels on the right to select specific images and slices.
 
67
  """
68
  )
69
  with gr.Row():
70
+ with gr.Column(scale=5):
71
  gr.Markdown("## Views")
72
  cmr_plot = gr.Plot(show_label=False)
73
+ with gr.Column(scale=3):
74
  gr.Markdown("## Data Settings")
75
  image_id = gr.Slider(
76
  minimum=1,
 
132
  outputs=[cmr_plot],
133
  )
134
 
135
+ return cmr_interface
136
+
137
+
138
+ @spaces.GPU
139
+ def mae_inference(
140
+ batch: dict[str, torch.Tensor],
141
+ transform: Compose,
142
+ model: CineMA,
143
+ mask_ratio: float,
144
+ ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], dict[str, np.ndarray]]:
145
+ model.to(device)
146
+ sax_slices = batch["sax"].shape[-1]
147
+ batch = transform(batch)
148
+ batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
149
+ with (
150
+ torch.no_grad(),
151
+ torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
152
+ ):
153
+ _, pred_dict, enc_mask_dict, _ = model(batch, enc_mask_ratio=mask_ratio)
154
+ grid_size_dict = {
155
+ k: v.patch_embed.grid_size for k, v in model.enc_down_dict.items()
156
+ }
157
+ reconstructed_dict, masks_dict = reconstruct_images(
158
+ batch,
159
+ pred_dict,
160
+ enc_mask_dict,
161
+ model.dec_patch_size_dict,
162
+ grid_size_dict,
163
+ sax_slices,
164
+ )
165
+ batch = {
166
+ k: v.detach().to(torch.float32).cpu().numpy()[0, 0]
167
+ for k, v in batch.items()
168
+ }
169
+ batch["sax"] = batch["sax"][..., :sax_slices]
170
+ return batch, reconstructed_dict, masks_dict
171
+
172
+
173
+ def mae(image_id, mask_ratio, progress=gr.Progress()):
174
+ t = 4 # which time frame to use
175
+ progress(0, desc="Downloading model...")
176
+ model = CineMA.from_pretrained()
177
+ model.eval()
178
+
179
+ progress(0, desc="Downloading data...")
180
+ lax_2c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_2c.nii.gz")
181
+ lax_3c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_3c.nii.gz")
182
+ lax_4c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_4c.nii.gz")
183
+ sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz")
184
+ transform = Compose(
185
+ [
186
+ ScaleIntensityd(keys=("sax", "lax_2c", "lax_3c", "lax_4c")),
187
+ SpatialPadd(keys="sax", spatial_size=(192, 192, 16), method="end"),
188
+ SpatialPadd(
189
+ keys=("lax_2c", "lax_3c", "lax_4c"),
190
+ spatial_size=(256, 256),
191
+ method="end",
192
+ ),
193
+ ]
194
+ )
195
+ lax_2c_image_np = np.transpose(sitk.GetArrayFromImage(lax_2c_image))
196
+ lax_3c_image_np = np.transpose(sitk.GetArrayFromImage(lax_3c_image))
197
+ lax_4c_image_np = np.transpose(sitk.GetArrayFromImage(lax_4c_image))
198
+ sax_image_np = np.transpose(sitk.GetArrayFromImage(sax_image))
199
+ image_dict = {
200
+ "sax": sax_image_np[None, ..., t],
201
+ "lax_2c": lax_2c_image_np[None, ..., 0, t],
202
+ "lax_3c": lax_3c_image_np[None, ..., 0, t],
203
+ "lax_4c": lax_4c_image_np[None, ..., 0, t],
204
+ }
205
+ batch = {k: torch.from_numpy(v) for k, v in image_dict.items()}
206
+
207
+ progress(0.5, desc="Running inference...")
208
+ batch, reconstructed_dict, masks_dict = mae_inference(
209
+ batch, transform, model, mask_ratio
210
+ )
211
+ progress(1, desc="Plotting results...")
212
+
213
+ fig = plot_mae_reconstruction(
214
+ batch,
215
+ reconstructed_dict,
216
+ masks_dict,
217
+ )
218
+ return fig
219
+
220
+
221
+ def mae_tab():
222
+ with gr.Blocks() as mae_interface:
223
+ gr.Markdown(
224
+ """
225
+ This page illustrates the masking and reconstruction process of the masked autoencoder. The model was trained with mask ratio 0.75 over 74,000 studies.
226
+ """
227
+ )
228
+ with gr.Row():
229
+ with gr.Column(scale=5):
230
+ gr.Markdown("## Reconstruction")
231
+ plot = gr.Plot(show_label=False)
232
+ with gr.Column(scale=3):
233
+ gr.Markdown("## Data Settings")
234
+ image_id = gr.Slider(
235
+ minimum=1,
236
+ maximum=4,
237
+ step=1,
238
+ label="Choose an image, ID is between 1 and 4",
239
+ value=1,
240
+ )
241
+ mask_ratio = gr.Slider(
242
+ minimum=0.05,
243
+ maximum=1,
244
+ step=0.05,
245
+ label="Mask ratio",
246
+ value=0.75,
247
+ )
248
+ run_button = gr.Button("Run Masked Autoencoder", variant="primary")
249
+ run_button.click(
250
+ fn=mae,
251
+ inputs=[image_id, mask_ratio],
252
+ outputs=[plot],
253
+ )
254
+
255
+ return mae_interface
256
 
257
 
258
  @spaces.GPU
 
279
  ):
280
  logits = model(batch)[view]
281
  labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
282
+ labels = torch.stack(labels_list, dim=-1).detach().to(torch.float32).cpu().numpy()
283
  return labels
284
 
285
 
 
308
  config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
309
  cache_dir=cache_dir,
310
  )
311
+ model.eval()
312
 
313
  # Inference
314
  progress(0, desc="Downloading data...")
 
346
  ### Data
347
 
348
  The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
349
+ Image 101 - 150 are from the test set.
350
 
351
  ### Model
352
 
 
360
  with gr.Column(scale=3):
361
  gr.Markdown("## Data Settings")
362
  image_id = gr.Slider(
363
+ minimum=101,
364
  maximum=150,
365
  step=1,
366
+ label="Choose an ACDC image, ID is between 101 and 150",
367
+ value=101,
368
  )
369
  t_step = gr.Slider(
370
  minimum=1,
 
427
  torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
428
  ):
429
  logits = model(batch)[view] # (1, 4, x, y)
430
+ labels = (
431
+ torch.argmax(logits, dim=1)[0].detach().to(torch.float32).cpu().numpy()
432
+ ) # (x, y)
433
 
434
  # the model seems to hallucinate an additional right ventricle and myocardium sometimes
435
  # find the connected component that is closest to left ventricle
 
452
  config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
453
  cache_dir=cache_dir,
454
  )
455
+ model.eval()
456
 
457
  # Inference
458
  progress(0, desc="Downloading data...")
 
549
  with gr.Tabs() as tabs:
550
  with gr.TabItem("Cine CMR Views"):
551
  cmr_tab()
552
+ with gr.TabItem("Masked Autoencoder"):
553
+ mae_tab()
554
  with gr.TabItem("Segmentation in SAX View"):
555
  segmentation_sax_tab()
556
  with gr.TabItem("Segmentation in LAX View"):
requirements.txt CHANGED
@@ -17,6 +17,6 @@ scikit-learn==1.6.1
17
  scipy==1.15.2
18
  spaces==0.36.0
19
  timm==1.0.15
20
- git+https://github.com/mathpluscode/CineMA@1ff0e2220676aeff34988614e458588fc8150473#egg=cinema
21
  --extra-index-url https://download.pytorch.org/whl/cu113
22
  torch==2.5.1
 
17
  scipy==1.15.2
18
  spaces==0.36.0
19
  timm==1.0.15
20
+ git+https://github.com/mathpluscode/CineMA@3ace4d79ee037f95e8767b35c7bc97d511f8b9c1#egg=cinema
21
  --extra-index-url https://download.pytorch.org/whl/cu113
22
  torch==2.5.1