Spaces:
Runtime error
Runtime error
import importlib | |
from collections import OrderedDict | |
from pathlib import Path | |
import gradio as gr | |
import os | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from sam_diffsr.utils_sr.hparams import set_hparams, hparams | |
from sam_diffsr.utils_sr.matlab_resize import imresize | |
def get_img_data(img_PIL, hparams, sr_scale=4): | |
img_lr = img_PIL.convert('RGB') | |
img_lr = np.uint8(np.asarray(img_lr)) | |
h, w, c = img_lr.shape | |
h, w = h * sr_scale, w * sr_scale | |
h = h - h % (sr_scale * 2) | |
w = w - w % (sr_scale * 2) | |
h_l = h // sr_scale | |
w_l = w // sr_scale | |
img_lr = img_lr[:h_l, :w_l] | |
to_tensor_norm = transforms.Compose([ | |
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] | |
img_lr = torch.unsqueeze(img_lr, dim=0) | |
img_lr_up = torch.unsqueeze(img_lr_up, dim=0) | |
return img_lr, img_lr_up | |
def load_checkpoint(model, ckpt_path): | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
print(f'loding check from: {ckpt_path}') | |
stat_dict = checkpoint['state_dict']['model'] | |
new_state_dict = OrderedDict() | |
for k, v in stat_dict.items(): | |
if k[:7] == 'module.': | |
k = k[7:] # ε»ζ `module.` | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict) | |
# model.cuda() | |
del checkpoint | |
torch.cuda.empty_cache() | |
def model_init(ckpt_path): | |
set_hparams() | |
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer | |
trainer = trainer() | |
trainer.build_model() | |
load_checkpoint(trainer.model, ckpt_path) | |
torch.backends.cudnn.benchmark = False | |
return trainer | |
def image_infer(img_PIL): | |
with torch.no_grad(): | |
trainer.model.eval() | |
img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4) | |
# img_lr = img_lr.to('cuda') | |
# img_lr_up = img_lr_up.to('cuda') | |
img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape) | |
img_sr = img_sr.clamp(-1, 1) | |
img_sr = trainer.tensor2img(img_sr)[0] | |
img_sr = Image.fromarray(img_sr) | |
return img_sr | |
root_path = os.path.dirname(__file__) | |
cheetah = os.path.join(root_path, "images/0801x4.png") | |
print(cheetah) | |
ckpt_path = os.path.join(root_path, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt') | |
trainer = model_init(ckpt_path) | |
demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image", | |
# flagging_options=["blurry", "incorrect", "other"], | |
examples=[ | |
os.path.join(root_path, "images/0801x4.png"), | |
os.path.join(root_path, "images/0804x4.png"), | |
os.path.join(root_path, "images/0809x4.png"), | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |