SAM-DiffSR / app.py
Traly's picture
init
a089d4c
raw
history blame
3.01 kB
import importlib
from collections import OrderedDict
from pathlib import Path
import gradio as gr
import os
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from sam_diffsr.utils_sr.hparams import set_hparams, hparams
from sam_diffsr.utils_sr.matlab_resize import imresize
def get_img_data(img_PIL, hparams, sr_scale=4):
img_lr = img_PIL.convert('RGB')
img_lr = np.uint8(np.asarray(img_lr))
h, w, c = img_lr.shape
h, w = h * sr_scale, w * sr_scale
h = h - h % (sr_scale * 2)
w = w - w % (sr_scale * 2)
h_l = h // sr_scale
w_l = w // sr_scale
img_lr = img_lr[:h_l, :w_l]
to_tensor_norm = transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
img_lr = torch.unsqueeze(img_lr, dim=0)
img_lr_up = torch.unsqueeze(img_lr_up, dim=0)
return img_lr, img_lr_up
def load_checkpoint(model, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location='cpu')
print(f'loding check from: {ckpt_path}')
stat_dict = checkpoint['state_dict']['model']
new_state_dict = OrderedDict()
for k, v in stat_dict.items():
if k[:7] == 'module.':
k = k[7:] # εŽ»ζŽ‰ `module.`
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# model.cuda()
del checkpoint
torch.cuda.empty_cache()
def model_init(ckpt_path):
set_hparams()
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer
trainer = trainer()
trainer.build_model()
load_checkpoint(trainer.model, ckpt_path)
torch.backends.cudnn.benchmark = False
return trainer
def image_infer(img_PIL):
with torch.no_grad():
trainer.model.eval()
img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4)
# img_lr = img_lr.to('cuda')
# img_lr_up = img_lr_up.to('cuda')
img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
img_sr = img_sr.clamp(-1, 1)
img_sr = trainer.tensor2img(img_sr)[0]
img_sr = Image.fromarray(img_sr)
return img_sr
root_path = os.path.dirname(__file__)
cheetah = os.path.join(root_path, "images/0801x4.png")
print(cheetah)
ckpt_path = os.path.join(root_path, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt')
trainer = model_init(ckpt_path)
demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image",
# flagging_options=["blurry", "incorrect", "other"],
examples=[
os.path.join(root_path, "images/0801x4.png"),
os.path.join(root_path, "images/0804x4.png"),
os.path.join(root_path, "images/0809x4.png"),
]
)
if __name__ == "__main__":
demo.launch()