HE-to-IHC / app.py
antoinedelplace
Change png to jpg image file
968d5c4
import gradio as gr
import numpy as np
import torch
from types import SimpleNamespace
from PIL import Image
from asp.models.cpt_model import CPTModel
from asp.util.general_utils import parse_args
from asp.util.io_utils import download_weights
def preprocess_image(img):
img_array = np.array(img) # Shape: [H, W, C], dtype: uint8, values: [0, 255]
if img_array.ndim == 2: # Grayscale image
img_array = np.stack((img_array,) * 3, axis=-1)
elif img_array.shape[2] == 4: # RGBA image
img_array = img_array[:, :, :3] # Discard the alpha channel
img_array = np.transpose(img_array, (2, 0, 1)) # Shape: [C, H, W]
img_array = img_array.astype(np.float32) # Convert to float32
img_array = img_array / 255.0 * 2.0 - 1.0 # Scale to [-1.0, 1.0]
img_tensor = torch.from_numpy(img_array) # Shape: [C, H, W]
img_tensor = img_tensor.unsqueeze(0) # Shape: [1, C, H, W]
return img_tensor
def postprocess_tensor(tensor):
output_img = tensor.squeeze(0).detach().cpu()
output_img = output_img.clamp(-1.0, 1.0).float().numpy()
output_img = (np.transpose(output_img, (1, 2, 0)) + 1) / 2.0 * 255.0
output_img = output_img.astype(np.uint8)
output_img = Image.fromarray(output_img)
return output_img
def convert_he2ihc(output_stain, input_he_image_path):
stain2folder_name = {
"HER2 (Human Epidermal growth factor Receptor 2)": "ASP_pretrained/MIST_her2_lambda_linear",
"ER (Estrogen Receptor)" : "ASP_pretrained/MIST_er_lambda_linear",
"Ki67 (Antigen KI-67)" : "ASP_pretrained/MIST_ki67_lambda_linear",
"PR (Progesterone Receptor)" : "ASP_pretrained/MIST_pr_lambda_linear",
}
input_img = Image.open(input_he_image_path).convert('RGB')
original_img_size = input_img.size
opt = SimpleNamespace(
gpu_ids=None,
isTrain=False,
checkpoints_dir="../../checkpoints",
name=stain2folder_name[output_stain],
preprocess="scale_width_and_crop",
nce_layers="0,4,8,12,16",
nce_idt=False,
input_nc=3,
output_nc=3,
ngf=64,
netG="resnet_6blocks",
normG="instance",
no_dropout=True,
init_type="xavier",
init_gain=0.02,
no_antialias=False,
no_antialias_up=False,
weight_norm="spectral",
netF="mlp_sample",
netF_nc=256,
no_flip=True,
load_size=1024,
crop_size=1024,
direction="AtoB",
flip_equivariance=False,
epoch="latest",
verbose=True
)
model = CPTModel(opt)
model.setup(opt)
model.parallelize()
model.eval()
input_img = input_img.resize((1024, 1024))
input_tensor = preprocess_image(input_img)
model.set_input({
"A": input_tensor,
"A_paths": input_he_image_path,
"B": input_tensor,
"B_paths": input_he_image_path,
})
model.test()
visuals = model.get_current_visuals()
output_img = postprocess_tensor(visuals['fake_B'])
output_img = output_img.resize(original_img_size)
print("np.shape(output_img)", np.shape(output_img))
return output_img
def main():
download_weights("1N_HOGU7FO4u-S1OD-bumZGyevYeucT4Q", "../../checkpoints/ASP_pretrained/MIST_her2_lambda_linear/latest_net_G.pth")
download_weights("1j6xu8MAOVUaZuV4O5CqsBfMtH6-droys", "../../checkpoints/ASP_pretrained/MIST_er_lambda_linear/latest_net_G.pth")
download_weights("10STHMS-GMkHMOJp_cJ44T66rRwKlUZyr", "../../checkpoints/ASP_pretrained/MIST_ki67_lambda_linear/latest_net_G.pth")
download_weights("1APIrm3kqtPhhAIcU7pvfIcYpMjpsIlQ9", "../../checkpoints/ASP_pretrained/MIST_pr_lambda_linear/latest_net_G.pth")
with gr.Blocks() as demo:
dropdown = gr.Dropdown(
choices=["HER2 (Human Epidermal growth factor Receptor 2)",
"ER (Estrogen Receptor)",
"Ki67 (Antigen KI-67)",
"PR (Progesterone Receptor)"],
label="Output Stain"
)
input_img = gr.Image(type="filepath", label="Input H&E Image")
output_img = gr.Image(label="Output IHC Image")
gr.Interface(
fn=convert_he2ihc,
inputs=[dropdown, input_img],
outputs=output_img,
title="H&E-to-IHC Stain Translation",
description="<h2>Stain your H&E (Hematoxylin and Eosin) images into IHC (ImmunoHistoChemistry) images automatically thanks to AI!</h2>",
theme="ParityError/Interstellar"
)
gr.Examples(
examples=[
["assets/he.jpg", "assets/ihc.jpg"],
],
inputs=[input_img, output_img],
examples_per_page=1
)
demo.launch()
if __name__ == "__main__":
args = parse_args(main)
main(**vars(args))
# python app.py