CineMA / app.py
mathpluscode's picture
Refactor
ebd9a25
raw
history blame
6.17 kB
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import SimpleITK as sitk # noqa: N813
import torch
from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
from cinema import ConvUNetR
from pathlib import Path
from examples.inference.segmentation_sax import plot_segmentations, plot_volume_changes
import spaces
# cache directories
cache_dir = Path("/tmp/.cinema")
cache_dir.mkdir(parents=True, exist_ok=True)
@spaces.GPU
def inferece(
images: torch.Tensor,
view: str,
transform: Compose,
model: ConvUNetR,
progress=gr.Progress(),
) -> np.ndarray:
# set device and dtype
dtype, device = torch.float32, torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
if torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
# inference
model.to(device)
n_slices, n_frames = images.shape[-2:]
labels_list = []
for t in range(0, n_frames):
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...")
batch = transform({view: torch.from_numpy(images[None, ..., t])})
batch = {
k: v[None, ...].to(device=device, dtype=torch.float32)
for k, v in batch.items()
}
with (
torch.no_grad(),
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()),
):
logits = model(batch)[view]
labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy()
return labels
def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()):
# Fixed parameters
view = "sax"
split = "train" if image_id <= 100 else "test"
trained_dataset = {
"ACDC": "acdc",
"M&MS": "mnms",
"M&MS2": "mnms2",
}[str(trained_dataset)]
# Download and load model
progress(0, desc="Downloading model and data...")
image_path = hf_hub_download(
repo_id="mathpluscode/ACDC",
repo_type="dataset",
filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz",
cache_dir=cache_dir,
)
model = ConvUNetR.from_finetuned(
repo_id="mathpluscode/CineMA",
model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml",
cache_dir=cache_dir,
)
# Load and process data
transform = Compose(
[
ScaleIntensityd(keys=view),
SpatialPadd(keys=view, spatial_size=(192, 192, 16), method="end"),
]
)
images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path)))
images = images[..., ::t_step]
labels = inferece(images, view, transform, model, progress)
progress(1, desc="Plotting results...")
fig1 = plot_segmentations(images, labels, t_step)
fig2 = plot_volume_changes(labels, t_step)
return fig1, fig2
# Create the Gradio interface
theme = gr.themes.Ocean(
primary_hue="red",
secondary_hue="purple",
)
with gr.Blocks(
theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI"
) as demo:
gr.Markdown(
"""
# CineMA: A Foundation Model for Cine Cardiac MRI πŸŽ₯πŸ«€
Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA).
"""
)
with gr.Row():
with gr.Column(scale=0.4):
gr.Markdown("## Description")
gr.Markdown("""
Please adjust the settings on the right panels and click the button to run the inference.
### Data
The available data is from ACDC. All images have been resampled to 1 mm Γ— 1 mm Γ— 10 mm and centre-cropped to 192 mm Γ— 192 mm for each SAX slice.
Image 1 - 100 are from the training set, and image 101 - 150 are from the test set.
### Model
The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0.
### Visualization
The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices.
The right panel plots the ventricle and mycoardium volumes across all inference time frames.
""")
with gr.Column(scale=0.3):
gr.Markdown("## Data Settings")
image_id = gr.Slider(
minimum=1,
maximum=150,
step=1,
label="Choose an ACDC image, ID is between 1 and 150",
value=150,
)
t_step = gr.Slider(
minimum=1,
maximum=10,
step=1,
label="Choose the gap between time frames",
value=2,
)
with gr.Column(scale=0.3):
gr.Markdown("## Model Setting")
trained_dataset = gr.Dropdown(
choices=["ACDC", "M&MS", "M&MS2"],
label="Choose which dataset the segmentation model was finetuned on",
value="ACDC",
)
seed = gr.Slider(
minimum=0,
maximum=2,
step=1,
label="Choose which seed the finetuning used",
value=0,
)
run_button = gr.Button("Run segmentation inference", variant="primary")
with gr.Row():
segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation")
volume_plot = gr.Plot(label="Ejection Fraction Prediction")
run_button.click(
fn=run_inference,
inputs=[trained_dataset, seed, image_id, t_step],
outputs=[segmentation_plot, volume_plot],
)
demo.launch()