mathpluscode commited on
Commit
68ea82c
Β·
1 Parent(s): 91c04bb

Add LAX panel

Browse files
Files changed (1) hide show
  1. app.py +245 -87
app.py CHANGED
@@ -7,36 +7,50 @@ from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
7
  from cinema import ConvUNetR
8
  from pathlib import Path
9
  from cinema.examples.inference.segmentation_sax import (
10
- plot_segmentations,
11
- plot_volume_changes,
12
  )
 
 
 
 
 
 
13
  import spaces
 
14
 
15
  # cache directories
16
  cache_dir = Path("/tmp/.cinema")
17
  cache_dir.mkdir(parents=True, exist_ok=True)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @spaces.GPU
21
- def inferece(
22
  images: torch.Tensor,
23
  view: str,
24
  transform: Compose,
25
  model: ConvUNetR,
26
  progress=gr.Progress(),
27
  ) -> np.ndarray:
28
- # set device and dtype
29
- dtype, device = torch.float32, torch.device("cpu")
30
- if torch.cuda.is_available():
31
- device = torch.device("cuda")
32
- if torch.cuda.is_bf16_supported():
33
- dtype = torch.bfloat16
34
-
35
- # inference
36
  model.to(device)
37
  n_slices, n_frames = images.shape[-2:]
38
  labels_list = []
39
- for t in range(0, n_frames):
40
  progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
41
  batch = transform({view: torch.from_numpy(images[None, ..., t])})
42
  batch = {
@@ -53,7 +67,7 @@ def inferece(
53
  return labels
54
 
55
 
56
- def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()):
57
  # Fixed parameters
58
  view = "sax"
59
  split = "train" if image_id <= 100 else "test"
@@ -64,7 +78,7 @@ def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress(
64
  }[str(trained_dataset)]
65
 
66
  # Download and load model
67
- progress(0, desc="Downloading model and data...")
68
  image_path = hf_hub_download(
69
  repo_id="mathpluscode/ACDC",
70
  repo_type="dataset",
@@ -79,7 +93,8 @@ def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress(
79
  cache_dir=cache_dir,
80
  )
81
 
82
- # Load and process data
 
83
  transform = Compose(
84
  [
85
  ScaleIntensityd(keys=view),
@@ -89,91 +104,234 @@ def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress(
89
 
90
  images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
91
  images = images[..., ::t_step]
92
- labels = inferece(images, view, transform, model, progress)
93
 
94
  progress(1, desc="Plotting results...")
95
- fig1 = plot_segmentations(images, labels, t_step)
96
- fig2 = plot_volume_changes(labels, t_step)
97
 
98
  return fig1, fig2
99
 
100
 
101
- # Create the Gradio interface
102
- theme = gr.themes.Ocean(
103
- primary_hue="red",
104
- secondary_hue="purple",
105
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  with gr.Blocks(
107
  theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
108
  ) as demo:
109
  gr.Markdown(
110
  """
111
- # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
112
 
113
- Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA).
114
- """
 
115
  )
116
 
117
- with gr.Row():
118
- with gr.Column(scale=4):
119
- gr.Markdown("## Description")
120
- gr.Markdown("""
121
- Please adjust the settings on the right panels and click the button to run the inference.
122
-
123
- ### Data
124
-
125
- The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
126
- Image 1 - 100 are from the training set, and image 101 - 150 are from the test set.
127
-
128
- ### Model
129
-
130
- The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0.
131
-
132
- ### Visualization
133
-
134
- The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices.
135
- The right panel plots the ventricle and mycoardium volumes across all inference time frames.
136
- """)
137
- with gr.Column(scale=3):
138
- gr.Markdown("## Data Settings")
139
- image_id = gr.Slider(
140
- minimum=1,
141
- maximum=150,
142
- step=1,
143
- label="Choose an ACDC image, ID is between 1 and 150",
144
- value=150,
145
- )
146
- t_step = gr.Slider(
147
- minimum=1,
148
- maximum=10,
149
- step=1,
150
- label="Choose the gap between time frames",
151
- value=2,
152
- )
153
- with gr.Column(scale=3):
154
- gr.Markdown("## Model Setting")
155
- trained_dataset = gr.Dropdown(
156
- choices=["ACDC", "M&MS", "M&MS2"],
157
- label="Choose which dataset the segmentation model was finetuned on",
158
- value="ACDC",
159
- )
160
- seed = gr.Slider(
161
- minimum=0,
162
- maximum=2,
163
- step=1,
164
- label="Choose which seed the finetuning used",
165
- value=0,
166
- )
167
- run_button = gr.Button("Run segmentation inference", variant="primary")
168
-
169
- with gr.Row():
170
- segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation")
171
- volume_plot = gr.Plot(label="Ejection Fraction Prediction")
172
-
173
- run_button.click(
174
- fn=run_inference,
175
- inputs=[trained_dataset, seed, image_id, t_step],
176
- outputs=[segmentation_plot, volume_plot],
177
- )
178
 
179
  demo.launch()
 
7
  from cinema import ConvUNetR
8
  from pathlib import Path
9
  from cinema.examples.inference.segmentation_sax import (
10
+ plot_segmentations as plot_segmentations_sax,
11
+ plot_volume_changes as plot_volume_changes_sax,
12
  )
13
+ from cinema.examples.inference.segmentation_lax_4c import (
14
+ plot_segmentations as plot_segmentations_lax,
15
+ plot_volume_changes as plot_volume_changes_lax,
16
+ post_process as post_process_lax_segmentation,
17
+ )
18
+ from tqdm import tqdm
19
  import spaces
20
+ import requests
21
 
22
  # cache directories
23
  cache_dir = Path("/tmp/.cinema")
24
  cache_dir.mkdir(parents=True, exist_ok=True)
25
 
26
 
27
+ # set device and dtype
28
+ dtype, device = torch.float32, torch.device("cpu")
29
+ if torch.cuda.is_available():
30
+ device = torch.device("cuda")
31
+ if torch.cuda.is_bf16_supported():
32
+ dtype = torch.bfloat16
33
+
34
+
35
+ # Create the Gradio interface
36
+ theme = gr.themes.Ocean(
37
+ primary_hue="red",
38
+ secondary_hue="purple",
39
+ )
40
+
41
+
42
  @spaces.GPU
43
+ def segmentation_sax_inference(
44
  images: torch.Tensor,
45
  view: str,
46
  transform: Compose,
47
  model: ConvUNetR,
48
  progress=gr.Progress(),
49
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
50
  model.to(device)
51
  n_slices, n_frames = images.shape[-2:]
52
  labels_list = []
53
+ for t in tqdm(range(0, n_frames), total=n_frames):
54
  progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
55
  batch = transform({view: torch.from_numpy(images[None, ..., t])})
56
  batch = {
 
67
  return labels
68
 
69
 
70
+ def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progress()):
71
  # Fixed parameters
72
  view = "sax"
73
  split = "train" if image_id <= 100 else "test"
 
78
  }[str(trained_dataset)]
79
 
80
  # Download and load model
81
+ progress(0, desc="Downloading model...")
82
  image_path = hf_hub_download(
83
  repo_id="mathpluscode/ACDC",
84
  repo_type="dataset",
 
93
  cache_dir=cache_dir,
94
  )
95
 
96
+ # Inference
97
+ progress(0, desc="Downloading data...")
98
  transform = Compose(
99
  [
100
  ScaleIntensityd(keys=view),
 
104
 
105
  images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
106
  images = images[..., ::t_step]
107
+ labels = segmentation_sax_inference(images, view, transform, model, progress)
108
 
109
  progress(1, desc="Plotting results...")
110
+ fig1 = plot_segmentations_sax(images, labels, t_step)
111
+ fig2 = plot_volume_changes_sax(labels, t_step)
112
 
113
  return fig1, fig2
114
 
115
 
116
+ def segmentation_sax_tab():
117
+ with gr.Blocks() as sax_interface:
118
+ gr.Markdown(
119
+ """
120
+ This page demonstrates the segmentation of cardiac structures in the Short-Axis (SAX) view.
121
+ Please adjust the settings on the right panels and click the button to run the inference.
122
+ """
123
+ )
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=4):
127
+ gr.Markdown("""
128
+ ## Description
129
+ ### Data
130
+
131
+ The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
132
+ Image 1 - 100 are from the training set, and image 101 - 150 are from the test set.
133
+
134
+ ### Model
135
+
136
+ The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2.
137
+
138
+ ### Visualization
139
+
140
+ The left figure shows the segmentation of ventricles and myocardium every n time steps across all SAX slices.
141
+ The right figure plots the ventricle and mycoardium volumes across all inference time frames.
142
+ """)
143
+ with gr.Column(scale=3):
144
+ gr.Markdown("## Data Settings")
145
+ image_id = gr.Slider(
146
+ minimum=1,
147
+ maximum=150,
148
+ step=1,
149
+ label="Choose an ACDC image, ID is between 1 and 150",
150
+ value=150,
151
+ )
152
+ t_step = gr.Slider(
153
+ minimum=1,
154
+ maximum=10,
155
+ step=1,
156
+ label="Choose the gap between time frames",
157
+ value=2,
158
+ )
159
+ with gr.Column(scale=3):
160
+ gr.Markdown("## Model Setting")
161
+ trained_dataset = gr.Dropdown(
162
+ choices=["ACDC", "M&MS", "M&MS2"],
163
+ label="Choose which dataset the segmentation model was finetuned on",
164
+ value="ACDC",
165
+ )
166
+ seed = gr.Slider(
167
+ minimum=0,
168
+ maximum=2,
169
+ step=1,
170
+ label="Choose which seed the finetuning used",
171
+ value=0,
172
+ )
173
+ run_button = gr.Button("Run SAX segmentation inference", variant="primary")
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ gr.Markdown("## Ventricle and Myocardium Segmentation")
178
+ segmentation_plot = gr.Plot(show_label=False)
179
+ with gr.Column():
180
+ gr.Markdown("## Ejection Fraction Prediction")
181
+ volume_plot = gr.Plot(show_label=False)
182
+
183
+ run_button.click(
184
+ fn=segmentation_sax,
185
+ inputs=[trained_dataset, seed, image_id, t_step],
186
+ outputs=[segmentation_plot, volume_plot],
187
+ )
188
+ return sax_interface
189
+
190
+
191
+ @spaces.GPU
192
+ def segmentation_lax_inference(
193
+ images: torch.Tensor,
194
+ view: str,
195
+ transform: Compose,
196
+ model: ConvUNetR,
197
+ progress=gr.Progress(),
198
+ ) -> np.ndarray:
199
+ model.to(device)
200
+ n_frames = images.shape[-1]
201
+ labels_list = []
202
+ for t in tqdm(range(n_frames), total=n_frames):
203
+ progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
204
+ batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
205
+ batch = {
206
+ k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()
207
+ }
208
+ with (
209
+ torch.no_grad(),
210
+ torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
211
+ ):
212
+ logits = model(batch)[view] # (1, 4, x, y)
213
+ labels = torch.argmax(logits, dim=1)[0].detach().cpu().numpy() # (x, y)
214
+
215
+ # the model seems to hallucinate an additional right ventricle and myocardium sometimes
216
+ # find the connected component that is closest to left ventricle
217
+ labels = post_process_lax_segmentation(labels)
218
+ labels_list.append(labels)
219
+ labels = np.stack(labels_list, axis=-1) # (x, y, t)
220
+ return labels
221
+
222
+
223
+ def segmentation_lax(seed, image_id, progress=gr.Progress()):
224
+ # Fixed parameters
225
+ trained_dataset = "mnms2"
226
+ view = "lax_4c"
227
+
228
+ # Download and load model
229
+ progress(0, desc="Downloading model...")
230
+ image_url = f"https://raw.githubusercontent.com/mathpluscode/CineMA/main/cinema/examples/data/ukb/{image_id}/{image_id}_{view}.nii.gz"
231
+ image_path = cache_dir / f"{image_id}_{view}.nii.gz"
232
+ response = requests.get(image_url)
233
+ with open(image_path, "wb") as f:
234
+ f.write(response.content)
235
+
236
+ model = ConvUNetR.from_finetuned(
237
+ repo_id="mathpluscode/CineMA",
238
+ model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
239
+ config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
240
+ cache_dir=cache_dir,
241
+ )
242
+
243
+ # Inference
244
+ progress(0, desc="Downloading data...")
245
+ transform = ScaleIntensityd(keys=view)
246
+
247
+ images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
248
+ labels = segmentation_lax_inference(images, view, transform, model, progress)
249
+
250
+ progress(1, desc="Plotting results...")
251
+ fig1 = plot_segmentations_lax(images, labels)
252
+ fig2 = plot_volume_changes_lax(labels)
253
+
254
+ return fig1, fig2
255
+
256
+
257
+ def segmentation_lax_tab():
258
+ with gr.Blocks() as lax_interface:
259
+ gr.Markdown(
260
+ """
261
+ This page demonstrates the segmentation of cardiac structures in the Long-Axis (LAX) view.
262
+ Please adjust the settings on the right panels and click the button to run the inference.
263
+ """
264
+ )
265
+
266
+ with gr.Row():
267
+ with gr.Column(scale=4):
268
+ gr.Markdown("""
269
+ ## Description
270
+ ### Data
271
+
272
+ There are four example samples. All images have been resampled to 1 mm Γ— 1 mm and centre-cropped.
273
+
274
+ ### Model
275
+
276
+ The available models are finetuned on [M&Ms2](https://www.ub.edu/mnms-2/). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2.
277
+
278
+ ### Visualization
279
+
280
+ The left figure shows the segmentation of ventricles and myocardium across all time frames.
281
+ The right figure plots the ventricle and mycoardium volumes across all inference time frames.
282
+ """)
283
+ with gr.Column(scale=3):
284
+ gr.Markdown("## Data Settings")
285
+ image_id = gr.Slider(
286
+ minimum=1,
287
+ maximum=4,
288
+ step=1,
289
+ label="Choose an image, ID is between 1 and 4",
290
+ value=4,
291
+ )
292
+ with gr.Column(scale=3):
293
+ gr.Markdown("## Model Setting")
294
+ seed = gr.Slider(
295
+ minimum=0,
296
+ maximum=2,
297
+ step=1,
298
+ label="Choose which seed the finetuning used",
299
+ value=0,
300
+ )
301
+ run_button = gr.Button("Run LAX segmentation inference", variant="primary")
302
+
303
+ with gr.Row():
304
+ with gr.Column():
305
+ gr.Markdown("## Ventricle and Myocardium Segmentation")
306
+ segmentation_plot = gr.Plot(show_label=False)
307
+ with gr.Column():
308
+ gr.Markdown("## Ejection Fraction Prediction")
309
+ volume_plot = gr.Plot(show_label=False)
310
+
311
+ run_button.click(
312
+ fn=segmentation_lax,
313
+ inputs=[seed, image_id],
314
+ outputs=[segmentation_plot, volume_plot],
315
+ )
316
+ return lax_interface
317
+
318
+
319
  with gr.Blocks(
320
  theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
321
  ) as demo:
322
  gr.Markdown(
323
  """
324
+ # CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
325
 
326
+ This demo showcases the capabilities of CineMA in multiple tasks.
327
+ For more details, checkout our [GitHub](https://github.com/mathpluscode/CineMA).
328
+ """
329
  )
330
 
331
+ with gr.Tabs() as tabs:
332
+ with gr.TabItem("Segmentation in SAX View"):
333
+ segmentation_sax_tab()
334
+ with gr.TabItem("Segmentation in LAX View"):
335
+ segmentation_lax_tab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  demo.launch()