File size: 3,013 Bytes
193c713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a089d4c
193c713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a089d4c
 
193c713
 
 
 
 
 
 
 
 
 
 
 
a089d4c
193c713
 
a089d4c
 
193c713
 
 
 
a089d4c
193c713
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import importlib
from collections import OrderedDict
from pathlib import Path

import gradio as gr
import os

import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from sam_diffsr.utils_sr.hparams import set_hparams, hparams
from sam_diffsr.utils_sr.matlab_resize import imresize


def get_img_data(img_PIL, hparams, sr_scale=4):
    img_lr = img_PIL.convert('RGB')
    img_lr = np.uint8(np.asarray(img_lr))

    h, w, c = img_lr.shape
    h, w = h * sr_scale, w * sr_scale
    h = h - h % (sr_scale * 2)
    w = w - w % (sr_scale * 2)
    h_l = h // sr_scale
    w_l = w // sr_scale

    img_lr = img_lr[:h_l, :w_l]

    to_tensor_norm = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    img_lr_up = imresize(img_lr / 256, hparams['sr_scale'])  # np.float [H, W, C]
    img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]

    img_lr = torch.unsqueeze(img_lr, dim=0)
    img_lr_up = torch.unsqueeze(img_lr_up, dim=0)

    return img_lr, img_lr_up


def load_checkpoint(model, ckpt_path):
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    print(f'loding check from: {ckpt_path}')
    stat_dict = checkpoint['state_dict']['model']

    new_state_dict = OrderedDict()
    for k, v in stat_dict.items():
        if k[:7] == 'module.':
            k = k[7:]  # εŽ»ζŽ‰ `module.`
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    # model.cuda()
    del checkpoint
    torch.cuda.empty_cache()


def model_init(ckpt_path):
    set_hparams()

    from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer

    trainer = trainer()

    trainer.build_model()
    load_checkpoint(trainer.model, ckpt_path)

    torch.backends.cudnn.benchmark = False

    return trainer


def image_infer(img_PIL):
    with torch.no_grad():
        trainer.model.eval()
        img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4)

        # img_lr = img_lr.to('cuda')
        # img_lr_up = img_lr_up.to('cuda')

        img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)

        img_sr = img_sr.clamp(-1, 1)
        img_sr = trainer.tensor2img(img_sr)[0]
        img_sr = Image.fromarray(img_sr)

    return img_sr


root_path = os.path.dirname(__file__)

cheetah = os.path.join(root_path, "images/0801x4.png")
print(cheetah)

ckpt_path = os.path.join(root_path, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt')
trainer = model_init(ckpt_path)
demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image",
                    # flagging_options=["blurry", "incorrect", "other"],
                    examples=[
                        os.path.join(root_path, "images/0801x4.png"),
                        os.path.join(root_path, "images/0804x4.png"),
                        os.path.join(root_path, "images/0809x4.png"),
                    ]
                    )

if __name__ == "__main__":
    demo.launch()