mathpluscode commited on
Commit
078793a
·
1 Parent(s): 68ea82c

Add view panel

Browse files
Files changed (1) hide show
  1. app.py +96 -7
app.py CHANGED
@@ -15,6 +15,7 @@ from cinema.examples.inference.segmentation_lax_4c import (
15
  plot_volume_changes as plot_volume_changes_lax,
16
  post_process as post_process_lax_segmentation,
17
  )
 
18
  from tqdm import tqdm
19
  import spaces
20
  import requests
@@ -39,6 +40,94 @@ theme = gr.themes.Ocean(
39
  )
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @spaces.GPU
43
  def segmentation_sax_inference(
44
  images: torch.Tensor,
@@ -227,12 +316,6 @@ def segmentation_lax(seed, image_id, progress=gr.Progress()):
227
 
228
  # Download and load model
229
  progress(0, desc="Downloading model...")
230
- image_url = f"https://raw.githubusercontent.com/mathpluscode/CineMA/main/cinema/examples/data/ukb/{image_id}/{image_id}_{view}.nii.gz"
231
- image_path = cache_dir / f"{image_id}_{view}.nii.gz"
232
- response = requests.get(image_url)
233
- with open(image_path, "wb") as f:
234
- f.write(response.content)
235
-
236
  model = ConvUNetR.from_finetuned(
237
  repo_id="mathpluscode/CineMA",
238
  model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
@@ -244,7 +327,11 @@ def segmentation_lax(seed, image_id, progress=gr.Progress()):
244
  progress(0, desc="Downloading data...")
245
  transform = ScaleIntensityd(keys=view)
246
 
247
- images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
 
 
 
 
248
  labels = segmentation_lax_inference(images, view, transform, model, progress)
249
 
250
  progress(1, desc="Plotting results...")
@@ -329,6 +416,8 @@ with gr.Blocks(
329
  )
330
 
331
  with gr.Tabs() as tabs:
 
 
332
  with gr.TabItem("Segmentation in SAX View"):
333
  segmentation_sax_tab()
334
  with gr.TabItem("Segmentation in LAX View"):
 
15
  plot_volume_changes as plot_volume_changes_lax,
16
  post_process as post_process_lax_segmentation,
17
  )
18
+ from cinema.examples.cine_cmr import plot_cmr_views
19
  from tqdm import tqdm
20
  import spaces
21
  import requests
 
40
  )
41
 
42
 
43
+ def load_nifti_from_github(name: str) -> sitk.Image:
44
+ path = cache_dir / name
45
+ if not path.exists():
46
+ image_url = f"https://raw.githubusercontent.com/mathpluscode/CineMA/main/cinema/examples/data/{name}"
47
+ response = requests.get(image_url)
48
+ path.parent.mkdir(parents=True, exist_ok=True)
49
+ with open(path, "wb") as f:
50
+ f.write(response.content)
51
+ return sitk.ReadImage(path)
52
+
53
+
54
+ def cmr_tab():
55
+ with gr.Blocks() as sax_interface:
56
+ gr.Markdown(
57
+ """
58
+ This page demonstrates the geometry of SAX and LAX views in 3D spaces.
59
+ Please adjust the settings on the right panels to select images and slices.
60
+ """
61
+ )
62
+ with gr.Row():
63
+ with gr.Column(scale=3):
64
+ gr.Markdown("## Views")
65
+ cmr_plot = gr.Plot(show_label=False)
66
+ with gr.Column(scale=1):
67
+ gr.Markdown("## Data Settings")
68
+ image_id = gr.Slider(
69
+ minimum=1,
70
+ maximum=4,
71
+ step=1,
72
+ label="Choose an image, ID is between 1 and 4",
73
+ value=1,
74
+ )
75
+ # Placeholder for slice slider, will update dynamically
76
+ slice_idx = gr.Slider(
77
+ minimum=0,
78
+ maximum=8,
79
+ step=1,
80
+ label="SAX slice to visualize",
81
+ value=0,
82
+ )
83
+
84
+ def get_num_slices(image_id):
85
+ sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz")
86
+ return sax_image.GetSize()[2]
87
+
88
+ def update_slice_slider(image_id):
89
+ num_slices = get_num_slices(image_id)
90
+ return gr.update(maximum=num_slices - 1, value=0, visible=True)
91
+
92
+ def fn(image_id, slice_idx):
93
+ lax_2c_image = load_nifti_from_github(
94
+ f"ukb/{image_id}/{image_id}_lax_2c.nii.gz"
95
+ )
96
+ lax_3c_image = load_nifti_from_github(
97
+ f"ukb/{image_id}/{image_id}_lax_3c.nii.gz"
98
+ )
99
+ lax_4c_image = load_nifti_from_github(
100
+ f"ukb/{image_id}/{image_id}_lax_4c.nii.gz"
101
+ )
102
+ sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz")
103
+ fig = plot_cmr_views(
104
+ lax_2c_image,
105
+ lax_3c_image,
106
+ lax_4c_image,
107
+ sax_image,
108
+ t_to_show=4,
109
+ depth_to_show=slice_idx,
110
+ )
111
+ fig.update_layout(height=600)
112
+ return fig
113
+
114
+ # When image changes, update the slice slider and plot
115
+ gr.on(
116
+ fn=lambda image_id: [update_slice_slider(image_id), fn(image_id, 0)],
117
+ inputs=[image_id],
118
+ outputs=[slice_idx, cmr_plot],
119
+ )
120
+
121
+ # When slice changes, update the plot
122
+ slice_idx.change(
123
+ fn=fn,
124
+ inputs=[image_id, slice_idx],
125
+ outputs=[cmr_plot],
126
+ )
127
+
128
+ return sax_interface
129
+
130
+
131
  @spaces.GPU
132
  def segmentation_sax_inference(
133
  images: torch.Tensor,
 
316
 
317
  # Download and load model
318
  progress(0, desc="Downloading model...")
 
 
 
 
 
 
319
  model = ConvUNetR.from_finetuned(
320
  repo_id="mathpluscode/CineMA",
321
  model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
 
327
  progress(0, desc="Downloading data...")
328
  transform = ScaleIntensityd(keys=view)
329
 
330
+ images = np.transpose(
331
+ sitk.GetArrayFromImage(
332
+ load_nifti_from_github(f"ukb/{image_id}/{image_id}_{view}.nii.gz")
333
+ )
334
+ )
335
  labels = segmentation_lax_inference(images, view, transform, model, progress)
336
 
337
  progress(1, desc="Plotting results...")
 
416
  )
417
 
418
  with gr.Tabs() as tabs:
419
+ with gr.TabItem("Cine CMR Views"):
420
+ cmr_tab()
421
  with gr.TabItem("Segmentation in SAX View"):
422
  segmentation_sax_tab()
423
  with gr.TabItem("Segmentation in LAX View"):