File size: 4,954 Bytes
207ef6f
 
 
 
 
 
 
 
 
 
6dadf7f
207ef6f
fb5bc28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207ef6f
3f587ea
 
 
 
 
 
 
 
207ef6f
a8adc67
207ef6f
 
5accddc
207ef6f
 
3f587ea
a8adc67
207ef6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb5bc28
 
 
207ef6f
fb5bc28
207ef6f
fb5bc28
207ef6f
 
 
 
 
fb5bc28
 
a8adc67
207ef6f
 
 
 
 
fb5bc28
3f587ea
 
 
fb5bc28
67a8194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968d5c4
67a8194
 
 
 
 
207ef6f
 
 
 
 
 
 
6bb9743
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import numpy as np

import torch

from types import SimpleNamespace
from PIL import Image

from asp.models.cpt_model import CPTModel
from asp.util.general_utils import parse_args
from asp.util.io_utils import download_weights

def preprocess_image(img):
    img_array = np.array(img)  # Shape: [H, W, C], dtype: uint8, values: [0, 255]

    if img_array.ndim == 2:  # Grayscale image
        img_array = np.stack((img_array,) * 3, axis=-1)
    elif img_array.shape[2] == 4:  # RGBA image
        img_array = img_array[:, :, :3]  # Discard the alpha channel

    img_array = np.transpose(img_array, (2, 0, 1))  # Shape: [C, H, W]
    img_array = img_array.astype(np.float32)  # Convert to float32
    img_array = img_array / 255.0 * 2.0 - 1.0  # Scale to [-1.0, 1.0]

    img_tensor = torch.from_numpy(img_array)  # Shape: [C, H, W]
    img_tensor = img_tensor.unsqueeze(0)  # Shape: [1, C, H, W]

    return img_tensor

def postprocess_tensor(tensor):
    output_img = tensor.squeeze(0).detach().cpu()
    output_img = output_img.clamp(-1.0, 1.0).float().numpy()
    output_img = (np.transpose(output_img, (1, 2, 0)) + 1) / 2.0 * 255.0
    output_img = output_img.astype(np.uint8)
    output_img = Image.fromarray(output_img)

    return output_img

def convert_he2ihc(output_stain, input_he_image_path):
    stain2folder_name = {
        "HER2 (Human Epidermal growth factor Receptor 2)": "ASP_pretrained/MIST_her2_lambda_linear",
        "ER (Estrogen Receptor)"                         : "ASP_pretrained/MIST_er_lambda_linear",
        "Ki67 (Antigen KI-67)"                           : "ASP_pretrained/MIST_ki67_lambda_linear",
        "PR (Progesterone Receptor)"                     : "ASP_pretrained/MIST_pr_lambda_linear",
    }

    input_img = Image.open(input_he_image_path).convert('RGB')
    original_img_size = input_img.size

    opt = SimpleNamespace(
        gpu_ids=None,
        isTrain=False,
        checkpoints_dir="../../checkpoints",
        name=stain2folder_name[output_stain],
        preprocess="scale_width_and_crop",
        nce_layers="0,4,8,12,16",
        nce_idt=False,
        input_nc=3,
        output_nc=3,
        ngf=64,
        netG="resnet_6blocks",
        normG="instance",
        no_dropout=True,
        init_type="xavier",
        init_gain=0.02,
        no_antialias=False,
        no_antialias_up=False,
        weight_norm="spectral",
        netF="mlp_sample",
        netF_nc=256,
        no_flip=True,
        load_size=1024,
        crop_size=1024,
        direction="AtoB",
        flip_equivariance=False,
        epoch="latest",
        verbose=True
    )
    model = CPTModel(opt)

    model.setup(opt)
    model.parallelize()
    model.eval()

    input_img = input_img.resize((1024, 1024))
    input_tensor = preprocess_image(input_img)

    model.set_input({
        "A": input_tensor, 
        "A_paths": input_he_image_path,
        "B": input_tensor,
        "B_paths": input_he_image_path,
    })
    model.test()
    visuals = model.get_current_visuals()

    output_img = postprocess_tensor(visuals['fake_B'])

    output_img = output_img.resize(original_img_size)
    print("np.shape(output_img)", np.shape(output_img))

    return output_img

def main():
    download_weights("1N_HOGU7FO4u-S1OD-bumZGyevYeucT4Q", "../../checkpoints/ASP_pretrained/MIST_her2_lambda_linear/latest_net_G.pth")
    download_weights("1j6xu8MAOVUaZuV4O5CqsBfMtH6-droys", "../../checkpoints/ASP_pretrained/MIST_er_lambda_linear/latest_net_G.pth")
    download_weights("10STHMS-GMkHMOJp_cJ44T66rRwKlUZyr", "../../checkpoints/ASP_pretrained/MIST_ki67_lambda_linear/latest_net_G.pth")
    download_weights("1APIrm3kqtPhhAIcU7pvfIcYpMjpsIlQ9", "../../checkpoints/ASP_pretrained/MIST_pr_lambda_linear/latest_net_G.pth")
    
    with gr.Blocks() as demo:
        dropdown = gr.Dropdown(
                    choices=["HER2 (Human Epidermal growth factor Receptor 2)", 
                    "ER (Estrogen Receptor)", 
                    "Ki67 (Antigen KI-67)", 
                    "PR (Progesterone Receptor)"], 
                    label="Output Stain"
                )
        input_img = gr.Image(type="filepath", label="Input H&E Image")
        output_img = gr.Image(label="Output IHC Image")
        gr.Interface(
            fn=convert_he2ihc,
            inputs=[dropdown, input_img],
            outputs=output_img,
            title="H&E-to-IHC Stain Translation",
            description="<h2>Stain your H&E (Hematoxylin and Eosin) images into IHC (ImmunoHistoChemistry) images automatically thanks to AI!</h2>",
            theme="ParityError/Interstellar"
        )

        gr.Examples(
            examples=[
                ["assets/he.jpg", "assets/ihc.jpg"],
            ],
            inputs=[input_img, output_img],
            examples_per_page=1
        )


    demo.launch()

if __name__ == "__main__":
    args = parse_args(main)
    main(**vars(args))

# python app.py