File size: 2,510 Bytes
193c713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import importlib
import os
import sys
from collections import OrderedDict
from pathlib import Path

from tasks.srdiff_df2k import InferDataSet

parent_path = Path(__file__).absolute().parent.parent
sys.path.append(os.path.abspath(parent_path))
os.chdir(parent_path)
print(f'>-------------> parent path {parent_path}')
print(f'>-------------> current work dir {os.getcwd()}')

cache_path = os.path.join(parent_path, 'cache')
os.environ["HF_DATASETS_CACHE"] = cache_path
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["torch_HOME"] = cache_path

import torch
from PIL import Image
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from utils_sr.hparams import hparams, set_hparams


def load_ckpt(ckpt_path, model):
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    stat_dict = checkpoint['state_dict']['model']
    
    new_state_dict = OrderedDict()
    for k, v in stat_dict.items():
        if k[:7] == 'module.':
            k = k[7:]  # εŽ»ζŽ‰ `module.`
        new_state_dict[k] = v
    
    model.load_state_dict(new_state_dict)
    model.cuda()


def infer(trainer, ckpt_path, img_dir, save_dir):
    trainer.build_model()
    load_ckpt(ckpt_path, trainer.model)
    
    dataset = InferDataSet(img_dir)
    test_dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
    
    torch.backends.cudnn.benchmark = False
    
    with torch.no_grad():
        trainer.model.eval()
        pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
        for batch_idx, batch in pbar:
            img_lr, img_lr_up, img_name = batch
            
            img_lr = img_lr.to('cuda')
            img_lr_up = img_lr_up.to('cuda')
            
            img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
            
            img_sr = img_sr.clamp(-1, 1)
            img_sr = trainer.tensor2img(img_sr)[0]
            img_sr = Image.fromarray(img_sr)
            img_sr.save(os.path.join(save_dir, img_name[0]))


if __name__ == '__main__':
    set_hparams()
    
    img_dir = hparams['img_dir']
    save_dir = hparams['save_dir']
    ckpt_path = hparams['ckpt_path']
    
    pkg = ".".join(hparams["trainer_cls"].split(".")[:-1])
    cls_name = hparams["trainer_cls"].split(".")[-1]
    trainer = getattr(importlib.import_module(pkg), cls_name)()
    
    os.makedirs(save_dir, exist_ok=True)
    
    infer(trainer, ckpt_path, img_dir, save_dir)