Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import re | |
import subprocess | |
from collections import OrderedDict | |
import lpips | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from skimage.metrics import peak_signal_noise_ratio as psnr | |
from skimage.metrics import structural_similarity as ssim | |
from .matlab_resize import imresize | |
def reduce_tensors(metrics): | |
new_metrics = {} | |
for k, v in metrics.items(): | |
if isinstance(v, torch.Tensor): | |
dist.all_reduce(v) | |
v = v / dist.get_world_size() | |
if type(v) is dict: | |
v = reduce_tensors(v) | |
new_metrics[k] = v | |
return new_metrics | |
def tensors_to_scalars(tensors): | |
if isinstance(tensors, torch.Tensor): | |
tensors = tensors.item() | |
return tensors | |
elif isinstance(tensors, dict): | |
new_tensors = {} | |
for k, v in tensors.items(): | |
v = tensors_to_scalars(v) | |
new_tensors[k] = v | |
return new_tensors | |
elif isinstance(tensors, list): | |
return [tensors_to_scalars(v) for v in tensors] | |
else: | |
return tensors | |
def tensors_to_np(tensors): | |
if isinstance(tensors, dict): | |
new_np = {} | |
for k, v in tensors.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.cpu().numpy() | |
if type(v) is dict: | |
v = tensors_to_np(v) | |
new_np[k] = v | |
elif isinstance(tensors, list): | |
new_np = [] | |
for v in tensors: | |
if isinstance(v, torch.Tensor): | |
v = v.cpu().numpy() | |
if type(v) is dict: | |
v = tensors_to_np(v) | |
new_np.append(v) | |
elif isinstance(tensors, torch.Tensor): | |
v = tensors | |
if isinstance(v, torch.Tensor): | |
v = v.cpu().numpy() | |
if type(v) is dict: | |
v = tensors_to_np(v) | |
new_np = v | |
else: | |
raise Exception(f'tensors_to_np does not support type {type(tensors)}.') | |
return new_np | |
def move_to_cpu(tensors): | |
ret = {} | |
for k, v in tensors.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.cpu() | |
if type(v) is dict: | |
v = move_to_cpu(v) | |
ret[k] = v | |
return ret | |
def move_to_cuda(batch, gpu_id=0): | |
# base case: object can be directly moved using `cuda` or `to` | |
if callable(getattr(batch, 'cuda', None)): | |
return batch.cuda(gpu_id, non_blocking=True) | |
elif callable(getattr(batch, 'to', None)): | |
return batch.to(torch.device('cuda', gpu_id), non_blocking=True) | |
elif isinstance(batch, list): | |
for i, x in enumerate(batch): | |
batch[i] = move_to_cuda(x, gpu_id) | |
return batch | |
elif isinstance(batch, tuple): | |
batch = list(batch) | |
for i, x in enumerate(batch): | |
batch[i] = move_to_cuda(x, gpu_id) | |
return tuple(batch) | |
elif isinstance(batch, dict): | |
for k, v in batch.items(): | |
batch[k] = move_to_cuda(v, gpu_id) | |
return batch | |
return batch | |
def get_last_checkpoint(work_dir, steps=None): | |
checkpoint = None | |
last_ckpt_path = None | |
ckpt_paths = get_all_ckpts(work_dir, steps) | |
if len(ckpt_paths) > 0: | |
last_ckpt_path = ckpt_paths[0] | |
checkpoint = torch.load(last_ckpt_path, map_location='cpu') | |
return checkpoint, last_ckpt_path | |
def get_all_ckpts(work_dir, steps=None): | |
if steps is None: | |
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' | |
else: | |
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' | |
return sorted(glob.glob(ckpt_path_pattern), | |
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) | |
def load_checkpoint(model, optimizer, work_dir, steps=None): | |
checkpoint, last_ckpt_path = get_last_checkpoint(work_dir, steps) | |
print(f'loding check from: {last_ckpt_path}') | |
if checkpoint is not None: | |
stat_dict = checkpoint['state_dict']['model'] | |
new_state_dict = OrderedDict() | |
for k, v in stat_dict.items(): | |
if k[:7] == 'module.': | |
k = k[7:] # ε»ζ `module.` | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict) | |
model.cuda() | |
optimizer.load_state_dict(checkpoint['optimizer_states'][0]) | |
training_step = checkpoint['global_step'] | |
del checkpoint | |
torch.cuda.empty_cache() | |
else: | |
training_step = 0 | |
model.cuda() | |
return training_step | |
def save_checkpoint(model, optimizer, work_dir, global_step, num_ckpt_keep): | |
ckpt_path = f'{work_dir}/model_ckpt_steps_{global_step}.ckpt' | |
print(f'Step@{global_step}: saving model to {ckpt_path}') | |
checkpoint = {'global_step': global_step} | |
optimizer_states = [] | |
optimizer_states.append(optimizer.state_dict()) | |
checkpoint['optimizer_states'] = optimizer_states | |
checkpoint['state_dict'] = {'model': model.state_dict()} | |
torch.save(checkpoint, ckpt_path, _use_new_zipfile_serialization=False) | |
for old_ckpt in get_all_ckpts(work_dir)[num_ckpt_keep:]: | |
remove_file(old_ckpt) | |
print(f'Delete ckpt: {os.path.basename(old_ckpt)}') | |
def remove_file(*fns): | |
for f in fns: | |
subprocess.check_call(f'rm -rf "{f}"', shell=True) | |
def plot_img(img): | |
img = img.data.cpu().numpy() | |
return np.clip(img, 0, 1) | |
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): | |
if os.path.isfile(ckpt_base_dir): | |
base_dir = os.path.dirname(ckpt_base_dir) | |
ckpt_path = ckpt_base_dir | |
checkpoint = torch.load(ckpt_base_dir, map_location='cpu') | |
else: | |
base_dir = ckpt_base_dir | |
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) | |
if checkpoint is not None: | |
state_dict = checkpoint["state_dict"] | |
if len([k for k in state_dict.keys() if '.' in k]) > 0: | |
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() | |
if k.startswith(f'{model_name}.')} | |
else: | |
state_dict = state_dict[model_name] | |
if not strict: | |
cur_model_state_dict = cur_model.state_dict() | |
unmatched_keys = [] | |
for key, param in state_dict.items(): | |
if key in cur_model_state_dict: | |
new_param = cur_model_state_dict[key] | |
if new_param.shape != param.shape: | |
unmatched_keys.append(key) | |
print("| Unmatched keys: ", key, new_param.shape, param.shape) | |
for key in unmatched_keys: | |
del state_dict[key] | |
cur_model.load_state_dict(state_dict, strict=strict) | |
print(f"| load '{model_name}' from '{ckpt_path}'.") | |
else: | |
e_msg = f"| ckpt not found in {base_dir}." | |
if force: | |
assert False, e_msg | |
else: | |
print(e_msg) | |
class Measure: | |
def __init__(self, net='alex'): | |
self.model = lpips.LPIPS(net=net) | |
def measure(self, imgA, imgB, img_lr, sr_scale): | |
""" | |
Args: | |
imgA: [C, H, W] uint8 or torch.FloatTensor [-1,1] | |
imgB: [C, H, W] uint8 or torch.FloatTensor [-1,1] | |
img_lr: [C, H, W] uint8 or torch.FloatTensor [-1,1] | |
sr_scale: | |
Returns: dict of metrics | |
""" | |
if isinstance(imgA, torch.Tensor): | |
imgA = np.round((imgA.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8) | |
imgB = np.round((imgB.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8) | |
img_lr = np.round((img_lr.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8) | |
imgA = imgA.transpose(1, 2, 0) | |
imgA_lr = imresize(imgA, 1 / sr_scale) | |
imgB = imgB.transpose(1, 2, 0) | |
img_lr = img_lr.transpose(1, 2, 0) | |
psnr = self.psnr(imgA, imgB) | |
ssim = self.ssim(imgA, imgB) | |
lpips = self.lpips(imgA, imgB) | |
lr_psnr = self.psnr(imgA_lr, img_lr) | |
res = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips, 'lr_psnr': lr_psnr} | |
return {k: float(v) for k, v in res.items()} | |
def lpips(self, imgA, imgB, model=None): | |
device = next(self.model.parameters()).device | |
tA = t(imgA).to(device) | |
tB = t(imgB).to(device) | |
dist01 = self.model.forward(tA, tB).item() | |
return dist01 | |
def ssim(self, imgA, imgB): | |
score, diff = ssim(imgA, imgB, full=True, channel_axis=2, data_range=255) | |
return score | |
def psnr(self, imgA, imgB): | |
return psnr(imgA, imgB, data_range=255) | |
def t(img): | |
def to_4d(img): | |
assert len(img.shape) == 3 | |
img_new = np.expand_dims(img, axis=0) | |
assert len(img_new.shape) == 4 | |
return img_new | |
def to_CHW(img): | |
return np.transpose(img, [2, 0, 1]) | |
def to_tensor(img): | |
return torch.Tensor(img) | |
return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1 | |