|
import json, os, random, math |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision.transforms as transforms |
|
|
|
import numpy as np |
|
from PIL import Image, ImageOps |
|
from .base_dataset import BaseDataset, check_filenames_in_zipdata |
|
from io import BytesIO |
|
|
|
|
|
|
|
|
|
def clean_annotations(annotations): |
|
for anno in annotations: |
|
anno.pop("segmentation", None) |
|
anno.pop("area", None) |
|
anno.pop("iscrowd", None) |
|
anno.pop("id", None) |
|
|
|
|
|
def make_a_sentence(obj_names, clean=False): |
|
|
|
if clean: |
|
obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names] |
|
|
|
caption = "" |
|
tokens_positive = [] |
|
for obj_name in obj_names: |
|
start_len = len(caption) |
|
caption += obj_name |
|
end_len = len(caption) |
|
caption += ", " |
|
tokens_positive.append( |
|
[[start_len, end_len]] |
|
) |
|
caption = caption[:-2] |
|
|
|
return caption |
|
|
|
|
|
class LayoutDataset(BaseDataset): |
|
""" |
|
Note: this dataset can somehow be achieved in cd_dataset.CDDataset |
|
Since if you donot set prob_real_caption=0 in CDDataset, then that |
|
dataset will only use detection annotations. However, in that dataset, |
|
we do not remove images but remove boxes. |
|
|
|
However, in layout2img works, people will just resize raw image data into 256*256, |
|
thus they pre-calculate box size and apply min_box_size before min/max_boxes_per_image. |
|
And then they will remove images if does not follow the rule. |
|
|
|
These two different methods will lead to different number of training/val images. |
|
Thus this dataset here is only for layout2img. |
|
|
|
""" |
|
def __init__(self, |
|
image_root, |
|
instances_json_path, |
|
stuff_json_path, |
|
category_embedding_path, |
|
fake_caption_type = 'empty', |
|
image_size=256, |
|
max_samples=None, |
|
min_box_size=0.02, |
|
min_boxes_per_image=3, |
|
max_boxes_per_image=8, |
|
include_other=False, |
|
random_flip=True |
|
): |
|
super().__init__(random_crop=None, random_flip=None, image_size=None) |
|
|
|
assert fake_caption_type in ['empty', 'made'] |
|
self.image_root = image_root |
|
self.instances_json_path = instances_json_path |
|
self.stuff_json_path = stuff_json_path |
|
self.category_embedding_path = category_embedding_path |
|
self.fake_caption_type = fake_caption_type |
|
self.image_size = image_size |
|
self.max_samples = max_samples |
|
self.min_box_size = min_box_size |
|
self.min_boxes_per_image = min_boxes_per_image |
|
self.max_boxes_per_image = max_boxes_per_image |
|
self.include_other = include_other |
|
self.random_flip = random_flip |
|
|
|
|
|
self.transform = transforms.Compose([transforms.Resize( (image_size, image_size) ), |
|
transforms.ToTensor(), |
|
transforms.Lambda(lambda t: (t * 2) - 1) ]) |
|
|
|
|
|
with open(instances_json_path, 'r') as f: |
|
instances_data = json.load(f) |
|
clean_annotations(instances_data["annotations"]) |
|
self.instances_data = instances_data |
|
|
|
with open(stuff_json_path, 'r') as f: |
|
stuff_data = json.load(f) |
|
clean_annotations(stuff_data["annotations"]) |
|
self.stuff_data = stuff_data |
|
|
|
|
|
|
|
self.category_embeddings = torch.load(category_embedding_path) |
|
self.embedding_len = list( self.category_embeddings.values() )[0].shape[0] |
|
|
|
|
|
|
|
self.image_ids = [] |
|
self.image_id_to_filename = {} |
|
self.image_id_to_size = {} |
|
assert instances_data['images'] == stuff_data["images"] |
|
for image_data in instances_data['images']: |
|
image_id = image_data['id'] |
|
filename = image_data['file_name'] |
|
width = image_data['width'] |
|
height = image_data['height'] |
|
self.image_ids.append(image_id) |
|
self.image_id_to_filename[image_id] = filename |
|
self.image_id_to_size[image_id] = (width, height) |
|
|
|
|
|
self.things_id_list = [] |
|
self.stuff_id_list = [] |
|
self.object_idx_to_name = {} |
|
for category_data in instances_data['categories']: |
|
self.things_id_list.append( category_data['id'] ) |
|
self.object_idx_to_name[category_data['id']] = category_data['name'] |
|
for category_data in stuff_data['categories']: |
|
self.stuff_id_list.append( category_data['id'] ) |
|
self.object_idx_to_name[category_data['id']] = category_data['name'] |
|
self.all_categories = [ self.object_idx_to_name.get(k, None) for k in range(183+1) ] |
|
|
|
|
|
|
|
self.image_id_to_objects = defaultdict(list) |
|
self.select_objects( instances_data['annotations'] ) |
|
self.select_objects( stuff_data['annotations'] ) |
|
|
|
|
|
|
|
new_image_ids = [] |
|
for image_id in self.image_ids: |
|
num_objs = len(self.image_id_to_objects[image_id]) |
|
if self.min_boxes_per_image <= num_objs <= self.max_boxes_per_image: |
|
new_image_ids.append(image_id) |
|
self.image_ids = new_image_ids |
|
|
|
|
|
|
|
all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids] |
|
check_filenames_in_zipdata(all_filenames, image_root) |
|
|
|
|
|
|
|
def select_objects(self, annotations): |
|
for object_anno in annotations: |
|
image_id = object_anno['image_id'] |
|
_, _, w, h = object_anno['bbox'] |
|
W, H = self.image_id_to_size[image_id] |
|
box_area = (w * h) / (W * H) |
|
box_ok = box_area > self.min_box_size |
|
object_name = self.object_idx_to_name[object_anno['category_id']] |
|
other_ok = object_name != 'other' or self.include_other |
|
if box_ok and other_ok: |
|
self.image_id_to_objects[image_id].append(object_anno) |
|
|
|
|
|
def total_images(self): |
|
return len(self) |
|
|
|
|
|
def __getitem__(self, index): |
|
if self.max_boxes_per_image > 99: |
|
assert False, "Are you sure setting such large number of boxes?" |
|
|
|
out = {} |
|
|
|
image_id = self.image_ids[index] |
|
out['id'] = image_id |
|
|
|
flip = self.random_flip and random.random()<0.5 |
|
|
|
|
|
filename = self.image_id_to_filename[image_id] |
|
zip_file = self.fetch_zipfile(self.image_root) |
|
image = Image.open(BytesIO(zip_file.read(filename))).convert('RGB') |
|
WW, HH = image.size |
|
if flip: |
|
image = ImageOps.mirror(image) |
|
out["image"] = self.transform(image) |
|
|
|
this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id]) |
|
|
|
|
|
obj_names = [] |
|
boxes = torch.zeros(self.max_boxes_per_image, 4) |
|
masks = torch.zeros(self.max_boxes_per_image) |
|
positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len) |
|
for idx, object_anno in enumerate(this_image_obj_annos): |
|
obj_name = self.object_idx_to_name[ object_anno['category_id'] ] |
|
obj_names.append(obj_name) |
|
x, y, w, h = object_anno['bbox'] |
|
x0 = x / WW |
|
y0 = y / HH |
|
x1 = (x + w) / WW |
|
y1 = (y + h) / HH |
|
if flip: |
|
x0, x1 = 1-x1, 1-x0 |
|
boxes[idx] = torch.tensor([x0,y0,x1,y1]) |
|
masks[idx] = 1 |
|
positive_embeddings[idx] = self.category_embeddings[obj_name] |
|
|
|
if self.fake_caption_type == 'empty': |
|
caption = "" |
|
else: |
|
caption = make_a_sentence(obj_names, clean=True) |
|
|
|
out["caption"] = caption |
|
out["boxes"] = boxes |
|
out["masks"] = masks |
|
out["positive_embeddings"] = positive_embeddings |
|
|
|
|
|
return out |
|
|
|
|
|
def __len__(self): |
|
if self.max_samples is None: |
|
return len(self.image_ids) |
|
return min(len(self.image_ids), self.max_samples) |
|
|
|
|
|
|