Spaces:
Running
on
Zero
Running
on
Zero
from pathlib import Path | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import requests | |
import SimpleITK as sitk # noqa: N813 | |
import spaces | |
import torch | |
from cinema import CineMA, ConvUNetR, ConvViT, heatmap_soft_argmax | |
from cinema.examples.cine_cmr import plot_cmr_views | |
from cinema.examples.inference.landmark_heatmap import ( | |
plot_heatmap_and_landmarks, | |
plot_landmarks, | |
plot_lv, | |
) | |
from cinema.examples.inference.mae import plot_mae_reconstruction, reconstruct_images | |
from cinema.examples.inference.segmentation_lax_4c import ( | |
plot_segmentations as plot_segmentations_lax, | |
) | |
from cinema.examples.inference.segmentation_lax_4c import ( | |
plot_volume_changes as plot_volume_changes_lax, | |
) | |
from cinema.examples.inference.segmentation_lax_4c import ( | |
post_process as post_process_lax_segmentation, | |
) | |
from cinema.examples.inference.segmentation_sax import ( | |
plot_segmentations as plot_segmentations_sax, | |
) | |
from cinema.examples.inference.segmentation_sax import ( | |
plot_volume_changes as plot_volume_changes_sax, | |
) | |
from huggingface_hub import hf_hub_download | |
from monai.transforms import Compose, ScaleIntensityd, SpatialPadd | |
from tqdm import tqdm | |
# cache directories | |
cache_dir = Path(__file__).parent | |
cache_dir.mkdir(parents=True, exist_ok=True) | |
# set device and dtype | |
dtype, device = torch.float32, torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
if torch.cuda.is_bf16_supported(): | |
dtype = torch.bfloat16 | |
# Create the Gradio interface | |
theme = gr.themes.Ocean( | |
primary_hue="red", | |
secondary_hue="purple", | |
) | |
def load_nifti_from_github(name: str) -> sitk.Image: | |
path = cache_dir / name | |
if not path.exists(): | |
image_url = f"https://raw.githubusercontent.com/mathpluscode/CineMA/main/cinema/examples/data/{name}" | |
response = requests.get(image_url) | |
path.parent.mkdir(parents=True, exist_ok=True) | |
with open(path, "wb") as f: | |
f.write(response.content) | |
return sitk.ReadImage(path) | |
def cmr_tab(): | |
with gr.Blocks() as cmr_interface: | |
gr.Markdown( | |
""" | |
This page illustrates the spatial orientation of short-axis (SAX) and long-axis (LAX) views in 3D. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
gr.Markdown("## Views") | |
cmr_plot = gr.Plot(show_label=False) | |
with gr.Column(scale=3): | |
gr.Markdown("## Data Settings") | |
image_id = gr.Slider( | |
minimum=1, | |
maximum=4, | |
step=1, | |
label="Choose an image", | |
value=2, | |
) | |
# Placeholder for slice slider, will update dynamically | |
slice_idx = gr.Slider( | |
minimum=0, | |
maximum=8, | |
step=1, | |
label="SAX slice to visualize", | |
value=1, | |
) | |
def get_num_slices(image_id): | |
sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz") | |
return sax_image.GetSize()[2] | |
def update_slice_slider(image_id): | |
num_slices = get_num_slices(image_id) | |
return gr.update(maximum=num_slices - 1, value=1, visible=True) | |
def fn(image_id, slice_idx): | |
lax_2c_image = load_nifti_from_github( | |
f"ukb/{image_id}/{image_id}_lax_2c.nii.gz" | |
) | |
lax_3c_image = load_nifti_from_github( | |
f"ukb/{image_id}/{image_id}_lax_3c.nii.gz" | |
) | |
lax_4c_image = load_nifti_from_github( | |
f"ukb/{image_id}/{image_id}_lax_4c.nii.gz" | |
) | |
sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz") | |
fig = plot_cmr_views( | |
lax_2c_image, | |
lax_3c_image, | |
lax_4c_image, | |
sax_image, | |
t_to_show=4, | |
depth_to_show=slice_idx, | |
) | |
fig.update_layout(height=600) | |
return fig | |
# When image changes, update the slice slider and plot | |
gr.on( | |
fn=lambda image_id: [update_slice_slider(image_id), fn(image_id, 1)], | |
inputs=[image_id], | |
outputs=[slice_idx, cmr_plot], | |
) | |
# When slice changes, update the plot | |
slice_idx.change( | |
fn=fn, | |
inputs=[image_id, slice_idx], | |
outputs=[cmr_plot], | |
) | |
return cmr_interface | |
def mae_inference( | |
batch: dict[str, torch.Tensor], | |
transform: Compose, | |
model: CineMA, | |
mask_ratio: float, | |
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], dict[str, np.ndarray]]: | |
model.to(device) | |
sax_slices = batch["sax"].shape[-1] | |
batch = transform(batch) | |
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()} | |
with ( | |
torch.no_grad(), | |
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
): | |
_, pred_dict, enc_mask_dict, _ = model(batch, enc_mask_ratio=mask_ratio) | |
grid_size_dict = { | |
k: v.patch_embed.grid_size for k, v in model.enc_down_dict.items() | |
} | |
reconstructed_dict, masks_dict = reconstruct_images( | |
batch, | |
pred_dict, | |
enc_mask_dict, | |
model.dec_patch_size_dict, | |
grid_size_dict, | |
sax_slices, | |
) | |
batch = { | |
k: v.detach().to(torch.float32).cpu().numpy()[0, 0] | |
for k, v in batch.items() | |
} | |
batch["sax"] = batch["sax"][..., :sax_slices] | |
return batch, reconstructed_dict, masks_dict | |
def mae(image_id, mask_ratio, progress=gr.Progress()): | |
# Create file path for saving MAE reconstruction plot | |
mae_path = cache_dir / f"mae_image{image_id}_mask{mask_ratio * 100:.0f}.png" | |
# Check if result already exists | |
if mae_path.exists(): | |
progress(1, desc="Loading cached result...") | |
return str(mae_path) | |
t = 4 # which time frame to use | |
progress(0, desc="Downloading model...") | |
model = CineMA.from_pretrained() | |
model.eval() | |
progress(0, desc="Downloading data...") | |
lax_2c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_2c.nii.gz") | |
lax_3c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_3c.nii.gz") | |
lax_4c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_4c.nii.gz") | |
sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz") | |
transform = Compose( | |
[ | |
ScaleIntensityd(keys=("sax", "lax_2c", "lax_3c", "lax_4c")), | |
SpatialPadd(keys="sax", spatial_size=(192, 192, 16), method="end"), | |
SpatialPadd( | |
keys=("lax_2c", "lax_3c", "lax_4c"), | |
spatial_size=(256, 256), | |
method="end", | |
), | |
] | |
) | |
lax_2c_image_np = np.transpose(sitk.GetArrayFromImage(lax_2c_image)) | |
lax_3c_image_np = np.transpose(sitk.GetArrayFromImage(lax_3c_image)) | |
lax_4c_image_np = np.transpose(sitk.GetArrayFromImage(lax_4c_image)) | |
sax_image_np = np.transpose(sitk.GetArrayFromImage(sax_image)) | |
image_dict = { | |
"sax": sax_image_np[None, ..., t], | |
"lax_2c": lax_2c_image_np[None, ..., 0, t], | |
"lax_3c": lax_3c_image_np[None, ..., 0, t], | |
"lax_4c": lax_4c_image_np[None, ..., 0, t], | |
} | |
batch = {k: torch.from_numpy(v) for k, v in image_dict.items()} | |
progress(0.5, desc="Running inference...") | |
batch, reconstructed_dict, masks_dict = mae_inference( | |
batch, transform, model, mask_ratio | |
) | |
progress(1, desc="Inference finished. Plotting ...") | |
# (y, x, z) -> (x, y, z) | |
batch["sax"] = np.transpose(batch["sax"], (1, 0, 2)) | |
reconstructed_dict["sax"] = np.transpose(reconstructed_dict["sax"], (1, 0, 2)) | |
masks_dict["sax"] = np.transpose(masks_dict["sax"], (1, 0, 2)) | |
# Plot MAE reconstruction and save to file | |
plot_mae_reconstruction(batch, reconstructed_dict, masks_dict, mae_path) | |
return str(mae_path) | |
def mae_tab(): | |
with gr.Blocks() as mae_interface: | |
gr.Markdown( | |
""" | |
This page demonstrates the masking and reconstruction process. The model was trained with a mask ratio of 0.75. Click the button below to launch the inference. β¬οΈ | |
""" | |
) | |
run_button = gr.Button("Launch reconstruction", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=5): | |
gr.Markdown("## Reconstruction") | |
plot = gr.Image( | |
show_label=False, | |
type="filepath", | |
label="Masked Autoencoder Reconstruction", | |
) | |
with gr.Column(scale=5): | |
gr.Markdown("## Data Settings") | |
image_id = gr.Slider( | |
minimum=1, | |
maximum=4, | |
step=1, | |
label="Choose an image", | |
value=2, | |
) | |
mask_ratio = gr.Slider( | |
minimum=0.05, | |
maximum=1, | |
step=0.05, | |
label="Mask ratio", | |
value=0.75, | |
) | |
run_button.click( | |
fn=mae, | |
inputs=[image_id, mask_ratio], | |
outputs=[plot], | |
) | |
return mae_interface | |
def segmentation_sax_inference( | |
images: torch.Tensor, | |
view: str, | |
transform: Compose, | |
model: ConvUNetR, | |
progress: gr.Progress, | |
) -> np.ndarray: | |
model.to(device) | |
n_slices, n_frames = images.shape[-2:] | |
labels_list = [] | |
for t in tqdm(range(0, n_frames), total=n_frames): | |
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
batch = transform({view: torch.from_numpy(images[None, ..., t])}) | |
batch = { | |
k: v[None, ...].to(device=device, dtype=torch.float32) | |
for k, v in batch.items() | |
} | |
with ( | |
torch.no_grad(), | |
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
): | |
logits = model(batch)[view] | |
labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices]) | |
labels = torch.stack(labels_list, dim=-1).detach().to(torch.float32).cpu().numpy() | |
return labels | |
def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progress()): | |
# Create file paths for saving plots | |
seg_path = ( | |
cache_dir | |
/ f"sax_segmentation_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.gif" | |
) | |
vol_path = ( | |
cache_dir | |
/ f"sax_volume_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.png" | |
) | |
# Check if results already exist | |
if seg_path.exists() and vol_path.exists(): | |
progress(1, desc="Loading cached results...") | |
return (str(seg_path), str(vol_path)) | |
# Fixed parameters | |
view = "sax" | |
split = "train" if image_id <= 100 else "test" | |
trained_dataset = { | |
"ACDC": "acdc", | |
"M&MS": "mnms", | |
"M&MS2": "mnms2", | |
}[str(trained_dataset)] | |
# Download and load model | |
progress(0, desc="Downloading model...") | |
image_path = hf_hub_download( | |
repo_id="mathpluscode/ACDC", | |
repo_type="dataset", | |
filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz", | |
cache_dir=cache_dir, | |
) | |
model = ConvUNetR.from_finetuned( | |
repo_id="mathpluscode/CineMA", | |
model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", | |
config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", | |
cache_dir=cache_dir, | |
) | |
model.eval() | |
# Inference | |
progress(0, desc="Downloading data...") | |
transform = Compose( | |
[ | |
ScaleIntensityd(keys=view), | |
SpatialPadd(keys=view, spatial_size=(192, 192, 16), method="end"), | |
] | |
) | |
images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path))) | |
images = images[..., ::t_step] | |
labels = segmentation_sax_inference(images, view, transform, model, progress) | |
# (y, x, z, t) -> (x, y, z, t) | |
images = np.transpose(images, (1, 0, 2, 3)) | |
labels = np.transpose(labels, (1, 0, 2, 3)) | |
progress(1, desc="Inference finished. Plotting ...") | |
# Plot segmentations and volume changes with file paths | |
plot_segmentations_sax(images, labels, t_step, seg_path) | |
plot_volume_changes_sax(labels, t_step, vol_path) | |
return (str(seg_path), str(vol_path)) | |
def segmentation_sax_tab(): | |
with gr.Blocks() as sax_interface: | |
gr.Markdown( | |
""" | |
This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view. Click the button below to launch the inference. β¬οΈ | |
""" | |
) | |
run_button = gr.Button("Launch segmentation inference", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
gr.Markdown(""" | |
## Description | |
### Data | |
Images 101β150 are from the test set of [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/). | |
### Model | |
The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are three models finetuned with seeds: 0, 1, 2. | |
""") | |
with gr.Column(scale=6): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Data Settings") | |
image_id = gr.Slider( | |
minimum=101, | |
maximum=150, | |
step=1, | |
label="Choose an image", | |
value=150, | |
) | |
t_step = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
label="Choose the gap between time frames", | |
value=3, | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("## Model Settings") | |
trained_dataset = gr.Dropdown( | |
choices=["ACDC", "M&MS", "M&MS2"], | |
label="Choose which dataset the model was finetuned on", | |
value="ACDC", | |
) | |
seed = gr.Slider( | |
minimum=0, | |
maximum=2, | |
step=1, | |
label="Choose which seed the finetuning used", | |
value=1, | |
) | |
# Visualisation description block | |
gr.Markdown(""" | |
## Visualisation | |
The left figure shows the segmentation at every n time step across all SAX slices. | |
The right figure shows the volumes across time frames and estimates the ejection fraction (EF) for the left ventricle (LV) and right ventricle (RV). | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
segmentation_gif = gr.Image( | |
show_label=True, | |
type="filepath", | |
label="Ventricle and Myocardium Segmentation", | |
) | |
with gr.Column(): | |
volume_plot = gr.Image( | |
show_label=True, | |
type="filepath", | |
label="Ejection Fraction Estimation", | |
) | |
run_button.click( | |
fn=segmentation_sax, | |
inputs=[trained_dataset, seed, image_id, t_step], | |
outputs=[segmentation_gif, volume_plot], | |
) | |
return sax_interface | |
def segmentation_lax_inference( | |
images: torch.Tensor, | |
view: str, | |
transform: Compose, | |
model: ConvUNetR, | |
progress: gr.Progress, | |
) -> np.ndarray: | |
model.to(device) | |
n_frames = images.shape[-1] | |
labels_list = [] | |
for t in tqdm(range(n_frames), total=n_frames): | |
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])}) | |
batch = { | |
k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items() | |
} | |
with ( | |
torch.no_grad(), | |
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
): | |
logits = model(batch)[view] # (1, 4, x, y) | |
labels = ( | |
torch.argmax(logits, dim=1)[0].detach().to(torch.float32).cpu().numpy() | |
) # (x, y) | |
# the model seems to hallucinate an additional right ventricle and myocardium sometimes | |
# find the connected component that is closest to left ventricle | |
labels = post_process_lax_segmentation(labels) | |
labels_list.append(labels) | |
labels = np.stack(labels_list, axis=-1) # (x, y, t) | |
return labels | |
def segmentation_lax(seed, image_id, progress=gr.Progress()): | |
# Create file paths for saving plots | |
seg_path = cache_dir / f"lax_segmentation_image{image_id}_seed{seed}.gif" | |
vol_path = cache_dir / f"lax_volume_image{image_id}_seed{seed}.png" | |
# Check if results already exist | |
if seg_path.exists() and vol_path.exists(): | |
progress(1, desc="Loading cached results...") | |
return (str(seg_path), str(vol_path)) | |
# Fixed parameters | |
trained_dataset = "mnms2" | |
view = "lax_4c" | |
# Download and load model | |
progress(0, desc="Downloading model...") | |
model = ConvUNetR.from_finetuned( | |
repo_id="mathpluscode/CineMA", | |
model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", | |
config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", | |
cache_dir=cache_dir, | |
) | |
model.eval() | |
# Inference | |
progress(0, desc="Downloading data...") | |
transform = ScaleIntensityd(keys=view) | |
images = np.transpose( | |
sitk.GetArrayFromImage( | |
load_nifti_from_github(f"ukb/{image_id}/{image_id}_{view}.nii.gz") | |
) | |
) | |
labels = segmentation_lax_inference(images, view, transform, model, progress) | |
progress(1, desc="Inference finished. Plotting ...") | |
# Plot segmentations and save as GIF | |
plot_segmentations_lax(images, labels, seg_path) | |
# Plot volume changes and save as figure | |
plot_volume_changes_lax(labels, vol_path) | |
return (str(seg_path), str(vol_path)) | |
def segmentation_lax_tab(): | |
with gr.Blocks() as lax_interface: | |
gr.Markdown( | |
""" | |
This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view. Click the button below to launch the inference. β¬οΈ | |
""" | |
) | |
run_button = gr.Button("Launch segmentation inference", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
gr.Markdown(""" | |
## Description | |
### Data | |
There are four example images from the UK Biobank. Models were not trained supervisedly on these images. | |
### Model | |
The available models are finetuned on [M&Ms2](https://www.ub.edu/mnms-2/). There are three models finetuned with seeds: 0, 1, 2. | |
""") | |
with gr.Column(scale=6): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Data Settings") | |
image_id = gr.Slider( | |
minimum=1, | |
maximum=4, | |
step=1, | |
label="Choose an image", | |
value=2, | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("## Model Settings") | |
seed = gr.Slider( | |
minimum=0, | |
maximum=2, | |
step=1, | |
label="Choose which seed the finetuning used", | |
value=1, | |
) | |
# Visualisation description block | |
gr.Markdown(""" | |
## Visualisation | |
The left figure shows the segmentation across time frames. | |
The right figure shows the volumes across time frames and estimates the ejection fraction (EF). | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
segmentation_gif = gr.Image( | |
show_label=True, | |
type="filepath", | |
label="Ventricle and Myocardium Segmentation", | |
) | |
with gr.Column(): | |
volume_plot = gr.Image( | |
show_label=True, | |
type="filepath", | |
label="Ejection Fraction Prediction", | |
) | |
run_button.click( | |
fn=segmentation_lax, | |
inputs=[seed, image_id], | |
outputs=[segmentation_gif, volume_plot], | |
) | |
return lax_interface | |
def landmark_heatmap_inference( | |
images: torch.Tensor, | |
view: str, | |
transform: Compose, | |
model: ConvUNetR, | |
progress: gr.Progress, | |
) -> tuple[np.ndarray, np.ndarray]: | |
model.to(device) | |
n_frames = images.shape[-1] | |
probs_list = [] | |
coords_list = [] | |
for t in tqdm(range(n_frames), total=n_frames): | |
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])}) | |
batch = { | |
k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items() | |
} | |
with ( | |
torch.no_grad(), | |
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
): | |
logits = model(batch)[view] # (1, 3, x, y) | |
probs = torch.sigmoid(logits) # (1, 3, width, height) | |
probs_list.append(probs[0].detach().to(torch.float32).cpu().numpy()) | |
coords = heatmap_soft_argmax(probs)[0].detach().to(torch.float32).cpu().numpy() | |
coords = [int(x) for x in coords] | |
coords_list.append(coords) | |
probs = np.stack(probs_list, axis=-1) # (3, x, y, t) | |
coords = np.stack(coords_list, axis=-1) # (6, t) | |
return probs, coords | |
def landmark_coordinate_inference( | |
images: torch.Tensor, | |
view: str, | |
transform: Compose, | |
model: ConvViT, | |
progress: gr.Progress, | |
) -> np.ndarray: | |
model.to(device) | |
w, h, _, n_frames = images.shape | |
coords_list = [] | |
for t in tqdm(range(n_frames), total=n_frames): | |
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])}) | |
batch = { | |
k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items() | |
} | |
with ( | |
torch.no_grad(), | |
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
): | |
coords = model(batch)[0].detach().to(torch.float32).cpu().numpy() # (6,) | |
coords *= np.array([w, h, w, h, w, h]) | |
coords = [int(x) for x in coords] | |
coords_list.append(coords) | |
coords = np.stack(coords_list, axis=-1) # (6, t) | |
return coords | |
def landmark(image_id, view, method, seed, progress=gr.Progress()): | |
view = "lax_2c" if view == "LAX 2C" else "lax_4c" | |
method = method.lower() | |
# Create file paths for saving plots | |
landmark_path = ( | |
cache_dir / f"landmark_{view}_image{image_id}_{method}_seed{seed}.gif" | |
) | |
lv_path = cache_dir / f"lv_{view}_image{image_id}_{method}_seed{seed}.png" | |
# Check if results already exist | |
if landmark_path.exists() and lv_path.exists(): | |
progress(1, desc="Loading cached results...") | |
return (str(landmark_path), str(lv_path)) | |
# Download and load model | |
progress(0, desc="Downloading model...") | |
if method == "heatmap": | |
model = ConvUNetR.from_finetuned( | |
repo_id="mathpluscode/CineMA", | |
model_filename=f"finetuned/landmark_{method}/{view}/{view}_{seed}.safetensors", | |
config_filename=f"finetuned/landmark_{method}/{view}/config.yaml", | |
cache_dir=cache_dir, | |
) | |
elif method == "coordinate": | |
model = ConvViT.from_finetuned( | |
repo_id="mathpluscode/CineMA", | |
model_filename=f"finetuned/landmark_{method}/{view}/{view}_{seed}.safetensors", | |
config_filename=f"finetuned/landmark_{method}/{view}/config.yaml", | |
cache_dir=cache_dir, | |
) | |
else: | |
raise ValueError(f"Invalid method: {method}") | |
model.eval() | |
# Inference | |
progress(0, desc="Downloading data...") | |
transform = ScaleIntensityd(keys=view) | |
images = np.transpose( | |
sitk.GetArrayFromImage( | |
load_nifti_from_github(f"ukb/{image_id}/{image_id}_{view}.nii.gz") | |
) | |
) | |
if method == "heatmap": | |
probs, coords = landmark_heatmap_inference( | |
images, view, transform, model, progress | |
) | |
progress(1, desc="Inference finished. Plotting ...") | |
plot_heatmap_and_landmarks(images, probs, coords, landmark_path) | |
elif method == "coordinate": | |
coords = landmark_coordinate_inference(images, view, transform, model, progress) | |
progress(1, desc="Inference finished. Plotting ...") | |
plot_landmarks(images, coords, landmark_path) | |
else: | |
raise ValueError(f"Invalid method: {method}") | |
# Plot LV change in PNG | |
plot_lv(coords, lv_path) | |
return (str(landmark_path), str(lv_path)) | |
def landmark_tab(): | |
with gr.Blocks() as landmark_interface: | |
gr.Markdown( | |
""" | |
This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views. Click the button below to launch the inference. β¬οΈ | |
""" | |
) | |
run_button = gr.Button( | |
"Launch landmark localisation inference", variant="primary" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
gr.Markdown(""" | |
## Description | |
### Data | |
There are four example images from the UK Biobank. Models were not trained supervisedly on these images. | |
### Model | |
The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197) | |
There are two types of landmark localisation models: | |
- **Heatmap**: predicts dense probability maps of landmarks (more accurate) | |
- **Coordinate**: predicts landmark coordinates directly | |
For each type, there are three models finetuned with seeds: 0, 1, 2. | |
""") | |
with gr.Column(scale=6): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Data Settings") | |
image_id = gr.Slider( | |
minimum=1, | |
maximum=4, | |
step=1, | |
label="Choose an image", | |
value=2, | |
) | |
view = gr.Dropdown( | |
choices=["LAX 2C", "LAX 4C"], | |
label="Choose which view to localise the landmarks", | |
value="LAX 2C", | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("## Model Settings") | |
method = gr.Dropdown( | |
choices=["Heatmap", "Coordinate"], | |
label="Choose which method to use", | |
value="Heatmap", | |
) | |
seed = gr.Slider( | |
minimum=0, | |
maximum=2, | |
step=1, | |
label="Choose which seed the finetuning used", | |
value=1, | |
) | |
# Visualisation description block | |
gr.Markdown(""" | |
## Visualisation | |
The left figure shows the landmark positions across time frames. | |
The right figure shows the length of the left ventricle across time frames and estimates mitral annular plane systolic excursion (MAPSE) and global longitudinal shortening (GLS). | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
landmark_gif = gr.Image( | |
show_label=True, | |
type="filepath", | |
label="Landmark Localisation", | |
) | |
with gr.Column(): | |
lv_plot = gr.Image( | |
show_label=True, | |
type="filepath", | |
label="Left Ventricle Length Estimation", | |
) | |
run_button.click( | |
fn=landmark, | |
inputs=[image_id, view, method, seed], | |
outputs=[landmark_gif, lv_plot], | |
) | |
return landmark_interface | |
with gr.Blocks( | |
theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI" | |
) as demo: | |
gr.Markdown( | |
""" | |
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π« | |
π The following demonstrations showcase the capabilities of CineMA in multiple tasks. Click the button to launch the inference.<br> | |
β±οΈ The examples may take 10-60 seconds, if not cached, to download data and model, perform inference, and render plots.<br> | |
π For more details, check out our [manuscript](https://arxiv.org/abs/2506.00679) and [GitHub repository](https://github.com/mathpluscode/CineMA). | |
""" | |
) | |
with gr.Tabs(selected="sax_seg") as tabs: | |
with gr.TabItem("πΌοΈ Cine CMR Views", id="cmr"): | |
cmr_tab() | |
with gr.TabItem("π§© Masked Autoencoder", id="mae"): | |
mae_tab() | |
with gr.TabItem("βοΈ Segmentation in SAX View", id="sax_seg"): | |
segmentation_sax_tab() | |
with gr.TabItem("βοΈ Segmentation in LAX 4C View", id="lax_seg"): | |
segmentation_lax_tab() | |
with gr.TabItem("π Landmark Localisation in LAX 2C/4C View", id="landmark"): | |
landmark_tab() | |
demo.launch(allowed_paths=[cache_dir]) | |