mathpluscode commited on
Commit
ebd9a25
·
1 Parent(s): d505973
Files changed (1) hide show
  1. app.py +5 -54
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import numpy as np
2
  import gradio as gr
3
  from huggingface_hub import hf_hub_download
4
- import matplotlib.pyplot as plt
5
  import SimpleITK as sitk # noqa: N813
6
  import torch
7
  from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
8
  from cinema import ConvUNetR
9
  from pathlib import Path
 
10
  import spaces
11
 
12
  # cache directories
@@ -25,7 +25,6 @@ def inferece(
25
  # set device and dtype
26
  dtype, device = torch.float32, torch.device("cpu")
27
  if torch.cuda.is_available():
28
- torch.cuda.empty_cache()
29
  device = torch.device("cuda")
30
  if torch.cuda.is_bf16_supported():
31
  dtype = torch.bfloat16
@@ -81,13 +80,7 @@ def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress(
81
  transform = Compose(
82
  [
83
  ScaleIntensityd(keys=view),
84
- SpatialPadd(
85
- keys=view,
86
- spatial_size=(192, 192, 16),
87
- method="end",
88
- lazy=True,
89
- allow_missing_keys=True,
90
- ),
91
  ]
92
  )
93
 
@@ -96,50 +89,8 @@ def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress(
96
  labels = inferece(images, view, transform, model, progress)
97
 
98
  progress(1, desc="Plotting results...")
99
- # Create segmentation visualization
100
- n_slices, n_frames = labels.shape[-2:]
101
- fig1, axs = plt.subplots(n_frames, n_slices, figsize=(n_slices, n_frames), dpi=300)
102
- for t in range(n_frames):
103
- for z in range(n_slices):
104
- axs[t, z].imshow(images[..., z, t], cmap="gray")
105
- axs[t, z].imshow(
106
- (labels[..., z, t, None] == 1)
107
- * np.array([108 / 255, 142 / 255, 191 / 255, 0.6])
108
- )
109
- axs[t, z].imshow(
110
- (labels[..., z, t, None] == 2)
111
- * np.array([214 / 255, 182 / 255, 86 / 255, 0.6])
112
- )
113
- axs[t, z].imshow(
114
- (labels[..., z, t, None] == 3)
115
- * np.array([130 / 255, 179 / 255, 102 / 255, 0.6])
116
- )
117
- axs[t, z].set_xticks([])
118
- axs[t, z].set_yticks([])
119
- if z == 0:
120
- axs[t, z].set_ylabel(f"t = {t * t_step}")
121
- fig1.suptitle(f"Subject {image_id} in {split} split")
122
- axs[0, n_slices // 2].set_title("SAX Slices")
123
- fig1.tight_layout()
124
- plt.subplots_adjust(wspace=0, hspace=0)
125
-
126
- # Create volume plot
127
- xs = np.arange(n_frames) * t_step
128
- rv_volumes = np.sum(labels == 1, axis=(0, 1, 2)) * 10 / 1000
129
- myo_volumes = np.sum(labels == 2, axis=(0, 1, 2)) * 10 / 1000
130
- lv_volumes = np.sum(labels == 3, axis=(0, 1, 2)) * 10 / 1000
131
- lvef = (max(lv_volumes) - min(lv_volumes)) / max(lv_volumes) * 100
132
- rvef = (max(rv_volumes) - min(rv_volumes)) / max(rv_volumes) * 100
133
-
134
- fig2, ax = plt.subplots(figsize=(4, 4), dpi=120)
135
- ax.plot(xs, rv_volumes, color="#6C8EBF", label="RV")
136
- ax.plot(xs, myo_volumes, color="#D6B656", label="MYO")
137
- ax.plot(xs, lv_volumes, color="#82B366", label="LV")
138
- ax.set_xlabel("Frame")
139
- ax.set_ylabel("Volume (ml)")
140
- ax.set_title(f"LVEF = {lvef:.2f}%, RVEF = {rvef:.2f}%")
141
- ax.legend(loc="lower right")
142
- fig2.tight_layout()
143
 
144
  return fig1, fig2
145
 
@@ -187,7 +138,7 @@ with gr.Blocks(
187
  maximum=150,
188
  step=1,
189
  label="Choose an ACDC image, ID is between 1 and 150",
190
- value=1,
191
  )
192
  t_step = gr.Slider(
193
  minimum=1,
 
1
  import numpy as np
2
  import gradio as gr
3
  from huggingface_hub import hf_hub_download
 
4
  import SimpleITK as sitk # noqa: N813
5
  import torch
6
  from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
7
  from cinema import ConvUNetR
8
  from pathlib import Path
9
+ from examples.inference.segmentation_sax import plot_segmentations, plot_volume_changes
10
  import spaces
11
 
12
  # cache directories
 
25
  # set device and dtype
26
  dtype, device = torch.float32, torch.device("cpu")
27
  if torch.cuda.is_available():
 
28
  device = torch.device("cuda")
29
  if torch.cuda.is_bf16_supported():
30
  dtype = torch.bfloat16
 
80
  transform = Compose(
81
  [
82
  ScaleIntensityd(keys=view),
83
+ SpatialPadd(keys=view, spatial_size=(192, 192, 16), method="end"),
 
 
 
 
 
 
84
  ]
85
  )
86
 
 
89
  labels = inferece(images, view, transform, model, progress)
90
 
91
  progress(1, desc="Plotting results...")
92
+ fig1 = plot_segmentations(images, labels, t_step)
93
+ fig2 = plot_volume_changes(labels, t_step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  return fig1, fig2
96
 
 
138
  maximum=150,
139
  step=1,
140
  label="Choose an ACDC image, ID is between 1 and 150",
141
+ value=150,
142
  )
143
  t_step = gr.Slider(
144
  minimum=1,