File size: 3,056 Bytes
193c713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os.path

import torch

from sam_diffsr.models_sr.diffsr_modules import Unet, RRDBNet
from sam_diffsr.models_sr.diffusion import GaussianDiffusion
from sam_diffsr.tasks.trainer import Trainer
from sam_diffsr.utils_sr.hparams import hparams
from sam_diffsr.utils_sr.utils import load_ckpt


class SRDiffTrainer(Trainer):
    def build_model(self):
        hidden_size = hparams['hidden_size']
        dim_mults = hparams['unet_dim_mults']
        dim_mults = [int(x) for x in dim_mults.split('|')]
        denoise_fn = Unet(
                hidden_size, out_dim=3, cond_dim=hparams['rrdb_num_feat'], dim_mults=dim_mults)
        if hparams['use_rrdb']:
            rrdb = RRDBNet(3, 3, hparams['rrdb_num_feat'], hparams['rrdb_num_block'],
                           hparams['rrdb_num_feat'] // 2)
            if hparams['rrdb_ckpt'] != '' and os.path.exists(hparams['rrdb_ckpt']):
                load_ckpt(rrdb, hparams['rrdb_ckpt'])
        else:
            rrdb = None
        self.model = GaussianDiffusion(
                denoise_fn=denoise_fn,
                rrdb_net=rrdb,
                timesteps=hparams['timesteps'],
                loss_type=hparams['loss_type']
        )
        self.global_step = 0
        return self.model
    
    def sample_and_test(self, sample):
        ret = {k: 0 for k in self.metric_keys}
        ret['n_samples'] = 0
        img_hr = sample['img_hr']
        img_lr = sample['img_lr']
        img_lr_up = sample['img_lr_up']
        img_sr, rrdb_out = self.model.sample(img_lr, img_lr_up, img_hr.shape)
        for b in range(img_sr.shape[0]):
            s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
            ret['psnr'] += s['psnr']
            ret['ssim'] += s['ssim']
            ret['lpips'] += s['lpips']
            ret['lr_psnr'] += s['lr_psnr']
            ret['n_samples'] += 1
        return img_sr, rrdb_out, ret
    
    def build_optimizer(self, model):
        params = list(model.named_parameters())
        if not hparams['fix_rrdb']:
            params = [p for p in params if 'rrdb' not in p[0]]
        params = [p[1] for p in params]
        return torch.optim.Adam(params, lr=hparams['lr'])
    
    def build_scheduler(self, optimizer):
        if 'scheduler' in hparams:
            scheduler_config = hparams['scheduler']
            if scheduler_config['type'] == 'cosine':
                lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, hparams['max_updates'],
                                                                          eta_min=scheduler_config['eta_min'])
        
        else:
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
        
        return lr_scheduler
    
    def training_step(self, batch):
        img_hr = batch['img_hr']
        img_lr = batch['img_lr']
        img_lr_up = batch['img_lr_up']
        losses, _, _ = self.model(img_hr, img_lr, img_lr_up)
        total_loss = sum(losses.values())
        return losses, total_loss