Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import SimpleITK as sitk # noqa: N813 | |
import torch | |
from monai.transforms import Compose, ScaleIntensityd, SpatialPadd | |
from cinema import ConvUNetR | |
from pathlib import Path | |
from examples.inference.segmentation_sax import plot_segmentations, plot_volume_changes | |
import spaces | |
# cache directories | |
cache_dir = Path("/tmp/.cinema") | |
cache_dir.mkdir(parents=True, exist_ok=True) | |
def inferece( | |
images: torch.Tensor, | |
view: str, | |
transform: Compose, | |
model: ConvUNetR, | |
progress=gr.Progress(), | |
) -> np.ndarray: | |
# set device and dtype | |
dtype, device = torch.float32, torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
if torch.cuda.is_bf16_supported(): | |
dtype = torch.bfloat16 | |
# inference | |
model.to(device) | |
n_slices, n_frames = images.shape[-2:] | |
labels_list = [] | |
for t in range(0, n_frames): | |
progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
batch = transform({view: torch.from_numpy(images[None, ..., t])}) | |
batch = { | |
k: v[None, ...].to(device=device, dtype=torch.float32) | |
for k, v in batch.items() | |
} | |
with ( | |
torch.no_grad(), | |
torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
): | |
logits = model(batch)[view] | |
labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices]) | |
labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy() | |
return labels | |
def run_inference(trained_dataset, seed, image_id, t_step, progress=gr.Progress()): | |
# Fixed parameters | |
view = "sax" | |
split = "train" if image_id <= 100 else "test" | |
trained_dataset = { | |
"ACDC": "acdc", | |
"M&MS": "mnms", | |
"M&MS2": "mnms2", | |
}[str(trained_dataset)] | |
# Download and load model | |
progress(0, desc="Downloading model and data...") | |
image_path = hf_hub_download( | |
repo_id="mathpluscode/ACDC", | |
repo_type="dataset", | |
filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz", | |
cache_dir=cache_dir, | |
) | |
model = ConvUNetR.from_finetuned( | |
repo_id="mathpluscode/CineMA", | |
model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", | |
config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", | |
cache_dir=cache_dir, | |
) | |
# Load and process data | |
transform = Compose( | |
[ | |
ScaleIntensityd(keys=view), | |
SpatialPadd(keys=view, spatial_size=(192, 192, 16), method="end"), | |
] | |
) | |
images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path))) | |
images = images[..., ::t_step] | |
labels = inferece(images, view, transform, model, progress) | |
progress(1, desc="Plotting results...") | |
fig1 = plot_segmentations(images, labels, t_step) | |
fig2 = plot_volume_changes(labels, t_step) | |
return fig1, fig2 | |
# Create the Gradio interface | |
theme = gr.themes.Ocean( | |
primary_hue="red", | |
secondary_hue="purple", | |
) | |
with gr.Blocks( | |
theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI" | |
) as demo: | |
gr.Markdown( | |
""" | |
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π« | |
Below is an example of ejection fraction prediction inference. For more examples, checkout our [GitHub](https://github.com/mathpluscode/CineMA). | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=0.4): | |
gr.Markdown("## Description") | |
gr.Markdown(""" | |
Please adjust the settings on the right panels and click the button to run the inference. | |
### Data | |
The available data is from ACDC. All images have been resampled to 1 mm Γ 1 mm Γ 10 mm and centre-cropped to 192 mm Γ 192 mm for each SAX slice. | |
Image 1 - 100 are from the training set, and image 101 - 150 are from the test set. | |
### Model | |
The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. The default model is the one finetuned on ACDC dataset with seed 0. | |
### Visualization | |
The left panel shows the segmentation of ventricles and myocardium every n time steps across all SAX slices. | |
The right panel plots the ventricle and mycoardium volumes across all inference time frames. | |
""") | |
with gr.Column(scale=0.3): | |
gr.Markdown("## Data Settings") | |
image_id = gr.Slider( | |
minimum=1, | |
maximum=150, | |
step=1, | |
label="Choose an ACDC image, ID is between 1 and 150", | |
value=150, | |
) | |
t_step = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
label="Choose the gap between time frames", | |
value=2, | |
) | |
with gr.Column(scale=0.3): | |
gr.Markdown("## Model Setting") | |
trained_dataset = gr.Dropdown( | |
choices=["ACDC", "M&MS", "M&MS2"], | |
label="Choose which dataset the segmentation model was finetuned on", | |
value="ACDC", | |
) | |
seed = gr.Slider( | |
minimum=0, | |
maximum=2, | |
step=1, | |
label="Choose which seed the finetuning used", | |
value=0, | |
) | |
run_button = gr.Button("Run segmentation inference", variant="primary") | |
with gr.Row(): | |
segmentation_plot = gr.Plot(label="Ventricle and Myocardium Segmentation") | |
volume_plot = gr.Plot(label="Ejection Fraction Prediction") | |
run_button.click( | |
fn=run_inference, | |
inputs=[trained_dataset, seed, image_id, t_step], | |
outputs=[segmentation_plot, volume_plot], | |
) | |
demo.launch() | |