import glob import os import re import subprocess from collections import OrderedDict import lpips import numpy as np import torch import torch.distributed as dist from skimage.metrics import peak_signal_noise_ratio as psnr from skimage.metrics import structural_similarity as ssim from .matlab_resize import imresize def reduce_tensors(metrics): new_metrics = {} for k, v in metrics.items(): if isinstance(v, torch.Tensor): dist.all_reduce(v) v = v / dist.get_world_size() if type(v) is dict: v = reduce_tensors(v) new_metrics[k] = v return new_metrics def tensors_to_scalars(tensors): if isinstance(tensors, torch.Tensor): tensors = tensors.item() return tensors elif isinstance(tensors, dict): new_tensors = {} for k, v in tensors.items(): v = tensors_to_scalars(v) new_tensors[k] = v return new_tensors elif isinstance(tensors, list): return [tensors_to_scalars(v) for v in tensors] else: return tensors def tensors_to_np(tensors): if isinstance(tensors, dict): new_np = {} for k, v in tensors.items(): if isinstance(v, torch.Tensor): v = v.cpu().numpy() if type(v) is dict: v = tensors_to_np(v) new_np[k] = v elif isinstance(tensors, list): new_np = [] for v in tensors: if isinstance(v, torch.Tensor): v = v.cpu().numpy() if type(v) is dict: v = tensors_to_np(v) new_np.append(v) elif isinstance(tensors, torch.Tensor): v = tensors if isinstance(v, torch.Tensor): v = v.cpu().numpy() if type(v) is dict: v = tensors_to_np(v) new_np = v else: raise Exception(f'tensors_to_np does not support type {type(tensors)}.') return new_np def move_to_cpu(tensors): ret = {} for k, v in tensors.items(): if isinstance(v, torch.Tensor): v = v.cpu() if type(v) is dict: v = move_to_cpu(v) ret[k] = v return ret def move_to_cuda(batch, gpu_id=0): # base case: object can be directly moved using `cuda` or `to` if callable(getattr(batch, 'cuda', None)): return batch.cuda(gpu_id, non_blocking=True) elif callable(getattr(batch, 'to', None)): return batch.to(torch.device('cuda', gpu_id), non_blocking=True) elif isinstance(batch, list): for i, x in enumerate(batch): batch[i] = move_to_cuda(x, gpu_id) return batch elif isinstance(batch, tuple): batch = list(batch) for i, x in enumerate(batch): batch[i] = move_to_cuda(x, gpu_id) return tuple(batch) elif isinstance(batch, dict): for k, v in batch.items(): batch[k] = move_to_cuda(v, gpu_id) return batch return batch def get_last_checkpoint(work_dir, steps=None): checkpoint = None last_ckpt_path = None ckpt_paths = get_all_ckpts(work_dir, steps) if len(ckpt_paths) > 0: last_ckpt_path = ckpt_paths[0] checkpoint = torch.load(last_ckpt_path, map_location='cpu') return checkpoint, last_ckpt_path def get_all_ckpts(work_dir, steps=None): if steps is None: ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' else: ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' return sorted(glob.glob(ckpt_path_pattern), key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) def load_checkpoint(model, optimizer, work_dir, steps=None): checkpoint, last_ckpt_path = get_last_checkpoint(work_dir, steps) print(f'loding check from: {last_ckpt_path}') if checkpoint is not None: stat_dict = checkpoint['state_dict']['model'] new_state_dict = OrderedDict() for k, v in stat_dict.items(): if k[:7] == 'module.': k = k[7:] # 去掉 `module.` new_state_dict[k] = v model.load_state_dict(new_state_dict) model.cuda() optimizer.load_state_dict(checkpoint['optimizer_states'][0]) training_step = checkpoint['global_step'] del checkpoint torch.cuda.empty_cache() else: training_step = 0 model.cuda() return training_step def save_checkpoint(model, optimizer, work_dir, global_step, num_ckpt_keep): ckpt_path = f'{work_dir}/model_ckpt_steps_{global_step}.ckpt' print(f'Step@{global_step}: saving model to {ckpt_path}') checkpoint = {'global_step': global_step} optimizer_states = [] optimizer_states.append(optimizer.state_dict()) checkpoint['optimizer_states'] = optimizer_states checkpoint['state_dict'] = {'model': model.state_dict()} torch.save(checkpoint, ckpt_path, _use_new_zipfile_serialization=False) for old_ckpt in get_all_ckpts(work_dir)[num_ckpt_keep:]: remove_file(old_ckpt) print(f'Delete ckpt: {os.path.basename(old_ckpt)}') def remove_file(*fns): for f in fns: subprocess.check_call(f'rm -rf "{f}"', shell=True) def plot_img(img): img = img.data.cpu().numpy() return np.clip(img, 0, 1) def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): if os.path.isfile(ckpt_base_dir): base_dir = os.path.dirname(ckpt_base_dir) ckpt_path = ckpt_base_dir checkpoint = torch.load(ckpt_base_dir, map_location='cpu') else: base_dir = ckpt_base_dir checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) if checkpoint is not None: state_dict = checkpoint["state_dict"] if len([k for k in state_dict.keys() if '.' in k]) > 0: state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() if k.startswith(f'{model_name}.')} else: state_dict = state_dict[model_name] if not strict: cur_model_state_dict = cur_model.state_dict() unmatched_keys = [] for key, param in state_dict.items(): if key in cur_model_state_dict: new_param = cur_model_state_dict[key] if new_param.shape != param.shape: unmatched_keys.append(key) print("| Unmatched keys: ", key, new_param.shape, param.shape) for key in unmatched_keys: del state_dict[key] cur_model.load_state_dict(state_dict, strict=strict) print(f"| load '{model_name}' from '{ckpt_path}'.") else: e_msg = f"| ckpt not found in {base_dir}." if force: assert False, e_msg else: print(e_msg) class Measure: def __init__(self, net='alex'): self.model = lpips.LPIPS(net=net) def measure(self, imgA, imgB, img_lr, sr_scale): """ Args: imgA: [C, H, W] uint8 or torch.FloatTensor [-1,1] imgB: [C, H, W] uint8 or torch.FloatTensor [-1,1] img_lr: [C, H, W] uint8 or torch.FloatTensor [-1,1] sr_scale: Returns: dict of metrics """ if isinstance(imgA, torch.Tensor): imgA = np.round((imgA.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8) imgB = np.round((imgB.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8) img_lr = np.round((img_lr.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8) imgA = imgA.transpose(1, 2, 0) imgA_lr = imresize(imgA, 1 / sr_scale) imgB = imgB.transpose(1, 2, 0) img_lr = img_lr.transpose(1, 2, 0) psnr = self.psnr(imgA, imgB) ssim = self.ssim(imgA, imgB) lpips = self.lpips(imgA, imgB) lr_psnr = self.psnr(imgA_lr, img_lr) res = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips, 'lr_psnr': lr_psnr} return {k: float(v) for k, v in res.items()} def lpips(self, imgA, imgB, model=None): device = next(self.model.parameters()).device tA = t(imgA).to(device) tB = t(imgB).to(device) dist01 = self.model.forward(tA, tB).item() return dist01 def ssim(self, imgA, imgB): score, diff = ssim(imgA, imgB, full=True, channel_axis=2, data_range=255) return score def psnr(self, imgA, imgB): return psnr(imgA, imgB, data_range=255) def t(img): def to_4d(img): assert len(img.shape) == 3 img_new = np.expand_dims(img, axis=0) assert len(img_new.shape) == 4 return img_new def to_CHW(img): return np.transpose(img, [2, 0, 1]) def to_tensor(img): return torch.Tensor(img) return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1