Spaces:
Runtime error
Runtime error
import os | |
import random | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from sam_diffsr.tasks.srdiff import SRDiffTrainer | |
from sam_diffsr.utils_sr.dataset import SRDataSet | |
from sam_diffsr.utils_sr.hparams import hparams | |
from sam_diffsr.utils_sr.matlab_resize import imresize | |
class InferDataSet(Dataset): | |
def __init__(self, img_dir): | |
super().__init__() | |
self.img_path_list = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)] | |
self.to_tensor_norm = transforms.Compose([ | |
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
def __getitem__(self, index): | |
sr_scale = hparams['sr_scale'] | |
img_path = self.img_path_list[index] | |
img_name = os.path.basename(img_path) | |
img_lr = Image.open(img_path).convert('RGB') | |
img_lr = np.uint8(np.asarray(img_lr)) | |
h, w, c = img_lr.shape | |
h, w = h * sr_scale, w * sr_scale | |
h = h - h % (sr_scale * 2) | |
w = w - w % (sr_scale * 2) | |
h_l = h // sr_scale | |
w_l = w // sr_scale | |
img_lr = img_lr[:h_l, :w_l] | |
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] | |
return img_lr, img_lr_up, img_name | |
def __len__(self): | |
return len(self.img_path_list) | |
class Df2kDataSet(SRDataSet): | |
def __init__(self, prefix='train'): | |
if prefix == 'valid': | |
_prefix = 'test' | |
else: | |
_prefix = prefix | |
super().__init__(_prefix) | |
self.patch_size = hparams['patch_size'] | |
self.patch_size_lr = hparams['patch_size'] // hparams['sr_scale'] | |
if prefix == 'valid': | |
self.len = hparams['eval_batch_size'] * hparams['valid_steps'] | |
self.data_aug_transforms = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomRotation(20, resample=Image.BICUBIC), | |
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), | |
]) | |
def __getitem__(self, index): | |
item = self._get_item(index) | |
hparams = self.hparams | |
sr_scale = hparams['sr_scale'] | |
img_hr = np.uint8(item['img']) | |
img_lr = np.uint8(item['img_lr']) | |
# TODO: clip for SRFlow | |
h, w, c = img_hr.shape | |
h = h - h % (sr_scale * 2) | |
w = w - w % (sr_scale * 2) | |
h_l = h // sr_scale | |
w_l = w // sr_scale | |
img_hr = img_hr[:h, :w] | |
img_lr = img_lr[:h_l, :w_l] | |
# random crop | |
if self.prefix == 'train': | |
if self.data_augmentation and random.random() < 0.5: | |
img_hr, img_lr = self.data_augment(img_hr, img_lr) | |
i = random.randint(0, h - self.patch_size) // sr_scale * sr_scale | |
i_lr = i // sr_scale | |
j = random.randint(0, w - self.patch_size) // sr_scale * sr_scale | |
j_lr = j // sr_scale | |
img_hr = img_hr[i:i + self.patch_size, j:j + self.patch_size] | |
img_lr = img_lr[i_lr:i_lr + self.patch_size_lr, j_lr:j_lr + self.patch_size_lr] | |
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]] | |
return { | |
'img_hr': img_hr, 'img_lr': img_lr, | |
'img_lr_up': img_lr_up, 'item_name': item['item_name'], | |
'loc': np.array(item['loc']), 'loc_bdr': np.array(item['loc_bdr']) | |
} | |
def __len__(self): | |
return self.len | |
def data_augment(self, img_hr, img_lr): | |
sr_scale = self.hparams['sr_scale'] | |
img_hr = Image.fromarray(img_hr) | |
img_hr = self.data_aug_transforms(img_hr) | |
img_hr = np.asarray(img_hr) # np.uint8 [H, W, C] | |
img_lr = imresize(img_hr, 1 / sr_scale) | |
return img_hr, img_lr | |
class SRDiffDf2k(SRDiffTrainer): | |
def __init__(self): | |
super().__init__() | |
self.dataset_cls = Df2kDataSet | |