.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..baea4fab397c5fc584d20c8406f7d409ddc50016
--- /dev/null
+++ b/README.md
@@ -0,0 +1,21 @@
+
+---
+title: HEAT
+emoji: 📈
+colorFrom: indigo
+colorTo: yellow
+sdk: gradio
+sdk_version: 3.11.0
+app_file: app.py
+pinned: false
+license: apache-2.0
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a312bc5abfd2f432005727494d70979e7c89b86
--- /dev/null
+++ b/app.py
@@ -0,0 +1,33 @@
+'''
+Author: Egrt
+Date: 2022-01-13 13:34:10
+LastEditors: [egrt]
+LastEditTime: 2022-08-15 19:40:32
+FilePath: \MaskGAN\app.py
+'''
+from HEAT import HEAT
+import gradio as gr
+import os
+heat = HEAT()
+
+# --------模型推理---------- #
+def inference(img):
+ image_result = heat.detect_one_image(img)
+ return image_result
+
+# --------网页信息---------- #
+title = "HEAT"
+description = "HEAT: Holistic Edge Attention Transformer for Structured Reconstruction @Luuuu"
+article = "HEAT: Holistic Edge Attention Transformer for Structured Reconstruction | Github Repo
"
+example_img_dir = 'images/'
+example_img_name = os.listdir(example_img_dir)
+examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith(('.jpg','.jpeg', '.png'))]
+gr.Interface(
+ inference,
+ [gr.inputs.Image(type="pil", label="Input")],
+ gr.outputs.Image(type="pil", label="Output"),
+ title=title,
+ description=description,
+ article=article,
+ examples=examples
+ ).launch()
diff --git a/arguments.py b/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c4c7f253a4c3a844ca9a9cddce3deba411d27c7
--- /dev/null
+++ b/arguments.py
@@ -0,0 +1,33 @@
+import argparse
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
+ parser.add_argument('--exp_dataset', default='outdoor',
+ help='the dataset for experiments, outdoor/s3d_floorplan')
+ parser.add_argument('--lr', default=2e-4, type=float)
+ parser.add_argument('--batch_size', default=16, type=int)
+ parser.add_argument('--weight_decay', default=1e-5, type=float)
+ parser.add_argument('--epochs', default=800, type=int)
+ parser.add_argument('--lr_drop', default=600, type=int)
+ parser.add_argument('--clip_max_norm', default=0.1, type=float,
+ help='gradient clipping max norm')
+ parser.add_argument('--print_freq', default=40, type=int)
+ parser.add_argument('--output_dir', default='./checkpoints/ckpts_heat_outdoor_256',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--resume', default='',
+ help='resume from checkpoint')
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--num_workers', default=4, type=int)
+ parser.add_argument('--image_size', default=256, type=int)
+ parser.add_argument('--max_corner_num', default=150, type=int,
+ help='the max number of corners allowed in the experiments')
+ parser.add_argument('--corner_to_edge_multiplier', default=3, type=int,
+ help='the max number of edges based on the number of corner candidates (assuming the '
+ 'average degree never greater than 6)')
+ parser.add_argument('--lambda_corner', default=0.05, type=float,
+ help='the max number of corners allowed in the experiments')
+ parser.add_argument('--run_validation', action='store_true',
+ help='Whether run validation or not, default: False')
+ return parser
diff --git a/assets/img/pipeline.png b/assets/img/pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..33fc362e7e89c0857fb021ba025052df56473c15
Binary files /dev/null and b/assets/img/pipeline.png differ
diff --git a/assets/img/problem_description.png b/assets/img/problem_description.png
new file mode 100644
index 0000000000000000000000000000000000000000..15b30e4831a7eead382ef100f4b52c1153ae5383
Binary files /dev/null and b/assets/img/problem_description.png differ
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/datasets/corners.py b/datasets/corners.py
new file mode 100644
index 0000000000000000000000000000000000000000..b27288021d6b43502fcdfd29b0cbc601646aa85c
--- /dev/null
+++ b/datasets/corners.py
@@ -0,0 +1,183 @@
+import numpy as np
+from torch.utils.data import Dataset
+from scipy.ndimage import gaussian_filter
+import cv2
+
+mean = [0.485, 0.456, 0.406]
+std = [0.229, 0.224, 0.225]
+
+
+class CornersDataset(Dataset):
+ def __init__(self, image_size=256, inference=False):
+ super(CornersDataset, self).__init__()
+ self.image_size = image_size
+ self.inference = inference
+ self._data_names = []
+
+ def __len__(self):
+ raise len(self._data_names)
+
+ def __getitem__(self, idx):
+ raise NotImplementedError
+
+ def process_data(self, data):
+ img = data['image']
+ corners = data['corners']
+ annot = data['annot']
+
+ # pre-process the image to use ImageNet-pretrained backbones
+ img = img.transpose((2, 0, 1))
+ raw_img = img.copy()
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
+ img = img.astype(np.float32)
+
+ corners = np.array(corners)
+
+ all_data = {
+ "annot": annot,
+ "name": data['name'],
+ 'img': img,
+ 'annot_path': data['annot_path'],
+ 'img_path': data['img_path'],
+ 'det_path': data['det_path'],
+ 'raw_img': raw_img,
+ }
+
+ # corner labels for training
+ if not self.inference:
+ pixel_labels, gauss_labels = self.get_corner_labels(corners)
+ all_data['pixel_labels'] = pixel_labels
+ all_data['gauss_labels'] = gauss_labels
+
+ return all_data
+
+ def get_corner_labels(self, corners):
+ labels = np.zeros((self.image_size, self.image_size))
+ corners = corners.round()
+ xint, yint = corners[:, 0].astype(np.int), corners[:, 1].astype(np.int)
+ labels[yint, xint] = 1
+
+ gauss_labels = gaussian_filter(labels, sigma=2)
+ gauss_labels = gauss_labels / gauss_labels.max()
+ return labels, gauss_labels
+
+ def resize_data(self, image, annot, det_corners):
+ new_image = cv2.resize(image, (self.image_size, self.image_size))
+ new_annot = {}
+ r = self.image_size / 256
+ for c, connections in annot.items():
+ new_c = tuple(np.array(c) * r)
+ new_connections = [other_c * r for other_c in connections]
+ new_annot[new_c] = new_connections
+ new_dets = det_corners * r
+ return new_image, new_annot, new_dets
+
+ def random_aug_annot(self, img, annot, det_corners=None):
+ # do random flipping
+ img, annot, det_corners = self.random_flip(img, annot, det_corners)
+
+ # prepare random augmentation parameters (only do random rotation for now)
+ theta = np.random.randint(0, 360) / 360 * np.pi * 2
+ r = self.image_size / 256
+ origin = [127 * r, 127 * r]
+ p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
+ p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
+ p1_old = [127 * r, 127 * r - 100 * r] # y_axis
+ p2_old = [127 * r + 100 * r, 127 * r] # x_axis
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
+ M_rot = cv2.getAffineTransform(pts1, pts2)
+
+ # Combine annotation corners and detection corners
+ all_corners = list(annot.keys())
+ if det_corners is not None:
+ for i in range(det_corners.shape[0]):
+ all_corners.append(tuple(det_corners[i]))
+ all_corners_ = np.array(all_corners)
+
+ # Do the corner transform within a big matrix transformation
+ corner_mapping = dict()
+ ones = np.ones([all_corners_.shape[0], 1])
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
+ aug_corners = np.matmul(M_rot, all_corners_.T).T
+
+ for idx, corner in enumerate(all_corners):
+ corner_mapping[corner] = aug_corners[idx]
+
+ # If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
+ new_corners = np.array(list(corner_mapping.values()))
+ if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
+ # return self.random_aug_annot(img, annot, det_corners)
+ return img, annot, None, det_corners
+
+ # build the new annot dict
+ aug_annot = dict()
+ for corner, connections in annot.items():
+ new_corner = corner_mapping[corner]
+ tuple_new_corner = tuple(new_corner)
+ aug_annot[tuple_new_corner] = list()
+ for to_corner in connections:
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
+
+ # Also transform the image correspondingly
+ rows, cols, ch = img.shape
+ new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
+
+ y_start = (new_img.shape[0] - self.image_size) // 2
+ x_start = (new_img.shape[1] - self.image_size) // 2
+ aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
+
+ if det_corners is None:
+ return aug_img, aug_annot, corner_mapping, None
+ else:
+ aug_det_corners = list()
+ for corner in det_corners:
+ new_corner = corner_mapping[tuple(corner)]
+ aug_det_corners.append(new_corner)
+ aug_det_corners = np.array(aug_det_corners)
+ return aug_img, aug_annot, corner_mapping, aug_det_corners
+
+ def random_flip(self, img, annot, det_corners):
+ height, width, _ = img.shape
+ rand_int = np.random.randint(0, 4)
+ if rand_int == 0:
+ return img, annot, det_corners
+
+ all_corners = list(annot.keys())
+ if det_corners is not None:
+ for i in range(det_corners.shape[0]):
+ all_corners.append(tuple(det_corners[i]))
+ new_corners = np.array(all_corners)
+
+ if rand_int == 1:
+ img = img[:, ::-1, :]
+ new_corners[:, 0] = width - new_corners[:, 0]
+ elif rand_int == 2:
+ img = img[::-1, :, :]
+ new_corners[:, 1] = height - new_corners[:, 1]
+ else:
+ img = img[::-1, ::-1, :]
+ new_corners[:, 0] = width - new_corners[:, 0]
+ new_corners[:, 1] = height - new_corners[:, 1]
+
+ new_corners = np.clip(new_corners, 0, self.image_size - 1) # clip into [0, 255]
+ corner_mapping = dict()
+ for idx, corner in enumerate(all_corners):
+ corner_mapping[corner] = new_corners[idx]
+
+ aug_annot = dict()
+ for corner, connections in annot.items():
+ new_corner = corner_mapping[corner]
+ tuple_new_corner = tuple(new_corner)
+ aug_annot[tuple_new_corner] = list()
+ for to_corner in connections:
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
+
+ if det_corners is not None:
+ aug_det_corners = list()
+ for corner in det_corners:
+ new_corner = corner_mapping[tuple(corner)]
+ aug_det_corners.append(new_corner)
+ det_corners = np.array(aug_det_corners)
+
+ return img, aug_annot, det_corners
diff --git a/datasets/data_utils.py b/datasets/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2777cc59a8ced44bc86ae1415a7fe1c6ec415be
--- /dev/null
+++ b/datasets/data_utils.py
@@ -0,0 +1,57 @@
+from PIL import ImageFilter
+from torchvision import transforms
+import numpy as np
+from utils.nn_utils import positional_encoding_2d
+from torch.utils.data.dataloader import default_collate
+
+
+def RandomBlur(radius=2.):
+ blur = GaussianBlur(radius=radius)
+ full_transform = transforms.RandomApply([blur], p=.3)
+ return full_transform
+
+
+class ImageFilterTransform(object):
+
+ def __init__(self):
+ raise NotImplementedError
+
+ def __call__(self, img):
+ return img.filter(self.filter)
+
+
+class GaussianBlur(ImageFilterTransform):
+
+ def __init__(self, radius=2.):
+ self.filter = ImageFilter.GaussianBlur(radius=radius)
+
+
+def collate_fn(data):
+ batched_data = {}
+ for field in data[0].keys():
+ if field in ['annot', 'rec_mat']:
+ batch_values = [item[field] for item in data]
+ else:
+ batch_values = default_collate([d[field] for d in data])
+ if field in ['pixel_features', 'pixel_labels', 'gauss_labels']:
+ batch_values = batch_values.float()
+ batched_data[field] = batch_values
+
+ return batched_data
+
+
+def get_pixel_features(image_size, d_pe=128):
+ all_pe = positional_encoding_2d(d_pe, image_size, image_size)
+ pixels_x = np.arange(0, image_size)
+ pixels_y = np.arange(0, image_size)
+
+ xv, yv = np.meshgrid(pixels_x, pixels_y)
+ all_pixels = list()
+ for i in range(xv.shape[0]):
+ pixs = np.stack([xv[i], yv[i]], axis=-1)
+ all_pixels.append(pixs)
+ pixels = np.stack(all_pixels, axis=0)
+
+ pixel_features = all_pe[:, pixels[:, :, 1], pixels[:, :, 0]]
+ pixel_features = pixel_features.permute(1, 2, 0)
+ return pixels, pixel_features
\ No newline at end of file
diff --git a/datasets/outdoor_buildings.py b/datasets/outdoor_buildings.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a39e018c8edc3e8587cdb856415454ee2558931
--- /dev/null
+++ b/datasets/outdoor_buildings.py
@@ -0,0 +1,183 @@
+import numpy as np
+from datasets.corners import CornersDataset
+import os
+import skimage
+import cv2
+from torchvision import transforms
+from PIL import Image
+from datasets.data_utils import RandomBlur
+
+class OutdoorBuildingDataset(CornersDataset):
+ def __init__(self, data_path, det_path, phase='train', image_size=256, rand_aug=True,
+ inference=False):
+ super(OutdoorBuildingDataset, self).__init__(image_size, inference)
+ self.data_path = data_path
+ self.det_path = det_path
+ self.phase = phase
+ self.rand_aug = rand_aug
+ self.image_size = image_size
+ self.inference = inference
+
+ blur_transform = RandomBlur()
+ self.train_transform = transforms.Compose([
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
+ transforms.RandomGrayscale(p=0.3),
+ blur_transform])
+
+ if phase == 'train':
+ datalistfile = os.path.join(data_path, 'train_list.txt')
+ self.training = True
+ else:
+ datalistfile = os.path.join(data_path, 'valid_list.txt')
+ self.training = False
+ with open(datalistfile, 'r') as f:
+ _data_names = f.readlines()
+ if phase == 'train':
+ self._data_names = _data_names
+ else:
+ # based on the data split rule from previous works
+ if phase == 'valid':
+ self._data_names = _data_names[:50]
+ elif phase == 'test':
+ self._data_names = _data_names[50:]
+ else:
+ raise ValueError('Invalid phase {}'.format(phase))
+
+ def __len__(self):
+ return len(self._data_names)
+
+ def __getitem__(self, idx):
+ data_name = self._data_names[idx][:-1]
+ annot_path = os.path.join(self.data_path, 'annot', data_name + '.npy')
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
+ det_path = os.path.join(self.det_path, data_name + '.npy')
+ det_corners = np.array(np.load(det_path, allow_pickle=True)) # [N, 2]
+ det_corners = det_corners[:, ::-1] # turn into x,y format
+
+ img_path = os.path.join(self.data_path, 'rgb', data_name + '.jpg')
+ rgb = cv2.imread(img_path)
+
+ if self.image_size != 256:
+ rgb, annot, det_corners = self.resize_data(rgb, annot, det_corners)
+
+ if self.rand_aug:
+ image, annot, corner_mapping, det_corners = self.random_aug_annot(rgb, annot, det_corners=det_corners)
+ else:
+ image = rgb
+ rec_mat = None
+
+ corners = np.array(list(annot.keys()))[:, [1, 0]]
+
+ if not self.inference and len(corners) > 100:
+ new_idx = np.random.randint(0, len(self))
+ return self.__getitem__(new_idx)
+
+ if self.training:
+ # Add some randomness for g.t. corners
+ corners += np.random.normal(0, 0, size=corners.shape)
+ pil_img = Image.fromarray(image)
+ image = self.train_transform(pil_img)
+ image = np.array(image)
+ image = skimage.img_as_float(image)
+
+ # sort by the second value and then the first value, here the corners are in the format of (y, x)
+ sort_idx = np.lexsort(corners.T)
+ corners = corners[sort_idx]
+
+ corner_list = []
+ for corner_i in range(corners.shape[0]):
+ corner_list.append((corners[corner_i][1], corners[corner_i][0])) # to (x, y) format
+
+ raw_data = {
+ 'name': data_name,
+ 'corners': corner_list,
+ 'annot': annot,
+ 'image': image,
+ 'rec_mat': rec_mat,
+ 'annot_path': annot_path,
+ 'det_path': det_path,
+ 'img_path': img_path,
+ }
+
+ return self.process_data(raw_data)
+
+ def random_aug_annot(self, img, annot, det_corners=None):
+ # do random flipping
+ img, annot, det_corners = self.random_flip(img, annot, det_corners)
+
+ # prepare random augmentation parameters (only do random rotation for now)
+ theta = np.random.randint(0, 360) / 360 * np.pi * 2
+ r = self.image_size / 256
+ origin = [127 * r, 127 * r]
+ p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
+ p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
+ p1_old = [127 * r, 127 * r - 100 * r] # y_axis
+ p2_old = [127 * r + 100 * r, 127 * r] # x_axis
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
+ M_rot = cv2.getAffineTransform(pts1, pts2)
+
+ # Combine annotation corners and detection corners
+ all_corners = list(annot.keys())
+ if det_corners is not None:
+ for i in range(det_corners.shape[0]):
+ all_corners.append(tuple(det_corners[i]))
+ all_corners_ = np.array(all_corners)
+
+ # Do the corner transform within a big matrix transformation
+ corner_mapping = dict()
+ ones = np.ones([all_corners_.shape[0], 1])
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
+ aug_corners = np.matmul(M_rot, all_corners_.T).T
+
+ for idx, corner in enumerate(all_corners):
+ corner_mapping[corner] = aug_corners[idx]
+
+ # If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
+ new_corners = np.array(list(corner_mapping.values()))
+ if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
+ # return self.random_aug_annot(img, annot, det_corners)
+ return img, annot, None, det_corners
+
+ # build the new annot dict
+ aug_annot = dict()
+ for corner, connections in annot.items():
+ new_corner = corner_mapping[corner]
+ tuple_new_corner = tuple(new_corner)
+ aug_annot[tuple_new_corner] = list()
+ for to_corner in connections:
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
+
+ # Also transform the image correspondingly
+ rows, cols, ch = img.shape
+ new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
+
+ y_start = (new_img.shape[0] - self.image_size) // 2
+ x_start = (new_img.shape[1] - self.image_size) // 2
+ aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
+
+ if det_corners is None:
+ return aug_img, aug_annot, corner_mapping, None
+ else:
+ aug_det_corners = list()
+ for corner in det_corners:
+ new_corner = corner_mapping[tuple(corner)]
+ aug_det_corners.append(new_corner)
+ aug_det_corners = np.array(aug_det_corners)
+ return aug_img, aug_annot, corner_mapping, aug_det_corners
+
+
+
+if __name__ == '__main__':
+ from torch.utils.data import DataLoader
+
+ DATAPATH = './data/cities_dataset'
+ DET_PATH = './data/det_final'
+ train_dataset = OutdoorBuildingDataset(DATAPATH, DET_PATH, phase='train')
+ train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0,
+ collate_fn=collate_fn)
+ for i, item in enumerate(train_dataloader):
+ import pdb;
+
+ pdb.set_trace()
+ print(item)
diff --git a/datasets/s3d_floorplans.py b/datasets/s3d_floorplans.py
new file mode 100644
index 0000000000000000000000000000000000000000..a931cbcb4dc946fe493a37fc873579069d1dcc5d
--- /dev/null
+++ b/datasets/s3d_floorplans.py
@@ -0,0 +1,187 @@
+import numpy as np
+from datasets.corners import CornersDataset
+import os
+import skimage
+import cv2
+import itertools
+
+
+mean = [0.485, 0.456, 0.406]
+std = [0.229, 0.224, 0.225]
+
+all_combibations = dict()
+for length in range(2, 351):
+ ids = np.arange(length)
+ combs = np.array(list(itertools.combinations(ids, 2)))
+ all_combibations[length] = combs
+
+
+class S3DFloorplanDataset(CornersDataset):
+ def __init__(self, data_path, phase='train', image_size=256, rand_aug=True, inference=False):
+ super(S3DFloorplanDataset, self).__init__(image_size, inference)
+ self.data_path = data_path
+ self.phase = phase
+ self.rand_aug = rand_aug
+
+ if phase == 'train':
+ datalistfile = os.path.join(data_path, 'train_list.txt')
+ self.training = True
+ elif phase == 'valid':
+ datalistfile = os.path.join(data_path, 'valid_list.txt')
+ self.training = False
+ else:
+ datalistfile = os.path.join(data_path, 'test_list.txt')
+ self.training = False
+ with open(datalistfile, 'r') as f:
+ self._data_names = f.readlines()
+
+ def __len__(self):
+ return len(self._data_names)
+
+ def __getitem__(self, idx):
+ data_name = self._data_names[idx][:-1]
+ annot_path = os.path.join(self.data_path, 'annot', data_name + '.npy')
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
+
+ density_path = os.path.join(self.data_path, 'density', data_name + '.png')
+ normal_path = os.path.join(self.data_path, 'normals', data_name + '.png')
+
+ density = cv2.imread(density_path)
+ normal = cv2.imread(normal_path)
+ rgb = np.maximum(density, normal)
+
+ if self.image_size != 256:
+ rgb, annot, det_corners = self.resize_data(rgb, annot, None)
+
+ if self.rand_aug:
+ image, annot, _ = self.random_aug_annot(rgb, annot, det_corners=None)
+ else:
+ image = rgb
+ rec_mat = None
+
+ corners = np.array(list(annot.keys()))[:, [1, 0]]
+
+ if not self.inference and len(corners) > 150:
+ new_idx = np.random.randint(0, len(self))
+ return self.__getitem__(new_idx)
+
+ if self.training:
+ # Add some randomness for g.t. corners
+ corners += np.random.normal(0, 0, size=corners.shape)
+
+ image = skimage.img_as_float(image)
+
+ # sort by the second value and then the first value, here the corners are in the format of (y, x)
+ sort_idx = np.lexsort(corners.T)
+ corners = corners[sort_idx]
+
+ corner_list = []
+ for corner_i in range(corners.shape[0]):
+ corner_list.append((corners[corner_i][1], corners[corner_i][0])) # to (x, y) format
+
+ raw_data = {
+ 'name': data_name,
+ 'corners': corner_list,
+ 'annot': annot,
+ 'image': image,
+ 'rec_mat': rec_mat,
+ 'annot_path': annot_path,
+ 'img_path': density_path,
+ }
+
+ return self.process_data(raw_data)
+
+ def process_data(self, data):
+ img = data['image']
+ corners = data['corners']
+ annot = data['annot']
+
+ # pre-process the image to use ImageNet-pretrained backbones
+ img = img.transpose((2, 0, 1))
+ raw_img = img.copy()
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
+ img = img.astype(np.float32)
+
+ corners = np.array(corners)
+
+ all_data = {
+ "annot": annot,
+ "name": data['name'],
+ 'img': img,
+ 'annot_path': data['annot_path'],
+ 'img_path': data['img_path'],
+ 'raw_img': raw_img,
+ }
+
+ # corner labels
+ if not self.inference:
+ pixel_labels, gauss_labels = self.get_corner_labels(corners)
+ all_data['pixel_labels'] = pixel_labels
+ all_data['gauss_labels'] = gauss_labels
+
+ return all_data
+
+ def random_aug_annot(self, img, annot, det_corners=None):
+ # do random flipping
+ img, annot, det_corners = self.random_flip(img, annot, det_corners)
+ # return img, annot, None
+
+ # prepare random augmentation parameters (only do random rotation for now)
+ theta = np.random.randint(0, 360) / 360 * np.pi * 2
+ r = self.image_size / 256
+ origin = [127 * r, 127 * r]
+ p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
+ p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
+ p1_old = [127 * r, 127 * r - 100 * r] # y_axis
+ p2_old = [127 * r + 100 * r, 127 * r] # x_axis
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
+ M_rot = cv2.getAffineTransform(pts1, pts2)
+
+ # Combine annotation corners and detection corners
+ all_corners = list(annot.keys())
+ if det_corners is not None:
+ for i in range(det_corners.shape[0]):
+ all_corners.append(tuple(det_corners[i]))
+ all_corners_ = np.array(all_corners)
+
+ # Do the per-corner transform
+ # Done in a big matrix transformation to save processing time.
+ corner_mapping = dict()
+ ones = np.ones([all_corners_.shape[0], 1])
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
+ aug_corners = np.matmul(M_rot, all_corners_.T).T
+
+ for idx, corner in enumerate(all_corners):
+ corner_mapping[corner] = aug_corners[idx]
+
+ # If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
+ new_corners = np.array(list(corner_mapping.values()))
+ if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
+ # return self.random_aug_annot(img, annot, det_corners)
+ return img, annot, None
+
+ # build the new annot dict
+ aug_annot = dict()
+ for corner, connections in annot.items():
+ new_corner = corner_mapping[corner]
+ tuple_new_corner = tuple(new_corner)
+ aug_annot[tuple_new_corner] = list()
+ for to_corner in connections:
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
+
+ # Also transform the image correspondingly
+ rows, cols, ch = img.shape
+ new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
+
+ y_start = (new_img.shape[0] - self.image_size) // 2
+ x_start = (new_img.shape[1] - self.image_size) // 2
+ aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
+
+ return aug_img, aug_annot, None
+
+
+
+
+
+
diff --git a/images/test.jpg b/images/test.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a299d4d4a9763bf7b0748fcb0629509cfdb0d626
Binary files /dev/null and b/images/test.jpg differ
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d213a9889f374377eabdc94caebc7c84738c2ea
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,455 @@
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+from datasets.outdoor_buildings import OutdoorBuildingDataset
+from datasets.s3d_floorplans import S3DFloorplanDataset
+from datasets.data_utils import collate_fn, get_pixel_features
+from models.resnet import ResNetBackbone
+from models.corner_models import HeatCorner
+from models.edge_models import HeatEdge
+from models.corner_to_edge import get_infer_edge_pairs
+from utils.geometry_utils import corner_eval
+import numpy as np
+import cv2
+import os
+import scipy.ndimage.filters as filters
+import matplotlib.pyplot as plt
+from metrics.get_metric import compute_metrics, get_recall_and_precision
+import skimage
+import argparse
+
+
+def visualize_cond_generation(positive_pixels, confs, image, save_path, gt_corners=None, prec=None, recall=None,
+ image_masks=None, edges=None, edge_confs=None):
+ image = image.copy() # get a new copy of the original image
+ if confs is not None:
+ viz_confs = confs
+
+ if edges is not None:
+ preds = positive_pixels.astype(int)
+ c_degrees = dict()
+ for edge_i, edge_pair in enumerate(edges):
+ conf = (edge_confs[edge_i] * 2) - 1
+ cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
+ c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
+ c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
+
+ for idx, c in enumerate(positive_pixels):
+ if edges is not None and idx not in c_degrees:
+ continue
+ if confs is None:
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
+ else:
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
+ # if edges is not None:
+ # cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
+ # 0.5, (255, 0, 0), 1, cv2.LINE_AA)
+
+ if gt_corners is not None:
+ for c in gt_corners:
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
+
+ if image_masks is not None:
+ mask_ids = np.where(image_masks == 1)[0]
+ for mask_id in mask_ids:
+ y_idx = mask_id // 64
+ x_idx = (mask_id - y_idx * 64)
+ x_coord = x_idx * 4
+ y_coord = y_idx * 4
+ cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
+
+ # if confs is not None:
+ # cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
+ # 0.5, (255, 255, 0), 1, cv2.LINE_AA)
+ if prec is not None:
+ if isinstance(prec, tuple):
+ cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
+ cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
+ else:
+ cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
+ cv2.imwrite(save_path, image)
+
+
+def corner_nms(preds, confs, image_size):
+ data = np.zeros([image_size, image_size])
+ neighborhood_size = 5
+ threshold = 0
+
+ for i in range(len(preds)):
+ data[preds[i, 1], preds[i, 0]] = confs[i]
+
+ data_max = filters.maximum_filter(data, neighborhood_size)
+ maxima = (data == data_max)
+ data_min = filters.minimum_filter(data, neighborhood_size)
+ diff = ((data_max - data_min) > threshold)
+ maxima[diff == 0] = 0
+
+ results = np.where(maxima > 0)
+ filtered_preds = np.stack([results[1], results[0]], axis=-1)
+
+ new_confs = list()
+ for i, pred in enumerate(filtered_preds):
+ new_confs.append(data[pred[1], pred[0]])
+ new_confs = np.array(new_confs)
+
+ return filtered_preds, new_confs
+
+
+def main(dataset, ckpt_path, image_size, viz_base, save_base, infer_times):
+ ckpt = torch.load(ckpt_path)
+ print('Load from ckpts of epoch {}'.format(ckpt['epoch']))
+ ckpt_args = ckpt['args']
+ if dataset == 'outdoor':
+ data_path = './data/outdoor/cities_dataset'
+ det_path = './data/outdoor/det_final'
+ test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='test', image_size=image_size, rand_aug=False,
+ inference=True)
+ elif dataset == 's3d_floorplan':
+ data_path = './data/s3d_floorplan'
+ test_dataset = S3DFloorplanDataset(data_path, phase='test', rand_aug=False, inference=True)
+ else:
+ raise ValueError('Unknown dataset type: {}'.format(dataset))
+
+ test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0,
+ collate_fn=collate_fn)
+
+ backbone = ResNetBackbone()
+ strides = backbone.strides
+ num_channels = backbone.num_channels
+ backbone = nn.DataParallel(backbone)
+ backbone = backbone.cuda()
+ backbone.eval()
+ corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
+ backbone_num_channels=num_channels)
+ corner_model = nn.DataParallel(corner_model)
+ corner_model = corner_model.cuda()
+ corner_model.eval()
+
+ edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
+ backbone_num_channels=num_channels)
+ edge_model = nn.DataParallel(edge_model)
+ edge_model = edge_model.cuda()
+ edge_model.eval()
+
+ backbone.load_state_dict(ckpt['backbone'])
+ corner_model.load_state_dict(ckpt['corner_model'])
+ edge_model.load_state_dict(ckpt['edge_model'])
+ print('Loaded saved model from {}'.format(ckpt_path))
+
+ if not os.path.exists(viz_base):
+ os.makedirs(viz_base)
+ if not os.path.exists(save_base):
+ os.makedirs(save_base)
+
+ all_prec = list()
+ all_recall = list()
+
+ corner_tp = 0.0
+ corner_fp = 0.0
+ corner_length = 0.0
+ edge_tp = 0.0
+ edge_fp = 0.0
+ edge_length = 0.0
+ region_tp = 0.0
+ region_fp = 0.0
+ region_length = 0.0
+
+ # get the positional encodings for all pixels
+ pixels, pixel_features = get_pixel_features(image_size=image_size)
+
+ for data_i, data in enumerate(test_dataloader):
+ image = data['img'].cuda()
+ img_path = data['img_path'][0]
+ annot_path = data['annot_path'][0]
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
+
+ with torch.no_grad():
+ pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = get_results(image, annot, backbone,
+ corner_model,
+ edge_model,
+ pixels, pixel_features,
+ ckpt_args, infer_times,
+ corner_thresh=0.01,
+ image_size=image_size)
+
+ # viz_image = cv2.imread(img_path)
+ positive_pixels = np.array(list(annot.keys())).round()
+
+ viz_image = data['raw_img'][0].cpu().numpy().transpose(1, 2, 0)
+ viz_image = (viz_image * 255).astype(np.uint8)
+
+ # visualize G.T.
+ gt_path = os.path.join(viz_base, '{}_gt.png'.format(data_i))
+ visualize_cond_generation(positive_pixels, None, viz_image, gt_path, gt_corners=None, image_masks=None)
+
+ if len(pred_corners) > 0:
+ prec, recall = corner_eval(positive_pixels, pred_corners)
+ else:
+ prec = recall = 0
+ all_prec.append(prec)
+ all_recall.append(recall)
+
+ if pred_confs.shape[0] == 0:
+ pred_confs = None
+
+ if image_size != 256:
+ pred_corners_viz = pred_corners * (image_size / 256)
+ else:
+ pred_corners_viz = pred_corners
+ recon_path = os.path.join(viz_base, '{}_pred_corner.png'.format(data_i))
+ visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=prec,
+ recall=recall)
+
+ pred_corners, pred_confs, pos_edges = postprocess_preds(pred_corners, pred_confs, pos_edges)
+
+ pred_data = {
+ 'corners': pred_corners,
+ 'edges': pos_edges,
+ }
+
+ if dataset == 's3d_floorplan':
+ save_filename = os.path.basename(annot_path)
+ save_npy_path = os.path.join(save_base, save_filename)
+ np.save(save_npy_path, pred_data)
+ else:
+ save_results = {
+ 'corners': pred_corners,
+ 'edges': pos_edges,
+ 'image_path': img_path,
+ }
+ save_path = os.path.join(save_base, '{}_results.npy'.format(data_i))
+ np.save(save_path, save_results)
+
+ gt_data = convert_annot(annot)
+
+ score = compute_metrics(gt_data, pred_data)
+
+ edge_recall, edge_prec = get_recall_and_precision(score['edge_tp'], score['edge_fp'], score['edge_length'])
+ region_recall, region_prec = get_recall_and_precision(score['region_tp'], score['region_fp'],
+ score['region_length'])
+ er_recall = (edge_recall, region_recall)
+ er_prec = (edge_prec, region_prec)
+
+ if image_size != 256:
+ pred_corners_viz = pred_corners * (image_size / 256)
+ else:
+ pred_corners_viz = pred_corners
+ recon_path = os.path.join(viz_base, '{}_pred_edge.png'.format(data_i))
+ visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=er_prec,
+ recall=er_recall, edges=pos_edges, edge_confs=edge_confs)
+ corner_tp += score['corner_tp']
+ corner_fp += score['corner_fp']
+ corner_length += score['corner_length']
+ edge_tp += score['edge_tp']
+ edge_fp += score['edge_fp']
+ edge_length += score['edge_length']
+ region_tp += score['region_tp']
+ region_fp += score['region_fp']
+ region_length += score['region_length']
+
+ print('Finish inference for sample No.{}'.format(data_i))
+ avg_prec = np.array(all_prec).mean()
+ avg_recall = np.array(all_recall).mean()
+
+ recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
+ print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
+
+ # edge
+ recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
+ print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
+
+ # region
+ recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
+ print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
+
+ print('Avg prec: {}, Avg recall: {}'.format(avg_prec, avg_recall))
+
+
+def get_results(image, annot, backbone, corner_model, edge_model, pixels, pixel_features,
+ args, infer_times, corner_thresh=0.5, image_size=256):
+ image_feats, feat_mask, all_image_feats = backbone(image)
+ pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
+ preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
+
+ c_outputs = preds_s1
+ # get predicted corners
+ c_outputs_np = c_outputs[0].detach().cpu().numpy()
+ pos_indices = np.where(c_outputs_np >= corner_thresh)
+ pred_corners = pixels[pos_indices]
+ pred_confs = c_outputs_np[pos_indices]
+ pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
+
+ pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
+
+ corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
+ max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0)
+
+ all_pos_ids = set()
+ all_edge_confs = dict()
+
+ for tt in range(infer_times):
+ if tt == 0:
+ gt_values = torch.zeros_like(edge_mask).long()
+ gt_values[:, :] = 2
+
+ # run the edge model
+ s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = edge_model(image_feats, feat_mask,
+ pixel_features,
+ edge_coords, edge_mask,
+ gt_values, corner_nums,
+ max_candidates,
+ True)
+ # do_inference=True)
+
+ num_total = s1_logits.shape[2]
+ num_selected = selected_ids.shape[1]
+ num_filtered = num_total - num_selected
+
+ s1_preds = s1_logits.squeeze().softmax(0)
+ s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
+ s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
+ s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
+ s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
+ s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
+
+ selected_ids = selected_ids.squeeze().detach().cpu().numpy()
+ if tt != infer_times - 1:
+ s2_preds_np = s2_preds_hb_np
+
+ pos_edge_ids = np.where(s2_preds_np >= 0.9)
+ neg_edge_ids = np.where(s2_preds_np <= 0.01)
+ for pos_id in pos_edge_ids[0]:
+ actual_id = selected_ids[pos_id]
+ if gt_values[0, actual_id] != 2:
+ continue
+ all_pos_ids.add(actual_id)
+ all_edge_confs[actual_id] = s2_preds_np[pos_id]
+ gt_values[0, actual_id] = 1
+ for neg_id in neg_edge_ids[0]:
+ actual_id = selected_ids[neg_id]
+ if gt_values[0, actual_id] != 2:
+ continue
+ gt_values[0, actual_id] = 0
+ num_to_pred = (gt_values == 2).sum()
+ if num_to_pred <= num_filtered:
+ break
+ else:
+ s2_preds_np = s2_preds_hb_np
+
+ pos_edge_ids = np.where(s2_preds_np >= 0.5)
+ for pos_id in pos_edge_ids[0]:
+ actual_id = selected_ids[pos_id]
+ if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
+ continue
+ all_pos_ids.add(actual_id)
+ all_edge_confs[actual_id] = s2_preds_np[pos_id]
+
+ # print('Inference time {}'.format(tt+1))
+ pos_edge_ids = list(all_pos_ids)
+ edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
+ pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
+ edge_confs = np.array(edge_confs)
+
+ if image_size != 256:
+ pred_corners = pred_corners / (image_size / 256)
+
+ return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np
+
+
+def postprocess_preds(corners, confs, edges):
+ corner_degrees = dict()
+ for edge_i, edge_pair in enumerate(edges):
+ corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
+ corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
+ good_ids = [i for i in range(len(corners)) if i in corner_degrees]
+ if len(good_ids) == len(corners):
+ return corners, confs, edges
+ else:
+ good_corners = corners[good_ids]
+ good_confs = confs[good_ids]
+ id_mapping = {value: idx for idx, value in enumerate(good_ids)}
+ new_edges = list()
+ for edge_pair in edges:
+ new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
+ new_edges.append(new_pair)
+ new_edges = np.array(new_edges)
+ return good_corners, good_confs, new_edges
+
+
+def process_image(img):
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ img = skimage.img_as_float(img)
+ img = img.transpose((2, 0, 1))
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
+ img = torch.Tensor(img).cuda()
+ img = img.unsqueeze(0)
+ return img
+
+
+def plot_heatmap(results, filename):
+ # generate 2 2d grids for the x & y bounds
+ # import pdb; pdb.set_trace()
+ y, x = np.meshgrid(np.linspace(0, 255, 256), np.linspace(0, 255, 256))
+
+ z = results[::-1, :]
+ # x and y are bounds, so z should be the value *inside* those bounds.
+ # Therefore, remove the last value from the z array.
+ z = z[:-1, :-1]
+
+ fig, ax = plt.subplots()
+
+ c = ax.pcolormesh(y, x, z, cmap='RdBu', vmin=0, vmax=1)
+ # set the limits of the plot to the limits of the data
+ ax.axis([x.min(), x.max(), y.min(), y.max()])
+ fig.colorbar(c, ax=ax)
+ fig.savefig(filename)
+ plt.close()
+
+
+def convert_annot(annot):
+ corners = np.array(list(annot.keys()))
+ corners_mapping = {tuple(c): idx for idx, c in enumerate(corners)}
+ edges = set()
+ for corner, connections in annot.items():
+ idx_c = corners_mapping[tuple(corner)]
+ for other_c in connections:
+ idx_other_c = corners_mapping[tuple(other_c)]
+ if (idx_c, idx_other_c) not in edges and (idx_other_c, idx_c) not in edges:
+ edges.add((idx_c, idx_other_c))
+ edges = np.array(list(edges))
+ gt_data = {
+ 'corners': corners,
+ 'edges': edges
+ }
+ return gt_data
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
+ parser.add_argument('--dataset', default='outdoor',
+ help='the dataset for experiments, outdoor/s3d_floorplan')
+ parser.add_argument('--checkpoint_path', default='',
+ help='path to the checkpoints of the model')
+ parser.add_argument('--image_size', default=256, type=int)
+ parser.add_argument('--viz_base', default='./results/viz',
+ help='path to save the intermediate visualizations')
+ parser.add_argument('--save_base', default='./results/npy',
+ help='path to save the prediction results in npy files')
+ parser.add_argument('--infer_times', default=3, type=int)
+ return parser
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('HEAT inference', parents=[get_args_parser()])
+ args = parser.parse_args()
+ main(args.dataset, args.checkpoint_path, args.image_size, args.viz_base, args.save_base,
+ infer_times=args.infer_times)
diff --git a/metrics/get_metric.py b/metrics/get_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..02bcacde44f6c68932843ccba6ca2494da98e8a8
--- /dev/null
+++ b/metrics/get_metric.py
@@ -0,0 +1,219 @@
+import os
+import numpy as np
+import pickle
+import cv2
+from metrics.new_utils import *
+
+
+class Metric():
+ def calc(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7):
+ ### compute corners precision/recall
+ gts = gt_data['corners']
+ dets = conv_data['corners']
+
+ per_sample_corner_tp = 0.0
+ per_sample_corner_fp = 0.0
+ per_sample_corner_length = gts.shape[0]
+ found = [False] * gts.shape[0]
+ c_det_annot = {}
+
+
+ # for each corner detection
+ for i, det in enumerate(dets):
+ # get closest gt
+ near_gt = [0, 999999.0, (0.0, 0.0)]
+ for k, gt in enumerate(gts):
+ dist = np.linalg.norm(gt - det)
+ if dist < near_gt[1]:
+ near_gt = [k, dist, gt]
+ if near_gt[1] <= thresh and not found[near_gt[0]]:
+ per_sample_corner_tp += 1.0
+ found[near_gt[0]] = True
+ c_det_annot[i] = near_gt[0]
+ else:
+ per_sample_corner_fp += 1.0
+
+ per_corner_score = {
+ 'recall': per_sample_corner_tp / gts.shape[0],
+ 'precision': per_sample_corner_tp / (per_sample_corner_tp + per_sample_corner_fp + 1e-8)
+ }
+
+ ### compute edges precision/recall
+ per_sample_edge_tp = 0.0
+ per_sample_edge_fp = 0.0
+ edge_corner_annots = gt_data['edges']
+ per_sample_edge_length = edge_corner_annots.shape[0]
+
+ false_edge_ids = []
+ match_gt_ids = set()
+
+ for l, e_det in enumerate(conv_data['edges']):
+ c1, c2 = e_det
+
+ # check if corners are mapped
+ if (c1 not in c_det_annot.keys()) or (c2 not in c_det_annot.keys()):
+ per_sample_edge_fp += 1.0
+ false_edge_ids.append(l)
+ continue
+ # check hit
+ c1_prime = c_det_annot[c1]
+ c2_prime = c_det_annot[c2]
+ is_hit = False
+
+ for k, e_annot in enumerate(edge_corner_annots):
+ c3, c4 = e_annot
+ if ((c1_prime == c3) and (c2_prime == c4)) or ((c1_prime == c4) and (c2_prime == c3)):
+ is_hit = True
+ match_gt_ids.add(k)
+ break
+
+ # hit
+ if is_hit:
+ per_sample_edge_tp += 1.0
+ else:
+ per_sample_edge_fp += 1.0
+ false_edge_ids.append(l)
+
+ per_edge_score = {
+ 'recall': per_sample_edge_tp / edge_corner_annots.shape[0],
+ 'precision': per_sample_edge_tp / (per_sample_edge_tp + per_sample_edge_fp + 1e-8)
+ }
+
+ # computer regions precision/recall
+ conv_mask = render(corners=conv_data['corners'], edges=conv_data['edges'], render_pad=0, edge_linewidth=1)[0]
+ conv_mask = 1 - conv_mask
+ conv_mask = conv_mask.astype(np.uint8)
+ labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
+
+ #cv2.imwrite('mask-pred.png', region_mask.astype(np.uint8) * 20)
+
+ background_label = region_mask[0, 0]
+ all_conv_masks = []
+ for region_i in range(1, labels):
+ if region_i == background_label:
+ continue
+ the_region = region_mask == region_i
+ if the_region.sum() < 20:
+ continue
+ all_conv_masks.append(the_region)
+
+ gt_mask = render(corners=gt_data['corners'], edges=gt_data['edges'], render_pad=0, edge_linewidth=1)[0]
+ gt_mask = 1 - gt_mask
+ gt_mask = gt_mask.astype(np.uint8)
+ labels, region_mask = cv2.connectedComponents(gt_mask, connectivity=4)
+
+ #cv2.imwrite('mask-gt.png', region_mask.astype(np.uint8) * 20)
+
+ background_label = region_mask[0, 0]
+ all_gt_masks = []
+ for region_i in range(1, labels):
+ if region_i == background_label:
+ continue
+ the_region = region_mask == region_i
+ if the_region.sum() < 20:
+ continue
+ all_gt_masks.append(the_region)
+
+ per_sample_region_tp = 0.0
+ per_sample_region_fp = 0.0
+ per_sample_region_length = len(all_gt_masks)
+ found = [False] * len(all_gt_masks)
+ for i, r_det in enumerate(all_conv_masks):
+ # gt closest gt
+ near_gt = [0, 0, None]
+ for k, r_gt in enumerate(all_gt_masks):
+ iou = np.logical_and(r_gt, r_det).sum() / float(np.logical_or(r_gt, r_det).sum())
+ if iou > near_gt[1]:
+ near_gt = [k, iou, r_gt]
+ if near_gt[1] >= iou_thresh and not found[near_gt[0]]:
+ per_sample_region_tp += 1.0
+ found[near_gt[0]] = True
+ else:
+ per_sample_region_fp += 1.0
+
+ per_region_score = {
+ 'recall': per_sample_region_tp / len(all_gt_masks),
+ 'precision': per_sample_region_tp / (per_sample_region_tp + per_sample_region_fp + 1e-8)
+ }
+
+ return {
+ 'corner_tp': per_sample_corner_tp,
+ 'corner_fp': per_sample_corner_fp,
+ 'corner_length': per_sample_corner_length,
+ 'edge_tp': per_sample_edge_tp,
+ 'edge_fp': per_sample_edge_fp,
+ 'edge_length': per_sample_edge_length,
+ 'region_tp': per_sample_region_tp,
+ 'region_fp': per_sample_region_fp,
+ 'region_length': per_sample_region_length,
+ 'corner': per_corner_score,
+ 'edge': per_edge_score,
+ 'region': per_region_score
+ }
+
+
+def compute_metrics(gt_data, pred_data):
+ metric = Metric()
+ score = metric.calc(gt_data, pred_data)
+ return score
+
+
+def get_recall_and_precision(tp, fp, length):
+ recall = tp / (length + 1e-8)
+ precision = tp / (tp + fp + 1e-8)
+ return recall, precision
+
+
+if __name__ == '__main__':
+ base_path = './'
+ gt_datapath = '../data/cities_dataset/annot'
+ metric = Metric()
+ corner_tp = 0.0
+ corner_fp = 0.0
+ corner_length = 0.0
+ edge_tp = 0.0
+ edge_fp = 0.0
+ edge_length = 0.0
+ region_tp = 0.0
+ region_fp = 0.0
+ region_length = 0.0
+ for file_name in os.listdir(base_path):
+ if len(file_name) < 10:
+ continue
+ f = open(os.path.join(base_path, file_name), 'rb')
+ gt_data = np.load(os.path.join(gt_datapath, file_name + '.npy'), allow_pickle=True).tolist()
+ candidate = pickle.load(f)
+ conv_corners = candidate.graph.getCornersArray()
+ conv_edges = candidate.graph.getEdgesArray()
+ conv_data = {'corners': conv_corners, 'edges': conv_edges}
+ score = metric.calc(gt_data, conv_data)
+ corner_tp += score['corner_tp']
+ corner_fp += score['corner_fp']
+ corner_length += score['corner_length']
+ edge_tp += score['edge_tp']
+ edge_fp += score['edge_fp']
+ edge_length += score['edge_length']
+ region_tp += score['region_tp']
+ region_fp += score['region_fp']
+ region_length += score['region_length']
+
+ f = open(os.path.join(base_path, 'score.txt'), 'w')
+ # corner
+ recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
+ print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
+ f.write('corners - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
+
+ # edge
+ recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
+ print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
+ f.write('edges - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
+
+ # region
+ recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
+ print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
+ f.write('regions - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
+
+ f.close()
diff --git a/metrics/new_utils.py b/metrics/new_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b713ac47d687af7e6d8ad5c0957f80a4a5db52
--- /dev/null
+++ b/metrics/new_utils.py
@@ -0,0 +1,2100 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import cv2
+import threading
+import os
+import skimage
+import random
+import time
+
+TWO_CORNER_MINIMUM_DISTANCE = 5
+SAFE_NUM = 3
+score_weights = (1., 2., 100.)
+
+
+#########################################################################################
+################################# General Functions #####################################
+#########################################################################################
+def swap_two_corner_place(corners, edges, id1, id2):
+ for edge_i in range(edges.shape[0]):
+ if edges[edge_i, 0] == id1:
+ edges[edge_i, 0] = id2
+ elif edges[edge_i, 0] == id2:
+ edges[edge_i, 0] = id1
+ if edges[edge_i, 1] == id1:
+ edges[edge_i, 1] = id2
+ elif edges[edge_i, 1] == id2:
+ edges[edge_i, 1] = id1
+ temp = corners[id1].copy()
+ corners[id1] = corners[id2]
+ corners[id2] = temp
+ return corners, edges
+
+
+def get_neighbor_corner_id(corner_id, edges):
+ where = np.where(edges == corner_id)
+ return edges[where[0], 1 - where[1]]
+
+
+def swap_two_edge_place(edges, id1, id2):
+ temp = edges[id1].copy()
+ edges[id1] = edges[id2]
+ edges[id2] = temp
+ return edges
+
+
+def degree_of_three_corners(cornerA, cornerB, cornerM):
+ # cornerM is middle corner
+ AM_length = l2_distance(cornerA, cornerM)
+ BM_length = l2_distance(cornerB, cornerM)
+ dot = np.dot((cornerA[0] - cornerM[0], cornerA[1] - cornerM[1]),
+ (cornerB[0] - cornerM[0], cornerB[1] - cornerM[1]))
+ cos = dot / (AM_length + 1e-8) / (BM_length + 1e-8)
+ cos = min(1, max(-1, cos))
+ degree = np.arccos(cos)
+ return degree / np.pi * 180
+
+
+def sort_graph(corners, edges):
+ corners = corners.copy()
+ edges = edges.copy()
+ for corner_i in range(corners.shape[0]):
+ min_id = -1
+ min_pos = corners[corner_i]
+ for corner_j in range(corner_i + 1, corners.shape[0]):
+ if (corners[corner_j, 0] < min_pos[0]) or \
+ (corners[corner_j, 0] == min_pos[0] and corners[corner_j, 1] < min_pos[1]):
+ min_pos = corners[corner_j]
+ min_id = corner_j
+ if min_id != -1:
+ corners, edges = swap_two_corner_place(corners, edges, corner_i, min_id)
+
+ for edge_i in range(edges.shape[0]):
+ if edges[edge_i, 0] > edges[edge_i, 1]:
+ temp = edges[edge_i, 0]
+ edges[edge_i, 0] = edges[edge_i, 1]
+ edges[edge_i, 1] = temp
+
+ for edge_i in range(edges.shape[0]):
+ min_id = -1
+ min_pos = edges[edge_i]
+ for edge_j in range(edge_i + 1, edges.shape[0]):
+ if (edges[edge_j, 0] < min_pos[0]) or \
+ (edges[edge_j, 0] == min_pos[0] and edges[edge_j, 1] < min_pos[1]):
+ min_pos = edges[edge_j]
+ min_id = edge_j
+ if min_id != -1:
+ edges = swap_two_edge_place(edges, edge_i, min_id)
+
+ return corners, edges
+
+
+def IOU(maskA, maskB):
+ return np.logical_and(maskA, maskB).sum() / np.logical_or(maskA, maskB).sum()
+
+
+def render(corners, edges, render_pad=0, edge_linewidth=2, corner_size=3, scale=1.):
+ size = int(256 * scale)
+ mask = np.ones((2, size, size)) * render_pad
+
+ corners = np.round(corners.copy() * scale).astype(np.int)
+ for edge_i in range(edges.shape[0]):
+ a = edges[edge_i, 0]
+ b = edges[edge_i, 1]
+ mask[0] = cv2.line(mask[0], (int(corners[a, 1]), int(corners[a, 0])),
+ (int(corners[b, 1]), int(corners[b, 0])), 1.0, thickness=edge_linewidth)
+ for corner_i in range(corners.shape[0]):
+ mask[1] = cv2.circle(mask[1], (int(corners[corner_i, 1]), int(corners[corner_i, 0])), corner_size, 1.0, -1)
+
+ return mask
+
+
+def patch_samples(edge_num, batch_size):
+ num = edge_num // batch_size
+ patchs = []
+ for i in range(num):
+ patchs.append([i * batch_size + j for j in range(batch_size)])
+
+ if edge_num % batch_size != 0:
+ patchs.append([j for j in range(batch_size * num, edge_num)])
+
+ return patchs
+
+
+def l2_distance(x1, x2):
+ return np.sqrt((x1[0] - x2[0]) ** 2 + (x1[1] - x2[1]) ** 2)
+
+
+def triangle_region(A, B, C):
+ l1 = np.linalg.norm(np.array(A) - np.array(B))
+ l2 = np.linalg.norm(np.array(A) - np.array(C))
+ l3 = np.linalg.norm(np.array(B) - np.array(C))
+ p = (l1 + l2 + l3) / 2
+ area = np.sqrt(np.abs(p * (p - l1) * (p - l2) * (p - l3)))
+ return area
+
+
+def remove_intersection_and_duplicate(corners, edges, name):
+ over_all_flag = False
+ ori_corners = corners.copy()
+ ori_edges = edges.copy()
+ while True:
+ flag = False
+ for edge_i in range(edges.shape[0]):
+ for edge_j in range(edge_i + 1, edges.shape[0]):
+ corner11 = corners[edges[edge_i, 0]]
+ corner12 = corners[edges[edge_i, 1]]
+ corner21 = corners[edges[edge_j, 0]]
+ corner22 = corners[edges[edge_j, 1]]
+
+ y1 = corner11[0]
+ x1 = corner11[1]
+ y2 = corner12[0]
+ x2 = corner12[1]
+ a1 = y1 - y2
+ b1 = x2 - x1
+ c1 = x1 * y2 - x2 * y1
+ flag1 = (a1 * corner21[1] + b1 * corner21[0] + c1) * (a1 * corner22[1] + b1 * corner22[0] + c1)
+
+ y1 = corner21[0]
+ x1 = corner21[1]
+ y2 = corner22[0]
+ x2 = corner22[1]
+ a2 = y1 - y2
+ b2 = x2 - x1
+ c2 = x1 * y2 - x2 * y1
+ flag2 = (a2 * corner11[1] + b2 * corner11[0] + c2) * (a2 * corner12[1] + b2 * corner12[0] + c2)
+
+ if flag1 < -1e-5 and flag2 < -1e-5:
+ # intersection!
+ over_all_flag = True
+ flag = True
+
+ new_x = (c2 * b1 - c1 * b2) / (a1 * b2 - a2 * b1)
+ new_y = (a2 * c1 - a1 * c2) / (a1 * b2 - a2 * b1)
+
+ temp_d = 3
+ temp_id = -1
+ if l2_distance((new_y, new_x), corner11) < temp_d:
+ temp_id = edges[edge_i, 0]
+ temp_d = l2_distance((new_y, new_x), corner11)
+ if l2_distance((new_y, new_x), corner12) < temp_d:
+ temp_id = edges[edge_i, 1]
+ temp_d = l2_distance((new_y, new_x), corner12)
+ if l2_distance((new_y, new_x), corner21) < temp_d:
+ temp_id = edges[edge_j, 0]
+ temp_d = l2_distance((new_y, new_x), corner21)
+ if l2_distance((new_y, new_x), corner22) < temp_d:
+ temp_id = edges[edge_j, 1]
+ temp_d = l2_distance((new_y, new_x), corner22)
+ if temp_id != -1:
+ if edges[edge_i, 0] != temp_id and edges[edge_i, 1] != temp_id:
+ tt = edges[edge_i, 0]
+ edges[edge_i, 0] = temp_id
+ edges = np.append(edges, np.array([(temp_id, tt)]), 0)
+ if edges[edge_j, 0] != temp_id and edges[edge_j, 1] != temp_id:
+ tt = edges[edge_j, 0]
+ edges[edge_j, 0] = temp_id
+ edges = np.append(edges, np.array([(temp_id, tt)]), 0)
+ else:
+ corners = np.append(corners, np.array([(new_y, new_x)]), 0)
+ edge_id1 = edges[edge_i, 1]
+ edge_id2 = edges[edge_j, 1]
+ edges[edge_i, 1] = corners.shape[0] - 1
+ edges[edge_j, 1] = corners.shape[0] - 1
+ edges = np.append(edges, np.array([(edge_id1, corners.shape[0] - 1)]), 0)
+ edges = np.append(edges, np.array([(edge_id2, corners.shape[0] - 1)]), 0)
+ break
+ if flag:
+ break
+ if flag:
+ continue
+ break
+
+ # remove duplicate and zero degree
+ graph = Graph(np.round(corners), edges)
+ for corner_i in reversed(range(len(graph.getCorners()))):
+ corner_ele1 = graph.getCorners()[corner_i]
+ for corner_j in reversed(range(corner_i)):
+ corner_ele2 = graph.getCorners()[corner_j]
+ if l2_distance(corner_ele1.x, corner_ele2.x) < 3:
+ connected_edge = graph.getEdgeConnected(corner_ele1)
+ for edge_ele in connected_edge:
+ if edge_ele.x[0] == corner_ele1:
+ another = edge_ele.x[1]
+ else:
+ another = edge_ele.x[0]
+ if another == corner_ele2:
+ graph.remove(edge_ele)
+ edge_ele.x = (another, corner_ele2)
+ graph.remove(corner_ele1)
+ for corner_ele in graph.getCorners():
+ if graph.getCornerDegree(corner_ele) == 0:
+ graph.remove(corner_ele)
+
+ corners = graph.getCornersArray()
+ edges = graph.getEdgesArray()
+ # if over_all_flag:
+ # plt.subplot(121)
+ # ori = render(ori_corners, ori_edges, edge_linewidth=1, corner_size=1)
+ # temp = np.concatenate((ori.transpose((1,2,0)), np.zeros((ori.shape[1],ori.shape[2],1))),2)
+ # plt.imshow(temp)
+ # plt.subplot(122)
+ # new_ = render(corners, edges, edge_linewidth=1, corner_size=1)
+ # temp = np.concatenate((new_.transpose((1,2,0)), np.zeros((new_.shape[1],new_.shape[2],1))),2)
+ # plt.imshow(temp)
+ # plt.show()
+
+ return corners, edges
+
+
+def get_two_edge_intersection_location(corner11, corner12, corner21, corner22):
+ y1 = corner11[0]
+ x1 = corner11[1]
+ y2 = corner12[0]
+ x2 = corner12[1]
+ a1 = y1 - y2
+ b1 = x2 - x1
+ c1 = x1 * y2 - x2 * y1
+
+ y1 = corner21[0]
+ x1 = corner21[1]
+ y2 = corner22[0]
+ x2 = corner22[1]
+ a2 = y1 - y2
+ b2 = x2 - x1
+ c2 = x1 * y2 - x2 * y1
+
+ l = a1 * b2 - a2 * b1
+ if l == 0:
+ l = 1e-5
+
+ new_x = (c2 * b1 - c1 * b2) / l
+ new_y = (a2 * c1 - a1 * c2) / l
+
+ return round(new_y), round(new_x)
+
+
+def get_distance_of_corner_and_edge(corner1, corner2, corner):
+ x = corner[0]
+ y = corner[1]
+ x1 = corner1[0]
+ y1 = corner1[1]
+ x2 = corner2[0]
+ y2 = corner2[1]
+
+ cross = (x2 - x1) * (x - x1) + (y2 - y1) * (y - y1)
+ if cross <= 0:
+ # dist to corner1
+ return np.linalg.norm((x - x1, y - y1))
+
+ d2 = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
+ if cross >= d2:
+ # dist to corner2
+ return np.linalg.norm((x - x2, y - y2))
+
+ r = cross / d2
+ px = x1 + (x2 - x1) * r
+ py = y1 + (y2 - y1) * r
+ return np.linalg.norm((x - px, y - py))
+
+
+#########################################################################################
+################################# Dataset Functions #####################################
+#########################################################################################
+def EuclideanDistance(A, B):
+ BT = B.transpose()
+ vecProd = np.dot(A, BT)
+
+ SqA = A ** 2
+ sumSqA = np.matrix(np.sum(SqA, axis=1))
+ sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
+
+ SqB = B ** 2
+ sumSqB = np.sum(SqB, axis=1)
+ sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
+ SqED = sumSqBEx + sumSqAEx - 2 * vecProd
+ SqED[SqED < 0] = 0.0
+ ED = np.sqrt(SqED)
+ return ED
+
+
+def samedirection(conv_corner_id, gt_corner_id, conv_corners, gt_corners, conv_edges, gt_edges):
+ # degree
+ if np.where(conv_edges == conv_corner_id)[0].shape[0] != np.where(gt_edges == gt_corner_id)[0].shape[0]:
+ return False
+
+ # direction
+ place = np.where(conv_edges == conv_corner_id)
+ neighbor_id = conv_edges[place[0], 1 - place[1]]
+
+ distance = conv_corners[conv_corner_id] - conv_corners[neighbor_id]
+ direction = np.arctan2(distance[:, 0], distance[:, 1]) * 180 / np.pi / 15
+ direction = (direction + 24) % 24
+
+ conv_dir = np.sort(direction)
+
+ place = np.where(gt_edges == gt_corner_id)
+ neighbor_id = gt_edges[place[0], 1 - place[1]]
+
+ distance = gt_corners[gt_corner_id] - gt_corners[neighbor_id]
+ direction = np.arctan2(distance[:, 0], distance[:, 1]) * 180 / np.pi / 15
+ direction = (direction + 24) % 24
+
+ gt_dir = np.sort(direction)
+
+ conv_dir = list(conv_dir)
+ gt_dir = list(gt_dir)
+ for angle in gt_dir:
+ temp = sorted(conv_dir, key=lambda x: min(np.abs(x - angle), 24 - np.abs(x - angle)))
+ if min(np.abs(temp[0] - angle), 24 - np.abs(temp[0] - angle)) <= 1.3:
+ conv_dir.remove(temp[0])
+ else:
+ return False
+ return True
+
+
+def simplify_gt(gt_match_location, gt_corner, gt_edge):
+ graph = Graph(np.round(gt_corner), gt_edge)
+ for idx, corner in enumerate(graph.getCorners()):
+ # use score to store the matching info
+ corner.store_score(gt_match_location[idx])
+
+ for idx, corner in enumerate(graph.getCorners()):
+ if corner.get_score() is None:
+ connected_edges = graph.getEdgeConnected(corner)
+ neighbor_corners = []
+ for edge in connected_edges:
+ if edge.x[0] != corner:
+ neighbor_corners.append(edge.x[0])
+ continue
+ if edge.x[1] != corner:
+ neighbor_corners.append(edge.x[1])
+ continue
+ raise BaseException()
+ neighbor_corners = sorted(neighbor_corners, key=lambda ele: l2_distance(ele.x, corner.x))
+ for neighbor_ele in neighbor_corners:
+ if l2_distance(neighbor_ele.x, corner.x) > 8:
+ break
+ if neighbor_ele.get_score() is None:
+ continue
+ # find the suitable neighbor that replace corner
+ for ele in neighbor_corners:
+ if ele == neighbor_ele:
+ continue
+ graph.add_edge(ele, neighbor_ele)
+ neighbor_ele.x = (0.7 * neighbor_ele.x[0] + 0.3 * corner.x[0],
+ 0.7 * neighbor_ele.x[1] + 0.3 * corner.x[1])
+ graph.remove(corner)
+ break
+ return graph.getCornersArray(), graph.getEdgesArray()
+
+
+def get_wrong_corners(corners, gt_corners, edges, gt_edges):
+ corners = corners.copy()
+ gt_corners = gt_corners.copy()
+ edges = edges.copy()
+ gt_edges = gt_edges.copy()
+ dist_matrix = EuclideanDistance(gt_corners, corners)
+ assigned_id = set()
+ gt_match_same_degree = []
+ gt_match_location = []
+ for gt_i in range(gt_corners.shape[0]):
+ sort_id = np.argsort(dist_matrix[gt_i]).__array__()[0]
+ flag = True
+ for id_ in sort_id:
+ if dist_matrix[gt_i, id_] > 7:
+ break
+ temete = samedirection(id_, gt_i, corners, gt_corners, edges, gt_edges)
+ if temete == False:
+ break
+ elif id_ not in assigned_id:
+ assigned_id.add(id_)
+ gt_match_same_degree.append(id_)
+ flag = False
+ break
+ if flag:
+ gt_match_same_degree.append(None)
+
+ matched = []
+ gt_match_location = [None for _ in range(gt_corners.shape[0])]
+ for gt_i in sorted(list(range(gt_corners.shape[0])), key=lambda i: np.min(dist_matrix[i])):
+ sort_id = np.argsort(dist_matrix[gt_i]).__array__()[0]
+ if dist_matrix[gt_i, sort_id[0]] > 7:
+ gt_match_location[gt_i] = None
+ else:
+ for c_i in sort_id:
+ if c_i in matched:
+ continue
+ if dist_matrix[gt_i, c_i] > 7:
+ gt_match_location[gt_i] = None
+ break
+ else:
+ gt_match_location[gt_i] = c_i
+ matched.append(c_i)
+ break
+
+ return set(range(corners.shape[0])) - assigned_id, gt_match_same_degree, gt_match_location
+
+
+def get_wrong_edges(corners, gt_corners, edges, gt_edges, gt_match):
+ edges = edges.copy()
+ gt_edges = gt_edges.copy()
+
+ all_possible_good_edges = []
+ for edge_i in range(gt_edges.shape[0]):
+ if gt_match[gt_edges[edge_i, 0]] is None or gt_match[gt_edges[edge_i, 1]] is None:
+ continue
+ all_possible_good_edges.append((gt_match[gt_edges[edge_i, 0]], gt_match[gt_edges[edge_i, 1]]))
+ false_edge_id = []
+ for edge_i in range(edges.shape[0]):
+ id1 = edges[edge_i][0]
+ id2 = edges[edge_i][1]
+ if (id1, id2) not in all_possible_good_edges and (id2, id1) not in all_possible_good_edges:
+ false_edge_id.append(edge_i)
+ continue
+
+ return false_edge_id
+
+
+def get_corner_bin_map(corners, corner_list_for_each_bin, bin_size=10):
+ bin_map = np.zeros((bin_size, 256, 256))
+ for bin_i in range(bin_size):
+ bin_map[bin_i] = render(corners[corner_list_for_each_bin[bin_i]], np.array([]), render_pad=0)[1]
+ return bin_map
+
+
+#########################################################################################
+################################ Searching Functions ####################################
+#########################################################################################
+def visualization(candidate, show=True):
+ corners = candidate.graph.getCornersArray()
+ edges = candidate.graph.getEdgesArray()
+ mask = render(corners, edges)
+ mask = np.transpose(np.concatenate((mask, np.zeros((1, 256, 256))), 0), (1, 2, 0))
+ plt.imshow(mask)
+ if show:
+ plt.show()
+
+
+def check_intersection(edge1, edge2):
+ corner11 = edge1.x[0].x
+ corner12 = edge1.x[1].x
+ corner21 = edge2.x[0].x
+ corner22 = edge2.x[1].x
+
+ y1 = corner11[0]
+ x1 = corner11[1]
+ y2 = corner12[0]
+ x2 = corner12[1]
+ a = y1 - y2
+ b = x2 - x1
+ c = x1 * y2 - x2 * y1
+ flag1 = (a * corner21[1] + b * corner21[0] + c) * (a * corner22[1] + b * corner22[0] + c)
+
+ y1 = corner21[0]
+ x1 = corner21[1]
+ y2 = corner22[0]
+ x2 = corner22[1]
+ a = y1 - y2
+ b = x2 - x1
+ c = x1 * y2 - x2 * y1
+ flag2 = (a * corner11[1] + b * corner11[0] + c) * (a * corner12[1] + b * corner12[0] + c)
+
+ if flag1 < -1e-6 and flag2 < -1e-6:
+ return True
+
+ return False
+
+
+def adding_a_corner_by_triangle_operation(candidate):
+ new_candidates = []
+ name = candidate.name
+ gt_mask = region_cache.get_region(name)
+ gt_mask = gt_mask > 0.4
+ gt_mask_grow = cv2.dilate(gt_mask.astype(np.float64), np.ones((3, 3), np.uint8), iterations=6) > 0
+
+ # get the current candidate region mask
+ conv_mask = render(corners=candidate.graph.getCornersArray(), edges=candidate.graph.getEdgesArray(),
+ render_pad=0, edge_linewidth=1)[0]
+ conv_mask = 1 - conv_mask
+ conv_mask = conv_mask.astype(np.uint8)
+ labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
+
+ background_label = region_mask[0, 0]
+ all_masks = []
+ for region_i in range(1, labels):
+ if region_i == background_label:
+ continue
+ the_region = region_mask == region_i
+ if the_region.sum() < 20:
+ continue
+ all_masks.append(the_region)
+
+ candidate_mask = (np.sum(all_masks, 0) + (1 - conv_mask)) > 0
+
+ final_mask = np.logical_xor(gt_mask_grow, np.logical_and(candidate_mask, gt_mask_grow))
+
+ for corner_i in range(random.randint(0, 16), 256, 16):
+ for corner_j in range(random.randint(0, 16), 256, 16):
+ if candidate.addable((corner_i, corner_j)):
+ if final_mask[corner_i, corner_j] == True: # inside the region
+ new_corner = Element((corner_i, corner_j))
+ new_candidate = candidate.generate_new_candidate_add_a_corner(new_corner)
+ new_graph = new_candidate.graph
+ corners = new_graph.getCorners()
+
+ # find two suitable existed corners to make into a triangle (no intersection and no colinear)
+ for id_A in range(len(corners)):
+ ele_A = corners[id_A]
+ if ele_A == new_corner:
+ continue
+ for id_B in range(id_A + 1, len(corners)):
+ ele_B = corners[id_B]
+ if ele_B == new_corner:
+ continue
+ if new_graph.has_edge(new_corner, ele_A) is not None:
+ raise BaseException('should not have edge in this case')
+ if new_graph.has_edge(new_corner, ele_B) is not None:
+ raise BaseException('should not have edge in this case')
+ temp_edge1 = Element((new_corner, ele_A))
+ temp_edge2 = Element((new_corner, ele_B))
+
+ # check if addable
+ if new_candidate.addable(temp_edge1) is False:
+ continue
+ if new_candidate.addable(temp_edge2) is False:
+ continue
+
+ # avoid intersection
+ if new_graph.checkIntersectionEdge(temp_edge1):
+ continue
+ if new_graph.checkIntersectionEdge(temp_edge2):
+ continue
+
+ # avoid too small triangle
+ if triangle_region(new_corner.x, ele_A.x, ele_B.x) < 20:
+ continue
+
+ ### avoid colinear edge (only when fold case)
+ # for edge1
+ neighbor_edges = new_graph.getEdgeConnected(temp_edge1)
+ flag_ = True
+ for neighbor in neighbor_edges:
+ if new_corner in neighbor.x:
+ raise BaseException('new corner should not in any edge')
+ elif ele_A in neighbor.x:
+ shared_corner = ele_A
+ else:
+ raise BaseException('error.')
+ two_neighbor = {neighbor.x[0], neighbor.x[1], ele_A, new_corner}
+ two_neighbor.remove(shared_corner)
+ assert len(two_neighbor) == 2
+ two_neighbor = tuple(two_neighbor)
+
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
+ line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
+ cos = min(1, max(-1, cos))
+ if np.arccos(cos) < np.pi / 9: # 20 degree
+ flag_ = False
+ break
+ if flag_ is False:
+ continue
+ # for edge2
+ neighbor_edges = new_graph.getEdgeConnected(temp_edge2)
+ flag_ = True
+ for neighbor in neighbor_edges:
+ if new_corner in neighbor.x:
+ raise BaseException('new corner should not in any edge')
+ elif ele_B in neighbor.x:
+ shared_corner = ele_B
+ else:
+ raise BaseException('error.')
+ two_neighbor = {neighbor.x[0], neighbor.x[1], ele_B, new_corner}
+ two_neighbor.remove(shared_corner)
+ assert len(two_neighbor) == 2
+ two_neighbor = tuple(two_neighbor)
+
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
+ line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
+ cos = min(1, max(-1, cos))
+ if np.arccos(cos) < np.pi / 9: # 20 degree
+ flag_ = False
+ break
+ if flag_ is False:
+ continue
+
+ # make new candidate
+ try:
+ new_ = new_candidate.generate_new_candidate_add_an_edge(new_corner, ele_A)
+ new_ = new_.generate_new_candidate_add_an_edge(new_corner, ele_B)
+ new_candidates.append(new_)
+ except:
+ continue
+ # plt.subplot(151)
+ # visualization(candidate, show=False)
+ # plt.subplot(152)
+ # plt.imshow(final_mask)
+ # plt.subplot(153)
+ # plt.imshow(candidate_mask)
+ # plt.subplot(154)
+ # plt.imshow(gt_mask_grow)
+ # plt.subplot(155)
+ # visualization(new_, show=False)
+ # plt.show()
+
+ return new_candidates
+
+
+def adding_an_edge_from_new_corner_operation(candidate):
+ new_candidates = []
+ name = candidate.name
+ gt_mask = region_cache.get_region(name)
+ gt_mask = gt_mask > 0.4
+ gt_mask_grow = cv2.dilate(gt_mask.astype(np.float64), np.ones((3, 3), np.uint8), iterations=6) > 0
+
+ # get the current candidate region mask
+ conv_mask = render(corners=candidate.graph.getCornersArray(), edges=candidate.graph.getEdgesArray(),
+ render_pad=0, edge_linewidth=1)[0]
+ conv_mask = 1 - conv_mask
+ conv_mask = conv_mask.astype(np.uint8)
+ labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
+ background_label = region_mask[0, 0]
+ all_masks = []
+ for region_i in range(1, labels):
+ if region_i == background_label:
+ continue
+ the_region = region_mask == region_i
+ if the_region.sum() < 20:
+ continue
+ all_masks.append(the_region)
+ candidate_mask = (np.sum(all_masks, 0) + (1 - conv_mask)) > 0
+
+ final_mask = np.logical_xor(gt_mask_grow, np.logical_and(candidate_mask, gt_mask_grow))
+ for corner_i in range(random.randint(0, 16), 256, 16):
+ for corner_j in range(random.randint(0, 16), 256, 16):
+ if candidate.addable((corner_i, corner_j)):
+ if final_mask[corner_i, corner_j] == True:
+ # inside the region
+ new_corner = Element((corner_i, corner_j))
+ new_candidate = candidate.generate_new_candidate_add_a_corner(new_corner)
+ new_graph = new_candidate.graph
+ corners = new_graph.getCorners()
+
+ # find a suitable existed corner that can make
+ # a new edge with new_corner (no intersection and colinear)
+ for corner_ele in corners:
+ if corner_ele == new_corner:
+ continue
+ if new_graph.has_edge(new_corner, corner_ele) is not None:
+ raise BaseException('should not have edge in this case')
+ temp_edge = Element((new_corner, corner_ele))
+
+ # check if addable
+ if new_candidate.addable(temp_edge) is False:
+ continue
+
+ # avoid intersection
+ if new_graph.checkIntersectionEdge(temp_edge):
+ continue
+
+ # avoid colinear edge
+ neighbor_edges = new_graph.getEdgeConnected(temp_edge)
+ flag_ = True
+ for neighbor in neighbor_edges:
+ if new_corner in neighbor.x:
+ raise BaseException('new corner should not in any edge')
+ elif corner_ele in neighbor.x:
+ shared_corner = corner_ele
+ else:
+ raise BaseException('error.')
+ two_neighbor = {neighbor.x[0], neighbor.x[1], corner_ele, new_corner}
+ two_neighbor.remove(shared_corner)
+ assert len(two_neighbor) == 2
+ two_neighbor = tuple(two_neighbor)
+
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
+ line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
+ cos = min(1, max(-1, cos))
+ if np.arccos(cos) < np.pi / 9: # 20 degree
+ flag_ = False
+ break
+ if flag_ is False:
+ continue
+
+ # make new candidate
+ try:
+ new_ = new_candidate.generate_new_candidate_add_an_edge(new_corner, corner_ele)
+ new_candidates.append(new_)
+ except:
+ continue
+
+ return new_candidates
+
+
+def removing_a_corner_operation(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ corners = graph.getCorners()
+ for the_corner in corners:
+ if candidate.removable(the_corner):
+ try:
+ new_ = candidate.generate_new_candidate_remove_a_corner(the_corner)
+ new_candidates.append(new_)
+ except:
+ continue
+
+ return new_candidates
+
+
+def removing_a_colinear_corner_operation(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ corners = graph.getCorners()
+ for the_corner in corners:
+ if candidate.removable(the_corner): # NO NEED TO CHECK IF COLINEAR and graph.checkColinearCorner(the_corner):
+ try:
+ new_ = candidate.generate_new_candidate_remove_a_colinear_corner(the_corner)
+
+ if new_.graph.checkIntersectionEdge():
+ continue
+ new_candidates.append(new_)
+ except:
+ continue
+
+ return new_candidates
+
+
+def adding_an_edge_operation(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ corners = graph.getCorners()
+ for corner_i in range(len(corners)):
+ cornerA = corners[corner_i]
+ for corner_j in range(corner_i + 1, len(corners)):
+ cornerB = corners[corner_j]
+ if graph.has_edge(cornerA, cornerB) is not None:
+ continue
+
+ temp_edge = Element((cornerA, cornerB))
+ # check if addable (not in existed_before dict)
+ if candidate.addable(temp_edge) is False:
+ continue
+
+ if graph.checkIntersectionEdge(temp_edge):
+ continue
+
+ # avoid adding a colinear edge
+ neighbor_edges = graph.getEdgeConnected(temp_edge)
+ flag_ = True
+ for neighbor in neighbor_edges:
+ if cornerA in neighbor.x:
+ shared_corner = cornerA
+ elif cornerB in neighbor.x:
+ shared_corner = cornerB
+ else:
+ raise BaseException('error.')
+ two_neighbor = {neighbor.x[0], neighbor.x[1], cornerA, cornerB}
+ two_neighbor.remove(shared_corner)
+ assert len(two_neighbor) == 2
+ two_neighbor = tuple(two_neighbor)
+
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
+ line2 = np.array(two_neighbor[1].x) - np.array(shared_corner.x)
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
+ cos = min(1, max(-1, cos))
+ if np.arccos(cos) < np.pi / 18 or np.arccos(cos) > np.pi - np.pi / 18: # 10 degree
+ flag_ = False
+ break
+ if flag_ is False:
+ continue
+
+ # make new candidate
+ try:
+ new_ = candidate.generate_new_candidate_add_an_edge(cornerA, cornerB)
+ new_candidates.append(new_)
+ except:
+ continue
+
+ return new_candidates
+
+
+def removing_an_edge_operation(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ edges = graph.getEdges()
+ for edge_ele in edges:
+ if candidate.removable(edge_ele):
+ try:
+ new_ = candidate.generate_new_candidate_remove_an_edge(edge_ele)
+ new_candidates.append(new_)
+ except:
+ continue
+
+ return new_candidates
+
+
+def adding_an_edge_from_gt(candidate, gt_data):
+ new_candidates = []
+ corners_array = candidate.graph.getCornersArray()
+ edges_array = candidate.graph.getEdgesArray()
+
+ gt_corners = gt_data['corners'].copy()
+ gt_edges = gt_data['edges'].copy()
+
+ _, _, map_same_location = get_wrong_corners(
+ corners_array, gt_corners, edges_array, gt_edges)
+
+ gt_corners, gt_edges = simplify_gt(map_same_location, gt_corners, gt_edges)
+
+ _, _, map_same_location = get_wrong_corners(
+ corners_array, gt_corners, edges_array, gt_edges)
+
+ for corner_i in range(gt_corners.shape[0]):
+ if map_same_location[corner_i] is None:
+ # doesn't exist in candidate
+ neighbor_id = get_neighbor_corner_id(corner_i, gt_edges)
+ for corner_j in neighbor_id:
+ if map_same_location[corner_j] is not None:
+ # exist corner in candidate that maps neighbor corner
+ new_candidate = candidate.copy()
+ new_corner = Element(
+ (
+ int(np.round(gt_corners[corner_i, 0])), int(np.round(gt_corners[corner_i, 1]))
+ )
+ )
+ if new_candidate.addable(new_corner) is False:
+ continue
+ # new corner can be too close to an edge
+ flag = False
+ for edge_ele in new_candidate.graph.getEdges():
+ if get_distance_of_corner_and_edge(edge_ele.x[0].x, edge_ele.x[1].x, new_corner.x) < 7:
+ flag = True
+ break
+ if flag:
+ continue
+
+ new_corner = new_candidate.addCorner(new_corner)
+ neighbor_index = map_same_location[corner_j]
+ neighbor_corner = new_candidate.graph.getCorners()[neighbor_index]
+ new_edge = new_candidate.addEdge(new_corner, neighbor_corner)
+ if new_candidate.graph.checkIntersectionEdge(new_edge):
+ continue
+ new_candidates.append(new_candidate)
+
+ return new_candidates
+
+
+def adding_a_corner_from_two_edges_extension(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ edges = candidate.graph.getEdges()
+ for edge_i in range(len(edges)):
+ for edge_j in range(edge_i + 1, len(edges)):
+ edgeA = edges[edge_i]
+ edgeB = edges[edge_j]
+ if graph.isNeighbor(edgeA, edgeB):
+ continue
+ intersection_loc = get_two_edge_intersection_location(edgeA.x[0].x, edgeA.x[1].x, edgeB.x[0].x,
+ edgeB.x[1].x)
+ if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
+ intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
+ continue
+ # intersection point can not be too close to an edge
+ flag = False
+ for edge_ele in graph.getEdges():
+ if get_distance_of_corner_and_edge(edge_ele.x[0].x, edge_ele.x[1].x, intersection_loc) < 7:
+ flag = True
+ break
+ if flag:
+ continue
+ new_candidate = candidate.copy()
+ new_graph = new_candidate.graph
+ new_edgeA = new_graph.getRealElement(edgeA)
+ new_edgeB = new_graph.getRealElement(edgeB)
+ new_corner = Element(intersection_loc)
+ if new_candidate.addable(new_corner) is False:
+ continue
+ new_corner = new_candidate.addCorner_v2(new_corner)
+ # get cornerA and cornerB from edgeA, edgeB
+ if l2_distance(new_corner.x, new_edgeA.x[0].x) < l2_distance(new_corner.x, new_edgeA.x[1].x):
+ cornerA = new_edgeA.x[0]
+ else:
+ cornerA = new_edgeA.x[1]
+ if l2_distance(new_corner.x, new_edgeB.x[0].x) < l2_distance(new_corner.x, new_edgeB.x[1].x):
+ cornerB = new_edgeB.x[0]
+ else:
+ cornerB = new_edgeB.x[1]
+
+ # new edge can not be too short
+ if l2_distance(cornerA.x, new_corner.x) < 7:
+ continue
+ if l2_distance(cornerB.x, new_corner.x) < 7:
+ continue
+
+ # new intersection cannot be too flat
+ if degree_of_three_corners(cornerA.x, cornerB.x, new_corner.x) > 165:
+ continue
+
+ flag = False
+ for edge_ele in new_graph.getEdges():
+ if new_corner in edge_ele.x and cornerA in edge_ele.x:
+ flag = True
+ break
+ if edge_ele.x[0] not in (new_corner, cornerA):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[0].x)
+ if l <= 7:
+ flag = True
+ break
+ if edge_ele.x[1] not in (new_corner, cornerA):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[1].x)
+ if l <= 7:
+ flag = True
+ break
+ if flag:
+ continue
+ add_edgeA = new_candidate.addEdge(new_corner, cornerA)
+ if new_graph.checkIntersectionEdge(add_edgeA):
+ continue
+
+ flag = False
+ for edge_ele in new_graph.getEdges():
+ if new_corner in edge_ele.x and cornerB in edge_ele.x:
+ flag = True
+ break
+ if edge_ele.x[0] not in (new_corner, cornerB):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[0].x)
+ if l <= 7:
+ flag = True
+ break
+ if edge_ele.x[1] not in (new_corner, cornerB):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[1].x)
+ if l <= 7:
+ flag = True
+ break
+ if flag:
+ continue
+ add_edgeB = new_candidate.addEdge(new_corner, cornerB)
+ if new_graph.checkIntersectionEdge(add_edgeB):
+ continue
+
+ # make real new candidate
+ # new_candidate = candidate.copy()
+ # new_graph = new_candidate.graph
+ # new_corner = Element(intersection_loc)
+ # new_corner = new_graph.add_corner_v2(new_corner)
+ # new_candidate = new_candidate.generate_new_candidate_add_an_edge(new_corner, cornerA)
+ # new_candidate = new_candidate.generate_new_candidate_add_an_edge(new_corner, cornerB)
+
+ new_candidates.append(new_candidate)
+ return new_candidates
+
+
+def adding_a_corner_from_parallel(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ edges = candidate.graph.getEdges()
+ for edge_i in range(len(edges)):
+ for edge_j in range(edge_i + 1, len(edges)):
+ edgeA = edges[edge_i]
+ edgeB = edges[edge_j]
+ # get intersection loc
+ if graph.isNeighbor(edgeA, edgeB):
+ shared_corner = edgeA.x[0] if edgeA.x[0] in edgeB.x else edgeA.x[1]
+ intersection_loc = shared_corner.x
+ else:
+ intersection_loc = get_two_edge_intersection_location(
+ edgeA.x[0].x, edgeA.x[1].x, edgeB.x[0].x, edgeB.x[1].x)
+ if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
+ intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
+ continue
+
+ # get another two loc
+ locA = edgeA.x[1].x if \
+ l2_distance(edgeA.x[0].x, intersection_loc) < l2_distance(edgeA.x[1].x, intersection_loc) else \
+ edgeA.x[0].x
+ locB = edgeB.x[1].x if \
+ l2_distance(edgeB.x[0].x, intersection_loc) < l2_distance(edgeB.x[1].x, intersection_loc) else \
+ edgeB.x[0].x
+
+ # get new loc
+ new_loc = (locA[0] + locB[0] - intersection_loc[0], locA[1] + locB[1] - intersection_loc[1])
+ if new_loc[0] >= 255 or new_loc[1] >= 255 or \
+ new_loc[0] <= 0 or new_loc[1] <= 0:
+ continue
+
+ new_corner = Element(new_loc)
+ new_candidate = candidate.copy()
+ new_graph = new_candidate.graph
+ edgeA = new_graph.getRealElement(edgeA)
+ edgeB = new_graph.getRealElement(edgeB)
+ if new_candidate.addable(new_corner) is False:
+ continue
+ new_corner = new_candidate.addCorner_v2(new_corner)
+ # get cornerA and cornerB from edgeA, edgeB
+ cornerA = edgeA.x[1] if l2_distance(edgeA.x[0].x, intersection_loc) < l2_distance(edgeA.x[1].x,
+ intersection_loc) \
+ else edgeA.x[0]
+ cornerB = edgeB.x[1] if l2_distance(edgeB.x[0].x, intersection_loc) < l2_distance(edgeB.x[1].x,
+ intersection_loc) \
+ else edgeB.x[0]
+
+ # new edge can not be too short
+ if l2_distance(cornerA.x, new_corner.x) < 12:
+ continue
+ if l2_distance(cornerB.x, new_corner.x) < 12:
+ continue
+
+ flag = False
+ for edge_ele in new_graph.getEdges():
+ if new_corner in edge_ele.x and cornerA in edge_ele.x:
+ flag = True
+ break
+ if edge_ele.x[0] not in (new_corner, cornerA):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[0].x)
+ if l <= 7:
+ flag = True
+ break
+ if edge_ele.x[1] not in (new_corner, cornerA):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[1].x)
+ if l <= 7:
+ flag = True
+ break
+ if flag:
+ continue
+ add_edgeA = new_candidate.addEdge(new_corner, cornerA)
+ if new_graph.checkIntersectionEdge(add_edgeA):
+ continue
+
+ flag = False
+ for edge_ele in new_graph.getEdges():
+ if new_corner in edge_ele.x and cornerB in edge_ele.x:
+ flag = True
+ break
+ if edge_ele.x[0] not in (new_corner, cornerB):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[0].x)
+ if l <= 7:
+ flag = True
+ break
+ if edge_ele.x[1] not in (new_corner, cornerB):
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[1].x)
+ if l <= 7:
+ flag = True
+ break
+ if flag:
+ continue
+ add_edgeB = new_candidate.addEdge(new_corner, cornerB)
+ if new_graph.checkIntersectionEdge(add_edgeB):
+ continue
+
+ new_candidates.append(new_candidate)
+ return new_candidates
+
+
+def adding_a_orthogonal_edge(candidate):
+ new_candidates = []
+ graph = candidate.graph
+ edges = candidate.graph.getEdges()
+ for edge in edges:
+ cornerA = edge.x[0]
+ cornerB = edge.x[1]
+
+ # get orthogonal direction
+ dir_ = (cornerA.x[1] - cornerB.x[1], cornerB.x[0] - cornerA.x[0])
+
+ for the_corner in edge.x:
+ temp_orth_loc = (the_corner.x[0] - dir_[0], the_corner.x[1] - dir_[1])
+ for inter_edge in edges:
+ if inter_edge == edge:
+ continue
+ if the_corner in inter_edge.x:
+ continue
+ intersection_loc = get_two_edge_intersection_location(
+ the_corner.x, temp_orth_loc, inter_edge.x[0].x, inter_edge.x[1].x
+ )
+ if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
+ intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
+ continue
+ if np.dot((inter_edge.x[0].x[0] - intersection_loc[0], inter_edge.x[0].x[1] - intersection_loc[1]),
+ (inter_edge.x[1].x[0] - intersection_loc[0], inter_edge.x[1].x[1] - intersection_loc[1])) > 0:
+ # which means the intersection is not inside inter_edge but at the edge extension
+ continue
+ if l2_distance(intersection_loc, inter_edge.x[0].x) < 5 or \
+ l2_distance(intersection_loc, inter_edge.x[1].x) < 5:
+ continue
+
+ # no thin degree with neighbor edge
+ flag = False
+ neighbor_corners = graph.getNeighborCorner(the_corner)
+ for corner_ele in neighbor_corners:
+ if corner_ele in edge.x:
+ continue
+ if degree_of_three_corners(corner_ele.x, intersection_loc, the_corner.x) < 15:
+ flag = True
+ break
+ if degree_of_three_corners(corner_ele.x, intersection_loc, the_corner.x) > 165:
+ flag = True
+ break
+ if flag:
+ continue
+
+ new_candidate = candidate.copy()
+ new_graph = new_candidate.graph
+ new_corner = Element(intersection_loc)
+ if new_candidate.addable(new_corner) is False:
+ continue
+ new_corner = new_candidate.addCorner_v2(new_corner)
+
+ # new edge can not be too short
+ if l2_distance(new_corner.x, the_corner.x) < 7:
+ continue
+
+ add_edge = new_candidate.addEdge(new_corner, new_graph.getRealElement(the_corner))
+ if new_graph.checkIntersectionEdge(add_edge):
+ continue
+
+ new_candidates.append(new_candidate)
+ return new_candidates
+
+
+class _thread(threading.Thread):
+ def __init__(self, threadID, name, candidate, lock, result_list, func):
+ threading.Thread.__init__(self)
+ self.threadID = threadID
+ self.name = name
+ self.candidate = candidate
+ self.lock = lock
+ self.result_list = result_list
+ self.func = func
+
+ def run(self):
+ print('running id: ', self.name)
+ start_time = time.time()
+ candidates = self.func(self.candidate)
+ print('test: =================================', self.name, len(candidates))
+ self.lock.acquire()
+ self.result_list.extend(candidates)
+ self.lock.release()
+ print(self.name, "spend time: {}s".format(time.time() - start_time))
+
+
+def candidate_enumerate_training(candidate, gt):
+ new_candidates = []
+ # remove a corner
+ try:
+ new_ = removing_a_corner_operation(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with remove a corner !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
+
+ # remove a colinear corner
+ try:
+ new_ = removing_a_colinear_corner_operation(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with remove a colinear corner !!!!!!!!!!!!!!!!!!!!!!!!!!!')
+
+ # remove an edge
+ try:
+ new_ = removing_an_edge_operation(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with remove an edge !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
+
+ # add an edge from existed corner
+ try:
+ new_ = adding_an_edge_operation(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with add an edge from existed corner !!!!!!!!!!!!!!!!!!!!')
+
+ # add a corner from two edges
+ try:
+ new_ = adding_a_corner_from_two_edges_extension(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with add a corner from two edges !!!!!!!!!!!!!!!!!!!!!!!!')
+
+ try:
+ new_ = adding_a_corner_from_parallel(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with add a corner from parallel !!!!!!!!!!!!!!!!!!!!!!!!')
+
+ # add an edge from gt
+ try:
+ new_ = adding_an_edge_from_gt(candidate, gt)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with add an edge from gt !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
+
+ # add a orthogonal edge
+ try:
+ new_ = adding_a_orthogonal_edge(candidate)
+ if len(new_) > 0:
+ new_candidates.append(random.choice(new_))
+ except:
+ print('something wrong with add a orthogonal edge !!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
+ return new_candidates
+
+
+def candidate_enumerate(candidate):
+ new_candidates = []
+ new_candidates.extend(removing_a_corner_operation(candidate))
+ new_candidates.extend(removing_a_colinear_corner_operation(candidate))
+ new_candidates.extend(removing_an_edge_operation(candidate))
+ new_candidates.extend(adding_an_edge_operation(candidate))
+ new_candidates.extend(adding_a_corner_from_two_edges_extension(candidate))
+ new_candidates.extend(adding_a_corner_from_parallel(candidate))
+ new_candidates.extend(adding_a_orthogonal_edge(candidate))
+
+ return new_candidates
+
+
+def candidate_enumerate_thread(candidate):
+ new_candidates = []
+ lock = threading.Lock()
+
+ thread1 = _thread(1, 'remove_a_corner', candidate, lock, new_candidates, removing_a_corner_operation)
+ thread2 = _thread(2, 'remove_a_colinear_corner', candidate, lock, new_candidates,
+ removing_a_colinear_corner_operation)
+ thread3 = _thread(3, 'add_an_edge', candidate, lock, new_candidates, adding_an_edge_operation)
+ thread4 = _thread(4, 'remove_an_edge', candidate, lock, new_candidates, removing_an_edge_operation)
+
+ thread1.start()
+ thread2.start()
+ thread3.start()
+ thread4.start()
+
+ threads = []
+ threads.append(thread1)
+ threads.append(thread2)
+ threads.append(thread3)
+ threads.append(thread4)
+
+ for t in threads:
+ t.join()
+
+ return new_candidates
+
+
+def reduce_duplicate_candidate(candidates):
+ i = 0
+ while i < len(candidates):
+ for j in reversed(range(i + 1, len(candidates))):
+ if candidates[i].equal(candidates[j]):
+ del candidates[j]
+ i = i + 1
+ return candidates
+
+
+def save_candidate_image(candidate, base_path, base_name):
+ corners = candidate.graph.getCornersArray()
+ edges = candidate.graph.getEdgesArray()
+ # graph svg
+ svg = svg_generate(corners, edges, base_name, samecolor=True)
+ svg.saveas(os.path.join(base_path, base_name + '.svg'))
+ # corner image
+ temp_mask = np.zeros((256, 256))
+ for ele in candidate.graph.getCorners():
+ if ele.get_score() < 0:
+ temp_mask = cv2.circle(temp_mask, ele.x[::-1], 3, 1, -1)
+ fig = plt.figure(frameon=False)
+ fig.set_size_inches(1, 1)
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
+ ax.set_axis_off()
+ fig.add_axes(ax)
+ ax.imshow(temp_mask, aspect='auto')
+ fig.savefig(os.path.join(base_path, base_name + '_corner.png'), dpi=256)
+ # edges image
+ temp_mask = np.zeros((256, 256))
+ for ele in candidate.graph.getEdges():
+ if ele.get_score() < 0:
+ A = ele.x[0]
+ B = ele.x[1]
+ temp_mask = cv2.line(temp_mask, A.x[::-1], B.x[::-1], 1, thickness=1)
+ ax.imshow(temp_mask, aspect='auto')
+ fig.savefig(os.path.join(base_path, base_name + '_edge.png'), dpi=256)
+ # region no need fig
+ plt.close()
+
+
+#########################################################################################
+###################################### Class ############################################
+#########################################################################################
+
+class Element:
+ def __init__(self, x, safe_count=0):
+ assert type(x) is tuple
+ assert type(x[0]) == int or type(x[0]) == Element
+ assert type(x[1]) == int or type(x[1]) == Element
+ self.x = x
+ self.__score = None
+ self.safe_count = safe_count
+
+ def store_score(self, score):
+ self.__score = score
+
+ def get_score(self):
+ return self.__score
+
+ def equal(self, ele):
+ if type(self.x[0]) != type(ele.x[0]):
+ return False
+ if type(self.x[0]) == int:
+ # corner
+ return True if self.x[0] == ele.x[0] and self.x[1] == ele.x[1] else False
+ if type(self.x[0]) == Element:
+ # edge
+ if self.x[0].equal(ele.x[0]) and self.x[1].equal(ele.x[1]):
+ return True
+ if self.x[1].equal(ele.x[0]) and self.x[0].equal(ele.x[1]):
+ return True
+ return False
+ raise BaseException('no implement type')
+
+
+class regionCache():
+ def __init__(self, datapath):
+ self.cache = {}
+ self.datapath = datapath
+
+ def get_region(self, name):
+ if name in self.cache.keys():
+ return self.cache[name]
+ gt_mask = np.load(os.path.join(self.datapath, name + '.npy'))
+ if len(self.cache) == 5:
+ self.cache.pop(list(self.cache.keys())[0])
+ self.cache[name] = gt_mask
+ return gt_mask
+
+
+class imgCache():
+ def __init__(self, datapath):
+ self.cache = {}
+ self.datapath = datapath
+
+ def get_image(self, name):
+ if name in self.cache.keys():
+ return self.cache[name]
+ img = skimage.img_as_float(plt.imread(os.path.join(self.datapath, 'rgb', name + '.jpg')))
+ if len(self.cache) == 5:
+ self.cache.pop(list(self.cache.keys())[0])
+ self.cache[name] = img
+ return img
+
+
+class Graph:
+ def __init__(self, corners, edges):
+ corners, edges = sort_graph(corners, edges)
+
+ self.__corners = []
+ for corner_i in range(corners.shape[0]):
+ self.__corners.append(
+ Element(
+ tuple(
+ (int(corners[corner_i, 0]), int(corners[corner_i, 1]))
+ )
+ )
+ )
+ self.__edges = []
+ for edge_i in range(edges.shape[0]):
+ self.__edges.append(Element((self.__corners[edges[edge_i, 0]], self.__corners[edges[edge_i, 1]])))
+ self.__regions = []
+ self.__regions.append(Element((0, 0))) # we use entire region here
+
+ @classmethod
+ def initialFromTuple(cls, corners, edges):
+ edge_index = []
+ for item in edges:
+ a = corners.index(item[0])
+ b = corners.index(item[1])
+ edge_index.append((a, b))
+ edge_index = np.array(edge_index)
+ corners = np.array(corners)
+ return cls(corners, edge_index)
+
+ def store_score(self, corner_score=None, edge_score=None, region_score=None):
+ '''
+ :param corner_score: np array size: len(corners)
+ :param edge_score: np array size: len(edges)
+ :param region_score: np.array size: len(regions)
+ :return:
+ '''
+ if corner_score is not None:
+ for idx, element in enumerate(self.__corners):
+ element.store_score(corner_score[idx])
+ if edge_score is not None:
+ for idx, element in enumerate(self.__edges):
+ element.store_score(edge_score[idx])
+ if region_score is not None:
+ for idx, element in enumerate(self.__regions):
+ element.store_score(region_score[idx])
+ return
+
+ def getCornersArray(self):
+ c = []
+ for ele in self.__corners:
+ c.append(ele.x)
+ return np.array(c)
+
+ def getEdgesArray(self):
+ c = []
+ for ele in self.__edges:
+ corner1 = ele.x[0]
+ corner2 = ele.x[1]
+ idx1 = self.__corners.index(corner1)
+ idx2 = self.__corners.index(corner2)
+ c.append([idx1, idx2])
+ return np.array(c)
+
+ def getCorners(self):
+ return self.__corners
+
+ def getRegions(self):
+ return self.__regions
+
+ def getEdges(self):
+ return self.__edges
+
+ def graph_score(self):
+ corner_score = 0
+ for ele in self.__corners:
+ corner_score += ele.get_score()
+ edge_score = 0
+ for ele in self.__edges:
+ edge_score += ele.get_score()
+ region_score = 0
+ for ele in self.__regions:
+ region_score += ele.get_score()
+ return score_weights[0] * corner_score + score_weights[1] * edge_score + score_weights[2] * region_score
+
+ def corner_score(self):
+ corner_score = 0
+ for ele in self.__corners:
+ corner_score += ele.get_score()
+ return corner_score
+
+ def edge_score(self):
+ edge_score = 0
+ for ele in self.__edges:
+ edge_score += ele.get_score()
+ return edge_score
+
+ def region_score(self):
+ region_score = 0
+ for ele in self.__regions:
+ region_score += ele.get_score()
+ return region_score
+
+ def remove(self, ele):
+ '''
+ :param ele: remove eles as well as some other related elements
+ :return: set() of removed elements
+ '''
+ # corner
+ removed = set()
+ if ele in self.__corners:
+ self.__corners.remove(ele)
+ removed.add(ele)
+ # remove edge that has the corner
+ for idx in reversed(range(len(self.__edges))):
+ edge_ele = self.__edges[idx]
+ if ele in edge_ele.x:
+ removed = removed.union(self.remove(edge_ele))
+ # edge
+ elif ele in self.__edges:
+ self.__edges.remove(ele)
+ removed.add(ele)
+ corner1 = ele.x[0]
+ corner2 = ele.x[1]
+ if corner1.safe_count == 0:
+ # can be delete
+ _count = 0
+ for edge_ele in self.__edges:
+ if corner1 in edge_ele.x:
+ _count += 1
+ if _count == 0:
+ removed = removed.union(self.remove(corner1))
+ if corner2.safe_count == 0:
+ # can be delete
+ _count = 0
+ for edge_ele in self.__edges:
+ if corner2 in edge_ele.x:
+ _count += 1
+ if _count == 0:
+ removed = removed.union(self.remove(corner2))
+ return removed
+
+ def has_edge(self, ele1, ele2):
+ """
+ :param ele1: corner1
+ :param ele2: corner2
+ :return: edge or none
+ """
+ for edge_ele in self.__edges:
+ if ele1 in edge_ele.x and ele2 in edge_ele.x:
+ return edge_ele
+ return None
+
+ def add_edge(self, ele1, ele2):
+ temp = self.has_edge(ele1, ele2)
+ if temp is not None:
+ temp.safe_count = SAFE_NUM
+ return temp
+ new_ele = Element((ele1, ele2), safe_count=SAFE_NUM)
+ self.__edges.append(new_ele)
+ return new_ele
+
+ def add_corner(self, ele):
+ for corner in self.__corners:
+ if corner.x == ele.x:
+ corner.safe_count = SAFE_NUM
+ return corner
+ ele.safe_count = SAFE_NUM
+ self.__corners.append(ele)
+ return ele
+
+ def add_corner_v2(self, ele):
+ # if new corner is near a existed corner, return the existed corner
+ # if new corner is on an edge, split edge
+ for corner in self.__corners:
+ if l2_distance(corner.x, ele.x) < 5:
+ corner.safe_count = SAFE_NUM
+ return corner
+ min_d = 256
+ the_edge = None
+ for edge in self.__edges:
+ temp = get_distance_of_corner_and_edge(edge.x[0].x, edge.x[1].x, ele.x)
+ if temp < min_d:
+ min_d = temp
+ the_edge = edge
+ if min_d < 3:
+ # split edge
+ corner1 = the_edge.x[0]
+ corner2 = the_edge.x[1]
+ new_ele = Element((corner1, ele), safe_count=the_edge.safe_count)
+ self.__edges.append(new_ele)
+ new_ele = Element((corner2, ele), safe_count=the_edge.safe_count)
+ self.__edges.append(new_ele)
+ self.__edges.remove(the_edge)
+ ele.safe_count = SAFE_NUM
+ self.__corners.append(ele)
+ return ele
+
+ def checkColinearCorner(self, ele):
+ if self.getCornerDegree(ele) != 2:
+ return False
+ edge_in = []
+ for edge_ele in self.__edges:
+ if ele in edge_ele.x:
+ edge_in.append(edge_ele)
+ if len(edge_in) == 2:
+ break
+ two_neighbor = {edge_in[0].x[0], edge_in[0].x[1], edge_in[1].x[0], edge_in[1].x[1]}
+ two_neighbor.remove(ele)
+ two_neighbor = tuple(two_neighbor)
+ if self.has_edge(two_neighbor[0], two_neighbor[1]) is not None:
+ return False
+
+ line1 = np.array(ele.x) - np.array(two_neighbor[0].x)
+ line2 = np.array(two_neighbor[1].x) - np.array(ele.x)
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
+ cos = min(1, max(-1, cos))
+ if np.arccos(cos) < np.pi / 9: # 20 degree
+ return True
+ return False
+
+ def checkIntersectionEdge(self, ele=None):
+ if ele is None:
+ for edge_i in range(len(self.__edges)):
+ for edge_j in range(edge_i + 1, len(self.__edges)):
+ if check_intersection(self.__edges[edge_i], self.__edges[edge_j]):
+ return True
+ return False
+ for edge_ele in self.__edges:
+ if ele == edge_ele:
+ continue
+ if check_intersection(edge_ele, ele):
+ return True
+ return False
+
+ def getCornerDegree(self, ele):
+ degree = 0
+ for edge_ele in self.__edges:
+ if ele in edge_ele.x:
+ degree += 1
+ return degree
+
+ def getEdgeConnected(self, ele):
+ out_ = set()
+ if type(ele.x[0]) == int:
+ # corner
+ for edge_ele in self.__edges:
+ if ele in edge_ele.x:
+ out_.add(edge_ele)
+ return out_
+ if type(ele.x[0]) == Element:
+ # Edge
+ out_ = out_.union(self.getEdgeConnected(ele.x[0]))
+ out_ = out_.union(self.getEdgeConnected(ele.x[1]))
+ if ele in out_:
+ out_.remove(ele)
+ return out_
+
+ def getNeighborCorner(self, ele):
+ out_ = set()
+ for edge_ele in self.__edges:
+ if ele == edge_ele.x[0]:
+ out_.add(edge_ele.x[1])
+ if ele == edge_ele.x[1]:
+ out_.add(edge_ele.x[0])
+ return out_
+
+ def getRealElement(self, ele):
+ # edge
+ if type(ele.x[0]) == Element:
+ for e in self.__edges:
+ if (e.x[0].x == ele.x[0].x and e.x[1].x == ele.x[1].x) or \
+ (e.x[1].x == ele.x[0].x and e.x[0].x == ele.x[1].x):
+ return e
+ raise BaseException("no same edge exists.")
+ # corner
+ elif type(ele.x[0]) == int:
+ for c in self.__corners:
+ if c.x == ele.x:
+ return c
+ raise BaseException("no same corner exists.")
+
+ def copy(self):
+ corners = self.getCornersArray()
+ edges = self.getEdgesArray()
+ new_graph = Graph(corners, edges)
+ for idx, ele in enumerate(self.__corners):
+ new_graph.__corners[idx].store_score(self.__corners[idx].get_score())
+ for idx, ele in enumerate(self.__edges):
+ new_graph.__edges[idx].store_score(self.__edges[idx].get_score())
+ for idx, ele in enumerate(self.__regions):
+ new_graph.__regions[idx].store_score(self.__regions[idx].get_score)
+ return new_graph
+
+ def update_safe_count(self):
+ for ele in self.__corners:
+ if ele.safe_count > 0:
+ ele.safe_count -= 1
+ for ele in self.__edges:
+ if ele.safe_count > 0:
+ ele.safe_count -= 1
+
+ def isNeighbor(self, element1, element2):
+ '''
+ :param element1:
+ :param element2:
+ :return: True / False
+ '''
+ if element1 == element2:
+ return False
+ if type(element1.x[0]) != type(element2.x[0]):
+ # corner and edge
+ return False
+ if type(element1.x[0]) == int:
+ # both are corner type
+ for edge_ele in self.__edges:
+ if edge_ele.x[0] == element1 and edge_ele.x[1] == element2:
+ return True
+ if edge_ele.x[0] == element2 and edge_ele.x[1] == element1:
+ return True
+ return False
+ if type(element1.x[0]) == Element:
+ # both are edge type
+ if len({element1.x[0], element1.x[1], element2.x[0], element2.x[1]}) < 4:
+ return True
+ return False
+
+ def equal(self, graph):
+ if len(self.__corners) != len(graph.__corners) or \
+ len(self.__edges) != len(graph.__edges):
+ return False
+ for corner_i in range(len(self.__corners)):
+ if self.__corners[corner_i].equal(graph.__corners[corner_i]) is False:
+ return False
+ for edge_i in range(len(self.__edges)):
+ if self.__edges[edge_i].equal(graph.__edges[edge_i]) is False:
+ return False
+
+ return True
+
+
+class Candidate:
+ def __init__(self, graph, name, corner_existed_before, edge_existed_before):
+ '''
+ :param graph: Class graph
+ :param name: string, data name
+ :param corner_existed_before: dict {(x_i,y_i):c_1 ...} indicates counts for corresponding corners, after one search,
+ counts -= 1, if count == 0, remove from the set.
+ :param edge_existed_before: dict {((x_i1,y_i1),(x_i2,y_i2)):ci}
+ '''
+ self.graph = graph
+ self.name = name
+ self.corner_existed_before = corner_existed_before
+ self.edge_existed_before = edge_existed_before
+
+ @classmethod
+ def initial(cls, graph, name):
+ return cls(graph, name, {}, {})
+
+ def update(self):
+ # all the existed before elements count - 1
+ for key in self.corner_existed_before.keys():
+ self.corner_existed_before[key] -= 1
+ for key in self.edge_existed_before.keys():
+ self.edge_existed_before[key] -= 1
+
+ # check if some need to remove from existed before set
+ for key in list(self.corner_existed_before.keys()):
+ if self.corner_existed_before[key] == 0:
+ self.corner_existed_before.pop(key)
+
+ for key in list(self.edge_existed_before.keys()):
+ if self.edge_existed_before[key] == 0:
+ self.edge_existed_before.pop(key)
+
+ # update graph
+ self.graph.update_safe_count()
+
+ def copy(self):
+ corner_existed_before = self.corner_existed_before.copy()
+ edge_existed_before = self.edge_existed_before.copy()
+ new_graph = self.graph.copy()
+ return Candidate(new_graph, self.name, corner_existed_before, edge_existed_before)
+
+ def removable(self, ele):
+ '''
+ :param x: input is element
+ :return:
+ '''
+ assert type(ele) == Element
+ # edge
+ return True if ele.safe_count == 0 else False
+
+ def addable(self, ele):
+ if type(ele) == Element:
+ if type(ele.x[0]) == Element:
+ # edge
+ for edge in self.graph.getEdges():
+ c1 = edge.x[0]
+ c2 = edge.x[1]
+ if (ele.x[0].x == c1.x and ele.x[1].x == c2.x) or \
+ (ele.x[1].x == c1.x and ele.x[0].x == c2.x):
+ # already existed
+ return False
+ corner1_loc = ele.x[0].x
+ corner2_loc = ele.x[1].x
+ if (corner1_loc, corner2_loc) in self.edge_existed_before.keys() or \
+ (corner2_loc, corner1_loc) in self.edge_existed_before.keys():
+ return False
+ return True
+ else:
+ # corner
+ for corner in self.graph.getCorners():
+ if l2_distance(ele.x, corner.x) < TWO_CORNER_MINIMUM_DISTANCE:
+ # already existed
+ return False
+ if ele.x in self.corner_existed_before.keys():
+ return False
+ return True
+ else: # (x,y) or ((x1,y1),(x2,y2))
+ if type(ele[0]) == tuple:
+ # edge
+ corner1_loc = ele[0]
+ corner2_loc = ele[1]
+ for edge in self.graph.getEdges():
+ c1 = edge.x[0]
+ c2 = edge.x[1]
+ if (corner1_loc == c1.x and corner2_loc == c2.x) or \
+ (corner2_loc == c1.x and corner1_loc == c2.x):
+ # already existed
+ return False
+ if (corner1_loc, corner2_loc) in self.edge_existed_before.keys() or \
+ (corner2_loc, corner1_loc) in self.edge_existed_before.keys():
+ return False
+ return True
+ else:
+ # corner
+ for corner in self.graph.getCorners():
+ if l2_distance(ele, corner.x) < TWO_CORNER_MINIMUM_DISTANCE:
+ # already existed
+ return False
+ if ele in self.corner_existed_before.keys():
+ return False
+ return True
+
+ def addCorner(self, ele):
+ if ele.x in self.corner_existed_before.keys():
+ raise BaseException('cannot add the corner')
+ new_ele = self.graph.add_corner(ele) # possible changed
+ return new_ele
+
+ def addCorner_v2(self, ele):
+ if ele.x in self.corner_existed_before.keys():
+ raise BaseException('cannot add the corner')
+ new_ele = self.graph.add_corner_v2(ele)
+ return new_ele
+
+ def addEdge(self, ele1, ele2):
+ corner1 = ele1
+ corner2 = ele2
+ assert corner1 in self.graph.getCorners()
+ assert corner2 in self.graph.getCorners()
+ if (corner1.x, corner2.x) in self.edge_existed_before.keys() or \
+ (corner2.x, corner1.x) in self.edge_existed_before.keys():
+ raise BaseException('cannot add the edge')
+ new_ele = self.graph.add_edge(corner1, corner2)
+ return new_ele
+
+ def removeCorner(self, ele):
+ if ele.x in self.corner_existed_before.keys():
+ raise BaseException('already existed.')
+ self.corner_existed_before[ele.x] = SAFE_NUM
+
+ def removeEdge(self, ele):
+ corner1 = ele.x[0]
+ corner2 = ele.x[1]
+ loc1 = corner1.x
+ loc2 = corner2.x
+ if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
+ loc1 = corner2.x
+ loc2 = corner1.x
+ if (loc1, loc2) in self.edge_existed_before.keys():
+ raise BaseException('already existed.')
+ self.edge_existed_before[(loc1, loc2)] = SAFE_NUM
+
+ def generate_new_candidate_remove_a_colinear_corner(self, ele):
+ # need to check if ele is a colinear corner before
+ new_candidate = self.copy()
+ new_graph = new_candidate.graph
+ ele = new_graph.getRealElement(ele)
+
+ # find two neighbor corners
+ temp = set()
+ for element in new_graph.getEdgeConnected(ele):
+ # edge
+ if type(element.x[0]) == Element:
+ temp.add(element.x[0])
+ temp.add(element.x[1])
+ temp.remove(ele)
+ temp = tuple(temp)
+ assert len(temp) == 2
+
+ # add edge to two neighbor corners
+ # (add before remove, in case the neighbor corners will be removed by zero degree)
+ # special case no need to check existed_before, instead remove if in existed_before dict
+ added = new_graph.add_edge(temp[0], temp[1])
+ if (temp[0].x, temp[1].x) in self.edge_existed_before.keys():
+ self.edge_existed_before.pop((temp[0].x, temp[1].x))
+ if (temp[1].x, temp[0].x) in self.edge_existed_before.keys():
+ self.edge_existed_before.pop((temp[1].x, temp[0].x))
+
+ # remove
+ removed = new_graph.remove(ele)
+
+ # add removed elements into existed before
+ for element in removed:
+ # edge
+ if type(element.x[0]) == Element:
+ new_candidate.removeEdge(element)
+ # corner
+ elif type(element.x[0]) == int:
+ new_candidate.removeCorner(element)
+ else:
+ raise BaseException('wrong type.')
+
+ # modify scores that need to be recounted
+ # all corners are recounted
+ for element in new_graph.getCorners():
+ element.store_score(None)
+
+ # edges that are neighbors to the removed edges OR new edges will be recounted
+ for element in new_graph.getEdges():
+ for modified_ele in removed.union({added}):
+ if new_graph.isNeighbor(element, modified_ele):
+ element.store_score(None)
+ break
+
+ # all regions are recounted
+ for element in new_graph.getRegions():
+ element.store_score(None)
+
+ return new_candidate
+
+ def generate_new_candidate_remove_a_corner(self, ele):
+ # need to check if ele is removable before call this method
+ new_candidate = self.copy()
+ new_graph = new_candidate.graph
+ ele = new_graph.getRealElement(ele)
+ removed = new_graph.remove(ele)
+
+ # add removed elements into existed before
+ for element in removed:
+ # edge
+ if type(element.x[0]) == Element:
+ corner1 = element.x[0]
+ corner2 = element.x[1]
+ loc1 = corner1.x
+ loc2 = corner2.x
+ if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
+ loc1 = corner2.x
+ loc2 = corner1.x
+ if (loc1, loc2) in self.edge_existed_before.keys():
+ raise BaseException('already existed.')
+ new_candidate.edge_existed_before[(loc1, loc2)] = SAFE_NUM
+ # corner
+ elif type(element.x[0]) == int:
+ if element.x in self.corner_existed_before.keys():
+ raise BaseException('already existed.')
+ new_candidate.corner_existed_before[element.x] = SAFE_NUM
+ else:
+ raise BaseException('wrong type.')
+
+ # modify scores that need to be recounted
+ # all corners are recounted
+ for element in new_graph.getCorners():
+ element.store_score(None)
+
+ # edges that are neighbors to the removed edges will be recounted
+ for element in new_graph.getEdges():
+ for removed_ele in removed:
+ if new_graph.isNeighbor(element, removed_ele):
+ element.store_score(None)
+ break
+
+ # all regions are recounted
+ for element in new_graph.getRegions():
+ element.store_score(None)
+
+ return new_candidate
+
+ def generate_new_candidate_add_an_edge(self, ele1, ele2):
+ # need to check addable before call this method
+ new_candidate = self.copy()
+ new_graph = new_candidate.graph
+ ele1 = new_graph.getRealElement(ele1)
+ ele2 = new_graph.getRealElement(ele2)
+
+ # add edge
+ new_ele = new_candidate.addEdge(ele1, ele2)
+
+ # modify scores that need to be recounted
+ # all corners are recounted
+ for element in new_graph.getCorners():
+ element.store_score(None)
+
+ # edges that are neighbors to the added edges will be recounted
+ for element in new_graph.getEdges():
+ if new_graph.isNeighbor(element, new_ele):
+ element.store_score(None)
+
+ # all regions are recounted
+ for element in new_graph.getRegions():
+ element.store_score(None)
+
+ return new_candidate
+
+ def generate_new_candidate_remove_an_edge(self, ele):
+ # need to check if ele is removable before call this method
+ new_candidate = self.copy()
+ new_graph = new_candidate.graph
+ ele = new_graph.getRealElement(ele)
+ removed = new_graph.remove(ele)
+
+ # add removed elements into existed before
+ for element in removed:
+ # edge
+ if type(element.x[0]) == Element:
+ corner1 = element.x[0]
+ corner2 = element.x[1]
+ loc1 = corner1.x
+ loc2 = corner2.x
+ if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
+ loc1 = corner2.x
+ loc2 = corner1.x
+ if (loc1, loc2) in self.edge_existed_before.keys():
+ raise BaseException('already existed.')
+ new_candidate.edge_existed_before[(loc1, loc2)] = SAFE_NUM
+ # corner
+ elif type(element.x[0]) == int:
+ if element.x in self.corner_existed_before.keys():
+ raise BaseException('already existed.')
+ new_candidate.corner_existed_before[element.x] = SAFE_NUM
+ else:
+ raise BaseException('wrong type.')
+
+ # modify scores that need to be recounted
+ # all corners are recounted
+ for element in new_graph.getCorners():
+ element.store_score(None)
+
+ # edges that are neighbors to the removed edges will be recounted
+ for element in new_graph.getEdges():
+ for removed_ele in removed:
+ if new_graph.isNeighbor(element, removed_ele):
+ element.store_score(None)
+ break
+
+ # all regions are recounted
+ for element in new_graph.getRegions():
+ element.store_score(None)
+
+ return new_candidate
+
+ def generate_new_candidate_add_a_new_triangle(self, ele_new, ele1, ele2):
+ # this method is to add a new corner as well as two new edges into the graph
+ # need to check addable of ele_new before call this method
+ new_candidate = self.copy()
+ new_graph = new_candidate.graph
+ ele1 = new_graph.getRealElement(ele1)
+ ele2 = new_graph.getRealElement(ele2)
+
+ # add corner
+ ele_new = new_candidate.addCorner(ele_new) # ele_new possible change
+
+ # no score need to be recounted in current situation
+
+ # add two_new edge (ele1, ele_new) and (ele2, ele_new)
+ new_candidate = new_candidate.generate_new_candidate_add_an_edge(ele_new, ele1)
+ new_candidate = new_candidate.generate_new_candidate_add_an_edge(ele_new, ele2)
+
+ return new_candidate
+
+ def generate_new_candidate_add_a_corner(self, ele):
+ # need to check addable of ele before call this method
+ new_candidate = self.copy()
+ new_graph = new_candidate.graph
+
+ # add corner
+ ele = new_candidate.addCorner(ele)
+
+ # modify scores that need to be recounted
+ # all corners are recounted
+ for element in new_graph.getCorners():
+ element.store_score(None)
+
+ # no edge need to be recounted
+ # all regions are recounted
+ for element in new_graph.getRegions():
+ element.store_score(None)
+
+ return new_candidate
+
+ def equal(self, candidate):
+ return self.graph.equal(candidate.graph)
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/corner_models.py b/models/corner_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c45ea5617ae4b6d3f46f8e834da06715c1d636
--- /dev/null
+++ b/models/corner_models.py
@@ -0,0 +1,275 @@
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+import numpy as np
+import math
+from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
+ DeformableTransformerDecoder, DeformableAttnDecoderLayer
+from models.ops.modules import MSDeformAttn
+from models.resnet import convrelu
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+from einops.layers.torch import Rearrange
+from utils.misc import NestedTensor
+
+
+class HeatCorner(nn.Module):
+ """
+ The corner model of HEAT is the edge model till the edge-filtering part. So only per-candidate prediction w/o
+ relational modeling.
+ """
+ def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
+ super(HeatCorner, self).__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_feature_levels = num_feature_levels
+
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone_strides)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone_num_channels[_]
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )])
+
+ self.patch_size = 4
+ patch_dim = (self.patch_size ** 2) * input_dim
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size),
+ nn.Linear(patch_dim, input_dim),
+ nn.Linear(input_dim, hidden_dim),
+ )
+
+ self.pixel_pe_fc = nn.Linear(input_dim, hidden_dim)
+ self.transformer = CornerTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
+ dim_feedforward=1024, dropout=0.1)
+
+ self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
+
+ @staticmethod
+ def get_ms_feat(xs, img_mask):
+ out: Dict[str, NestedTensor] = {}
+ for name, x in sorted(xs.items()):
+ m = img_mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+ @staticmethod
+ def get_decoder_reference_points(height, width, device):
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
+ torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / height
+ ref_x = ref_x.reshape(-1)[None] / width
+ ref = torch.stack((ref_x, ref_y), -1)
+ return ref
+
+ def forward(self, image_feats, feat_mask, pixels_feat, pixels, all_image_feats):
+ # process image features
+ features = self.get_ms_feat(image_feats, feat_mask)
+
+ srcs = []
+ masks = []
+ all_pos = []
+
+ new_features = list()
+ for name, x in sorted(features.items()):
+ new_features.append(x)
+ features = new_features
+
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ mask = mask.to(src.device)
+ srcs.append(self.input_proj[l](src))
+ pos = self.img_pos(src).to(src.dtype)
+ all_pos.append(pos)
+ masks.append(mask)
+ assert mask is not None
+
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = feat_mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
+ pos_l = self.img_pos(src).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ all_pos.append(pos_l)
+
+ sp_inputs = self.to_patch_embedding(pixels_feat)
+
+ # compute the reference points
+ H_tgt = W_tgt = int(np.sqrt(sp_inputs.shape[1]))
+ reference_points_s1 = self.get_decoder_reference_points(H_tgt, W_tgt, sp_inputs.device)
+
+ corner_logits = self.transformer(srcs, masks, all_pos, sp_inputs, reference_points_s1, all_image_feats)
+ return corner_logits
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x):
+ mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]]).bool().to(x.device)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class CornerTransformer(nn.Module):
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+ dim_feedforward=1024, dropout=0.1,
+ activation="relu", return_intermediate_dec=False,
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
+ ):
+ super(CornerTransformer, self).__init__()
+
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, enc_n_points)
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
+
+ decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, dec_n_points)
+ self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
+
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ # upconv layers
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+ self.conv_up1 = convrelu(256 + 256, 256, 3, 1)
+ self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
+ self.conv_original_size2 = convrelu(64 + 128, d_model, 3, 1)
+ self.output_fc_1 = nn.Linear(d_model, 1)
+ self.output_fc_2 = nn.Linear(d_model, 1)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ normal_(self.level_embed)
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, all_image_feats):
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ src = src.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # encoder
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
+ mask_flatten)
+
+ # prepare input for decoder
+ bs, _, c = memory.shape
+
+ tgt = query_embed
+
+ # relational decoder
+ hs_pixels_s1, _ = self.per_edge_decoder(tgt, reference_points, memory,
+ spatial_shapes, level_start_index, valid_ratios, query_embed,
+ mask_flatten)
+
+ feats_s1, preds_s1 = self.generate_corner_preds(hs_pixels_s1, all_image_feats)
+
+ return preds_s1
+
+ def generate_corner_preds(self, outputs, conv_outputs):
+ B, L, C = outputs.shape
+ side = int(np.sqrt(L))
+ outputs = outputs.view(B, side, side, C)
+ outputs = outputs.permute(0, 3, 1, 2)
+ outputs = torch.cat([outputs, conv_outputs['layer1']], dim=1)
+ x = self.conv_up1(outputs)
+
+ x = self.upsample(x)
+ x = torch.cat([x, conv_outputs['layer0']], dim=1)
+ x = self.conv_up0(x)
+
+ x = self.upsample(x)
+ x = torch.cat([x, conv_outputs['x_original']], dim=1)
+ x = self.conv_original_size2(x)
+
+ logits = x.permute(0, 2, 3, 1)
+ preds = self.output_fc_1(logits)
+ preds = preds.squeeze(-1).sigmoid()
+ return logits, preds
diff --git a/models/corner_to_edge.py b/models/corner_to_edge.py
new file mode 100644
index 0000000000000000000000000000000000000000..f959d5ebdb89d433d75a87809d59847d51d516c7
--- /dev/null
+++ b/models/corner_to_edge.py
@@ -0,0 +1,232 @@
+import torch
+import numpy as np
+import scipy.ndimage.filters as filters
+import cv2
+import itertools
+
+NEIGHBOUR_SIZE = 5
+MATCH_THRESH = 5
+LOCAL_MAX_THRESH = 0.01
+viz_count = 0
+
+# pre-compute all combinations to generate edge candidates faster
+all_combibations = dict()
+for length in range(2, 351):
+ ids = np.arange(length)
+ combs = np.array(list(itertools.combinations(ids, 2)))
+ all_combibations[length] = combs
+
+
+def prepare_edge_data(c_outputs, annots, images, max_corner_num):
+ bs = c_outputs.shape[0]
+ # prepares parameters for each sample of the batch
+ all_results = list()
+
+ for b_i in range(bs):
+ annot = annots[b_i]
+ output = c_outputs[b_i]
+ results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]}, max_corner_num)
+ all_results.append(results)
+
+ processed_corners = [item['corners'] for item in all_results]
+ edge_coords = [item['edges'] for item in all_results]
+ edge_labels = [item['labels'] for item in all_results]
+
+ edge_info = {
+ 'edge_coords': edge_coords,
+ 'edge_labels': edge_labels,
+ 'processed_corners': processed_corners
+ }
+
+ edge_data = collate_edge_info(edge_info)
+ return edge_data
+
+
+def process_annot(annot, do_round=True):
+ corners = np.array(list(annot.keys()))
+ ind = np.lexsort(corners.T) # sort the g.t. corners to fix the order for the matching later
+ corners = corners[ind] # sorted by y, then x
+ corner_mapping = {tuple(k): v for v, k in enumerate(corners)}
+
+ edges = list()
+ for c, connections in annot.items():
+ for other_c in connections:
+ edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)])
+ edges.append(edge_pair)
+ corner_degrees = [len(annot[tuple(c)]) for c in corners]
+ if do_round:
+ corners = corners.round()
+ return corners, edges, corner_degrees
+
+
+def process_each_sample(data, max_corner_num):
+ annot = data['annot']
+ output = data['output']
+
+ preds = output.detach().cpu().numpy()
+
+ data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE)
+ maxima = (preds == data_max)
+ data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE)
+ diff = ((data_max - data_min) > 0)
+ maxima[diff == 0] = 0
+ local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH))
+ pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]] # to (x, y format)
+
+ # produce edge labels labels from pred corners here
+
+ processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot, max_corner_num)
+ # global viz_count
+ # viz_img = data['viz_img']
+ #output_path = './viz_training/{}_example_gt.png'.format(viz_count)
+ #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
+ #viz_count += 1
+
+ results = {
+ 'corners': processed_corners,
+ 'edges': edges,
+ 'labels': labels,
+ }
+ return results
+
+
+def get_edge_label_mix_gt(pred_corners, annot, max_corner_num):
+ ind = np.lexsort(pred_corners.T) # sort the pred corners to fix the order for matching
+ pred_corners = pred_corners[ind] # sorted by y, then x
+ gt_corners, edge_pairs, corner_degrees = process_annot(annot)
+
+ output_to_gt = dict()
+ gt_to_output = dict()
+ diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1))
+ diff = diff.T
+
+ if len(pred_corners) > 0:
+ for target_i, target in enumerate(gt_corners):
+ dist = diff[target_i]
+ if len(output_to_gt) > 0:
+ dist[list(output_to_gt.keys())] = 1000 # ignore already matched pred corners
+ min_dist = dist.min()
+ min_idx = dist.argmin()
+ if min_dist < MATCH_THRESH and min_idx not in output_to_gt: # a positive match
+ output_to_gt[min_idx] = (target_i, min_dist)
+ gt_to_output[target_i] = min_idx
+
+ all_corners = gt_corners.copy()
+
+ # replace matched g.t. corners with pred corners
+ for gt_i in range(len(gt_corners)):
+ if gt_i in gt_to_output:
+ all_corners[gt_i] = pred_corners[gt_to_output[gt_i]]
+
+ nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt]
+ nm_pred_ids = np.random.permutation(nm_pred_ids)
+ if len(nm_pred_ids) > 0:
+ nm_pred_corners = pred_corners[nm_pred_ids]
+ #if len(nm_pred_ids) + len(all_corners) <= 150:
+ if len(nm_pred_ids) + len(all_corners) <= max_corner_num:
+ all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0)
+ else:
+ #all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0)
+ all_corners = np.concatenate([all_corners, nm_pred_corners[:(max_corner_num - len(gt_corners)), :]], axis=0)
+
+ processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs)
+
+ return processed_corners, edges, labels
+
+
+def _get_edges(corners, edge_pairs):
+ ind = np.lexsort(corners.T)
+ corners = corners[ind] # sorted by y, then x
+ corners = corners.round()
+ id_mapping = {old: new for new, old in enumerate(ind)}
+
+ all_ids = all_combibations[len(corners)]
+ edges = corners[all_ids]
+ labels = np.zeros(edges.shape[0])
+
+ N = len(corners)
+ edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs]
+ edge_pairs = [p for p in edge_pairs if p[0] < p[1]]
+ pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs]
+ labels[pos_ids] = 1
+
+ edge_ids = np.array(all_ids)
+ return corners, edges, edge_ids, labels
+
+
+def collate_edge_info(data):
+ batched_data = {}
+ lengths_info = {}
+ for field in data.keys():
+ batch_values = data[field]
+ all_lens = [len(value) for value in batch_values]
+ max_len = max(all_lens)
+ pad_value = 0
+ batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values]
+ batch_values = np.stack(batch_values, axis=0)
+
+ if field in ['edge_coords', 'edge_labels', 'gt_values']:
+ batch_values = torch.Tensor(batch_values).long()
+ if field in ['processed_corners', 'edge_coords']:
+ lengths_info[field] = all_lens
+ batched_data[field] = batch_values
+
+ # Add length and mask into the data, the mask if for Transformers' input format, True means padding
+ for field, lengths in lengths_info.items():
+ lengths_str = field + '_lengths'
+ batched_data[lengths_str] = torch.Tensor(lengths).long()
+ mask = torch.arange(max(lengths))
+ mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1)
+ mask = mask >= batched_data[lengths_str].unsqueeze(-1)
+ mask_str = field + '_mask'
+ batched_data[mask_str] = mask
+
+ return batched_data
+
+
+def pad_sequence(seq, length, pad_value=0):
+ if len(seq) == length:
+ return seq
+ else:
+ pad_len = length - len(seq)
+ if len(seq.shape) == 1:
+ if pad_value == 0:
+ paddings = np.zeros([pad_len, ])
+ else:
+ paddings = np.ones([pad_len, ]) * pad_value
+ else:
+ if pad_value == 0:
+ paddings = np.zeros([pad_len, ] + list(seq.shape[1:]))
+ else:
+ paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value
+ padded_seq = np.concatenate([seq, paddings], axis=0)
+ return padded_seq
+
+
+def get_infer_edge_pairs(corners, confs):
+ ind = np.lexsort(corners.T)
+ corners = corners[ind] # sorted by y, then x
+ confs = confs[ind]
+
+ edge_ids = all_combibations[len(corners)]
+ edge_coords = corners[edge_ids]
+
+ edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long()
+ mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool()
+ edge_ids = torch.tensor(np.array(edge_ids))
+ return corners, confs, edge_coords, mask, edge_ids
+
+
+def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path):
+ image = image.transpose([1, 2, 0])
+ image = (image * 255).astype(np.uint8)
+ image = np.ascontiguousarray(image)
+
+ for edge, label in zip(edges, edge_labels):
+ if label == 1:
+ cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2)
+
+ for c in corners:
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
+
+ cv2.imwrite(save_path, image)
diff --git a/models/deformable_transformer.py b/models/deformable_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..be347036aa279a293e478cf0d0141995c3310450
--- /dev/null
+++ b/models/deformable_transformer.py
@@ -0,0 +1,236 @@
+import copy
+import torch
+from torch import nn, Tensor
+from models.ops.modules import MSDeformAttn
+import torch.nn.functional as F
+
+
+class DeformableTransformerEncoderLayer(nn.Module):
+ def __init__(self,
+ d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4):
+ super().__init__()
+
+ # self attention
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
+ # self attention
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index,
+ padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ return src
+
+
+class DeformableTransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
+ output = src
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
+ for _, layer in enumerate(self.layers):
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
+
+ return output
+
+
+class DeformableAttnDecoderLayer(nn.Module):
+ def __init__(self, d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4):
+ super().__init__()
+ # cross attention
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index,
+ src_padding_mask=None,
+ key_padding_mask=None):
+ # cross attention
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
+ reference_points,
+ src, src_spatial_shapes, level_start_index, src_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # ffn
+ tgt = self.forward_ffn(tgt)
+
+ return tgt
+
+
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ def __init__(self, d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4):
+ super().__init__()
+ # cross attention
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index,
+ src_padding_mask=None,
+ key_padding_mask=None,
+ get_image_feat=True):
+ # self attention
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = \
+ self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), key_padding_mask=key_padding_mask)[
+ 0].transpose(0, 1)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ if get_image_feat:
+ # cross attention
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
+ reference_points,
+ src, src_spatial_shapes, level_start_index, src_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # ffn
+ tgt = self.forward_ffn(tgt)
+
+ return tgt
+
+
+class DeformableTransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False, with_sa=True):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.with_sa = with_sa
+
+ def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
+ query_pos=None, src_padding_mask=None, key_padding_mask=None, get_image_feat=True):
+ output = tgt
+
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = reference_points[:, :, None] \
+ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
+ if self.with_sa:
+ output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index,
+ src_padding_mask, key_padding_mask, get_image_feat)
+ else:
+ output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes,
+ src_level_start_index,
+ src_padding_mask, key_padding_mask)
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
+
+ return output, reference_points
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
diff --git a/models/edge_models.py b/models/edge_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c7bbb19a981d67277cd61844b820494e2ed2095
--- /dev/null
+++ b/models/edge_models.py
@@ -0,0 +1,314 @@
+# coding=utf-8
+import torch
+import torch.nn as nn
+import numpy as np
+from models.mlp import MLP
+from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
+ DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer
+from models.ops.modules import MSDeformAttn
+from models.corner_models import PositionEmbeddingSine
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+import torch.nn.functional as F
+from utils.misc import NestedTensor
+
+
+class HeatEdge(nn.Module):
+ def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
+ super(HeatEdge, self).__init__()
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.num_feature_levels = num_feature_levels
+
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone_strides)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone_num_channels[_]
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )])
+
+ self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
+
+ self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim)
+ self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2)
+
+ self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
+ num_decoder_layers=6, dim_feedforward=1024, dropout=0.1)
+
+ @staticmethod
+ def get_ms_feat(xs, img_mask):
+ out: Dict[str, NestedTensor] = {}
+ for name, x in sorted(xs.items()):
+ m = img_mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+ def forward(self, image_feats, feat_mask, corner_outputs, edge_coords, edge_masks, gt_values, corner_nums,
+ max_candidates, do_inference=False):
+ # Prepare ConvNet features
+ features = self.get_ms_feat(image_feats, feat_mask)
+
+ srcs = []
+ masks = []
+ all_pos = []
+
+ new_features = list()
+ for name, x in sorted(features.items()):
+ new_features.append(x)
+ features = new_features
+
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ mask = mask.to(src.device)
+ srcs.append(self.input_proj[l](src))
+ pos = self.img_pos(src).to(src.dtype)
+ all_pos.append(pos)
+ masks.append(mask)
+ assert mask is not None
+
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = feat_mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
+ pos_l = self.img_pos(src).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ all_pos.append(pos_l)
+
+ bs = edge_masks.size(0)
+ num_edges = edge_masks.size(1)
+
+ corner_feats = corner_outputs
+ edge_feats = list()
+ for b_i in range(bs):
+ feats = corner_feats[b_i, edge_coords[b_i, :, :, 1], edge_coords[b_i, :, :, 0], :]
+ edge_feats.append(feats)
+ edge_feats = torch.stack(edge_feats, dim=0)
+ edge_feats = edge_feats.view(bs, num_edges, -1)
+
+ edge_inputs = self.edge_input_fc(edge_feats.view(bs * num_edges, -1))
+ edge_inputs = edge_inputs.view(bs, num_edges, -1)
+
+ edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2
+ edge_center = edge_center / feat_mask.shape[1]
+
+ logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values = self.transformer(srcs,
+ masks,
+ all_pos,
+ edge_inputs,
+ edge_center,
+ gt_values,
+ edge_masks,
+ corner_nums,
+ max_candidates,
+ do_inference)
+
+ return logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values
+
+
+class EdgeTransformer(nn.Module):
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+ num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
+ activation="relu", return_intermediate_dec=False,
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
+ ):
+ super(EdgeTransformer, self).__init__()
+
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, enc_n_points)
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
+
+ decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, dec_n_points)
+ # one-layer decoder, without self-attention layers
+ self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
+
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, dec_n_points)
+
+ # edge decoder w/ self-attention layers (image-aware decoder and geom-only decoder)
+ self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers,
+ return_intermediate_dec, with_sa=True)
+
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ self.gt_label_embed = nn.Embedding(3, d_model)
+
+ self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
+ self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
+
+ self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
+ self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
+ self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ normal_(self.level_embed)
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, labels, key_padding_mask, corner_nums,
+ max_candidates, do_inference=False):
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ src = src.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # encoder
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
+ mask_flatten)
+
+ # prepare input for decoder
+ bs, _, c = memory.shape
+
+ tgt = query_embed
+
+ # per-edge filtering with single-layer decoder (no self-attn)
+ hs_per_edge, _ = self.per_edge_decoder(tgt, reference_points, memory,
+ spatial_shapes, level_start_index, valid_ratios, query_embed,
+ mask_flatten)
+ logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1)
+ filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids = self.candidate_filtering(
+ logits_per_edge,
+ hs_per_edge, query_embed, reference_points,
+ labels,
+ key_padding_mask, corner_nums, max_candidates)
+
+ # generate the info for masked training
+ if not do_inference:
+ filtered_gt_values = self.generate_gt_masking(filtered_labels, filtered_mask)
+ else:
+ filtered_gt_values = filtered_labels
+ gt_info = self.gt_label_embed(filtered_gt_values)
+
+ # relational decoder with image feature (image-aware decoder)
+ hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1))
+
+ hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory,
+ spatial_shapes, level_start_index, valid_ratios, filtered_query,
+ mask_flatten,
+ key_padding_mask=filtered_mask, get_image_feat=True)
+
+ logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1)
+
+ # relational decoder without image feature (geom-only decoder)
+ rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1))
+
+ hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory,
+ spatial_shapes, level_start_index, valid_ratios, filtered_query,
+ mask_flatten,
+ key_padding_mask=filtered_mask, get_image_feat=False)
+
+ logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1)
+
+ return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values
+
+ @staticmethod
+ def candidate_filtering(logits, hs, query, rp, labels, key_padding_mask, corner_nums, max_candidates):
+ """
+ Filter out the easy-negatives from the edge candidates, and update the edge information correspondingly
+ """
+ B, L, _ = hs.shape
+ preds = logits.detach().softmax(1)[:, 1, :] # BxL
+ preds[key_padding_mask == True] = -1 # ignore the masking parts
+ sorted_ids = torch.argsort(preds, dim=-1, descending=True)
+ filtered_hs = list()
+ filtered_mask = list()
+ filtered_query = list()
+ filtered_rp = list()
+ filtered_labels = list()
+ selected_ids = list()
+ for b_i in range(B):
+ num_candidates = corner_nums[b_i] * 3
+ ids = sorted_ids[b_i, :max_candidates[b_i]]
+ filtered_hs.append(hs[b_i][ids])
+ new_mask = key_padding_mask[b_i][ids]
+ new_mask[num_candidates:] = True
+ filtered_mask.append(new_mask)
+ filtered_query.append(query[b_i][ids])
+ filtered_rp.append(rp[b_i][ids])
+ filtered_labels.append(labels[b_i][ids])
+ selected_ids.append(ids)
+ filtered_hs = torch.stack(filtered_hs, dim=0)
+ filtered_mask = torch.stack(filtered_mask, dim=0)
+ filtered_query = torch.stack(filtered_query, dim=0)
+ filtered_rp = torch.stack(filtered_rp, dim=0)
+ filtered_labels = torch.stack(filtered_labels, dim=0)
+ selected_ids = torch.stack(selected_ids, dim=0)
+
+ return filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids
+
+ @staticmethod
+ def generate_gt_masking(labels, mask):
+ """
+ Generate the info for masked training on-the-fly with ratio=0.5
+ """
+ bs = labels.shape[0]
+ gt_values = torch.zeros_like(mask).long()
+ for b_i in range(bs):
+ edge_length = (mask[b_i] == 0).sum()
+ rand_ratio = np.random.rand() * 0.5 + 0.5
+ gt_rand = torch.rand(edge_length)
+ gt_flag = torch.zeros(edge_length)
+ gt_flag[torch.where(gt_rand >= rand_ratio)] = 1
+ gt_idx = torch.where(gt_flag == 1)
+ pred_idx = torch.where(gt_flag == 0)
+ gt_values[b_i, gt_idx[0]] = labels[b_i, gt_idx[0]]
+ gt_values[b_i, pred_idx[0]] = 2 # use 2 to represent unknown value, need to predict
+ return gt_values
diff --git a/models/loss.py b/models/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce9ca81cc05af6b1e23bffb76e2cf450d9b0380
--- /dev/null
+++ b/models/loss.py
@@ -0,0 +1,63 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from utils.geometry_utils import edge_acc
+
+
+class CornerCriterion(nn.Module):
+ def __init__(self, image_size):
+ super().__init__()
+ self.loss_rate = 9
+
+ def forward(self, outputs_s1, targets, gauss_targets, epoch=0):
+ # Compute the acc first, use the acc to guide the setup of loss weight
+ preds_s1 = (outputs_s1 >= 0.5).float()
+ pos_target_ids = torch.where(targets == 1)
+ correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum()
+ recall_s1 = correct / len(pos_target_ids[0])
+
+ rate = self.loss_rate
+
+ loss_weight = (gauss_targets > 0.5).float() * rate + 1
+ loss_s1 = F.binary_cross_entropy(outputs_s1, gauss_targets, weight=loss_weight, reduction='none')
+ loss_s1 = loss_s1.sum(-1).sum(-1).mean()
+
+ return loss_s1, recall_s1
+
+
+class EdgeCriterion(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.edge_loss = nn.CrossEntropyLoss(weight=torch.tensor([0.33, 1.0]).cuda(), reduction='none')
+
+ def forward(self, logits_s1, logits_s2_hybrid, logits_s2_rel, s2_ids, s2_edge_mask, edge_labels, edge_lengths,
+ edge_mask, s2_gt_values):
+ # loss for edge filtering
+ s1_losses = self.edge_loss(logits_s1, edge_labels)
+ s1_losses[torch.where(edge_mask == True)] = 0
+ s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0]
+ gt_values = torch.ones_like(edge_mask).long() * 2
+ s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values)
+
+ # loss for stage-2
+ s2_labels = torch.gather(edge_labels, 1, s2_ids)
+
+ # the image-aware decoder
+ s2_losses_hybrid = self.edge_loss(logits_s2_hybrid, s2_labels)
+ s2_losses_hybrid[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
+ # aggregate the loss into the final scalar
+ s2_losses_hybrid = s2_losses_hybrid[torch.where(s2_losses_hybrid > 0)].sum() / s2_edge_mask.shape[0]
+ s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
+ # compute edge-level acc
+ s2_acc_hybrid = edge_acc(logits_s2_hybrid, s2_labels, s2_edge_lengths, s2_gt_values)
+
+ # the geom-only decoder
+ s2_losses_rel = self.edge_loss(logits_s2_rel, s2_labels)
+ s2_losses_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
+ # aggregate the loss into the final scalar
+ s2_losses_rel = s2_losses_rel[torch.where(s2_losses_rel > 0)].sum() / s2_edge_mask.shape[0]
+ s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
+ # compute edge-level f1-score
+ s2_acc_rel = edge_acc(logits_s2_rel, s2_labels, s2_edge_lengths, s2_gt_values)
+
+ return s1_losses, s1_acc, s2_losses_hybrid, s2_acc_hybrid, s2_losses_rel, s2_acc_rel
diff --git a/models/mlp.py b/models/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c38680662de93b2f67ab51aafaf367a20c329ad6
--- /dev/null
+++ b/models/mlp.py
@@ -0,0 +1,21 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super(MLP, self).__init__()
+ self.output_dim = output_dim
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ B, N, D = x.size()
+ x = x.reshape(B*N, D)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ x = x.view(B, N, self.output_dim)
+ return x
diff --git a/models/ops/functions/__init__.py b/models/ops/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2197bda3199aa32cafc5b9d396479609853dd2
--- /dev/null
+++ b/models/ops/functions/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn_func import MSDeformAttnFunction
+
diff --git a/models/ops/functions/ms_deform_attn_func.py b/models/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c5df8cf5d23aca963eec6c1133c180b37289607
--- /dev/null
+++ b/models/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,61 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+import MultiScaleDeformableAttention as MSDA
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
+ ctx.im2col_step = im2col_step
+ output = MSDA.ms_deform_attn_forward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = \
+ MSDA.ms_deform_attn_backward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
+ return output.transpose(1, 2).contiguous()
diff --git a/models/ops/make.sh b/models/ops/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..65bb7b9358fc37e0e7521621a0f8803b93c347d1
--- /dev/null
+++ b/models/ops/make.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+python3 setup.py build install --user
diff --git a/models/ops/modules/__init__.py b/models/ops/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f82cb1ad9d634a87b54ba6a71b58a230bcade5fe
--- /dev/null
+++ b/models/ops/modules/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/models/ops/modules/ms_deform_attn.py b/models/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..663d64a3d2c33f56474b1f01ce7b1162f4966986
--- /dev/null
+++ b/models/ops/modules/ms_deform_attn.py
@@ -0,0 +1,115 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+
+from ..functions import MSDeformAttnFunction
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n-1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ "which is more efficient in our CUDA implementation.")
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.)
+ constant_(self.attention_weights.bias.data, 0.)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.)
+
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+ # N, Len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ else:
+ raise ValueError(
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
+ output = MSDeformAttnFunction.apply(
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
+ output = self.output_proj(output)
+ return output
diff --git a/models/ops/setup.py b/models/ops/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0131bc21cf1b45b90fcf174e2c53e4c08e9c641
--- /dev/null
+++ b/models/ops/setup.py
@@ -0,0 +1,71 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ raise NotImplementedError('Cuda is not availabel')
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "MultiScaleDeformableAttention",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+setup(
+ name="MultiScaleDeformableAttention",
+ version="1.0",
+ author="Weijie Su",
+ url="https://github.com/fundamentalvision/Deformable-DETR",
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+ packages=find_packages(exclude=("configs", "tests",)),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/models/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e1bf854de1f3860d20b6fef5c1a17817c268e70a
--- /dev/null
+++ b/models/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,41 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
diff --git a/models/ops/src/cpu/ms_deform_attn_cpu.h b/models/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..81b7b58a3d9502bbb684dc84687a526dedf94cae
--- /dev/null
+++ b/models/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,33 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+
diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.cu b/models/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d6d583647cce987196d5ad1968a8a365a379e774
--- /dev/null
+++ b/models/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,153 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
\ No newline at end of file
diff --git a/models/ops/src/cuda/ms_deform_attn_cuda.h b/models/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..c7ae53f99c820ce6193b608ad344550348a0b42c
--- /dev/null
+++ b/models/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,30 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..6bc2acb7aea0eab2e9e91e769a16861e1652c284
--- /dev/null
+++ b/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/models/ops/src/ms_deform_attn.h b/models/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..ac0ef2ec25f7d0ee51ca2d807b159ddf85652017
--- /dev/null
+++ b/models/ops/src/ms_deform_attn.h
@@ -0,0 +1,62 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/models/ops/src/vision.cpp b/models/ops/src/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2201f63a51dca16d0b31148ed2c9e8e47ec15bdc
--- /dev/null
+++ b/models/ops/src/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/models/ops/test.py b/models/ops/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbf6d5547d131f01a8c5c28b76557bd27a9334b
--- /dev/null
+++ b/models/ops/test.py
@@ -0,0 +1,89 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
+level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H*W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ func = MSDeformAttnFunction.apply
+
+ value.requires_grad = grad_value
+ sampling_locations.requires_grad = grad_sampling_loc
+ attention_weights.requires_grad = grad_attn_weight
+
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
+
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+ check_forward_equal_with_pytorch_double()
+ check_forward_equal_with_pytorch_float()
+
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
+ check_gradient_numerical(channels, True, True, True)
+
+
+
diff --git a/models/resnet.py b/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a9cdeda451e27e56d40b59aee7ad603434e5a16
--- /dev/null
+++ b/models/resnet.py
@@ -0,0 +1,167 @@
+import torch
+import torch.nn as nn
+from torchvision import models
+
+
+def convrelu(in_channels, out_channels, kernel, padding):
+ return nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
+ nn.ReLU(inplace=True),
+ )
+
+
+class ResNetBackbone(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.base_model = models.resnet50(pretrained=False)
+ self.base_layers = list(self.base_model.children())
+
+ self.conv_original_size0 = convrelu(3, 64, 3, 1)
+ self.conv_original_size1 = convrelu(64, 64, 3, 1)
+ self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
+ self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
+ self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
+ self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
+ self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
+
+ self.strides = [8, 16, 32]
+ self.num_channels = [512, 1024, 2048]
+
+ def forward(self, inputs):
+ x_original = self.conv_original_size0(inputs)
+ x_original = self.conv_original_size1(x_original)
+ layer0 = self.layer0(inputs)
+ layer1 = self.layer1(layer0)
+ layer2 = self.layer2(layer1)
+ layer3 = self.layer3(layer2)
+ layer4 = self.layer4(layer3)
+
+ xs = {"0": layer2, "1": layer3, "2": layer4}
+ all_feats = {'layer0': layer0, 'layer1': layer1, 'layer2': layer2,
+ 'layer3': layer3, 'layer4': layer4, 'x_original': x_original}
+
+ mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
+ return xs, mask, all_feats
+
+ def train(self, mode=True):
+ # Override train so that the training mode is set as we want
+ nn.Module.train(self, mode)
+ if mode:
+ # fix all bn layers
+ def set_bn_eval(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm') != -1:
+ m.eval()
+
+ self.apply(set_bn_eval)
+
+
+class ResNetUNet(nn.Module):
+ def __init__(self, n_class, out_dim=None, ms_feat=False):
+ super().__init__()
+
+ self.return_ms_feat = ms_feat
+ self.out_dim = out_dim
+
+ self.base_model = models.resnet50(pretrained=True)
+ self.base_layers = list(self.base_model.children())
+
+ self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
+ # self.layer0_1x1 = convrelu(64, 64, 1, 0)
+ self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
+ # self.layer1_1x1 = convrelu(256, 256, 1, 0)
+ self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
+ # self.layer2_1x1 = convrelu(512, 512, 1, 0)
+ self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
+ # self.layer3_1x1 = convrelu(1024, 1024, 1, 0)
+ self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
+ # self.layer4_1x1 = convrelu(2048, 2048, 1, 0)
+
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+
+ self.conv_up3 = convrelu(1024 + 2048, 1024, 3, 1)
+ self.conv_up2 = convrelu(512 + 1024, 512, 3, 1)
+ self.conv_up1 = convrelu(256 + 512, 256, 3, 1)
+ self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
+ # self.conv_up1 = convrelu(512, 256, 3, 1)
+ # self.conv_up0 = convrelu(256, 128, 3, 1)
+
+ self.conv_original_size0 = convrelu(3, 64, 3, 1)
+ self.conv_original_size1 = convrelu(64, 64, 3, 1)
+ self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
+ # self.conv_last = nn.Conv2d(128, n_class, 1)
+ self.conv_last = nn.Conv2d(64, n_class, 1)
+ if out_dim:
+ self.conv_out = nn.Conv2d(64, out_dim, 1)
+ # self.conv_out = nn.Conv2d(128, out_dim, 1)
+
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+ self.strides = [8, 16, 32]
+ self.num_channels = [512, 1024, 2048]
+
+ def forward(self, inputs):
+ x_original = self.conv_original_size0(inputs)
+ x_original = self.conv_original_size1(x_original)
+
+ layer0 = self.layer0(inputs)
+ layer1 = self.layer1(layer0)
+ layer2 = self.layer2(layer1)
+ layer3 = self.layer3(layer2)
+ layer4 = self.layer4(layer3)
+
+ # layer4 = self.layer4_1x1(layer4)
+ x = self.upsample(layer4)
+ # layer3 = self.layer3_1x1(layer3)
+ x = torch.cat([x, layer3], dim=1)
+ x = self.conv_up3(x)
+ layer3_up = x
+
+ x = self.upsample(x)
+ # layer2 = self.layer2_1x1(layer2)
+ x = torch.cat([x, layer2], dim=1)
+ x = self.conv_up2(x)
+ layer2_up = x
+
+ x = self.upsample(x)
+ # layer1 = self.layer1_1x1(layer1)
+ x = torch.cat([x, layer1], dim=1)
+ x = self.conv_up1(x)
+
+ x = self.upsample(x)
+ # layer0 = self.layer0_1x1(layer0)
+ x = torch.cat([x, layer0], dim=1)
+ x = self.conv_up0(x)
+
+ x = self.upsample(x)
+ x = torch.cat([x, x_original], dim=1)
+ x = self.conv_original_size2(x)
+
+ out = self.conv_last(x)
+ out = out.sigmoid().squeeze(1)
+
+ # xs = {"0": layer2, "1": layer3, "2": layer4}
+ xs = {"0": layer2_up, "1": layer3_up, "2": layer4}
+ mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
+ # ms_feats = self.ms_feat(xs, mask)
+
+ if self.return_ms_feat:
+ if self.out_dim:
+ out_feat = self.conv_out(x)
+ out_feat = out_feat.permute(0, 2, 3, 1)
+ return xs, mask, out, out_feat
+ else:
+ return xs, mask, out
+ else:
+ return out
+
+ def train(self, mode=True):
+ # Override train so that the training mode is set as we want
+ nn.Module.train(self, mode)
+ if mode:
+ # fix all bn layers
+ def set_bn_eval(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm') != -1:
+ m.eval()
+
+ self.apply(set_bn_eval)
diff --git a/models/stacked_hg.py b/models/stacked_hg.py
new file mode 100644
index 0000000000000000000000000000000000000000..33b397642bed84447e96e424e9296ed485b3124d
--- /dev/null
+++ b/models/stacked_hg.py
@@ -0,0 +1,246 @@
+"""
+Hourglass network inserted in the pre-activated Resnet
+Use lr=0.01 for current version
+(c) Nan Xue (HAWP)
+(c) Yichao Zhou (LCNN)
+(c) YANG, Wei
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["HourglassNet", "hg"]
+
+
+class Bottleneck2D(nn.Module):
+ expansion = 2
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck2D, self).__init__()
+
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
+ self.bn3 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.bn1(x)
+ out = self.relu(out)
+ out = self.conv1(out)
+
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ out = self.bn3(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+
+class Hourglass(nn.Module):
+ def __init__(self, block, num_blocks, planes, depth):
+ super(Hourglass, self).__init__()
+ self.depth = depth
+ self.block = block
+ self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
+
+ def _make_residual(self, block, num_blocks, planes):
+ layers = []
+ for i in range(0, num_blocks):
+ layers.append(block(planes * block.expansion, planes))
+ return nn.Sequential(*layers)
+
+ def _make_hour_glass(self, block, num_blocks, planes, depth):
+ hg = []
+ for i in range(depth):
+ res = []
+ for j in range(3):
+ res.append(self._make_residual(block, num_blocks, planes))
+ if i == 0:
+ res.append(self._make_residual(block, num_blocks, planes))
+ hg.append(nn.ModuleList(res))
+ return nn.ModuleList(hg)
+
+ def _hour_glass_forward(self, n, x):
+ up1 = self.hg[n - 1][0](x)
+ low1 = F.max_pool2d(x, 2, stride=2)
+ low1 = self.hg[n - 1][1](low1)
+
+ if n > 1:
+ low2 = self._hour_glass_forward(n - 1, low1)
+ else:
+ low2 = self.hg[n - 1][3](low1)
+ low3 = self.hg[n - 1][2](low2)
+ up2 = F.interpolate(low3, scale_factor=2)
+ out = up1 + up2
+ return out
+
+ def forward(self, x):
+ return self._hour_glass_forward(self.depth, x)
+
+
+class HourglassNet(nn.Module):
+ """Hourglass model from Newell et al ECCV 2016"""
+
+ def __init__(self, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes):
+ super(HourglassNet, self).__init__()
+
+ self.inplanes = inplanes
+ self.num_feats = num_feats
+ self.num_stacks = num_stacks
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.layer1 = self._make_residual(block, self.inplanes, 1)
+ self.layer2 = self._make_residual(block, self.inplanes, 1)
+ self.layer3 = self._make_residual(block, self.num_feats, 1)
+ self.maxpool = nn.MaxPool2d(2, stride=2)
+
+ # build hourglass modules
+ ch = self.num_feats * block.expansion
+ # vpts = []
+ hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
+ for i in range(num_stacks):
+ hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
+ res.append(self._make_residual(block, self.num_feats, num_blocks))
+ fc.append(self._make_fc(ch, ch))
+ score.append(head(ch, num_classes))
+ # vpts.append(VptsHead(ch))
+ # vpts.append(nn.Linear(ch, 9))
+ # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
+ # score[i].bias.data[0] += 4.6
+ # score[i].bias.data[2] += 4.6
+ if i < num_stacks - 1:
+ fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
+ score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
+ self.hg = nn.ModuleList(hg)
+ self.res = nn.ModuleList(res)
+ self.fc = nn.ModuleList(fc)
+ self.score = nn.ModuleList(score)
+ # self.vpts = nn.ModuleList(vpts)
+ self.fc_ = nn.ModuleList(fc_)
+ self.score_ = nn.ModuleList(score_)
+
+ def _make_residual(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ )
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_fc(self, inplanes, outplanes):
+ bn = nn.BatchNorm2d(inplanes)
+ conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
+ return nn.Sequential(conv, bn, self.relu)
+
+ def forward(self, x):
+ out = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.maxpool(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ for i in range(self.num_stacks):
+ y = self.hg[i](x)
+ y = self.res[i](y)
+ y = self.fc[i](y)
+ score = self.score[i](y)
+ out.append(score)
+
+ if i < self.num_stacks - 1:
+ fc_ = self.fc_[i](y)
+ score_ = self.score_[i](score)
+ x = x + fc_ + score_
+
+ return out[::-1], y
+
+ def train(self, mode=True):
+ # Override train so that the training mode is set as we want
+ nn.Module.train(self, mode)
+ if mode:
+ # fix all bn layers
+ def set_bn_eval(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm') != -1:
+ m.eval()
+
+ self.apply(set_bn_eval)
+
+
+class MultitaskHead(nn.Module):
+ def __init__(self, input_channels, num_class, head_size):
+ super(MultitaskHead, self).__init__()
+
+ m = int(input_channels / 4)
+ heads = []
+ for output_channels in sum(head_size, []):
+ heads.append(
+ nn.Sequential(
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(m, output_channels, kernel_size=1),
+ )
+ )
+ self.heads = nn.ModuleList(heads)
+ assert num_class == sum(sum(head_size, []))
+
+ def forward(self, x):
+ return torch.cat([head(x) for head in self.heads], dim=1)
+
+
+def build_hg():
+ inplanes = 64
+ num_feats = 256 //2
+ depth = 4
+ num_stacks = 2
+ num_blocks = 1
+ head_size = [[2], [2]]
+
+ out_feature_channels = 256
+
+ num_class = sum(sum(head_size, []))
+ model = HourglassNet(
+ block=Bottleneck2D,
+ inplanes = inplanes,
+ num_feats= num_feats,
+ depth=depth,
+ head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
+ num_stacks = num_stacks,
+ num_blocks = num_blocks,
+ num_classes = num_class)
+
+ model.out_feature_channels = out_feature_channels
+
+ return model
+
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..093dbd5d91ae570192939c19b8d2b9dcb078884d
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,33 @@
+'''
+Author: [egrt]
+Date: 2022-08-23 13:21:27
+LastEditors: [egrt]
+LastEditTime: 2022-08-23 13:45:21
+Description:
+'''
+#--------------------------------------------------------------#
+# 对单张图片进行预测,运行结果保存在根目录
+# 默认保存文件为results/predict_out/predict_srgan.png
+#--------------------------------------------------------------#
+from PIL import Image
+
+from HEAT import HEAT
+
+if __name__ == "__main__":
+ heat = HEAT()
+ #----------------------------#
+ # 单张图片的保存路径
+ #----------------------------#
+ save_path = "assets/test_out.jpg"
+
+ while True:
+ img = input('Input image filename:')
+ try:
+ image = Image.open(img)
+ except:
+ print('Open Error! Try again!')
+ continue
+ else:
+ r_image = heat.detect_one_image(image)
+ r_image.save(save_path)
+ r_image.show()
\ No newline at end of file
diff --git a/qualitative_outdoor/generate_html.py b/qualitative_outdoor/generate_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c37e7615a73c36448bcfcf90fc9819035ffd74
--- /dev/null
+++ b/qualitative_outdoor/generate_html.py
@@ -0,0 +1,64 @@
+import os
+import os.path as osp
+import numpy as np
+
+
+head = '''
+
+
+
+
+
+
+
+
+
`
+
+'''
+
+def writeHTML(out_path, results_dirs):
+ f = open(out_path, 'w')
+ f.write(head + '\n')
+ f.write(''
+ ' ID | '
+ ' Input | '
+ ' ConvMPN | '
+ ' Exp-cls | '
+ ' HAWP | '
+ ' LETR | '
+ ' HEAT (Ours) | '
+ ' G.T. | '
+ '
')
+
+ fileids_path = '../data/cities_dataset/valid_list.txt'
+ img_base = '../data/cities_dataset/rgb'
+ with open(fileids_path) as ff:
+ file_ids = ff.readlines()
+ file_ids = file_ids[50:]
+ file_ids = [file_id.strip() for file_id in file_ids]
+ permuted_ids = np.random.permutation(file_ids)
+ file_ids = permuted_ids[:100]
+
+ for file_id in file_ids:
+ row_str = ''
+ row_str += ' {} | '.format(file_id)
+ row_str += ' | '.format(os.path.join(img_base, file_id + '.jpg'))
+ for dir_idx, result_dir in enumerate(results_dirs):
+ pred_filepath = osp.join(result_dir, '{}.png'.format(file_id))
+ row_str += ' | '.format(pred_filepath)
+ row_str += '
'
+ f.write(row_str + '\n')
+
+ f.write(end + '\n')
+
+
+if __name__ == '__main__':
+ results_dirs = ['svg_images_256/convmpn', 'svg_images_256/exp_cls', 'svg_images_256/hawp', 'svg_images_256/letr', 'svg_images_256/heat', 'svg_images_256/gt']
+
+ writeHTML(out_path='./outdoor_qual.html', results_dirs=results_dirs)
diff --git a/qualitative_outdoor/plot_utils.py b/qualitative_outdoor/plot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4addb7ae021d2e86ba1e925f1bde3dc22ede1fc
--- /dev/null
+++ b/qualitative_outdoor/plot_utils.py
@@ -0,0 +1,43 @@
+import cv2
+import svgwrite
+import colorsys
+
+
+def plot_preds(image, corners, edges):
+ for line in edges:
+ cv2.line(image, tuple(line[:2]), tuple(line[2:]), (255, 255, 0), 2)
+ for c in corners:
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
+ return image
+
+
+def random_colors(N, bright=True, same=False, colors=None):
+ brightness = 1.0 if bright else 0.7
+ if colors is None or same:
+ if same:
+ hsv = [(0, 1, brightness) for i in range(N)]
+ else:
+ hsv = [(i / N, 1, brightness) for i in range(N)]
+ else:
+ hsv = [(colors[i], 1, brightness) for i in range(N)]
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
+ return colors
+
+
+def svg_generate(image_link, corners, edges, name, size=512):
+ dwg = svgwrite.Drawing(name + '.svg', size=('{}'.format(size), '{}'.format(size)))
+ shapes = dwg.add(dwg.g(id='shape', fill='black'))
+ # colors = random_colors(len(edges), same=True)
+ shapes.add(dwg.image(href=image_link, size=(size, size)))
+
+ scale = size / 256
+ for i, edge in enumerate(edges):
+ x = edge[:2] * scale
+ y = edge[2:] * scale
+ shapes.add(dwg.line((int(x[0]), int(x[1])), (int(y[0]), int(y[1])),
+ stroke="#EE6507", stroke_width=3*scale, opacity=0.7))
+
+ for i, corner in enumerate(corners):
+ shapes.add(dwg.circle((int(corners[i][0] * scale), int(corners[i][1]) * scale), r=4*scale,
+ stroke='green', fill='white', stroke_width=2*scale, opacity=0.8))
+ return dwg
diff --git a/qualitative_outdoor/visualize_gt.py b/qualitative_outdoor/visualize_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5db87e5bf11dd5310ffe05dd8f040089d12c3a
--- /dev/null
+++ b/qualitative_outdoor/visualize_gt.py
@@ -0,0 +1,46 @@
+import os
+import json
+import cv2
+import numpy as np
+from plot_utils import plot_preds, svg_generate
+import cairosvg
+
+image_base = '../data/cities_dataset/rgb/'
+annot_base = '../data/cities_dataset/annot/'
+data_filename = '../data/cities_dataset/valid_list.txt'
+with open(data_filename) as f:
+ filenames = f.readlines()
+
+filenames = filenames[50:]
+filenames = [filename.strip() for filename in filenames]
+
+
+for filename in filenames:
+ image_path = os.path.join(image_base, filename + '.jpg')
+ # image = cv2.imread(image_path)
+ annot_path = os.path.join(annot_base, filename + '.npy')
+
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
+ corners = np.array(list(annot.keys())).astype(np.int)
+
+ edges = set()
+ for c, others in annot.items():
+ for other_c in others:
+ edge = (c[0], c[1], other_c[0], other_c[1])
+ edge_2 = (other_c[0], other_c[1], c[0], c[1])
+ if edge not in edges and edge_2 not in edges:
+ edges.add(edge)
+
+ edges = np.array(list(edges)).astype(np.int)
+
+ # image = plot_preds(image, corners, edges)
+ # out_path = os.path.join(out_base, filename + '.png')
+ # cv2.imwrite(out_path, image)
+
+ svg = svg_generate(image_path, corners, edges, name='temp', size=256)
+ svg_path = './svg_results/' + 'tmp.svg'
+ svg.saveas(svg_path)
+ svg_img_path = './svg_images_256/gt/' + '{}.png'.format(filename)
+ cairosvg.svg2png(url=svg_path, write_to=svg_img_path)
+
+
diff --git a/qualitative_outdoor/visualize_npy.py b/qualitative_outdoor/visualize_npy.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0b431988a9869d5a0d1d412b6f493f250987235
--- /dev/null
+++ b/qualitative_outdoor/visualize_npy.py
@@ -0,0 +1,46 @@
+import os
+import json
+import cv2
+import numpy as np
+import cairosvg
+from plot_utils import plot_preds, svg_generate
+
+image_base = '../../data/outdoor/cities_dataset/rgb/'
+svg_base = './svg_results'
+
+if not os.path.exists(svg_base):
+ os.makedirs(svg_base)
+
+data_filename = '../data/outdoor/cities_dataset/valid_list.txt'
+with open(data_filename) as f:
+ filenames = f.readlines()
+
+filenames = filenames[50:] # according to previous works, the testing samples are the last 350 samples of the val split
+filenames = [filename.strip() for filename in filenames]
+idx_to_filename = {idx: filename for idx, filename in enumerate(filenames)}
+
+method_name = 'heat'
+results_base = '../results/npy_outdoor_test_256/'
+
+svg_method_base = os.path.join(svg_base, method_name)
+if not os.path.exists(svg_method_base):
+ os.makedirs(svg_method_base)
+
+for result_filename in sorted(os.listdir(results_base)):
+ file_idx = int(result_filename[:-12])
+ filename = idx_to_filename[file_idx]
+
+ image_path = os.path.join(image_base, filename + '.jpg')
+
+ results_path = os.path.join(results_base, result_filename)
+ results = np.load(results_path, allow_pickle=True).tolist()
+ corners = results['corners'].astype(np.int)
+ edge_ids = results['edges']
+ edges = corners[edge_ids].reshape(edge_ids.shape[0], -1)
+
+ svg = svg_generate(image_path, corners, edges, name='temp', size=256)
+ svg_path = os.path.join(svg_base, 'tmp.svg')
+ svg.saveas(svg_path) # save the svg file temporarily
+
+ svg_img_path = os.path.join(svg_method_base, '{}.png'.format(filename))
+ cairosvg.svg2png(url=svg_path, write_to=svg_img_path)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6c4ca193d81bfabf2f027c201c3ce13ab5c35ccd
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,27 @@
+Cython==0.29.22
+defusedxml==0.6.0
+einops==0.4.1
+future==0.18.2
+imageio==2.16.1
+matplotlib==3.3.4
+MultiScaleDeformableAttention==1.0
+numpy==1.20.1
+opencv-python==4.4.0.44
+packaging==20.9
+Pillow==9.0.1
+prometheus-client==0.9.0
+prompt-toolkit==3.0.16
+ptyprocess==0.7.0
+pycparser==2.20
+Pygments==2.8.0
+python-dateutil==2.8.1
+scikit-image==0.19.2
+scikit-learn==1.0
+scipy==1.6.1
+six==1.15.0
+torch==1.5.1
+torchvision==0.6.1
+cairosvg==2.5.2
+svgwrite==1.4.2
+shapely==1.8.2
+gradio==2.5.3
\ No newline at end of file
diff --git a/s3d_floorplan_eval/DataRW/DataRW.py b/s3d_floorplan_eval/DataRW/DataRW.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91b444851712ae903ee8d05ccd6a7658f0f6915
--- /dev/null
+++ b/s3d_floorplan_eval/DataRW/DataRW.py
@@ -0,0 +1,4 @@
+
+class DataRW:
+ def __init__(self, options):
+ pass
\ No newline at end of file
diff --git a/s3d_floorplan_eval/DataRW/S3DRW.py b/s3d_floorplan_eval/DataRW/S3DRW.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e88dbfc29317c4a6348287ce5ad451fa062b0d6
--- /dev/null
+++ b/s3d_floorplan_eval/DataRW/S3DRW.py
@@ -0,0 +1,142 @@
+import numpy as np
+import cv2
+import torch
+import os
+import time
+
+from DataRW.DataRW import DataRW
+from S3DLoader.S3DLoader import S3DLoader
+
+class S3DRW(DataRW):
+ def __init__(self, options):
+ """
+ Class for accessing FloorNet dataset related data
+
+ :param options:
+ """
+ # initialize the base class variables
+ super(DataRW, self).__init__()
+
+ self.options = options
+
+ self.dataset_path = options.dataset_path
+ self.scene_id = options.scene_id
+
+ self.mcts_path = options.mcts_path
+ self.creation_time = int(time.time())
+
+ self.device = torch.device("cpu")
+
+ # mode = "train"
+ # mode = "online_eval"
+ mode = "test"
+ # For validation only
+ # self.loader = S3DLoader(options, 'online_eval').dataset
+ self.loader = S3DLoader(options, mode).dataset
+
+ # gt_sample = iter(floornet_loader.dataset[int(self.scene_id)])
+ # self.gt_sample = floornet_loader.load_sample(list(iter(floornet_loader.dataset))[int(self.scene_id)])
+
+ if mode == "online_eval":
+ scene_ind = int(self.scene_id[6:]) - 3000
+ elif mode == "test":
+ scene_ind = int(self.scene_id[6:]) - 3250
+ elif mode == "train":
+ scene_ind = int(self.scene_id[6:])
+ else:
+ assert False
+
+ # print(len(list(iter(self.s3d_loader.data))))
+ self.gt_sample = gt_sample = self.loader[scene_ind]
+ self.gt_sample["density_map"] = torch.tensor(self.gt_sample["density_map"][None], device=self.device)
+ self.gt_sample["room_map"] = torch.tensor(self.gt_sample["room_map"][None,:,:,None], device=self.device)
+ self.gt_sample["wall_map"] = torch.tensor(self.gt_sample["wall_map"][None,:,:,None], device=self.device)
+
+
+ self.density_map = self.gt_sample['density_map'][:,:,:,None]
+
+ self.h, self.w = self.density_map.shape[1], self.density_map.shape[2]
+
+ self.generate_input_map_from_props = self.generate_input_dict_from_room_props
+
+ def get_gt_solution(self):
+ """
+ Read top-view density map of the scene
+
+ :return:
+ """
+ img_path = os.path.join(self.dataset_path, str(self.scene_id) + "_density.png")
+ density_map = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)[:,:, 0][None,:,:,None]
+
+ density_map = torch.from_numpy(density_map).to(self.device)
+
+ dm_min = torch.min(density_map)
+ dm_max = torch.max(density_map)
+
+ density_map = (density_map - dm_min) / (dm_max - dm_min)
+
+ return density_map.type(torch.cuda.FloatTensor)
+
+ def polygonize_mask(self, pm, return_mask=True):
+ pm_np = pm.cpu().numpy()
+
+ room_mask = 255 * (pm_np == 1)
+ room_mask = room_mask.astype(np.uint8)
+ room_mask_inv = 255 - room_mask
+
+ ret, thresh = cv2.threshold(room_mask_inv, 250, 255, cv2.THRESH_BINARY_INV)
+
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+
+ cnt = contours[0]
+ max_area = cv2.contourArea(cnt)
+
+ for cont in contours:
+ if cv2.contourArea(cont) > max_area:
+ cnt = cont
+ max_area = cv2.contourArea(cont)
+
+ # define main island contour approx. and hull
+ perimeter = cv2.arcLength(cnt, True)
+ epsilon = 0.01 * cv2.arcLength(cnt, True)
+ approx = cv2.approxPolyDP(cnt, epsilon, True)
+
+ # approx = np.concatenate([approx, approx[0][None]], axis=0)
+ approx = approx.astype(np.int32).reshape((1, -1, 2))
+
+ if return_mask:
+ room_filled_map = np.zeros((self.h, self.w))
+ cv2.fillPoly(room_filled_map, approx, color=1.)
+
+ room_filled_map = torch.tensor(room_filled_map[:,:], dtype=torch.float32, device=self.device)
+
+ return room_filled_map
+ else:
+ approx_tensor = torch.tensor(approx, device=self.device)
+ return approx_tensor
+
+ def generate_input_dict_from_room_props(self, room_prop_list, score_function, use_thresh=False):
+ """
+
+ :param room_prop_list:
+ :type room_prop_list: list of FloorPlanRoomProp
+ :param score_function:
+ :return:
+ """
+
+ if score_function == "room_maskrcnn_iou":
+ inputs = self.generate_input_dict_for_room_maskrcnn_iou(room_prop_list)
+ elif score_function == "room_iou":
+ inputs = self.generate_input_dict_for_room_iou(room_prop_list, use_thresh=use_thresh)
+ else:
+ assert "generate_input_dict_from_room_props for %s not implemented" % score_function
+
+ return inputs
+
+
+
+
+
+
+
+
diff --git a/s3d_floorplan_eval/DataRW/wrong_annotatios.py b/s3d_floorplan_eval/DataRW/wrong_annotatios.py
new file mode 100644
index 0000000000000000000000000000000000000000..f089c63b9e759160f1d22e38829b671ab54aa8a5
--- /dev/null
+++ b/s3d_floorplan_eval/DataRW/wrong_annotatios.py
@@ -0,0 +1 @@
+wrong_s3d_annotations_list = [3261, 3271, 3276, 3296, 3342, 3387, 3398, 3466, 3496]
\ No newline at end of file
diff --git a/s3d_floorplan_eval/Evaluator/Evaluator.py b/s3d_floorplan_eval/Evaluator/Evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11e0f68671ddc751185755871eb918f314fd720
--- /dev/null
+++ b/s3d_floorplan_eval/Evaluator/Evaluator.py
@@ -0,0 +1,457 @@
+import os
+import torch
+import matplotlib.pyplot as plt
+import cv2
+import numpy as np
+from scipy.spatial import Delaunay
+import os
+import shapely
+from shapely.geometry import Polygon, MultiPolygon, LineString, MultiLineString
+
+corner_metric_thresh = 10
+angle_metric_thresh = 5
+
+
+
+# colormap_255 = [[i, i, i] for i in range(40)]
+
+class Evaluator():
+ def __init__(self, data_rw, options):
+ self.data_rw = data_rw
+ self.options = options
+
+ self.device = torch.device("cuda")
+
+ def polygonize_mask(self, mask, degree, return_mask=True):
+ h, w = mask.shape[0], mask.shape[1]
+ mask = mask
+
+ room_mask = 255 * (mask == 1)
+ room_mask = room_mask.astype(np.uint8)
+ room_mask_inv = 255 - room_mask
+
+ ret, thresh = cv2.threshold(room_mask_inv, 250, 255, cv2.THRESH_BINARY_INV)
+
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+
+ cnt = contours[0]
+ max_area = cv2.contourArea(cnt)
+
+ for cont in contours:
+ if cv2.contourArea(cont) > max_area:
+ cnt = cont
+ max_area = cv2.contourArea(cont)
+
+ perimeter = cv2.arcLength(cnt, True)
+ # epsilon = 0.01 * cv2.arcLength(cnt, True)
+ epsilon = degree * cv2.arcLength(cnt, True)
+ approx = cv2.approxPolyDP(cnt, epsilon, True)
+
+ # approx = np.concatenate([approx, approx[0][None]], axis=0)
+ approx = approx.astype(np.int32).reshape((-1, 2))
+
+ # approx_tensor = torch.tensor(approx, device=self.device)
+
+ # return approx_tensor
+ if return_mask:
+ room_filled_map = np.zeros((h, w))
+ cv2.fillPoly(room_filled_map, [approx], color=1.)
+
+ return approx, room_filled_map
+ else:
+ return approx
+
+ def print_res_str_for_latex(self, quant_result_dict):
+
+ str_fields = ""
+ str_values = ""
+
+ avg_value_prec = 0
+ avg_value_rec = 0
+ for k_ind, k in enumerate(quant_result_dict.keys()):
+ str_fields += " & " + k
+ str_values += " & %.2f " % quant_result_dict[k]
+
+ if k_ind % 2 == 0:
+ avg_value_prec += quant_result_dict[k] / 3
+ else:
+ avg_value_rec += quant_result_dict[k] / 3
+
+ str_fields += "tm_prec & tm_rec"
+
+ str_values += " & %.2f " % avg_value_prec
+ str_values += " & %.2f " % avg_value_rec
+
+ str_fields += " \\\\"
+ str_values += " \\\\"
+
+ print(str_fields)
+ print(str_values)
+
+ def calc_gradient(self, room_map):
+ grad_x = np.abs(room_map[:, 1:] - room_map[:, :-1])
+ grad_y = np.abs(room_map[1:] - room_map[:-1])
+
+ grad_xy = np.zeros_like(room_map)
+ grad_xy[1:] = grad_y
+ grad_xy[:, 1:] = np.maximum(grad_x, grad_xy[:,1:])
+
+ plt.figure()
+ plt.axis("off")
+ plt.imshow(grad_xy, cmap="gray")
+ # plt.show()
+ plt.savefig("grad.png", bbox_inches='tight')
+
+ plt.figure()
+ plt.axis("off")
+ plt.imshow(room_map, cmap="gray")
+ # plt.show()
+ plt.savefig("joint_mask.png", bbox_inches='tight')
+ assert False
+
+ def evaluate_scene(self, room_polys, show=False, name="ours", dataset_type="s3d"):
+
+ with torch.no_grad():
+ joint_room_map = np.zeros((self.options.height, self.options.width))
+
+ edge_map = np.zeros_like(joint_room_map)
+ room_filled_map = np.ones([joint_room_map.shape[0], joint_room_map.shape[1], 3])
+
+ density_map = self.data_rw.density_map.cpu().numpy()[0]
+ img_size = (density_map.shape[0], density_map.shape[0])
+
+ for room_ind, poly in enumerate(room_polys):
+ cv2.polylines(edge_map, [poly], isClosed=True, color=1.)
+ cv2.fillPoly(joint_room_map, [poly], color=1.)
+
+ joint_room_map_vis = np.ones([joint_room_map.shape[0], joint_room_map.shape[1], 3])
+
+ # Ground Truth
+
+ gt_polys_list = self.data_rw.gt_sample["polygons_list"]
+ gt_polys_list = [np.concatenate([poly, poly[None, 0]]) for poly in gt_polys_list]
+
+ ignore_mask_region = self.data_rw.gt_sample["wall_map"].cpu().numpy()[0, :, :, 0]
+
+ img_size = (joint_room_map.shape[0], joint_room_map.shape[1])
+ quant_result_dict = self.get_quantitative(gt_polys_list, ignore_mask_region, room_polys, img_size, dataset_type=dataset_type)
+
+ return quant_result_dict
+
+ def get_quantitative(self, gt_polys, ignore_mask_region, pred_polys=None, masks_list=None, img_size=(256, 256), dataset_type="s3d"):
+ def get_room_metric():
+ pred_overlaps = [False] * len(pred_room_map_list)
+
+ for pred_ind1 in range(len(pred_room_map_list) - 1):
+ pred_map1 = pred_room_map_list[pred_ind1]
+
+ for pred_ind2 in range(pred_ind1 + 1, len(pred_room_map_list)):
+ pred_map2 = pred_room_map_list[pred_ind2]
+
+ if dataset_type == "s3d":
+ kernel = np.ones((5, 5), np.uint8)
+ else:
+ kernel = np.ones((3, 3), np.uint8)
+
+ # todo: for our method, the rooms share corners and edges, need to check here
+ pred_map1_er = cv2.erode(pred_map1, kernel)
+ pred_map2_er = cv2.erode(pred_map2, kernel)
+
+ intersection = (pred_map1_er + pred_map2_er) == 2
+ # intersection = (pred_map1 + pred_map2) == 2
+
+ intersection_area = np.sum(intersection)
+
+ if intersection_area >= 1:
+ pred_overlaps[pred_ind1] = True
+ pred_overlaps[pred_ind2] = True
+
+ # import pdb; pdb.set_trace()
+ room_metric = [np.bool((1 - pred_overlaps[ind]) * pred2gt_exists[ind]) for ind in range(len(pred_polys))]
+
+ return room_metric
+
+ def get_corner_metric():
+
+ room_corners_metric = []
+ for pred_poly_ind, gt_poly_ind in enumerate(pred2gt_indices):
+ p_poly = pred_polys[pred_poly_ind][:-1] # Last vertex = First vertex
+
+ p_poly_corner_metrics = [False] * p_poly.shape[0]
+ if not room_metric[pred_poly_ind]:
+ room_corners_metric += p_poly_corner_metrics
+ continue
+
+ gt_poly = gt_polys[gt_poly_ind][:-1]
+
+ # for v in p_poly:
+ # v_dists = np.linalg.norm(v[None,:] - gt_poly, axis=1, ord=2)
+ # v_min_dist = np.min(v_dists)
+ #
+ # v_tp = v_min_dist <= 10
+ # room_corners_metric.append(v_tp)
+
+ for v in gt_poly:
+ v_dists = np.linalg.norm(v[None,:] - p_poly, axis=1, ord=2)
+ v_min_dist_ind = np.argmin(v_dists)
+ v_min_dist = v_dists[v_min_dist_ind]
+
+ if not p_poly_corner_metrics[v_min_dist_ind]:
+ v_tp = v_min_dist <= corner_metric_thresh
+ p_poly_corner_metrics[v_min_dist_ind] = v_tp
+
+ room_corners_metric += p_poly_corner_metrics
+
+ return room_corners_metric
+
+ def get_angle_metric():
+
+ def get_line_vector(p1, p2):
+ p1 = np.concatenate((p1, np.array([1])))
+ p2 = np.concatenate((p2, np.array([1])))
+
+ line_vector = -np.cross(p1, p2)
+
+ return line_vector
+
+ def get_poly_orientation(my_poly):
+ angles_sum = 0
+ for v_ind, _ in enumerate(my_poly):
+ if v_ind < len(my_poly) - 1:
+ v_sides = my_poly[[v_ind - 1, v_ind, v_ind, v_ind + 1], :]
+ else:
+ v_sides = my_poly[[v_ind - 1, v_ind, v_ind, 0], :]
+
+ v1_vector = get_line_vector(v_sides[0], v_sides[1])
+ v1_vector = v1_vector / (np.linalg.norm(v1_vector, ord=2) + 1e-4)
+ v2_vector = get_line_vector(v_sides[2], v_sides[3])
+ v2_vector = v2_vector / (np.linalg.norm(v2_vector, ord=2) + 1e-4)
+
+ orientation = (v_sides[1, 1] - v_sides[0, 1]) * (v_sides[3, 0] - v_sides[1, 0]) - (
+ v_sides[3, 1] - v_sides[1, 1]) * (
+ v_sides[1, 0] - v_sides[0, 0])
+
+ v1_vector_2d = v1_vector[:2] / (v1_vector[2] + 1e-4)
+ v2_vector_2d = v2_vector[:2] / (v2_vector[2] + 1e-4)
+
+ v1_vector_2d = v1_vector_2d / (np.linalg.norm(v1_vector_2d, ord=2) + 1e-4)
+ v2_vector_2d = v2_vector_2d / (np.linalg.norm(v2_vector_2d, ord=2) + 1e-4)
+
+ angle_cos = v1_vector_2d.dot(v2_vector_2d)
+ angle_cos = np.clip(angle_cos, -1, 1)
+
+ # G.T. has clockwise orientation, remove minus in the equation
+
+ angle = np.sign(orientation) * np.abs(np.arccos(angle_cos))
+ angle_degree = angle * 180 / np.pi
+
+ angles_sum += angle_degree
+
+ return np.sign(angles_sum)
+
+ def get_angle_v_sides(inp_v_sides, poly_orient):
+ v1_vector = get_line_vector(inp_v_sides[0], inp_v_sides[1])
+ v1_vector = v1_vector / (np.linalg.norm(v1_vector, ord=2) + 1e-4)
+ v2_vector = get_line_vector(inp_v_sides[2], inp_v_sides[3])
+ v2_vector = v2_vector / (np.linalg.norm(v2_vector, ord=2) + 1e-4)
+
+ orientation = (inp_v_sides[1, 1] - inp_v_sides[0, 1]) * (inp_v_sides[3, 0] - inp_v_sides[1, 0]) - (
+ inp_v_sides[3, 1] - inp_v_sides[1, 1]) * (
+ inp_v_sides[1, 0] - inp_v_sides[0, 0])
+
+ v1_vector_2d = v1_vector[:2] / (v1_vector[2]+ 1e-4)
+ v2_vector_2d = v2_vector[:2] / (v2_vector[2]+ 1e-4)
+
+ v1_vector_2d = v1_vector_2d / (np.linalg.norm(v1_vector_2d, ord=2) + 1e-4)
+ v2_vector_2d = v2_vector_2d / (np.linalg.norm(v2_vector_2d, ord=2) + 1e-4)
+
+ angle_cos = v1_vector_2d.dot(v2_vector_2d)
+ angle_cos = np.clip(angle_cos, -1, 1)
+
+ angle = poly_orient * np.sign(orientation) * np.arccos(angle_cos)
+ angle_degree = angle * 180 / np.pi
+
+ return angle_degree
+
+ room_angles_metric = []
+ for pred_poly_ind, gt_poly_ind in enumerate(pred2gt_indices):
+ p_poly = pred_polys[pred_poly_ind][:-1] # Last vertex = First vertex
+
+ p_poly_angle_metrics = [False] * p_poly.shape[0]
+ if not room_metric[pred_poly_ind]:
+ room_angles_metric += p_poly_angle_metrics
+ continue
+
+ gt_poly = gt_polys[gt_poly_ind][:-1]
+
+ # for v in p_poly:
+ # v_dists = np.linalg.norm(v[None,:] - gt_poly, axis=1, ord=2)
+ # v_min_dist = np.min(v_dists)
+ #
+ # v_tp = v_min_dist <= 10
+ # room_corners_metric.append(v_tp)
+
+ gt_poly_orient = get_poly_orientation(gt_poly)
+ p_poly_orient = get_poly_orientation(p_poly)
+
+ for v_gt_ind, v in enumerate(gt_poly):
+ v_dists = np.linalg.norm(v[None,:] - p_poly, axis=1, ord=2)
+ v_ind = np.argmin(v_dists)
+ v_min_dist = v_dists[v_ind]
+
+ if v_min_dist > corner_metric_thresh:
+ # room_angles_metric.append(False)
+ continue
+
+ if v_ind < len(p_poly) - 1:
+ v_sides = p_poly[[v_ind - 1, v_ind, v_ind, v_ind + 1], :]
+ else:
+ v_sides = p_poly[[v_ind - 1, v_ind, v_ind, 0], :]
+
+ v_sides = v_sides.reshape((4,2))
+ pred_angle_degree = get_angle_v_sides(v_sides, p_poly_orient)
+
+ # Note: replacing some variables with values from the g.t. poly
+
+ if v_gt_ind < len(gt_poly) - 1:
+ v_sides = gt_poly[[v_gt_ind - 1, v_gt_ind, v_gt_ind, v_gt_ind + 1], :]
+ else:
+ v_sides = gt_poly[[v_gt_ind - 1, v_gt_ind, v_gt_ind, 0], :]
+
+ v_sides = v_sides.reshape((4, 2))
+ gt_angle_degree = get_angle_v_sides(v_sides, gt_poly_orient)
+
+ angle_metric = np.abs(pred_angle_degree - gt_angle_degree)
+
+ # room_angles_metric.append(angle_metric < 5)
+ p_poly_angle_metrics[v_ind] = angle_metric <= angle_metric_thresh
+
+ # if angle_metric > 5:
+ # print(v_gt_ind, angle_metric)
+ # print(pred_angle_degree, gt_angle_degree)
+ # input("?")
+
+
+ room_angles_metric += p_poly_angle_metrics
+
+ for am, cm in zip(room_angles_metric, corner_metric):
+ assert not (cm == False and am == True), "cm: %d am: %d" %(cm, am)
+
+ return room_angles_metric
+
+ def poly_map_sort_key(x):
+ return np.sum(x[1])
+
+ h, w = img_size
+
+ gt_room_map_list = []
+ for room_ind, poly in enumerate(gt_polys):
+ room_map = np.zeros((h, w))
+ cv2.fillPoly(room_map, [poly], color=1.)
+
+ gt_room_map_list.append(room_map)
+
+ gt_polys_sorted_indcs = [i[0] for i in sorted(enumerate(gt_room_map_list), key=poly_map_sort_key, reverse=True)]
+
+ gt_polys = [gt_polys[ind] for ind in gt_polys_sorted_indcs]
+ gt_room_map_list = [gt_room_map_list[ind] for ind in gt_polys_sorted_indcs]
+
+ if pred_polys is not None:
+ pred_room_map_list = []
+ for room_ind, poly in enumerate(pred_polys):
+ room_map = np.zeros((h, w))
+ cv2.fillPoly(room_map, [poly], color=1.)
+
+ pred_room_map_list.append(room_map)
+ else:
+ pred_room_map_list = masks_list
+
+ gt2pred_indices = [-1] * len(gt_polys)
+ gt2pred_exists = [False] * len(gt_polys)
+
+ for gt_ind, gt_map in enumerate(gt_room_map_list):
+
+ best_iou = 0.
+ best_ind = -1
+ for pred_ind, pred_map in enumerate(pred_room_map_list):
+
+ intersection = (1 - ignore_mask_region) * ((pred_map + gt_map) == 2)
+ union = (1 - ignore_mask_region) * ((pred_map + gt_map) >= 1)
+
+ iou = np.sum(intersection) / (np.sum(union) + 1)
+
+ if iou > best_iou and iou > 0.5:
+ best_iou = iou
+ best_ind = pred_ind
+
+ # plt.figure()
+ # plt.subplot(121)
+ # plt.imshow(pred_map)
+ # plt.subplot(122)
+ # plt.imshow(gt_map)
+ # plt.show()
+ # if best_ind == -1:
+ # plt.figure()
+ # plt.imshow(gt_map)
+ # plt.show()
+
+ gt2pred_indices[gt_ind] = best_ind
+ gt2pred_exists[gt_ind] = best_ind != -1
+
+ # if best_ind == -1:
+ # plt.figure()
+ # plt.imshow(gt_map)
+ # plt.show()
+
+ pred2gt_exists = [True if pred_ind in gt2pred_indices else False for pred_ind, _ in enumerate(pred_polys)]
+ pred2gt_indices = [gt2pred_indices.index(pred_ind) if pred_ind in gt2pred_indices else -1 for pred_ind, _ in enumerate(pred_polys)]
+
+ # print(gt2pred_indices)
+ # print(pred2gt_indices)
+ # assert False
+
+ # import pdb; pdb.set_trace()
+ room_metric = get_room_metric()
+ if len(pred_polys) == 0:
+ room_metric_prec = 0
+ else:
+ room_metric_prec = sum(room_metric) / float(len(pred_polys))
+ room_metric_rec = sum(room_metric) / float(len(gt_polys))
+
+
+ corner_metric = get_corner_metric()
+ pred_corners_n = sum([poly.shape[0] - 1 for poly in pred_polys])
+ gt_corners_n = sum([poly.shape[0] - 1 for poly in gt_polys])
+
+ if pred_corners_n > 0:
+ corner_metric_prec = sum(corner_metric) / float(pred_corners_n)
+ else:
+ corner_metric_prec = 0
+ corner_metric_rec = sum(corner_metric) / float(gt_corners_n)
+
+
+ angles_metric = get_angle_metric()
+
+ if pred_corners_n > 0:
+ angles_metric_prec = sum(angles_metric) / float(pred_corners_n)
+ else:
+ angles_metric_prec = 0
+ angles_metric_rec = sum(angles_metric) / float(gt_corners_n)
+
+ assert room_metric_prec <= 1
+ assert room_metric_rec <= 1
+ assert corner_metric_prec <= 1
+ assert corner_metric_rec <= 1
+ assert angles_metric_prec <= 1
+ assert angles_metric_rec <= 1
+
+ result_dict = {
+ 'room_prec': room_metric_prec,
+ 'room_rec': room_metric_rec,
+ 'corner_prec': corner_metric_prec,
+ 'corner_rec': corner_metric_rec,
+ 'angles_prec': angles_metric_prec,
+ 'angles_rec': angles_metric_rec,
+ }
+
+ return result_dict
diff --git a/s3d_floorplan_eval/S3DLoader/S3DLoader.py b/s3d_floorplan_eval/S3DLoader/S3DLoader.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6dcd92258952f6f20694a76595f423984ecc9c
--- /dev/null
+++ b/s3d_floorplan_eval/S3DLoader/S3DLoader.py
@@ -0,0 +1,306 @@
+
+import torch
+from torch.utils.data import Dataset, DataLoader
+import torch.utils.data.distributed
+import os
+import cv2
+import json
+
+from S3DLoader.s3d_utils import *
+from S3DLoader.poly_utils import *
+
+
+class S3DLoader(object):
+ def __init__(self, args, mode, generate_input_candidates=False):
+ self.mode = mode
+ self.seed = 8978
+ np.random.seed(seed=self.seed)
+
+ if hasattr(args, 'network_mode'):
+ self.function_mode = args.network_mode
+ else:
+ self.function_mode = "S"
+
+ if hasattr(args, 'batch_size'):
+ self.batch_size = args.batch_size
+ else:
+ self.batch_size = 1
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print('Selected device is:', device)
+ self.device = device
+
+ if mode == 'train':
+ self.dataset = self.create_dataset(args, mode, generate_input_candidates)
+ self.augment = True
+
+ self.data = DataLoader(self.dataset, self.batch_size,
+ drop_last=True,
+ collate_fn=self.collate_fn,
+ shuffle=True)
+
+ self.sample_n = len(self.dataset)
+
+ elif mode == 'online_eval' or mode == 'test':
+ self.dataset = self.create_dataset(args, mode, generate_input_candidates)
+ self.augment = False
+ # self.batch_size = 4
+
+ self.sample_n = len(self.dataset)
+
+ self.data = DataLoader(self.dataset, self.batch_size,
+ drop_last=True,
+ collate_fn=self.collate_fn)
+
+
+ elif mode == 'test':
+ self.dataset = self.create_dataset(args, mode)
+ self.augment = False
+ self.batch_size = 1
+
+ self.sample_n = 20
+
+ self.data = DataLoader(self.dataset, self.batch_size,
+ num_workers=1,
+ drop_last=True,
+ collate_fn=self.collate_fn)
+
+ # elif mode == 'test':
+ # self.dataset = self.create_dataset(args, mode)
+ # self.augment = False
+ #
+ # self.data = DataLoader(self.dataset,
+ # 1,
+ # shuffle=False,
+ # num_workers=1)
+ # self.sample_n = 20
+
+ else:
+ print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
+
+ def collate_fn(self, samples):
+
+ # wall_maps = [torch.tensor(s["wall_map"][None,:,:,None], device=self.device) for s in samples]
+ room_maps = [torch.tensor(s["room_map"][None,:,:,None], device=self.device) for s in samples]
+ input_maps = [torch.tensor(s["input_map"][None], device=self.device) for s in samples]
+ scores = [torch.tensor(s["score"][None], device=self.device) for s in samples]
+
+ torch_sample = {}
+ torch_sample["room_map"] = torch.cat(room_maps, dim=0)
+ # torch_sample["wall_map"] = torch.cat(wall_maps, dim=0)
+ torch_sample["input_map"] = torch.cat(input_maps, dim=0)
+ torch_sample["score"] = torch.cat(scores, dim=0)
+
+
+ for key, value in torch_sample.items():
+ assert torch.all(torch_sample[key] == torch_sample[key])
+ assert torch.all(torch.logical_not(torch.isinf(torch_sample[key])))
+
+ return torch_sample
+
+ def create_dataset(self, args, mode, generate_input_candidates):
+ #dataset_path = "../Structured3D/montefloor_data"
+ self.args = args
+ dataset_path = args.dataset_path
+
+ if mode == "train":
+ scenes_path = os.path.join(dataset_path, "train")
+
+ dataset = S3DDataset(args, scenes_path, None,
+ num_scenes=3000, generate_input_candidates=generate_input_candidates, mode=mode)
+
+ elif mode == "online_eval":
+ scenes_path = os.path.join(dataset_path, "val")
+
+ dataset = S3DDataset(args, scenes_path, None,
+ num_scenes=250, generate_input_candidates=generate_input_candidates, mode=mode)
+ elif mode == "test":
+ scenes_path = os.path.join(dataset_path, "test")
+ # scenes_path = os.path.join(dataset_path, "val")
+
+ dataset = S3DDataset(args, scenes_path, None,
+ num_scenes=250, generate_input_candidates=generate_input_candidates, mode=mode)
+
+ return dataset
+
+ def load_sample(self, sample_batch):
+ """
+ Identity function. Everything is already loaded in Dataset class for Structured 3D
+ :param sample_batch:
+ :return:
+ """
+ return sample_batch
+
+
+class S3DDataset(Dataset):
+ def __init__(self, options, scenes_path, score_gen, num_scenes, generate_input_candidates, mode):
+ print("Creating Structured3D Dataset with %d scenes..." % num_scenes)
+ self.options = options
+ self.score_gen = None
+
+ self.mode = mode
+
+ self.scenes_path = scenes_path
+ self.floor_data_folder_name = ""
+
+ self.scenes_list = os.listdir(scenes_path)
+ self.scenes_list.sort()
+
+ inv_scenes = ["scene_01155", "scene_01852", "scene_01192", "scene_01816"]
+ self.scenes_list = [s for s in self.scenes_list if s not in inv_scenes]
+ self.scenes_list = self.scenes_list[:num_scenes]
+
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.device = device
+
+ self.gen_input_candidates = generate_input_candidates
+
+ def __getitem__(self, item):
+ scene_name = self.scenes_list[item]
+ sample = self.load_scene(scene_name)
+
+ return sample
+
+ def __len__(self):
+ return len(self.scenes_list)
+
+ def load_density_map(self, sp):
+ """
+ Load density map
+
+ :param sp:
+ :return:
+ """
+ density_path = os.path.join(sp, self.floor_data_folder_name, "density.png")
+ density_map = cv2.imread(density_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 255.
+
+ if self.gen_input_candidates:
+ thresh = np.maximum(np.random.random(), 0.8)
+ density_map = np.minimum(density_map, thresh) / thresh
+
+ if self.mode != "test":
+ pow = np.random.random()
+ pow = (1.5 - 1.) * (pow - 1) + 1.5
+ density_map = density_map ** pow
+
+
+ return density_map.astype(np.float32)
+
+ def load_annotation(self, sp):
+ """
+ Load annotation dict
+
+ :param sp:
+ :return:
+ :rtype: dict
+ """
+ anno_path = os.path.join(sp, self.floor_data_folder_name, "annotation_3d.json")
+ with open(anno_path, "r") as f:
+ anno_dict = json.load(f)
+
+ return anno_dict
+
+ def load_scene(self, scene_name):
+ """
+ Load scene
+
+ :param scene_name:
+ :return:
+ """
+
+ def cvt_tmp_sample_to_torch():
+ torch_sample = {}
+
+ room_map = torch.tensor(np.array(sample['room_map']), device=self.device)[None]
+ # room_map = kornia.morphology.dilation(room_map[:, None], kernel=torch.ones((5, 5), device=self.device))[:,0]
+
+ torch_sample['room_map'] = room_map
+
+ if 'input_map' in sample.keys():
+ torch_sample['input_map'] = torch.tensor(np.array(sample['input_map']), device=self.device)[None]
+ torch_sample['cand_inst'] = torch.tensor(np.array(sample['cand_inst']), device=self.device)[None]
+ torch_sample['cand_confidence'] = torch.tensor(np.array(sample['cand_confidence']), device=self.device)[
+ None]
+
+ else:
+ torch_sample['density_map'] = torch.tensor(np.array(sample['density_map']), device=self.device)[None]
+ torch_sample['wall_map'] = torch.tensor(np.array(sample['wall_map']), device=self.device)[None]
+ # torch_sample['room_map'] = torch.tensor(np.array(sample['room_map']), device=self.device)[None]
+ torch_sample['polygons_list'] = [torch.tensor(poly, device=self.device)[None] for poly in sample['polygons_list']]
+
+ return torch_sample
+
+ sp = os.path.join(self.scenes_path, scene_name)
+ sample = {}
+ sample["scene_name"] = scene_name
+
+ scene_anno = self.load_annotation(sp)
+
+ # density_map = torch.tensor(np.array(density_map))[None]
+ density_map = self.load_density_map(sp)
+
+ self.generate_room_map(sample, scene_anno, density_map)
+
+ sample['density_map'] = density_map
+
+ # import pdb; pdb.set_trace()
+ for key, value in sample.items():
+ assert np.all(value == value), "%s contains NaN" % key
+
+ # import matplotlib.pyplot as plt
+ # plt.figure()
+ # plt.subplot(131)
+ # plt.title(scene_name)
+ # plt.imshow(density_map)
+ # plt.subplot(132)
+ # plt.imshow(sample["room_map"])
+ # plt.subplot(133)
+ # # plt.imshow(sample["input_map"][:,:,1])
+ # # plt.imshow(sample["cand_inst"][:,:,0])
+ # plt.show()
+
+ return sample
+
+ def generate_room_map(self, sample, annos, density_map):
+ """
+
+ :param density_map:
+ :param sample:
+ :param annos:
+ :return:
+ """
+
+ h, w = density_map.shape
+
+ polys = parse_floor_plan_polys(annos)
+
+ room_map, polygons_list = generate_floorplan(annos, polys, h, w, ignore_types=['outwall', 'door', 'window'], constant_color=False, shuffle=self.gen_input_candidates)
+
+ room_map = cv2.dilate(room_map, np.ones((5,5)))
+
+
+ wall_map, _ = generate_floorplan(annos, polys, h, w, ignore_types=[], include_types=['outwall'], constant_color=True)
+ wall_map *= (room_map == 0)
+
+ sample['room_map'] = room_map.astype(np.float32)
+ sample['wall_map'] = wall_map.astype(np.float32)
+
+ sample['polygons_list'] = polygons_list
+
+ def generate_density(self, points, width=256, height=256):
+ image_res_tensor = torch.tensor([width, height], device=self.device).reshape(1, 1, 2)
+
+ coordinates = torch.round(points[:, :, :2] * image_res_tensor)
+ coordinates = torch.minimum(torch.maximum(coordinates, torch.zeros_like(image_res_tensor)),
+ image_res_tensor - 1).type(torch.cuda.LongTensor)
+
+ density = torch.zeros((self.batch_size, height, width), dtype=torch.float, device=self.device)
+
+ for i in range(self.batch_size):
+ unique_coordinates, counts = torch.unique(coordinates[i], return_counts=True, dim=0)
+
+ density[i, unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts.type(torch.cuda.FloatTensor)
+ density[i] = density[i] / torch.max(density[i])
+
+ return density
diff --git a/s3d_floorplan_eval/S3DLoader/poly_utils.py b/s3d_floorplan_eval/S3DLoader/poly_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9fedd2f511c711a567612e21e0620a19e4fee78
--- /dev/null
+++ b/s3d_floorplan_eval/S3DLoader/poly_utils.py
@@ -0,0 +1,29 @@
+import numpy as np
+
+
+def rotate_poly(poly, angle, flip_h):
+ """
+ Rotate poly
+
+ :param poly:
+ :return:
+ """
+
+ px, py = poly[:, 0], poly[:, 1]
+
+ angle_rad = angle * np.pi / 180
+
+ qx = np.cos(angle_rad) * px - np.sin(angle_rad) * py
+ qy = np.sin(angle_rad) * px + np.cos(angle_rad) * py
+
+ if flip_h:
+ qx = -qx
+
+ rotated_poly = np.zeros_like(poly)
+ rotated_poly[:, 0] = qx
+ rotated_poly[:, 1] = qy
+
+ # print("p", poly)
+ # print("r", rotated_poly)
+
+ return rotated_poly
\ No newline at end of file
diff --git a/s3d_floorplan_eval/S3DLoader/s3d_utils.py b/s3d_floorplan_eval/S3DLoader/s3d_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fadf664e66e8851dbc5bccb6629fb43837bc759
--- /dev/null
+++ b/s3d_floorplan_eval/S3DLoader/s3d_utils.py
@@ -0,0 +1,133 @@
+"""
+This code is an adaptation that uses Structured 3D for the code base.
+
+Reference: https://github.com/bertjiazheng/Structured3D
+"""
+
+import numpy as np
+import cv2
+from shapely.geometry import Polygon
+import random
+
+def parse_floor_plan_polys(annos):
+ planes = []
+ for semantic in annos['semantics']:
+ for planeID in semantic['planeID']:
+ if annos['planes'][planeID]['type'] == 'floor':
+ planes.append({'planeID': planeID, 'type': semantic['type']})
+
+ if semantic['type'] == 'outwall':
+ outerwall_planes = semantic['planeID']
+
+ # extract hole vertices
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+ lines_holes = np.unique(lines_holes)
+
+ # junctions on the floor
+ junctions = np.array([junc['coordinate'] for junc in annos['junctions']])
+ junction_floor = np.where(np.isclose(junctions[:, -1], 0))[0]
+
+ # construct each polygon
+ polygons = []
+ for plane in planes:
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][plane['planeID']]))[0].tolist()
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ polygon = convert_lines_to_vertices(junction_pairs)
+ polygons.append([polygon[0], plane['type']])
+
+ outerwall_floor = []
+ for planeID in outerwall_planes:
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ lineIDs = np.setdiff1d(lineIDs, lines_holes)
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ for start, end in junction_pairs:
+ if start in junction_floor and end in junction_floor:
+ outerwall_floor.append([start, end])
+
+ outerwall_polygon = convert_lines_to_vertices(outerwall_floor)
+ polygons.append([outerwall_polygon[0], 'outwall'])
+
+ return polygons
+
+def convert_lines_to_vertices(lines):
+ """
+ convert line representation to polygon vertices
+
+ """
+ polygons = []
+ lines = np.array(lines)
+
+ polygon = None
+ while len(lines) != 0:
+ if polygon is None:
+ polygon = lines[0].tolist()
+ lines = np.delete(lines, 0, 0)
+
+ lineID, juncID = np.where(lines == polygon[-1])
+ vertex = lines[lineID[0], 1 - juncID[0]]
+ lines = np.delete(lines, lineID, 0)
+
+ if vertex in polygon:
+ polygons.append(polygon)
+ polygon = None
+ else:
+ polygon.append(vertex)
+
+ return polygons
+
+def generate_floorplan(annos, polygons, height, width, ignore_types, include_types=None, fillpoly=True, constant_color=False, shuffle=False):
+ """
+ plot floorplan
+
+ """
+
+ floor_map = np.zeros((height, width))
+
+ junctions = np.array([junc['coordinate'][:2] for junc in annos['junctions']])
+
+ room_ind = 0
+ if shuffle:
+ room_ind = np.random.randint(0, 2)
+
+ polygons_list = []
+ for poly_ind, (polygon, poly_type) in enumerate(polygons):
+ if poly_type in ignore_types:
+ continue
+ if include_types is not None and poly_type not in include_types:
+ continue
+
+ polygon = junctions[np.array(polygon)].astype(np.int32)
+
+ poly_shapely = Polygon(polygon)
+ area = poly_shapely.area
+
+ # assert area > 10
+ if area < 100:
+ continue
+
+ polygons_list.append(polygon)
+
+ if shuffle:
+ random.shuffle(polygons_list)
+ for poly_ind, polygon in enumerate(polygons_list):
+
+ if shuffle:
+ room_ind += np.random.randint(1, 2)
+ else:
+ room_ind += 1
+
+ if fillpoly:
+ if constant_color:
+ clr = 1.
+ else:
+ clr = room_ind
+ cv2.fillPoly(floor_map, [polygon], color=clr)
+ else:
+ assert constant_color
+ cv2.polylines(floor_map, [polygon], isClosed=True, color=1., thickness=3)
+
+ return floor_map, polygons_list
\ No newline at end of file
diff --git a/s3d_floorplan_eval/convert_density.py b/s3d_floorplan_eval/convert_density.py
new file mode 100644
index 0000000000000000000000000000000000000000..64bff846a303c7fb4b76882b1dab27e98dd9d4ed
--- /dev/null
+++ b/s3d_floorplan_eval/convert_density.py
@@ -0,0 +1,21 @@
+import os
+import numpy as np
+import cv2
+
+
+
+
+source = '../Structured3D/montefloor_data/test/'
+dst = './viz_density'
+
+for dirname in sorted(os.listdir(source)):
+ density_path = os.path.join(source, dirname, 'density.png')
+ density_img = cv2.imread(density_path)
+ density = 255 - density_img
+ out_path = os.path.join(dst, dirname + '.png')
+ out_alpha_path = os.path.join(dst, dirname + '_alpha.png')
+ alphas = np.zeros([density.shape[0], density.shape[1], 1], dtype=np.int32)
+ alphas[density_img.sum(axis=-1) > 0] = 255
+ density_alpha = np.concatenate([density, alphas], axis=-1)
+ cv2.imwrite(out_path, density)
+ cv2.imwrite(out_alpha_path, density_alpha)
\ No newline at end of file
diff --git a/s3d_floorplan_eval/evaluate_solution.py b/s3d_floorplan_eval/evaluate_solution.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac335777f63a68e1c283cbbe6fdcac71dfa4e97c
--- /dev/null
+++ b/s3d_floorplan_eval/evaluate_solution.py
@@ -0,0 +1,111 @@
+import copy
+import functools
+import numpy as np
+import os
+
+from Evaluator.Evaluator import Evaluator
+from options import MCSSOptions
+from DataRW.S3DRW import S3DRW
+from DataRW.wrong_annotatios import wrong_s3d_annotations_list
+from planar_graph_utils import get_regions_from_pg
+
+
+room_polys_def = [np.array([[191, 150],
+ [191, 70],
+ [222, 70],
+ [222, 150],
+ [191, 150]]), np.array([[232, 65],
+ [232, 11],
+ [202, 11],
+ [202, 65],
+ [232, 65]]), np.array([[ 47, 50],
+ [ 47, 150],
+ [ 24, 150],
+ [ 24, 50],
+ [ 47, 50]]), np.array([[199, 156],
+ [199, 234],
+ [146, 234],
+ [146, 156],
+ [199, 156]]), np.array([[109, 184],
+ [120, 184],
+ [120, 156],
+ [ 50, 156],
+ [ 50, 234],
+ [109, 234],
+ [109, 184]]), np.array([[110, 234],
+ [144, 234],
+ [144, 187],
+ [110, 187],
+ [110, 234]]), np.array([[ 50, 50],
+ [ 50, 150],
+ [123, 150],
+ [123, 184],
+ [144, 184],
+ [144, 150],
+ [190, 150],
+ [190, 70],
+ [108, 70],
+ [108, 50],
+ [ 50, 50]])]
+
+pg_base = '../results/npy_heat_s3d_256/'
+
+options = MCSSOptions()
+opts = options.parse()
+
+if __name__ == '__main__':
+
+ # data_rw = FloorNetRW(opts)
+
+ if opts.scene_id == "val":
+
+ opts.scene_id = "scene_03250" # Temp. value
+ data_rw = S3DRW(opts)
+ scene_list = data_rw.loader.scenes_list
+
+ quant_result_dict = None
+ quant_result_maskrcnn_dict = None
+ scene_counter = 0
+ for scene_ind, scene in enumerate(scene_list):
+ if int(scene[6:]) in wrong_s3d_annotations_list:
+ continue
+
+ print("------------")
+ curr_opts = copy.deepcopy(opts)
+ curr_opts.scene_id = scene
+ curr_data_rw = S3DRW(curr_opts)
+ print("Running Evaluation for scene %s" % scene)
+
+ evaluator = Evaluator(curr_data_rw, curr_opts)
+
+ # TODO load your room polygons into room_polys, list of polygons (n x 2)
+ # room_polys = np.array([[[0,0], [200, 0], [200, 200]]]) # Placeholder
+
+ pg_path = os.path.join(pg_base, scene[6:] + '.npy')
+ example_pg = np.load(pg_path, allow_pickle=True).tolist()
+ regions = get_regions_from_pg(example_pg, corner_sorted=True)
+
+ room_polys = regions
+ # room_polys = room_polys_def # Placeholder
+
+
+ quant_result_dict_scene =\
+ evaluator.evaluate_scene(room_polys=room_polys)
+
+ if quant_result_dict is None:
+ quant_result_dict = quant_result_dict_scene
+ else:
+ for k in quant_result_dict.keys():
+ quant_result_dict[k] += quant_result_dict_scene[k]
+
+ scene_counter += 1
+
+ # break
+
+ for k in quant_result_dict.keys():
+ quant_result_dict[k] /= float(scene_counter)
+
+ print("Our: ", quant_result_dict)
+
+ print("Ours")
+ evaluator.print_res_str_for_latex(quant_result_dict)
diff --git a/s3d_floorplan_eval/generate_html.py b/s3d_floorplan_eval/generate_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..9236d3c55ade675e4c47a2b4518da7e62ff4f518
--- /dev/null
+++ b/s3d_floorplan_eval/generate_html.py
@@ -0,0 +1,59 @@
+import numpy as np
+import os.path as osp
+
+head = '''
+
+
+
+
+
+
+
+
+
`
+
+'''
+
+def writeHTML(out_path, results_dirs):
+ f = open(out_path, 'w')
+ f.write(head + '\n')
+ f.write(''
+ ' ID | '
+ ' Input | '
+ ' HAWP | '
+ ' LETR | '
+ ' HEAT (Ours) | '
+ ' Ground-truth | '
+ '
')
+
+ wrong_s3d_annotations_list = [3261, 3271, 3276, 3296, 3342, 3387, 3398, 3466, 3496]
+ file_ids = ['0{}'.format(x) for x in range(3250, 3500) if x not in wrong_s3d_annotations_list]
+ permuted_ids = np.random.permutation(file_ids)
+ file_ids = permuted_ids[:100]
+
+ for file_id in file_ids:
+ row_str = ''
+ row_str += ' {} | '.format(file_id)
+ for dir_idx, result_dir in enumerate(results_dirs):
+ if dir_idx == 0:
+ pred_filepath = osp.join(result_dir, 'scene_{}_alpha.png'.format(file_id))
+ row_str += ' | '.format(pred_filepath)
+ else:
+ pred_filepath = osp.join(result_dir, '{}.png'.format(file_id))
+ row_str += ' | '.format(pred_filepath)
+ row_str += '
'
+ f.write(row_str + '\n')
+
+ f.write(end + '\n')
+
+
+if __name__ == '__main__':
+ results_dirs = ['viz_density', 'viz_hawp', 'viz_letr', 'viz_heat_th5', 'viz_gt']
+
+ writeHTML(out_path='./indoor_qual.html', results_dirs=results_dirs)
diff --git a/s3d_floorplan_eval/options.py b/s3d_floorplan_eval/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cab53089e322d251f9b27274ef2abe62d5e7c95
--- /dev/null
+++ b/s3d_floorplan_eval/options.py
@@ -0,0 +1,59 @@
+from __future__ import absolute_import, division, print_function
+
+import os
+import argparse
+
+file_dir = os.path.dirname(__file__) # the directory that options.py resides in
+
+class MCSSOptions:
+ def __init__(self):
+ self.parser = argparse.ArgumentParser(description="MCSSFloor options")
+
+ # PATHS
+ self.parser.add_argument("--mcts_path",
+ type=str,
+ help="the name of the MonteFloorNet model",
+ default="/media/sinisa/Sinisa_hdd_data/Sinisa_Projects/corridor_localisation/experiments/MonteFloorNet_experiments/room_shape_experiments/Structured3D_test/")
+
+
+ self.parser.add_argument("--dataset_path",
+ type=str,
+ help="the name of the MonteFloorNet model",
+ default="")
+
+ self.parser.add_argument("--dataset_type",
+ type=str,
+ help="the name of the dataset type",
+ choices=["floornet", "s3d", "fsp"])
+ self.parser.add_argument("--scene_id",
+ type=str,
+ help="the name of the scene",
+ default="0")
+
+ self.parser.add_argument("--min_scene_ind",
+ type=int,
+ help="the name of the scene",
+ default=0)
+ self.parser.add_argument("--max_scene_ind",
+ type=int,
+ help="the name of the scene",
+ default=251)
+
+ # MonteFloorNet options
+ # self.parser.add_argument("--model_S_path",
+ # help="the name of the MonteFloorNet model",
+ # default="/home/sinisa/tmp/current_experiments/montefloornet_S_model_camera_ready16/best_models/weights_935")
+
+
+ self.parser.add_argument("--height",
+ type=int,
+ help="input image height",
+ default=256)
+ self.parser.add_argument("--width",
+ type=int,
+ help="input image width",
+ default=256)
+
+ def parse(self):
+ self.options = self.parser.parse_args()
+ return self.options
diff --git a/s3d_floorplan_eval/planar_graph_utils.py b/s3d_floorplan_eval/planar_graph_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02c24b9f7ecd3c6c3de8f1f96ed172696539ef8
--- /dev/null
+++ b/s3d_floorplan_eval/planar_graph_utils.py
@@ -0,0 +1,357 @@
+import os
+import numpy as np
+import cv2
+from scipy import ndimage
+from shapely.geometry import Polygon
+
+
+def extract_regions(adj_mat, corners, corner_sorted):
+ all_regions = list()
+ cur_idx = 0
+ corners = corners.astype(np.int)
+ nb_orders = _sort_neighours(adj_mat, corners)
+ while cur_idx is not None:
+ regions = _get_regions_for_corner(cur_idx, adj_mat, nb_orders)
+ all_regions.extend(regions)
+ cur_idx = _get_new_start(adj_mat, cur_idx, corners)
+
+ outwall_idx = get_outwall(all_regions, corners, corner_sorted)
+ all_regions.pop(outwall_idx)
+
+ # all_regions = filter_regions(all_regions) # only used for drawing visualization
+ # return all_regions
+
+ all_regions_coords = [corners[regions] for regions in all_regions]
+ return all_regions_coords
+
+
+def get_outwall(all_regions, corners, corner_sorted):
+ """
+ Find the outermost boundary loop, which should be discarded
+ """
+ if corner_sorted:
+ regions_for_top_bot = np.nonzero([(0 in region and len(corners) - 1 in region) for region in all_regions])[0]
+ if len(regions_for_top_bot) == 1:
+ return regions_for_top_bot[0]
+ else:
+ areas = [_compute_region_area(corners[all_regions[idx]]) for idx in range(len(all_regions))]
+ max_idx = np.argmax(areas)
+ return max_idx
+ else:
+ areas = [_compute_region_area(corners[all_regions[idx]]) for idx in range(len(all_regions))]
+ max_idx = np.argmax(areas)
+ return max_idx
+
+
+
+# def filter_regions(all_regions):
+# areas = [_compute_region_area(corners[all_regions[idx]]) for idx in range(len(all_regions))]
+# all_regions = [region for idx, region in enumerate(all_regions) if areas[idx] > 20]
+# return all_regions
+
+
+def _compute_region_area(region):
+ edge_map = np.zeros([256, 256])
+ for idx, c in enumerate(region[:-1]):
+ cv2.line(edge_map, tuple(c), tuple(region[idx + 1]), 1, 3)
+ reverse_edge_map = 1 - edge_map
+ label, num_features = ndimage.label(reverse_edge_map)
+ if num_features < 2:
+ return 0
+ # import pdb; pdb.set_trace()
+ # raise ValueError('Invalid region structure')
+ bg_label = label[0, 0]
+ num_labels = [(label == l).sum() for l in range(1, num_features + 1)]
+ num_labels[bg_label - 1] = 0
+ room_label = np.argmax(num_labels) + 1
+ area = (label == room_label).sum()
+ return area
+
+
+def _get_regions_for_corner(cur_idx, adj_mat, nb_orders):
+ regions = list()
+ if adj_mat[cur_idx].sum() == 0:
+ assert ValueError('Zero-degree corner, should not reach here')
+ # elif adj_mat[cur_idx].sum() == 1: # remove the connection if only one neighbour
+ # other_idx = nb_orders[0]
+ # import pdb; pdb.set_trace()
+ # adj_mat[cur_idx, other_idx] = 0
+ else:
+ v_s = cur_idx
+ know_v_q = False
+ while v_s is not None:
+ if not know_v_q:
+ v_p, v_q = _find_wedge_nbs(v_s, nb_orders, adj_mat)
+ if v_p is None: # cannot find proper wedge, remove this corner
+ adj_mat[v_s, :] = 0
+ adj_mat[:, v_s] = 0
+ break
+ else:
+ assert v_q is not None, 'v_q should be known here'
+ v_p = _find_wedge_third_v(v_q, v_s, nb_orders, adj_mat, dir=-1)
+ if v_p is None:
+ adj_mat[v_s, :] = 0
+ adj_mat[:, v_s] = 0
+ break
+ cur_region = [v_p, v_s, ]
+ # try:
+ assert adj_mat[v_p, v_s] == 1, 'Wrong connection matrix!'
+ # except:
+ # import pdb; pdb.set_trace()
+ adj_mat[v_p, v_s] = 0
+ region_i = 0
+ closed_polygon = False
+ while v_q is not None: # tracking the current region
+ cur_region.append(v_q)
+ assert adj_mat[v_s, v_q] == 1, 'Wrong connection matrix!'
+ adj_mat[v_s, v_q] = 0
+ # update the nb order list for the current v_s
+ if v_q == cur_region[0]: # get a closed polygon
+ closed_polygon = True
+ break
+ else:
+ v_p = cur_region[region_i + 1]
+ v_s = cur_region[region_i + 2]
+ v_q = _find_wedge_third_v(v_p, v_s, nb_orders, adj_mat, dir=1)
+ if v_q == None:
+ closed_polygon = False
+ break
+ region_i += 1
+
+ if closed_polygon: # find a closed region, keep iteration
+ regions.append(cur_region)
+ found_next = False
+ for temp_i in range(1, len(cur_region)):
+ if adj_mat[cur_region[temp_i], cur_region[temp_i - 1]] == 1:
+ found_next = True
+ v_s_idx = temp_i
+ break
+ if not found_next:
+ v_s = None
+ else:
+ v_s = cur_region[v_s_idx]
+ v_q = cur_region[v_s_idx - 1]
+ know_v_q = True
+ else: # no closed region, directly quit the search for the current v_s
+ break
+ return regions
+
+
+def _find_wedge_nbs(v_s, nb_orders, adj_mat):
+ sorted_nbs = nb_orders[v_s]
+ start_idx = 0
+ while True:
+ if start_idx == -len(sorted_nbs):
+ return None, None
+ v_p, v_q = sorted_nbs[start_idx], sorted_nbs[start_idx - 1]
+ if adj_mat[v_p, v_s] == 1 and adj_mat[v_s, v_q] == 1:
+ return v_p, v_q
+ else:
+ start_idx -= 1
+
+
+def _find_wedge_third_v(v1, v2, nb_orders, adj_mat, dir):
+ sorted_nbs = nb_orders[v2]
+ v1_idx = sorted_nbs.index(v1)
+ if dir == 1:
+ v3_idx = v1_idx - 1
+ while adj_mat[v2, sorted_nbs[v3_idx]] == 0:
+ if sorted_nbs[v3_idx] == v1:
+ return None
+ v3_idx -= 1
+ elif dir == -1:
+ v3_idx = v1_idx + 1 if v1_idx <= len(sorted_nbs) - 2 else 0
+ while adj_mat[sorted_nbs[v3_idx], v2] == 0:
+ if sorted_nbs[v3_idx] == v1:
+ return None
+ v3_idx = v3_idx + 1 if v3_idx <= len(sorted_nbs) - 2 else 0
+ else:
+ raise ValueError('unknown dir {}'.format(dir))
+ return sorted_nbs[v3_idx]
+
+
+def _get_new_start(adj_mat, cur_idx, corners):
+ for i in range(cur_idx, len(corners)):
+ if adj_mat[i].sum() > 0:
+ return i
+ return None
+
+
+def _sort_neighours(adj_mat, corners):
+ nb_orders = dict()
+ for idx, c in enumerate(corners):
+ nb_ids = np.nonzero(adj_mat[idx])[0]
+ nb_degrees = [_compute_degree(c, corners[other_idx]) for other_idx in nb_ids]
+ degree_ranks = np.argsort(nb_degrees)
+ sort_nb_ids = [nb_ids[i] for i in degree_ranks]
+ nb_orders[idx] = sort_nb_ids
+ return nb_orders
+
+
+def _compute_degree(c1, c2):
+ vec = (c2[0] - c1[0], -(c2[1] - c1[1])) # note that the y direction should be flipped (image coord system)
+ cos = (vec[0] * 1 + vec[1] * 0) / np.sqrt(vec[0] ** 2 + vec[1] ** 2)
+ theta = np.arccos(cos)
+ if vec[1] < 0:
+ theta = np.pi * 2 - theta
+ return theta
+
+
+def preprocess_pg(pg):
+ corners = pg['corners']
+ edge_pairs = pg['edges']
+ adj_mat = np.zeros([len(corners), len(corners)])
+ for edge_pair in edge_pairs:
+ c1, c2 = edge_pair
+ adj_mat[c1][c2] = 1
+ adj_mat[c2][c1] = 1
+
+ return corners, adj_mat
+
+
+def cleanup_pg(pg):
+ corners = pg['corners']
+ edge_pairs = pg['edges']
+ adj_list = [[] for _ in range(len(corners))]
+
+ for edge_pair in edge_pairs:
+ adj_list[edge_pair[0]].append(edge_pair[1])
+ adj_list[edge_pair[1]].append(edge_pair[0])
+
+ for idx in range(len(corners)):
+ if len(adj_list[idx]) < 2:
+ _remove_corner(idx, adj_list)
+
+ new_corners = list()
+ removed_ids = list()
+ old_to_new = dict()
+ counter = 0
+ for c_i in range(len(adj_list)):
+ if len(adj_list[c_i]) > 0:
+ assert len(adj_list[c_i]) >= 2
+ new_corners.append(corners[c_i])
+ old_to_new[c_i] = counter
+ counter += 1
+ else:
+ removed_ids.append(c_i)
+
+ new_edges = list()
+ for c_i_1 in range(len(adj_list)):
+ for c_i_2 in adj_list[c_i_1]:
+ if c_i_1 < c_i_2:
+ new_edge = (old_to_new[c_i_1], old_to_new[c_i_2])
+ new_edges.append(new_edge)
+ new_corners = np.array(new_corners)
+ new_edges = np.array(new_edges)
+ new_pg = {
+ 'corners': new_corners,
+ 'edges': new_edges,
+ }
+ return new_pg
+
+
+def _remove_corner(idx, adj_list):
+ assert len(adj_list[idx]) <= 1
+ if len(adj_list[idx]) == 0:
+ return
+ nbs = list(adj_list[idx])
+ adj_list[idx].pop(0)
+ for nb in nbs:
+ adj_list[nb].remove(idx)
+ if len(adj_list[nb]) < 2:
+ _remove_corner(nb, adj_list)
+
+
+def get_regions_from_pg(pg, corner_sorted):
+ pg = cleanup_pg(pg)
+ corners, adj_mat = preprocess_pg(pg)
+ if len(corners) == 0:
+ regions = []
+ else:
+ regions = extract_regions(adj_mat, corners, corner_sorted)
+ return regions
+
+
+def convert_annot(annot):
+ corners = np.array(list(annot.keys()))
+ corners_mapping = {tuple(c): idx for idx, c in enumerate(corners)}
+ edges = set()
+ for corner, connections in annot.items():
+ idx_c = corners_mapping[tuple(corner)]
+ for other_c in connections:
+ idx_other_c = corners_mapping[tuple(other_c)]
+ if (idx_c, idx_other_c) not in edges and (idx_other_c, idx_c) not in edges:
+ edges.add((idx_c, idx_other_c))
+ edges = np.array(list(edges))
+ pg_data = {
+ 'corners': corners,
+ 'edges': edges
+ }
+ return pg_data
+
+
+colors_12 = [
+ "#DCECC9",
+ "#B3DDCC",
+ "#8ACDCE",
+ "#62BED2",
+ "#46AACE",
+ "#3D91BE",
+ "#3677AE",
+ "#2D5E9E",
+ "#24448E",
+ "#1C2B7F",
+ "#162165",
+ "#11174B",
+]
+
+
+def plot_floorplan_with_regions(regions, corners, edges, scale):
+ colors = colors_12[:8]
+
+ regions = [(region * scale / 256).round().astype(np.int) for region in regions]
+ corners = (corners * scale / 256).round().astype(np.int)
+
+ # define the color map
+ room_colors = [colors[i % 8] for i in range(len(regions))]
+
+ colorMap = [tuple(int(h[i:i + 2], 16) for i in (1, 3, 5)) for h in room_colors]
+ colorMap = np.asarray(colorMap)
+ if len(regions) > 0:
+ colorMap = np.concatenate([np.full(shape=(1, 3), fill_value=0), colorMap], axis=0).astype(
+ np.uint8)
+ else:
+ colorMap = np.concatenate([np.full(shape=(1, 3), fill_value=0)], axis=0).astype(
+ np.uint8)
+ # when using opencv, we need to flip, from RGB to BGR
+ colorMap = colorMap[:, ::-1]
+
+ alpha_channels = np.zeros(colorMap.shape[0], dtype=np.uint8)
+ alpha_channels[1:len(regions) + 1] = 150
+
+ colorMap = np.concatenate([colorMap, np.expand_dims(alpha_channels, axis=-1)], axis=-1)
+
+ room_map = np.zeros([scale, scale]).astype(np.int32)
+ # sort regions
+ if len(regions) > 1:
+ avg_corner = [region.mean(axis=0) for region in regions]
+ ind = np.argsort(np.array(avg_corner)[:, 0], axis=0)
+ regions = np.array(regions)[ind]
+
+ for idx, polygon in enumerate(regions):
+ cv2.fillPoly(room_map, [polygon], color=idx + 1)
+
+ image = colorMap[room_map.reshape(-1)].reshape((scale, scale, 4))
+
+ pointColor = tuple((np.array([0.95, 0.3, 0.3, 1]) * 255).astype(np.uint8).tolist())
+ for point in corners:
+ cv2.circle(image, tuple(point), color=pointColor, radius=12, thickness=-1)
+ cv2.circle(image, tuple(point), color=(255, 255, 255, 255), radius=6, thickness=-1)
+
+ for edge in edges:
+ c1 = corners[edge[0]]
+ c2 = corners[edge[1]]
+ cv2.line(image, tuple(c1), tuple(c2), color=(0, 0, 0, 255), thickness=3)
+
+ return image
+
diff --git a/s3d_floorplan_eval/visualize_npy.py b/s3d_floorplan_eval/visualize_npy.py
new file mode 100644
index 0000000000000000000000000000000000000000..299df7d92869ca6678d3d91837f9247c89243672
--- /dev/null
+++ b/s3d_floorplan_eval/visualize_npy.py
@@ -0,0 +1,31 @@
+import os
+import numpy as np
+import cv2
+from planar_graph_utils import get_regions_from_pg, plot_floorplan_with_regions
+
+# example_pg = {
+# (127, 20): [(20, 120), (234, 120)],
+# (20, 120): [(127, 20), (234, 120), (20, 240)],
+# (234, 120): [(127, 20), (20, 120), (234, 240)],
+# (20, 240): [(20, 120), (234, 240)],
+# (234, 240): [(234, 120), (20, 240)],
+# }
+
+pg_base = '../results/npy_heat_s3d_256/'
+viz_base = './viz_heat'
+if not os.path.exists(viz_base):
+ os.makedirs(viz_base)
+
+for filename in sorted(os.listdir(pg_base)):
+ pg_path = os.path.join(pg_base, filename)
+ example_pg = np.load(pg_path, allow_pickle=True).tolist()
+
+ corners = example_pg['corners']
+ corners = corners.astype(np.int)
+ edges = example_pg['edges']
+
+ print('Processing file: {}'.format(filename))
+ regions = get_regions_from_pg(example_pg, corner_sorted=True)
+ print('num of extracted regions {}'.format(len(regions)))
+ floorplan_image = plot_floorplan_with_regions(regions, corners, edges, scale=1000)
+ cv2.imwrite(os.path.join(viz_base, '{}.png'.format(filename[:-4])), floorplan_image)
diff --git a/s3d_preprocess/.gitignore b/s3d_preprocess/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5f124176ca7e3a7f6ca3bf2dd72b087c1cf46c2b
--- /dev/null
+++ b/s3d_preprocess/.gitignore
@@ -0,0 +1,8 @@
+*.tar
+*.zip
+*.png
+.DS_Store
+__pycache__
+s3d_raw
+s3d_floorplan
+montefloor_data
diff --git a/s3d_preprocess/DataProcessing/FloorRW.py b/s3d_preprocess/DataProcessing/FloorRW.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c61696bf75412e26ba7b7f8a845d058d8356f8
--- /dev/null
+++ b/s3d_preprocess/DataProcessing/FloorRW.py
@@ -0,0 +1,294 @@
+import numpy as np
+import open3d as o3d
+import os
+import json
+import time
+import cv2
+import matplotlib.pyplot as plt
+from shapely.geometry import Polygon
+from descartes.patch import PolygonPatch
+import io
+import PIL
+import time
+
+from misc.figures import plot_coords
+from misc.colors import colormap_255, semantics_cmap
+from visualize_3d import visualize_floorplan
+
+from DataProcessing.PointCloudReaderPanorama import PointCloudReaderPanorama
+
+
+class FloorRW:
+ def __init__(self):
+ self.dataset_path = "./s3d_raw/" # set the path to the raw data here
+ self.mode = "train"
+ self.scenes_path = os.path.join(self.dataset_path, self.mode)
+
+ self.out_folder = "floor_data_with_normals"
+ self.density_map_file_name = "density.png"
+ self.normals_map_file_name = "normals.png"
+ self.anno_file_name = "annotation_3d.json"
+ self.vis_file_name = "vis.jpg"
+
+ self.coco_floor_json_path = os.path.join(self.dataset_path, self.mode + "_floor.json")
+
+ # TODO Don't change these values, Adapt PointCloudReaderPanorama first
+ self.w = 256
+ self.h = 256
+
+ self.invalid_scenes = ["scene_00183", "scene_01155", "scene_01816"]
+
+ def generate_floors(self):
+ scenes = os.listdir(self.scenes_path)
+ scenes.sort()
+
+
+
+ for scene_ind, scene in enumerate(scenes):
+ # if scene == "scene_01155":
+ # continue
+ if scene in self.invalid_scenes:
+ continue
+ # if scene_ind < 178:
+ # # # if scene_ind != 0:
+ # continue
+ print("%d / %d Current scene %s" % (scene_ind + 1, len(scenes), scene))
+ start_time = time.time()
+
+ scene_path = os.path.join(self.scenes_path, scene)
+ # annotation_json = self.normalize_annotations(scene_path, {})
+
+ reader = PointCloudReaderPanorama(scene_path, random_level=0, generate_color=True, generate_normal=False)
+ density_map, normals_map, normalization_dict = reader.generate_density()
+
+ normalized_annotations = self.normalize_annotations(scene_path, normalization_dict)
+
+ # visualize_floorplan(normalized_annotations)
+ # self.vis_scene_data(density_map, normalized_annotations)
+ # reader.visualize()
+ self.export_scene(scene_path, density_map, normals_map, normalized_annotations)
+
+ print("Scene processing time %.3f" % (time.time() - start_time))
+
+ def generate_coco_json(self):
+ scenes = os.listdir(self.scenes_path)
+ scenes.sort()
+
+ img_id = -1
+ instance_id = -1
+ coco_dict = {}
+ coco_dict["images"] = []
+ coco_dict["annotations"] = []
+ coco_dict["categories"] = [{"supercategory": "room", "id": 1, "name": "room"}]
+ for scene_ind, scene in enumerate(scenes):
+ # if scene_ind != 66:
+ # continue
+ if scene in self.invalid_scenes:
+ continue
+ # if scene_ind > 1000:
+ # break
+
+ img_id += 1
+ print("%d / %d Current scene %s" % (scene_ind + 1, len(scenes), scene))
+
+ scene_path = os.path.join(self.scenes_path, scene)
+
+ img_relative_path = os.path.join("./", scene, self.out_folder, self.density_map_file_name)
+ annos_path = os.path.join(scene_path, self.out_folder, self.anno_file_name)
+
+ with open(annos_path, "r") as f:
+ annos = json.load(f)
+
+ img_dict = {}
+ img_dict["file_name"] = img_relative_path
+ img_dict["id"] = img_id
+ img_dict["width"] = self.w
+ img_dict["height"] = self.h
+
+ coco_annotation_dict_list = self.parse_coco_annotation(annos, instance_id, img_id)
+
+ coco_dict["images"].append(img_dict)
+ coco_dict["annotations"] += coco_annotation_dict_list
+ instance_id += len(coco_annotation_dict_list)
+
+ with open(self.coco_floor_json_path, 'w') as f:
+ json.dump(coco_dict, f)
+
+ def parse_coco_annotation(self, annos, curr_instance_id, curr_img_id):
+ polygons = visualize_floorplan(annos, vis=False, ret=True)
+
+ ignore_types = ['outwall', 'door', 'window']
+
+ coco_annotation_dict_list = []
+ junctions = np.array([junc['coordinate'][:2] for junc in annos['junctions']])
+ for (poly, poly_type) in polygons:
+ if poly_type in ignore_types:
+ continue
+
+ poly = junctions[np.array(poly)]
+ poly_shapely = Polygon(poly)
+ area = poly_shapely.area
+
+ # assert area > 10
+ if area < 100:
+ continue
+
+ rectangle_shapely = poly_shapely.envelope
+
+ coco_seg_poly = []
+ for p in poly:
+ coco_seg_poly += list(p)
+
+ # Slightly wider bounding box
+ bound_pad = 5
+ bb_x, bb_y = rectangle_shapely.exterior.xy
+ bb_x = np.unique(bb_x)
+ bb_y = np.unique(bb_y)
+ bb_x_min = np.maximum(np.min(bb_x) - bound_pad, 0)
+ bb_y_min = np.maximum(np.min(bb_y) - bound_pad, 0)
+
+ bb_x_max = np.minimum(np.max(bb_x) + bound_pad, self.w - 1)
+ bb_y_max = np.minimum(np.max(bb_y) + bound_pad, self.h - 1)
+
+ bb_width = (bb_x_max - bb_x_min)
+ bb_height = (bb_y_max - bb_y_min)
+
+ coco_bb = [bb_x_min, bb_y_min, bb_width, bb_height]
+
+ curr_instance_id += 1
+ coco_annotation_dict = {
+ "segmentation": [coco_seg_poly],
+ "area": area,
+ "iscrowd": 0,
+ "image_id": curr_img_id,
+ "bbox": coco_bb,
+ "category_id": 1,
+ "id": curr_instance_id}
+ coco_annotation_dict_list.append(coco_annotation_dict)
+
+ # plt.figure()
+ # plt.imshow(np.zeros((256, 256)))
+ # x, y = poly_shapely.exterior.xy
+ # plt.plot(x, y, "r")
+ #
+ # plt.plot([bb_x_min, bb_x_min + bb_width, bb_x_min + bb_width, bb_x_min, bb_x_min],
+ # [bb_y_min, bb_y_min, bb_y_min + bb_height, bb_y_min + bb_height, bb_y_min], "b")
+ #
+ # plt.show()
+
+ return coco_annotation_dict_list
+
+
+
+ def export_scene(self, scene_path, density_map, normals_map, annos):
+ def export_density():
+ density_path = os.path.join(floorplan_folder_path, self.density_map_file_name)
+ density_uint8 = (density_map * 255).astype(np.uint8)
+ cv2.imwrite(density_path, density_uint8)
+
+ def export_normals():
+ normals_path = os.path.join(floorplan_folder_path, self.normals_map_file_name)
+ normals_uint8 = (np.clip(normals_map, 0, 1) * 255).astype(np.uint8)
+ cv2.imwrite(normals_path, normals_uint8)
+
+ def export_annos():
+ anno_path = os.path.join(floorplan_folder_path, self.anno_file_name)
+ with open(anno_path, 'w') as f:
+ json.dump(annos, f)
+
+ def export_vis():
+ vis_path = os.path.join(floorplan_folder_path, self.vis_file_name)
+ vis = self.vis_scene_data(density_map, annos, show=False)
+ if vis is not None:
+ cv2.imwrite(vis_path, vis)
+ else:
+ print("Visualization is None. Skip exporting the visualization...")
+
+ floorplan_folder_path = os.path.join(scene_path, self.out_folder)
+ if not os.path.isdir(floorplan_folder_path):
+ os.mkdir(floorplan_folder_path)
+
+ export_density()
+ export_normals()
+ export_annos()
+ export_vis()
+
+ def normalize_annotations(self, scene_path, normalization_dict):
+ annotation_path = os.path.join(scene_path, "annotation_3d.json")
+ with open(annotation_path, "r") as f:
+ annotation_json = json.load(f)
+
+ for line in annotation_json["lines"]:
+ point = line["point"]
+ point = self.normalize_point(point, normalization_dict)
+ line["point"] = point
+
+ for junction in annotation_json["junctions"]:
+ point = junction["coordinate"]
+ point = self.normalize_point(point, normalization_dict)
+ junction["coordinate"] = point
+
+ normalization_dict["min_coords"] = normalization_dict["min_coords"].tolist()
+ normalization_dict["max_coords"] = normalization_dict["max_coords"].tolist()
+ normalization_dict["image_res"] = normalization_dict["image_res"].tolist()
+
+ return annotation_json
+
+ def normalize_point(self, point, normalization_dict):
+
+ min_coords = normalization_dict["min_coords"]
+ max_coords = normalization_dict["max_coords"]
+ image_res = normalization_dict["image_res"]
+
+ point_2d = \
+ np.round(
+ (point[:2] - min_coords[:2]) / (max_coords[:2] - min_coords[:2]) * image_res)
+ point_2d = np.minimum(np.maximum(point_2d, np.zeros_like(image_res)),
+ image_res - 1)
+
+ point[:2] = point_2d.tolist()
+
+ return point
+
+ def vis_scene_data(self, density_map, annos, show=True):
+ polygons = visualize_floorplan(annos, vis=False, ret=True)
+
+ if polygons is None:
+ return None
+
+ fig = plt.figure()
+ gs = fig.add_gridspec(1, 2)
+
+ ax0 = fig.add_subplot(gs[0, 0])
+ ax0.imshow(density_map)
+ plt.axis('equal')
+ plt.axis('off')
+
+ ax1 = fig.add_subplot(gs[0, 1])
+
+ junctions = np.array([junc['coordinate'][:2] for junc in annos['junctions']])
+ for (polygon, poly_type) in polygons:
+ polygon = Polygon(junctions[np.array(polygon)])
+ plot_coords(ax1, polygon.exterior, alpha=0.5)
+ if poly_type == 'outwall':
+ patch = PolygonPatch(polygon, facecolor=semantics_cmap[poly_type], alpha=0)
+ else:
+ patch = PolygonPatch(polygon, facecolor=semantics_cmap[poly_type], alpha=0.5)
+ ax1.add_patch(patch)
+
+ ax1.set_ylim(density_map.shape[0], 0)
+ ax1.set_xlim(0, density_map.shape[1])
+ plt.axis('equal')
+ plt.axis('off')
+
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png', bbox_inches='tight')
+ buf.seek(0)
+
+ if show:
+ plt.show()
+ plt.close()
+
+ vis = PIL.Image.open(buf)
+ vis = np.array(vis)
+ return vis
diff --git a/s3d_preprocess/DataProcessing/PointCloudReaderPanorama.py b/s3d_preprocess/DataProcessing/PointCloudReaderPanorama.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed7f5699e0741cfa2bb24a391a34c9f1fd65c608
--- /dev/null
+++ b/s3d_preprocess/DataProcessing/PointCloudReaderPanorama.py
@@ -0,0 +1,274 @@
+import cv2
+import open3d as o3d
+import os
+from sklearn.preprocessing import normalize
+import json
+import matplotlib.pyplot as plt
+
+from sem_seg_utils import *
+from visualize_3d import visualize_wireframe
+
+NUM_SECTIONS = -1
+
+class PointCloudReaderPanorama():
+
+ def __init__(self, path, resolution="full", random_level=0, generate_color=False, generate_normal=False):
+ self.path = path
+ self.random_level = random_level
+ self.resolution = resolution
+ self.generate_color = generate_color
+ self.generate_normal = generate_normal
+ sections = [p for p in os.listdir(os.path.join(path, "2D_rendering"))]
+ self.depth_paths = [os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "depth.png"]) for p in sections]
+ self.rgb_paths = [os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "rgb_coldlight.png"]) for p in sections]
+ self.normal_paths = [os.path.join(*[path, "2D_rendering", p, "panorama", self.resolution, "normal.png"]) for p in sections]
+ self.camera_paths = [os.path.join(*[path, "2D_rendering", p, "panorama", "camera_xyz.txt"]) for p in sections]
+ self.camera_centers = self.read_camera_center()
+ self.point_cloud = self.generate_point_cloud(self.random_level, color=self.generate_color, normal=self.generate_normal)
+
+ def read_camera_center(self):
+ camera_centers = []
+ for i in range(len(self.camera_paths)):
+ with open(self.camera_paths[i], 'r') as f:
+ line = f.readline()
+ center = list(map(float, line.strip().split(" ")))
+ camera_centers.append(np.asarray([center[0], center[1], center[2]]))
+ return camera_centers
+
+ def generate_point_cloud(self, random_level=0, color=False, normal=False):
+ coords = []
+ colors = []
+ normals = []
+ points = {}
+
+ # Getting Coordinates
+ for i in range(len(self.depth_paths)):
+ depth_img = cv2.imread(self.depth_paths[i], cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
+ x_tick = 180.0/depth_img.shape[0]
+ y_tick = 360.0/depth_img.shape[1]
+
+ rgb_img = cv2.imread(self.rgb_paths[i])
+ rgb_img = cv2.cvtColor(rgb_img, code=cv2.COLOR_BGR2RGB)
+ normal_img = cv2.imread(self.normal_paths[i])
+
+ for x in range(0, depth_img.shape[0]):
+ for y in range(0, depth_img.shape[1]):
+ # need 90 - -09
+ alpha = 90 - (x * x_tick)
+ beta = y * y_tick -180
+
+ depth = depth_img[x,y] + np.random.random()*random_level
+
+ if depth > 500.:
+ z_offset = depth*np.sin(np.deg2rad(alpha))
+ xy_offset = depth*np.cos(np.deg2rad(alpha))
+ x_offset = xy_offset * np.sin(np.deg2rad(beta))
+ y_offset = xy_offset * np.cos(np.deg2rad(beta))
+ point = np.asarray([x_offset, y_offset, z_offset])
+ coords.append(point + self.camera_centers[i])
+ colors.append(rgb_img[x, y])
+ # normals.append(normalize(normal_img[x, y].reshape(-1, 1)).ravel())
+ # break
+
+ coords = np.asarray(coords)
+ colors = np.asarray(colors) / 255.0
+ # normals = np.asarray(normals)
+
+ coords[:,:2] = np.round(coords[:,:2] / 10) * 10.
+ coords[:,2] = np.round(coords[:,2] / 100) * 100.
+ unique_coords, unique_ind = np.unique(coords, return_index=True, axis=0)
+
+ coords = coords[unique_ind]
+ colors = colors[unique_ind]
+ # normals = normals[unique_ind]
+
+
+ points['coords'] = coords
+ points['colors'] = colors
+ # points['normals'] = normals
+
+ # if color:
+ # # Getting RGB color
+ # for i in range(len(self.rgb_paths)):
+ # rgb_img = cv2.imread(self.rgb_paths[i])
+ # rgb_img = cv2.cvtColor(rgb_img, code=cv2.COLOR_BGR2RGB)
+ # for x in range(0, rgb_img.shape[0], 2):
+ # for y in range(0, rgb_img.shape[1], 2):
+ # colors.append(rgb_img[x, y])
+ # points['colors'] = np.asarray(colors)/255.0
+ # if normal:
+ # # Getting Normal
+ # for i in range(len(self.normal_paths)):
+ # normal_img = cv2.imread(self.normal_paths[i])
+ # for x in range(0, normal_img.shape[0], 2):
+ # for y in range(0, normal_img.shape[1], 2):
+ # normals.append(normalize(normal_img[x, y].reshape(-1, 1)).ravel())
+ # points['normals'] = normals
+
+ print("Pointcloud size:", points['coords'].shape[0])
+ return points
+
+ def get_point_cloud(self):
+ return self.point_cloud
+
+ def generate_density(self, width=256, height=256):
+
+ ps = self.point_cloud["coords"] * -1
+ ps[:,0] *= -1
+ ps[:,1] *= -1
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(ps)
+ pcd.estimate_normals()
+
+ # zs = np.round(ps[:,2] / 100) * 100
+ # zs, zs_ind = np.unique(zs, return_index=True, axis=0)
+ # ps_ind = ps[:, :2] ==
+ # print("Generate density...")
+
+ image_res = np.array((width, height))
+
+ max_coords = np.max(ps, axis=0)
+ min_coords = np.min(ps, axis=0)
+ max_m_min = max_coords - min_coords
+
+ max_coords = max_coords + 0.1 * max_m_min
+ min_coords = min_coords - 0.1 * max_m_min
+
+ normalization_dict = {}
+ normalization_dict["min_coords"] = min_coords
+ normalization_dict["max_coords"] = max_coords
+ normalization_dict["image_res"] = image_res
+
+
+ # coordinates = np.round(points[:, :2] / max_coordinates[None,:2] * image_res[None])
+ coordinates = \
+ np.round(
+ (ps[:, :2] - min_coords[None, :2]) / (max_coords[None,:2] - min_coords[None, :2]) * image_res[None])
+ coordinates = np.minimum(np.maximum(coordinates, np.zeros_like(image_res)),
+ image_res - 1)
+
+ density = np.zeros((height, width), dtype=np.float32)
+
+ unique_coordinates, counts = np.unique(coordinates, return_counts=True, axis=0)
+ # print(np.unique(counts))
+ # counts = np.minimum(counts, 1e2)
+
+ unique_coordinates = unique_coordinates.astype(np.int32)
+
+ density[unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts
+ density = density / np.max(density)
+ # print(np.unique(density))
+
+ normals = np.array(pcd.normals)
+ normals_map = np.zeros((density.shape[0], density.shape[1], 3))
+
+ import time
+ start_time = time.time()
+ for i, unique_coord in enumerate(unique_coordinates):
+ # print(normals[unique_ind])
+ normals_indcs = np.argwhere(np.all(coordinates[::10] == unique_coord, axis=1))[:,0]
+ normals_map[unique_coordinates[i, 1], unique_coordinates[i, 0], :] = np.mean(normals[::10][normals_indcs, :], axis=0)
+
+ print("Time for normals: ", time.time() - start_time)
+
+ normals_map = (np.clip(normals_map,0,1) * 255).astype(np.uint8)
+
+ # plt.figure()
+ # plt.imshow(normals_map)
+ # plt.show()
+
+ return density, normals_map, normalization_dict
+
+ def visualize(self, export_path=None):
+ pcd = o3d.geometry.PointCloud()
+
+ points = self.point_cloud['coords']
+
+ print(np.max(points, axis=0))
+ indices = np.where(points[:, 2] < 2000)
+
+ points = points[indices]
+ points[:,1] *= -1
+ points[:,:] /= 1000
+ pcd.points = o3d.utility.Vector3dVector(points)
+
+ if self.generate_normal:
+ normals = self.point_cloud['normals']
+ normals = normals[indices]
+ pcd.normals = o3d.utility.Vector3dVector(normals)
+ if self.generate_color:
+ colors = self.point_cloud['colors']
+ colors = colors[indices]
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+
+ with open("/media/sinisa/Sinisa_hdd_data/Sinisa_Projects/corridor_localisation/Datasets/Structured_3D_dataset/Structured3D/Structured3D_0/Structured3D/train/scene_00015/annotation_3d.json") as file:
+ annos = json.load(file)
+
+
+
+ # wireframe_geo_list = visualize_wireframe(annos, vis=False, ret=True)
+ # o3d.visualization.draw_geometries([pcd] + wireframe_geo_list)
+ # o3d.visualization.draw_geometries([pcd])
+
+ pcd.estimate_normals()
+
+ # radii = 0.01
+ # mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, radii)
+
+ # alpha = 0.1
+ # tetra_mesh, pt_map = o3d.geometry.TetraMesh.create_from_point_cloud(pcd)
+ # mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha, tetra_mesh, pt_map)
+
+ o3d.visualization.draw_geometries([pcd])
+
+ if export_path is not None:
+
+ o3d.io.write_point_cloud(export_path, pcd)
+
+ # o3d.visualization.draw_geometries([pcd])
+
+ def export_ply(self, path):
+ '''
+ ply
+ format ascii 1.0
+ comment Mars model by Paul Bourke
+ element vertex 259200
+ property float x
+ property float y
+ property float z
+ property uchar r
+ property uchar g
+ property uchar b
+ property float nx
+ property float ny
+ property float nz
+ end_header
+ '''
+ with open(path, "w") as f:
+ f.write("ply\n")
+ f.write("format ascii 1.0\n")
+ f.write("element vertex %d\n" % self.point_cloud['coords'].shape[0])
+ f.write("property float x\n")
+ f.write("property float y\n")
+ f.write("property float z\n")
+ if self.generate_color:
+ f.write("property uchar red\n")
+ f.write("property uchar green\n")
+ f.write("property uchar blue\n")
+ if self.generate_normal:
+ f.write("property float nx\n")
+ f.write("property float ny\n")
+ f.write("property float nz\n")
+ f.write("end_header\n")
+ for i in range(self.point_cloud['coords'].shape[0]):
+ normal = []
+ color = []
+ coord = self.point_cloud['coords'][i].tolist()
+ if self.generate_color:
+ color = list(map(int, (self.point_cloud['colors'][i]*255).tolist()))
+ if self.generate_normal:
+ normal = self.point_cloud['normals'][i].tolist()
+ data = coord + color + normal
+ f.write(" ".join(list(map(str,data)))+'\n')
diff --git a/s3d_preprocess/DataProcessing/PointCloudReaderPerspective.py b/s3d_preprocess/DataProcessing/PointCloudReaderPerspective.py
new file mode 100644
index 0000000000000000000000000000000000000000..3562fa0aae9f9b6176d43ebbcc8008d3a6fff4ce
--- /dev/null
+++ b/s3d_preprocess/DataProcessing/PointCloudReaderPerspective.py
@@ -0,0 +1,395 @@
+import cv2
+import open3d
+import os
+import matplotlib.pyplot as plt
+from PIL import Image
+import json
+
+from misc.utils import parse_camera_info
+from sem_seg_utils import *
+from visualize_3d import visualize_wireframe
+
+class PointCloudReaderPerspective():
+
+ def __init__(self, path, resolution="full", random_level=0, generate_color=False, generate_normal=False,
+ generate_segmentation=False):
+ perspective_str = "perspective"
+ self.path = path
+ self.random_level = random_level
+ self.resolution = resolution
+ self.generate_color = generate_color
+ self.generate_normal = generate_normal
+ self.generate_segmentation = generate_segmentation
+ sections = sorted([p for p in os.listdir(os.path.join(path, "2D_rendering"))])
+
+ sections_views = [sorted(os.listdir(os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution]))) \
+ if os.path.isdir(os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution])) \
+ else [] \
+ for p in sections]
+
+ self.depth_paths = []
+ self.rgb_paths = []
+ self.seg_paths = []
+ self.normal_paths = []
+ self.pose_paths = []
+ for p, views in zip(sections, sections_views):
+ if not os.path.isdir(os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution])):
+ continue
+
+ self.depth_paths += [os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution, v, "depth.png"]) for v in views]
+ self.rgb_paths += [os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution, v, "rgb_rawlight.png"]) for v in views]
+ self.seg_paths += [os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution, v, "semantic.png"]) for v in views]
+ self.normal_paths += [os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution, v, "normal.png"]) for v in views]
+ self.pose_paths += [os.path.join(*[path, "2D_rendering", p, perspective_str, self.resolution, v, "camera_pose.txt"]) for v in views]
+
+ self.point_cloud = self.generate_point_cloud(self.random_level, color=self.generate_color,
+ normal=self.generate_normal,
+ seg=self.generate_segmentation)
+
+
+ def read_camera_center(self):
+ camera_centers = []
+ print(self.camera_paths)
+ print(self.depth_paths)
+ for i in range(len(self.camera_paths)):
+ with open(self.camera_paths[i], 'r') as f:
+ line = f.readline()
+ center = list(map(float, line.strip().split(" ")))
+ camera_centers.append(np.asarray([center[0], center[1], center[2]]))
+ print(camera_centers)
+ return camera_centers
+
+ def generate_point_cloud(self, random_level=0, color=False, normal=False, seg=False):
+ coords = []
+ colors = []
+ segmentations = []
+ normals = []
+ points = {}
+
+ # Getting Coordinates
+ for i in range(len(self.depth_paths)):
+ print(i)
+ # i = 13
+ W, H = (1280, 720)
+ depth_img = cv2.imread(self.depth_paths[i], cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) / 1000.
+ inv_depth_mask = depth_img < .2
+ depth_img[inv_depth_mask] = .2 # Why does this fix the problem?
+ # rgb_img = cv2.imread(self.rgb_paths[i])
+ # plt.subplot(121)
+ # plt.imshow(rgb_img)
+ # plt.subplot(122)
+ # plt.imshow(depth_img)
+ # plt.show()
+
+ camera_pose = np.loadtxt(self.pose_paths[i])
+ rot, trans, K = parse_camera_info(camera_pose, H, W, inverse=True)
+
+ pose = np.eye(4)
+ pose[:3, :3] = rot
+ pose[:3, 3] = trans / 1000.
+ inv_pose = np.linalg.inv(pose)
+
+ xs, ys = np.meshgrid(range(W), range(H), indexing='xy')
+
+ # xyz_homo = np.concatenate([xyz, np.ones_like(xs)], axis=0)
+ # xyz_h_global = pose.dot(xyz_homo).T
+ # xyz_global = xyz_h_global[:, :3] / xyz_h_global[:, 3][:, None]
+
+ if color:
+ rgb_img = cv2.imread(self.rgb_paths[i])
+ rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
+ # xs, ys = np.meshgrid(range(1280), range(720), indexing='xy')
+ if seg:
+ seg_img = Image.open(self.seg_paths[i])
+ # xs, ys = np.meshgrid(range(1280), range(720), indexing='xy')
+ seg_labels = np.array(seg_img.convert(mode="P", palette=create_color_palette()))
+
+ def seg_grad(seg1):
+ # [-1 0 1] kernel
+ dx = np.abs(seg1[:, 2:] - seg1[:, :-2])
+ dy = np.abs(seg1[2:, :] - seg1[:-2, :])
+
+ grad = np.zeros_like(seg1)
+ grad[:, 1:-1] = dx
+ grad[1:-1, :] = np.maximum(grad[1:-1, :], dy)
+
+ grad = grad != 0
+ return grad
+
+ def depth_grad(depth1):
+ # [-1 0 1] kernel
+ dx = np.abs(depth1[:, 2:] - depth1[:, :-2])
+ dy = np.abs(depth1[2:, :] - depth1[:-2, :])
+
+ grad = np.zeros_like(depth1)
+ grad[:, 1:-1] = dx
+ grad[1:-1, :] = np.maximum(grad[1:-1, :], dy)
+
+ grad = np.abs(grad) > 0.1
+ return grad
+
+ grad_mask = np.logical_and(depth_grad(depth_img), seg_grad(seg_labels))
+ # kern = np.ones((3, 3), np.uint8)
+ # seg_mask = cv2.dilate((seg_mask).astype(np.uint8), kernel=kern, iterations=1)
+
+ # plt.imshow(seg_mask)
+ # plt.show()
+ # not_windows = np.argwhere(seg_labels != class_name_to_id['window'])
+ # ys = not_windows[:, 0]
+ # xs = not_windows[:, 1]
+ #
+ # seg_labels = np.tile(np.round(seg_labels)[ys, xs].reshape(-1, 1), reps=[1, 3])
+ # seg_labels = np.tile(seg_labels[:, :, None], reps=[1, 1, 3]) / 255
+
+ # valid_mask = np.argwhere(valid_mask == 0)
+ # valid_mask = np.argwhere(np.logical_and(grad_mask == 0, seg_labels != class_name_to_id['window']))
+ valid_mask = np.argwhere(grad_mask == 0)
+
+ ys = valid_mask[:, 0]
+ xs = valid_mask[:, 1]
+
+ seg_labels[inv_depth_mask] = 38
+ seg_labels = np.tile(np.round(seg_labels)[ys, xs].reshape(-1, 1), reps=[1, 3])
+
+ zs = depth_img[ys, xs]
+ xs = xs.reshape(1, -1)
+ ys = ys.reshape(1, -1)
+ zs = zs.reshape(1, -1)
+
+ inverse_K = np.linalg.inv(K)
+
+ xyz = (inverse_K[:3, :3].dot(np.concatenate([xs, ys, np.ones_like(xs)], axis=0)))
+ xyz = zs * (xyz / np.linalg.norm(xyz, axis=0, ord=2))
+ # xyz = zs * xyz
+ xyz_o3d = open3d.geometry.PointCloud()
+ xyz_o3d.points = open3d.utility.Vector3dVector(xyz.T)
+ xyz_o3d.transform(pose)
+ xyz_global = np.asarray(xyz_o3d.points)
+
+ segmentations += list(seg_labels)
+ colors += list(rgb_img[ys, xs].reshape(-1,3))
+ coords += list(xyz_global)
+ # break
+
+ points['coords'] = np.asarray(coords) * 1000.
+ points['colors'] = np.asarray(colors) / 255.0
+ points['segs'] = np.asarray(segmentations)
+
+
+
+
+ # if normal:
+ # # Getting Normal
+ # for i in range(len(self.normal_paths)):
+ # print(self.normal_paths[i])
+ # normal_img = cv2.imread(self.normal_paths[i])
+ # for x in range(normal_img.shape[0]):
+ # for y in range(normal_img.shape[1]):
+ # normals.append(normalize(normal_img[x, y].reshape(-1, 1)).ravel())
+ # points['normals'] = normals
+
+
+
+ return points
+
+ def get_point_cloud(self):
+ return self.point_cloud
+
+ def display_inlier_outlier(self, cloud, ind):
+ inlier_cloud = cloud.select_down_sample(ind)
+ # outlier_cloud = cloud.select_down_sample(ind, invert=True)
+
+ print("Showing outliers (red) and inliers (gray): ")
+ # outlier_cloud.paint_uniform_color([1, 0, 0])
+ # inlier_cloud.paint_uniform_color([0.8, 0.8, 0.8])
+ # o3d.visualization.draw_geometries([inlier_cloud, outlier_cloud])
+ return inlier_cloud
+
+ def visualize(self, o3d_pcd=None):
+ # input("Max depth?")
+ print("Visualizing...")
+ pcd = open3d.geometry.PointCloud()
+
+ if o3d_pcd is None:
+ pcd.points = open3d.utility.Vector3dVector(self.point_cloud['coords'])
+ # if self.generate_normal:
+ # pcd.normals = open3d.utility.Vector3dVector(self.point_cloud['normals'])
+
+ # if False and self.generate_segmentation:
+ # pcd.colors = open3d.utility.Vector3dVector(self.point_cloud['segs'] / 255.)
+ # elif self.generate_color:
+ # pcd.colors = open3d.utility.Vector3dVector(self.point_cloud['colors'])
+ # pcd.colors = open3d.utility.Vector3dVector(self.point_cloud['colors'])
+ pcd.colors = open3d.utility.Vector3dVector(self.point_cloud['segs'] / 255.)
+ else:
+ pcd = o3d_pcd
+
+ vis = open3d.visualization.Visualizer()
+ # vis.create_window(window_name="O3D")
+ vis.create_window(window_name="O3D", width=1280, height=720, left=0, top=0,
+ visible=True) # use visible=True to visualize the point cloud
+ # vis.get_render_option().light_on = False
+ # vis.get_render_option().point_size = 20
+
+ vis.add_geometry(pcd)
+
+
+ with open("/media/sinisa/Sinisa_hdd_data/Sinisa_Projects/corridor_localisation/Datasets/Structured_3D_dataset/Structured3D/Structured3D_0/Structured3D/train/scene_00015/annotation_3d.json") as file:
+ annos = json.load(file)
+
+ wireframe_geo_list = visualize_wireframe(annos, vis=False, ret=True)
+
+ vis.add_geometry(wireframe_geo_list[0])
+ vis.add_geometry(wireframe_geo_list[1])
+
+ # for view_ind in range(len(self.pose_paths)):
+ # # if view_ind != 25:
+ # # continue
+ # W, H = (1280, 720)
+ # depth_img = cv2.imread(self.depth_paths[view_ind], cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) / 1000.
+ #
+ # # rgb_img = cv2.imread(self.rgb_paths[i])
+ # # plt.subplot(121)
+ # # plt.imshow(rgb_img)
+ # # plt.subplot(122)
+ # # plt.imshow(depth_img)
+ # # plt.show()
+ #
+ # camera_pose = np.loadtxt(self.pose_paths[view_ind])
+ # rot, trans, K = parse_camera_info(camera_pose, H, W, inverse=True)
+ #
+ # pose = np.eye(4)
+ # pose[:3, :3] = rot
+ # pose[:3, 3] = trans / 1000.
+ #
+ # camera_param = vis.get_view_control().convert_to_pinhole_camera_parameters()
+ # fx, fy = camera_param.intrinsic.get_focal_length()
+ # cx = camera_param.intrinsic.intrinsic_matrix[0, 2]
+ # cy = camera_param.intrinsic.intrinsic_matrix[1, 2]
+ # camera_param.intrinsic.set_intrinsics(camera_param.intrinsic.width, camera_param.intrinsic.height,
+ # K[0, 0], K[1, 1], cx, cy)
+ # camera_param.extrinsic = np.linalg.inv(pose)
+ # ctr = vis.get_view_control()
+ # ctr.convert_from_pinhole_camera_parameters(camera_param)
+ # depth_render = vis.capture_depth_float_buffer(do_render=True)
+ # depth_render = np.asarray(depth_render)
+ #
+ #
+ # camera_param = vis.get_view_control().convert_to_pinhole_camera_parameters()
+ # print("My_intr", K)
+ # print("O3D_intr", camera_param.intrinsic.intrinsic_matrix)
+ # print("view ind", view_ind)
+ #
+ # print("Plot")
+ # plt.subplot(131)
+ # plt.imshow(depth_img)
+ # plt.subplot(132)
+ # plt.imshow(depth_render)
+ # plt.subplot(133)
+ # plt.imshow(np.abs(depth_render - depth_img))
+ # plt.show()
+
+ vis.run()
+ vis.destroy_window()
+
+ def generate_density(self, width=256, height=256):
+
+ ps = self.point_cloud["coords"]
+
+ unique_coords, unique_ind = np.unique(np.round(ps / 10) * 10., return_index=True, axis=0)
+
+ ps = unique_coords
+
+
+ image_res = np.array((width, height))
+
+ max_coords = np.max(ps, axis=0)
+ min_coords = np.min(ps, axis=0)
+ max_m_min = max_coords - min_coords
+
+ max_coords = max_coords + 0.1 * max_m_min
+ min_coords = min_coords - 0.1 * max_m_min
+
+
+ # coordinates = np.round(points[:, :2] / max_coordinates[None,:2] * image_res[None])
+ coordinates = \
+ np.round(
+ (ps[:, :2] - min_coords[None, :2]) / (max_coords[None,:2] - min_coords[None, :2]) * image_res[None])
+ coordinates = np.minimum(np.maximum(coordinates, np.zeros_like(image_res)),
+ image_res - 1)
+
+ density = np.zeros((height, width), dtype=np.float32)
+
+ unique_coordinates, counts = np.unique(coordinates, return_counts=True, axis=0)
+ print(np.unique(counts))
+ # counts = np.minimum(counts, 2e3)
+ #
+ unique_coordinates = unique_coordinates.astype(np.int32)
+
+ density[unique_coordinates[:, 1], unique_coordinates[:, 0]] = counts
+ density = density / np.max(density)
+ # print(np.unique(density))
+
+ plt.figure()
+ plt.imshow(density)
+ plt.show()
+
+ return density
+
+ def subsample_pcd(self, seg=False):
+ # input("Max depth?")
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(self.point_cloud['coords'])
+ # if self.generate_normal:
+ # pcd.normals = open3d.utility.Vector3dVector(self.point_cloud['normals'])
+
+ if seg:
+ pcd.colors = open3d.utility.Vector3dVector(self.point_cloud['segs'] / 255.)
+ else:
+ pcd.colors = open3d.utility.Vector3dVector(self.point_cloud['colors'])
+
+ final_pcd = pcd
+ final_pcd, inds = pcd.remove_statistical_outlier(nb_neighbors=10,
+ std_ratio=3.0)
+ #
+ final_pcd = final_pcd.uniform_down_sample(every_k_points=10)
+ return final_pcd
+
+ def export_ply_from_o3d_pcd(self, path, pcd, seg=False):
+ '''
+ ply
+ format ascii 1.0
+ comment Mars model by Paul Bourke
+ element vertex 259200
+ property float x
+ property float y
+ property float z
+ property uchar r
+ property uchar g
+ property uchar b
+ property float nx
+ property float ny
+ property float nz
+ end_header
+ '''
+
+ coords = np.asarray(pcd.points)
+ colors = (np.asarray(pcd.colors) * 255).astype(np.int32)
+ with open(path, "w") as f:
+ f.write("ply\n")
+ f.write("format ascii 1.0\n")
+ f.write("element vertex %d\n" % coords.shape[0])
+ f.write("property float x\n")
+ f.write("property float y\n")
+ f.write("property float z\n")
+ if self.generate_color:
+ f.write("property uchar red\n")
+ f.write("property uchar green\n")
+ f.write("property uchar blue\n")
+
+ f.write("end_header\n")
+ for i in range(coords.shape[0]):
+ coord = coords[i].tolist()
+ color = colors[i].tolist()
+ data = coord + color
+ f.write(" ".join(list(map(str,data)))+'\n')
diff --git a/s3d_preprocess/DataProcessing/path_variables.py b/s3d_preprocess/DataProcessing/path_variables.py
new file mode 100644
index 0000000000000000000000000000000000000000..330b9ebb077669cdfd27ae594847d69a99f808ad
--- /dev/null
+++ b/s3d_preprocess/DataProcessing/path_variables.py
@@ -0,0 +1,10 @@
+scenes_path = "/media/sinisa/Sinisa_hdd_data/Sinisa_Projects/corridor_localisation/Datasets/Structured_3D_dataset/Structured3D/Structured3D_0/Structured3D/test/"
+scene_name = "/scene_03314"
+scene_path = scenes_path + scene_name
+
+output_folder = '/home/sinisa/Sinisa_Projects/indoor_localisation/SceneLayout/S3D_SceneLayout_scenes/'
+output_path = output_folder + scene_name + '/'
+scene_segmented_ply_path = output_folder + scene_name + "/" + scene_name + '_segmented.ply'
+#----------------------------------------
+
+scene_ply_path = output_folder + scene_name + "/" + scene_name + '.ply'
diff --git a/s3d_preprocess/README.md b/s3d_preprocess/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0abd9cfa26e4394da599e3079e20b498ee44cc3
--- /dev/null
+++ b/s3d_preprocess/README.md
@@ -0,0 +1,39 @@
+# Structured3D preprocessing for floorplan data
+
+We thank the authors of [MonteFloor](https://openaccess.thecvf.com/content/ICCV2021/papers/Stekovic_MonteFloor_Extending_MCTS_for_Reconstructing_Accurate_Large-Scale_Floor_Plans_ICCV_2021_paper.pdf) for providing the preprocessing scripts to generate the floorplan data from Structured3D dataset.
+
+
+We prepare the training data for HEAT based on the generated density/normal images and the raw floorplan annotations. Note that all the data used in our paper can be downloaded from [our links](https://github.com/woodfrog/heat#data), and this readme doc is an inexhaustive explanation for those who interested in the data preprocessing process.
+
+
+## Generate floorplan data (the original readme provided by MonteFloor)
+
+This code is based on Structured3D repository.
+
+To generate floorplans, run generate_floors.py script:
+
+```
+python generate_floors.py
+```
+
+Prior to that, you should modify path variables in DataProcessing.FloorRW. (daataset_path, and mode)
+
+
+Some scenes have missing/wrong annotations. These are the indices that you should additionally exclude from test set:
+
+```
+wrong_s3d_annotations_list = [3261, 3271, 3276, 3296, 3342, 3387, 3398, 3466, 3496]
+```
+
+
+## Generate the training annotations for HEAT
+
+In HEAT's formulation, each floorplan is represented by a planar graph. However, the raw annotations from Structured3D represent the floorplan by a list of closed loops. To prepare the **ground-truth training data** for HEAT, we need to further process the raw annotations to get proper planar graphs. We refer to the room merging step of [Floor-SP](https://arxiv.org/abs/1908.06702) and implement a merging algorithm (in ```generate_planar_graph.py```) to generate planar graphs from the raw annotations.
+
+**Note**: the generated planar graphs are **only used for training HEAT**. For evaluation, we extract the rooms from the estimtaed planar graph as closed loops and follow the original evaluation pipeline established by MonteFloor. Check the [quantitative evaluation section](https://github.com/woodfrog/heat#floorplan-reconstruction) for the details.
+
+Please run the script ```generate_planar_graph.py``` to merge the rooms and get the training annotations for HEAT.
+
+
+
+
diff --git a/s3d_preprocess/data_organization.md b/s3d_preprocess/data_organization.md
new file mode 100644
index 0000000000000000000000000000000000000000..c16c38ab18ed83b808e9530cd1313e71c317ba8e
--- /dev/null
+++ b/s3d_preprocess/data_organization.md
@@ -0,0 +1,146 @@
+# Data Organization
+
+There is a separate subdirectory for every scene (*i.e.*, house design), which is named by a unique ID. Within each scene directory, there are separate directories for different types of data as follows:
+```
+scene_
+├── 2D_rendering
+│ └──
+│ ├── panorama
+│ │ ├──
+│ │ │ ├── rgb_light.png
+│ │ │ ├── semantic.png
+│ │ │ ├── albedo.png
+│ │ │ ├── depth.png
+│ │ │ └── normal.png
+│ │ ├── layout.txt
+│ │ └── camera_xyz.txt
+│ └── perspective
+│ └── full
+│ └──
+│ ├── rgb_rawlight.png
+│ ├── semantic.png
+│ ├── instance.png
+│ ├── albedo.png
+│ ├── depth.png
+│ ├── normal.png
+│ ├── layout.json
+│ └── camera_pose.txt
+├── bbox_3d.json
+└── annotation_3d.json
+```
+
+# Annotation Format
+
+We provide the primitive and relationship based structure annotation for each scene, and oriented bounding box for each object instance.
+
+**Structure annotation (`annotation_3d.json`)**: see all the room types [here](metadata/room_types.txt).
+```
+{
+ // PRIMITVIES
+ "junctions":[
+ {
+ "ID": : int,
+ "coordinate" : List[float] // 3D vector
+ }
+ ],
+ "lines": [
+ {
+ "ID": : int,
+ "point" : List[float], // 3D vector
+ "direction" : List[float] // 3D vector
+ }
+ ],
+ "planes": [
+ {
+ "ID": : int,
+ "type" : str, // ceiling, floor, wall
+ "normal" : List[float], // 3D vector, the normal points to the empty space
+ "offset" : float
+ }
+ ],
+ // RELATIONSHIPS
+ "semantics": [
+ {
+ "ID" : int,
+ "type" : str, // room type, door, window
+ "planeID" : List[int] // indices of the planes
+ }
+ ],
+ "planeLineMatrix" : Matrix[int], // matrix W_1 where the ij-th entry is 1 iff l_i is on p_j
+ "lineJunctionMatrix" : Matrix[int], // matrix W_2 here the mn-th entry is 1 iff x_m is on l_nj
+ // OTHERS
+ "cuboids": [
+ {
+ "ID": : int,
+ "planeID" : List[int] // indices of the planes
+ }
+ ]
+ "manhattan": [
+ {
+ "ID": : int,
+ "planeID" : List[int] // indices of the planes
+ }
+ ]
+}
+```
+
+**Bounding box (`bbox_3d.json`)**: the oriented bounding box annotation in world coordinate, same as [SUN RGB-D](http://rgbd.cs.princeton.edu).
+```
+[
+ {
+ "ID" : int, // instance id
+ "basis" : Matrix[flaot], // basis of the bounding box, one row is one basis
+ "coeffs" : List[flaot], // radii in each dimension
+ "centroid" : List[flaot], // 3D centroid of the bounding box
+ }
+]
+```
+
+For each image, we provide semantic, instance, albedo, depth, normal, layout annotation and camera position. Please note that we have different layout and camera annotation format for panoramic and perspective images.
+
+**Semantic annotation (`semantic.png`)**: unsigned 8-bit integers within a PNG. We use [NYUv2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2) 40-label set, see all the label ids [here](metadata/labelids.txt).
+
+**Instance annotation for perspective (`instance.png`)**: unsigned 16-bit integers within a PNG. The maximum value (65535) denotes *background*.
+
+**Albedo data (`albedo.png`)**: unsigned 8-bit integers within a PNG.
+
+**Depth data (`depth.png`)**: unsigned 16-bit integers within a PNG. The units are millimeters, a value of 1000 is a meter. A zero value denotes *no reading*.
+
+**Normal data (`normal.png`)**: unsigned 8-bit integers within a PNG (x, y, z), where the integer values in the file are 128 \* (1 + n), where n is a normal coordinate in range [-1, 1].
+
+**Layout annotation for panorama (`layout.txt`)**: an ordered list of 2D positions of the junctions (same as [LayoutNet](https://github.com/zouchuhang/LayoutNet) and [HorizonNet](https://github.com/sunset1995/HorizonNet)). The order of the junctions is shown in the figure below. In our dataset, the cameras of the panoramas are aligned with the gravity direction, thus a pair of ceiling-wall and floor-wall junctions share the same x-axis coordinates.
+
+
+
+
+
+**Layout annotation for perspecitve (`layout.json`)**: We also include the junctions that formed by line segments intersecting with each other or image boundary. We consider the visible and invisible part caused by the room structure instead of furniture.
+```
+{
+ "junctions":[
+ {
+ "ID" : int, // corresponding 3D junction id, none corresponds to fake 3D junction
+ "coordinate" : List[int], // 2D location in the camera coordinate
+ "isvisible" : bool // this junction is whether occluded by the other walls
+ }
+ ],
+ "planes": [
+ {
+ "ID" : int, // corresponding 3D plane id
+ "visible_mask" : List[List[int]], // visible segmentation mask, list of junctions ids
+ "amodal_mask" : List[List[int]], // amodal segmentation mask, list of junctions ids
+ "normal" : List[float], // normal in the camera coordinate
+ "offset" : float, // offset in the camera coordinate
+ "type" : str // ceiling, floor, wall
+ }
+ ]
+}
+```
+
+**Camera location for panorama (`camera_xyz.txt`)**: For each panoramic image, we only store the camera location in global coordinates. The direction of the camera is always along the negative y-axis. Global coordinate system is arbitrary, but the z-axis generally points upward.
+
+**Camera location for perspective (`camera_pose.txt`)**: For each perspective image, we store the camera location and pose in global coordinates.
+```
+vx vy vz tx ty tz ux uy uz xfov yfov 1
+```
+where `(vx, vy, vz)` is the eye viewpoint of the camera, `(tx, ty, tz)` is the view direction, `(ux, uy, uz)` is the up direction, and `xfov` and `yfov` are the half-angles of the horizontal and vertical fields of view of the camera in radians (the angle from the central ray to the leftmost/bottommost ray in the field of view), same as [Matterport3D](https://github.com/niessner/Matterport).
diff --git a/s3d_preprocess/generate_coco_json.py b/s3d_preprocess/generate_coco_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..48c30fcf2d409702b2147a130639f4f2f77e0217
--- /dev/null
+++ b/s3d_preprocess/generate_coco_json.py
@@ -0,0 +1,4 @@
+from DataProcessing.FloorRW import FloorRW
+
+floor_rw = FloorRW()
+floor_rw.generate_coco_json()
\ No newline at end of file
diff --git a/s3d_preprocess/generate_floors.py b/s3d_preprocess/generate_floors.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cff3a8537387bc7cc783e3744af0580ad367a7a
--- /dev/null
+++ b/s3d_preprocess/generate_floors.py
@@ -0,0 +1,4 @@
+from DataProcessing.FloorRW import FloorRW
+
+floor_rw = FloorRW()
+floor_rw.generate_floors()
\ No newline at end of file
diff --git a/s3d_preprocess/generate_planar_graph.py b/s3d_preprocess/generate_planar_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8df36ed15c22bcc57e5024a5a6f4770034aca8b
--- /dev/null
+++ b/s3d_preprocess/generate_planar_graph.py
@@ -0,0 +1,603 @@
+import os
+import json
+import cv2
+import numpy as np
+from collections import defaultdict
+from scipy import ndimage
+
+
+def generate_graph(annot, image_path, out_path):
+ lines = annot['lines']
+ junctions = annot['junctions']
+ line_junc_mat = np.array(annot['lineJunctionMatrix'])
+ planes = annot['planes']
+ plane_line_mat = annot['planeLineMatrix']
+ plane_to_line = np.array(annot['planeLineMatrix'])
+ line_to_plane = plane_to_line.T
+ semantics = annot['semantics']
+
+ all_room_edges = get_room_edges(semantics, planes, lines, junctions, plane_to_line, line_junc_mat)
+
+ all_room_edges = filter_rooms(all_room_edges, im_size=256)
+
+ all_colinear_pairs = find_all_colinear_paris(all_room_edges)
+ colinear_sets = combine_colinear_edges(all_colinear_pairs)
+
+ for colinear_set in colinear_sets:
+ edges_to_merge = list(colinear_set)
+ edges_to_merge = sorted(edges_to_merge, key=lambda x: -x[0])
+ merged_edges = merge_edges(edges_to_merge)
+
+ for merged_edge, old_edge in zip(merged_edges, edges_to_merge):
+ if len(merged_edge) > 0:
+ assert merged_edge[0][0] == old_edge[0]
+ room_idx = merged_edge[0][0]
+
+ if len(merged_edge) == 1 and merged_edge[0] == old_edge:
+ continue
+ # change room graph accordingly
+ replaced_idx = all_room_edges[room_idx].index(old_edge[1])
+ all_room_edges[room_idx].pop(replaced_idx)
+ for new_idx, new_edge in enumerate(merged_edge):
+ insert_idx = new_idx + replaced_idx
+ all_room_edges[room_idx].insert(insert_idx, new_edge[1])
+ else:
+ room_idx = old_edge[0]
+ replaced_idx = all_room_edges[room_idx].index(old_edge[1])
+ all_room_edges[room_idx].pop(replaced_idx)
+
+ # take intersection for every rooms to recover the room structure
+ refined_room_edges = [adjust_room_edges(room_edges) for room_edges in all_room_edges]
+
+ # clean every room loop by removing I-shape corners
+ cleaned_room_edges = clean_room_edges(refined_room_edges)
+
+ global_graph = defaultdict(list)
+ for room_edges in cleaned_room_edges:
+ for edge in room_edges:
+ c1, c2 = edge
+ global_graph[c1] += [c2, ]
+ global_graph[c2] += [c1, ]
+ for corner in global_graph:
+ global_graph[corner] = list(set(global_graph[corner]))
+
+ annot_path = os.path.join(out_path, 'annot.npy')
+ np.save(annot_path, global_graph)
+
+ # draw the planar graph on the density map
+ viz_image = cv2.imread(image_path)
+ for c, connections in global_graph.items():
+ for other_c in connections:
+ cv2.line(viz_image, (int(c[0]), int(c[1])), (int(other_c[0]), int(other_c[1])), (255, 255, 0), 2)
+ for c in global_graph.keys():
+ cv2.circle(viz_image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
+
+ # viz_image = np.zeros([256, 266, 3]).astype(np.uint8)
+ # room_idx = 0
+ # for room_edges in cleaned_room_edges:
+ # for line_idx, edge in enumerate(room_edges):
+ # c1, c2 = np.array(edge).astype(np.int)
+ # cv2.line(viz_image, tuple(c1), tuple(c2), (255, 255, 0), 2)
+ # cv2.circle(viz_image, tuple(c1), 3, (0, 0, 255), -1)
+ # cv2.circle(viz_image, tuple(c2), 3, (0, 0, 255), -1)
+ # room_idx += 1
+ cv2.imwrite(os.path.join(out_path, 'planar_graph.png'), viz_image)
+
+
+def get_room_edges(semantics, planes, lines, junctions, plane_to_line, line_junc_mat):
+ room_edges = list()
+ for semantic in semantics:
+ plane_ids = semantic['planeID']
+ label = semantic['type']
+ if label in ['door', 'window', 'outwall']: # skip non-room elements
+ continue
+ all_planes = [planes[idx] for idx in plane_ids]
+ floor_planes = [plane for plane in all_planes if plane['type'] == 'floor']
+ assert len(floor_planes) == 1, 'There should be only one floor for each room'
+ floor_plane = floor_planes[0]
+ floor_plane_id = floor_plane['ID']
+ line_ids = np.where(plane_to_line[floor_plane_id])[0].tolist()
+ floor_lines_raw = [lines[line_id] for line_id in line_ids]
+ floor_lines = list()
+ for line_idx, floor_line in enumerate(floor_lines_raw):
+ c_id_1, c_id_2 = np.where(line_junc_mat[floor_line['ID']])[0].tolist()
+ c1 = tuple(junctions[c_id_1]['coordinate'][:2])
+ c2 = tuple(np.array(junctions[c_id_2]['coordinate'][:2]))
+ if c1 == c2:
+ continue
+ floor_lines.append((c1, c2))
+
+ floor_lines = list(set(floor_lines)) # remove duplications
+ floor_lines = sort_room_edges(floor_lines)
+ room_edges.append(floor_lines)
+ return room_edges
+
+
+def sort_room_edges(lines):
+ cur_id = 0
+ picked = [False] * len(lines)
+ id_list = [0, ]
+ while len(id_list) < len(lines):
+ line = lines[cur_id]
+ picked[cur_id] = True
+ check_ = [(line[1] in other) and not picked[other_idx] for other_idx, other in enumerate(lines)]
+ next_ids = np.nonzero(check_)[0]
+ try:
+ assert len(next_ids) == 1
+ except:
+ raise WrongRoomError('Invalid room shape')
+ next_id = next_ids[0]
+ id_list.append(next_id)
+ if lines[next_id][1] == line[1]: # swap the two endpoints, to make the loop valid.
+ lines[next_id] = (lines[next_id][1], lines[next_id][0])
+ cur_id = next_id
+ if lines[next_id][1] == lines[0][0]: # already form a closed loop, then skip the remaining lines
+ break
+ sorted_lines = [lines[idx] for idx in id_list]
+ return sorted_lines
+
+
+def find_all_colinear_paris(all_room_edges):
+ colinear_pairs = list()
+ for room_idx, room_edges in enumerate(all_room_edges):
+ for edge_idx, edge in enumerate(room_edges):
+ for other_room_idx, other_edges in enumerate(all_room_edges):
+ if other_room_idx < room_idx:
+ continue
+ for other_edge_idx, other_edge in enumerate(other_edges):
+ if other_room_idx == room_idx and other_edge_idx <= edge_idx:
+ continue
+ if _check_colinear(edge, other_edge, line_dist_th=8):
+ ele1 = (room_idx, edge)
+ ele2 = (other_room_idx, other_edge)
+ colinear_pairs.append([ele1, ele2])
+ return colinear_pairs
+
+
+def combine_colinear_edges(colinear_pairs):
+ all_colinear_sets = list()
+ all_pairs = list(colinear_pairs) # make a copy of the input list
+ combined = [False] * len(colinear_pairs)
+
+ while len(all_pairs) > 0:
+ colinear_set = _combine_colinear_pairs(0, all_pairs, combined)
+ all_colinear_sets.append(colinear_set)
+ all_pairs = [all_pairs[i] for i in range(len(all_pairs)) if combined[i] is False]
+ combined = [False] * len(all_pairs)
+ return all_colinear_sets
+
+
+def _combine_colinear_pairs(idx, all_pairs, combined):
+ colinear_set = set(all_pairs[idx])
+ combined[idx] = True
+ for other_idx, pair in enumerate(all_pairs):
+ if not combined[other_idx] and (
+ all_pairs[idx][0] in all_pairs[other_idx] or all_pairs[idx][1] in all_pairs[other_idx]):
+ colinear_set = colinear_set.union(_combine_colinear_pairs(other_idx, all_pairs, combined))
+ return colinear_set
+
+
+def _check_colinear(e1, e2, line_dist_th=8):
+ # first check whether two line segments are parallel to each other, if not, return False directly
+ len_e1 = len_edge(e1)
+ len_e2 = len_edge(e2)
+ # we need to always make e2 the shorter one
+ if len_e1 < len_e2:
+ e1, e2 = e2, e1
+ v1_01 = (e1[1][0] - e1[0][0], e1[1][1] - e1[0][1])
+ v1_10 = (e1[0][0] - e1[1][0], e1[0][1] - e1[1][1])
+ v2_01 = (e2[1][0] - e2[0][0], e2[1][1] - e2[0][1])
+ v2_10 = (e2[0][0] - e2[1][0], e2[0][1] - e2[1][1])
+ len_1 = np.sqrt(v1_01[0] ** 2 + v1_01[1] ** 2)
+ len_2 = np.sqrt(v2_01[0] ** 2 + v2_01[1] ** 2)
+ if len_1 == 0 or len_2 == 0:
+ cos = 0
+ else:
+ cos = (v1_01[0] * v2_01[0] + v1_01[1] * v2_01[1]) / (len_1 * len_2)
+ if abs(cos) > 0.99:
+ # then check the distance between two parallel lines
+ len_10_20 = len_edge((e1[0], e2[0]))
+ len_10_21 = len_edge((e1[0], e2[1]))
+ len_11_20 = len_edge((e1[1], e2[0]))
+ len_11_21 = len_edge((e1[1], e2[1]))
+
+ # two endpoints are very close, then we can say these two edges are colinear
+ if np.min([len_10_20, len_10_21, len_11_20, len_11_21]) <= 5:
+ return True
+ # otherwise we need to check the distance first
+ v_10_20 = (e2[0][0] - e1[0][0], e2[0][1] - e1[0][1])
+ cos_11_10_20 = (v1_01[0] * v_10_20[0] + v1_01[1] * v_10_20[1]) / (len_1 * len_10_20)
+ sin_11_10_20 = np.sqrt(1 - cos_11_10_20 ** 2)
+ dist_20_e1 = len_10_20 * sin_11_10_20
+ if dist_20_e1 <= line_dist_th:
+ # we need two check whether they have some overlaps
+ v_11_20 = (e2[0][0] - e1[1][0], e2[0][1] - e1[1][1])
+ cos_10_11_20 = (v1_10[0] * v_11_20[0] + v1_10[1] * v_11_20[1]) / (len_1 * len_11_20)
+ if cos_11_10_20 >= 0 and cos_10_11_20 >= 0:
+ return True
+ v_10_21 = (e2[1][0] - e1[0][0], e2[1][1] - e1[0][1])
+ cos_11_10_21 = (v1_01[0] * v_10_21[0] + v1_01[1] * v_10_21[1]) / (len_1 * len_10_21)
+ v_11_21 = (e2[1][0] - e1[1][0], e2[1][1] - e1[1][1])
+ cos_10_11_21 = (v1_10[0] * v_11_21[0] + v1_10[1] * v_11_21[1]) / (len_1 * len_11_21)
+ if cos_11_10_21 >= 0 and cos_10_11_21 >= 0:
+ return True
+ return False
+ else:
+ # if the two line segments have distance > 3, we can say they are not colinear
+ return False
+ else:
+ return False
+
+
+def merge_edges(edges):
+ base_e = edges[0][1]
+ merged_edges = [edges[0], ]
+ base_len = np.sqrt((base_e[1][0] - base_e[0][0]) ** 2 + (base_e[1][1] - base_e[0][1]) ** 2)
+ base_unit_v = ((base_e[1][0] - base_e[0][0]) / base_len, (base_e[1][1] - base_e[0][1]) / base_len)
+
+ for edge in edges[1:]:
+ room_idx = edge[0]
+ e = edge[1]
+ v_b0e0 = (e[0][0] - base_e[0][0], e[0][1] - base_e[0][1])
+ proj_len = (v_b0e0[0] * base_unit_v[0] + v_b0e0[1] * base_unit_v[1])
+ proj_e0 = (int(base_e[0][0] + base_unit_v[0] * proj_len), int(base_e[0][1] + base_unit_v[1] * proj_len))
+ proj_e1 = (int(proj_e0[0] + e[1][0] - e[0][0]), int(proj_e0[1] + e[1][1] - e[0][1]))
+ new_e = (proj_e0, proj_e1)
+ new_edge = (room_idx, new_e)
+ merged_edges.append(new_edge)
+
+ adjusted_merged_edges = adjust_colinear_edges(merged_edges)
+
+ return adjusted_merged_edges
+
+
+def adjust_colinear_edges(edges):
+ base_corner = (edges[0][0], edges[0][1][0])
+ all_corners = [base_corner, (edges[0][0], edges[0][1][1])]
+ for edge in edges[1:]:
+ all_corners.append((edge[0], edge[1][0]))
+ all_corners.append((edge[0], edge[1][1]))
+ unit_v = unit_v_edge(edges[0][1])
+ corner_projs = list()
+ # FIXME: need to fix the corner coords here, it's wrong now!! they are not merged...
+ for room, other_c in all_corners:
+ v_base_c = (other_c[0] - base_corner[1][0], other_c[1] - base_corner[1][1])
+ proj = (unit_v[0] * v_base_c[0] + unit_v[1] * v_base_c[1])
+ corner_projs.append(proj)
+ order = np.argsort(corner_projs).tolist()
+
+ # # merge corners that are too close to the prev corner
+ # for o_idx, corner_idx in enumerate(order[1:]):
+ # corner = all_corners[corner_idx][1]
+ # prev_idx = order[o_idx]
+ # prev_corner = all_corners[prev_idx][1]
+ # dist = len_edge((corner, prev_corner))
+ # if dist <= 5:
+ # all_corners[corner_idx] = (all_corners[corner_idx][0], prev_corner)
+
+ adjusted_edges = list()
+ for idx, edge in enumerate(edges):
+ room_idx = edge[0]
+ idx_1 = idx * 2
+ idx_2 = idx * 2 + 1
+ adj_idx_1 = order.index(idx_1)
+ adj_idx_2 = order.index(idx_2)
+ step_direction = 1 if adj_idx_2 > adj_idx_1 else -1
+ adjusted_edge = list()
+ for o_idx in range(adj_idx_1, adj_idx_2, step_direction):
+ c_idx = order[o_idx]
+ next_c_idx = order[o_idx + step_direction]
+ segment = (room_idx, (all_corners[c_idx][1], all_corners[next_c_idx][1]))
+ if len_edge(segment[1]) == 0:
+ continue
+ adjusted_edge.append(segment)
+ adjusted_edges.append(adjusted_edge)
+ return adjusted_edges
+
+
+def adjust_room_edges(room_edges):
+ refined_room_edges = list()
+
+ init_room_edges = list(room_edges)
+ for edge_i, edge in enumerate(room_edges):
+ next_i = edge_i
+ while True:
+ next_i = next_i + 1 if next_i < len(room_edges) - 1 else 0
+ next_edge = room_edges[next_i]
+ if next_edge[0] != next_edge[1]:
+ break
+ if edge[1] == next_edge[0]: # no need for refining
+ refined_room_edges.append(edge)
+ else: # the two corners disagree, refinement is required
+ if edge[0] == edge[1]:
+ print('skip collasped edge')
+ continue
+ unit_edge = unit_v_edge(edge)
+ ext_edge = ((edge[0][0] - unit_edge[0] * 50, edge[0][1] - unit_edge[1] * 50),
+ (edge[1][0] + unit_edge[0] * 50, edge[1][1] + unit_edge[1] * 50))
+ unit_next = unit_v_edge(next_edge)
+ ext_next = ((next_edge[0][0] - unit_next[0] * 50, next_edge[0][1] - unit_next[1] * 50),
+ (next_edge[1][0] + unit_next[0] * 50, next_edge[1][1] + unit_next[1] * 50))
+ intersec = get_intersection(ext_edge[0], ext_edge[1], ext_next[0], ext_next[1])
+ try:
+ assert intersec is not None
+ except:
+ print('no intersect, move endpoint directly')
+ intersec = next_edge[0]
+ intersec = (int(np.round(intersec[0])), int(np.round(intersec[1])))
+ refined_edge = (edge[0], intersec)
+ refined_room_edges.append(refined_edge)
+ room_edges[edge_i] = refined_edge
+ room_edges[next_i] = (intersec, next_edge[1])
+ if next_i < edge_i:
+ refined_room_edges[next_i] = room_edges[next_i]
+
+ # drop collapsed edges
+ refined_room_edges = [edge for edge in refined_room_edges if edge[0] != edge[1]]
+ for edge_i in range(len(refined_room_edges)):
+ next_i = edge_i + 1 if edge_i < len(refined_room_edges) - 1 else 0
+ if refined_room_edges[edge_i][1] != refined_room_edges[next_i][0]:
+ new_edge = (refined_room_edges[edge_i][0], refined_room_edges[next_i][0])
+ refined_room_edges[edge_i] = new_edge
+ return refined_room_edges
+
+
+def clean_room_edges(all_room_edges):
+ refined_room_paths = [_extract_room_path(room_edges) for room_edges in all_room_edges]
+ corner_to_room = defaultdict(list)
+ for room_idx, room_path in enumerate(refined_room_paths):
+ for corner in room_path:
+ corner_to_room[corner].append(room_idx)
+ # remove I-shape corner used by only one room
+ for room_idx, room_edges in enumerate(all_room_edges):
+ cp_room_edges = list(room_edges)
+ rm_flag = True
+ while rm_flag:
+ rm_flag = False
+ for edge_i, edge in enumerate(cp_room_edges):
+ prev_i = edge_i - 1
+ prev_edge = cp_room_edges[prev_i]
+ if _check_colinear(prev_edge, edge, line_dist_th=5):
+ rm_candidate = edge[0]
+ if len(corner_to_room[rm_candidate]) == 1 and corner_to_room[rm_candidate][0] == room_idx:
+ cp_room_edges[prev_i] = (prev_edge[0], edge[1])
+ rm_flag = True
+ cp_room_edges.pop(edge_i)
+ break
+ next_i = edge_i + 1 if edge_i < len(cp_room_edges) - 1 else 0
+ next_edge = cp_room_edges[next_i]
+ if _check_colinear(next_edge, edge, line_dist_th=5):
+ rm_candidate = edge[1]
+ if len(corner_to_room[rm_candidate]) == 1 and corner_to_room[rm_candidate][0] == room_idx:
+ cp_room_edges[next_i] = (edge[0], next_edge[1])
+ rm_flag = True
+ cp_room_edges.pop(edge_i)
+ break
+ if len(cp_room_edges) != len(room_edges):
+ all_room_edges[room_idx] = cp_room_edges
+
+ corner_to_room = get_corner_to_room(all_room_edges)
+ all_corners = list(corner_to_room.keys())
+ corners_to_merge = find_corners_to_merge(all_corners, corner_to_room)
+ while corners_to_merge is not None:
+ num_aff = [len(corner_to_room[x]) for x in corners_to_merge]
+ order = np.argsort(num_aff)[::-1]
+ base_corner = corners_to_merge[order[0]]
+ for corner in corners_to_merge:
+ if corner == base_corner:
+ continue
+ all_room_edges = move_corner(corner, base_corner, corner_to_room, all_room_edges)
+
+ corner_to_room = get_corner_to_room(all_room_edges)
+ all_corners = list(corner_to_room.keys())
+ corners_to_merge = find_corners_to_merge(all_corners, corner_to_room)
+
+ # for room_idx, room_edges in enumerate(all_room_edges):
+ # cp_room_edges = list(room_edges)
+ # rm_flag = True
+ # while rm_flag:
+ # rm_flag = False
+ # for edge_i, edge in enumerate(cp_room_edges):
+ # len_e = len_edge(edge)
+ # if len_e <= 5:
+ # if len(corner_to_room[edge[0]]) == 1:
+ # prev_i = edge_i - 1
+ # prev_edge = cp_room_edges[prev_i]
+ # cp_room_edges[prev_i] = (prev_edge[0], edge[1])
+ # rm_flag = True
+ # cp_room_edges.pop(edge_i)
+ # break
+ # elif len(corner_to_room[edge[1]]) == 1:
+ # next_i = edge_i + 1 if edge_i < len(cp_room_edges) - 1 else 0
+ # next_edge = cp_room_edges[next_i]
+ # cp_room_edges[next_i] = (edge[0], next_edge[1])
+ # rm_flag = True
+ # cp_room_edges.pop(edge_i)
+ # else:
+ # continue
+ #
+ # if len(cp_room_edges) != len(room_edges):
+ # all_room_edges[room_idx] = cp_room_edges
+
+ return all_room_edges
+
+
+def move_corner(c, target, corner_to_room, all_room_edges):
+ rooms = corner_to_room[c]
+ for room_idx in rooms:
+ for edge_idx, edge in enumerate(all_room_edges[room_idx]):
+ if c in edge:
+ if c == edge[0]:
+ new_edge = (target, edge[1])
+ elif c == edge[1]:
+ new_edge = (edge[0], target)
+ else:
+ continue
+ all_room_edges[room_idx][edge_idx] = new_edge
+ return all_room_edges
+
+
+def find_corners_to_merge(all_corners, corner_to_room, th=5):
+ all_close_pairs = list()
+ for idx1, corner in enumerate(all_corners):
+ for idx2, other_corner in enumerate(all_corners):
+ if idx2 <= idx1:
+ continue
+ if len_edge((corner, other_corner)) <= th:
+ rooms_1 = tuple(sorted(corner_to_room[corner]))
+ rooms_2 = tuple(sorted(corner_to_room[other_corner]))
+ if rooms_1 == rooms_2:
+ continue
+ elif len(rooms_1) ==1:
+ if rooms_1[0] in list(rooms_2):
+ continue
+ else:
+ all_close_pairs.append([corner, other_corner])
+ elif len(rooms_2) ==1:
+ if rooms_2[0] in list(rooms_1):
+ continue
+ else:
+ all_close_pairs.append([corner, other_corner])
+ else:
+ all_close_pairs.append([corner, other_corner])
+
+ if len(all_close_pairs) == 0:
+ return None
+
+ close_set = find_one_close_set(all_close_pairs)
+ corners_to_merge = list(close_set)
+
+ return corners_to_merge
+
+
+def find_one_close_set(all_corner_paris):
+ all_pairs = list(all_corner_paris) # make a copy of the input list
+ combined = [False] * len(all_corner_paris)
+
+ close_set = _combine_colinear_pairs(0, all_pairs, combined)
+
+ return close_set
+
+
+def get_corner_to_room(all_room_edges):
+ room_paths = [_extract_room_path(room_edges) for room_edges in all_room_edges]
+ corner_to_room = defaultdict(list)
+ for room_idx, room_path in enumerate(room_paths):
+ for corner in room_path:
+ corner_to_room[corner].append(room_idx)
+ return corner_to_room
+
+
+def filter_rooms(all_room_edges, im_size):
+ # filter rooms that are covered by larger rooms
+ room_masks = list()
+ updated_room_edges = list()
+ for room_edges in all_room_edges:
+ room_mask = draw_room_seg_from_edges(room_edges, im_size)
+ if room_mask is not None and room_mask.sum() > 20: # remove too small rooms
+ room_masks.append(room_mask)
+ updated_room_edges.append(room_edges)
+ all_room_edges = updated_room_edges
+
+ removed = list()
+ for room_idx, room_mask in enumerate(room_masks):
+ # do not consider the current room, and do not consider removed rooms
+ other_masks = [room_masks[i] for i in range(len(all_room_edges)) if i != room_idx and i not in removed]
+ if len(other_masks) == 0: # if all other masks are removed..
+ other_masks_all = np.zeros([im_size, im_size])
+ else:
+ other_masks_all = np.clip(np.sum(np.stack(other_masks, axis=-1), axis=-1), 0, 1)
+ joint_mask = np.clip(other_masks_all + room_mask, 0, 1)
+ mask_area = room_mask.sum()
+ overlap_area = mask_area + other_masks_all.sum() - joint_mask.sum()
+ if overlap_area / mask_area > 0.5:
+ removed.append(room_idx)
+
+ all_room_edges = [all_room_edges[idx] for idx in range(len(all_room_edges)) if idx not in removed]
+
+ return all_room_edges
+
+
+## Utils
+
+class WrongRoomError(Exception):
+ pass
+
+def _extract_room_path(room_edges):
+ room_path = [edge[0] for edge in room_edges]
+ return room_path
+
+
+def len_edge(e):
+ return np.sqrt((e[1][0] - e[0][0]) ** 2 + (e[1][1] - e[0][1]) ** 2)
+
+
+def unit_v_edge(e):
+ len_e = len_edge(e)
+ assert len_e != 0
+ unit_v = ((e[1][0] - e[0][0]) / len_e, (e[1][1] - e[0][1]) / len_e)
+ return unit_v
+
+
+def get_intersection(p0, p1, p2, p3):
+ """
+ reference: StackOverflow https://stackoverflow.com/questions/563198/how-do-you-detect-where-two-line-segments-intersect#565282
+ """
+ s1_x = p1[0] - p0[0]
+ s1_y = p1[1] - p0[1]
+ s2_x = p3[0] - p2[0]
+ s2_y = p3[1] - p2[1]
+
+ s = (-s1_y * (p0[0] - p2[0]) + s1_x * (p0[1] - p2[1])) / (-s2_x * s1_y + s1_x * s2_y)
+ t = (s2_x * (p0[1] - p2[1]) - s2_y * (p0[0] - p2[0])) / (-s2_x * s1_y + s1_x * s2_y)
+
+ if 1 >= s >= 0 and 1 >= t >= 0:
+ i_x = p0[0] + (t * s1_x)
+ i_y = p0[1] + (t * s1_y)
+ return (i_x, i_y)
+ else:
+ return None
+
+
+def draw_room_seg_from_edges(edges, im_size):
+ edge_map = np.zeros([im_size, im_size])
+ for edge in edges:
+ edge = np.array(edge).astype(np.int)
+ cv2.line(edge_map, tuple(edge[0]), tuple(edge[1]), 1, 3)
+ reverse_edge_map = 1 - edge_map
+ label, num_features = ndimage.label(reverse_edge_map)
+ if num_features < 2:
+ return None
+ bg_label = label[0, 0]
+ num_labels = [(label==l).sum() for l in range(1, num_features+1)]
+ num_labels[bg_label-1] = 0
+ room_label = np.argmax(num_labels) + 1
+ room_map = np.zeros([im_size, im_size])
+ room_map[np.where(label == room_label)] = 1
+
+ return room_map
+
+
+
+if __name__ == '__main__':
+ data_base = './montefloor_data/'
+ dir_names = list(sorted(os.listdir(data_base)))
+
+ invalid_scenes = list()
+
+ for dir_name in dir_names:
+ if 'scene' not in dir_name:
+ continue
+ data_dir = os.path.join(data_base, dir_name)
+ annot_path = os.path.join(data_dir, 'annotation_3d.json')
+ with open(annot_path) as f:
+ annot = json.load(f)
+ image_path = os.path.join(data_dir, 'density.png')
+
+ try:
+ generate_graph(annot, image_path, data_dir)
+ except WrongRoomError:
+ invalid_scenes.append(dir_name)
+ print('Finish processing data {}'.format(dir_name))
+
+ print('Failed on {} scenes with invalid rooms: {}'.format(len(invalid_scenes), invalid_scenes))
diff --git a/s3d_preprocess/generate_point_cloud.py b/s3d_preprocess/generate_point_cloud.py
new file mode 100644
index 0000000000000000000000000000000000000000..6669d6170c0e6d34dc8b5a408a6f7019e012ec73
--- /dev/null
+++ b/s3d_preprocess/generate_point_cloud.py
@@ -0,0 +1,31 @@
+from DataProcessing.path_variables import *
+from DataProcessing.PointCloudReaderPanorama import PointCloudReaderPanorama
+from DataProcessing.PointCloudReaderPerspective import PointCloudReaderPerspective
+
+if __name__ == "__main__":
+ scenes = [scene_path]
+ print(scenes)
+ for scene in scenes:
+
+ reader = PointCloudReaderPanorama(scene, random_level=0, generate_color=True, generate_normal=False)
+ path = "/home/sinisa/Sinisa_Projects/papers/ICCV21/supp_figures/blender_project/vis/" + scene_name + ".ply"
+ # reader.export_ply(path)
+ density_map = reader.generate_density()
+ reader.visualize(export_path=path)
+
+ # print("Creating point cloud from perspective views...")
+ # reader = PointCloudReaderPerspective(scene, random_level=0, generate_color=True, generate_normal=False,
+ # generate_segmentation=True)
+ # print("Subsampling point cloud...")
+ # o3d_pcd = reader.subsample_pcd(seg=False)
+ # reader.visualize(o3d_pcd)
+ # reader.generate_density()
+
+
+ # print("Writing point cloud...")
+ # reader.export_ply_from_o3d_pcd(scene_ply_path, o3d_pcd, seg=False)
+ #
+ # print("Subsampling segmented point cloud...")
+ # o3d_seg_pcd = reader.subsample_pcd(seg=True)
+ # print("Writing segmented point cloud...")
+ # reader.export_ply_from_o3d_pcd(scene_segmented_ply_path, o3d_seg_pcd, seg=True)
\ No newline at end of file
diff --git a/s3d_preprocess/label_names.txt b/s3d_preprocess/label_names.txt
new file mode 100644
index 0000000000000000000000000000000000000000..733d7647ba73adad55781b4223000c501b2a2b4a
--- /dev/null
+++ b/s3d_preprocess/label_names.txt
@@ -0,0 +1,41 @@
+__ignore__
+wall
+floor
+cabinet
+bed
+chair
+sofa
+table
+door
+window
+bookshelf
+picture
+counter
+blinds
+desk
+shelves
+curtain
+dresser
+pillow
+mirror
+floor mat
+clothes
+ceiling
+book
+refridgerator
+television
+paper
+towel
+shower curtain
+box
+whiteboard
+person
+night stand
+toilet
+sink
+lamp
+bathtub
+bag
+otherstructure
+otherfurniture
+otherprop
\ No newline at end of file
diff --git a/s3d_preprocess/metadata/errata.txt b/s3d_preprocess/metadata/errata.txt
new file mode 100644
index 0000000000000000000000000000000000000000..859568c5df2af217fe6688265bc6db8058dc91d5
--- /dev/null
+++ b/s3d_preprocess/metadata/errata.txt
@@ -0,0 +1,99 @@
+# invalid scene
+scene_01155
+scene_01714
+scene_01816
+scene_03398
+scene_01192
+scene_01852
+# a pair of junctions are not aligned along the x-axis
+scene_01778_room_858455
+# self-intersection layout
+scene_00010_room_846619
+scene_00043_room_1518
+scene_00043_room_3128
+scene_00043_room_474
+scene_00043_room_732
+scene_00043_room_856
+scene_00173_room_4722
+scene_00240_room_384
+scene_00325_room_970753
+scene_00335_room_686
+scene_00339_room_2193
+scene_00501_room_1840
+scene_00515_room_277475
+scene_00543_room_176
+scene_00587_room_9914
+scene_00703_room_762455
+scene_00703_room_771712
+scene_00728_room_5662
+scene_00828_room_607228
+scene_00865_room_1026
+scene_00865_room_1402
+scene_00875_room_739214
+scene_00917_room_188
+scene_00917_room_501284
+scene_00926_room_2290
+scene_00936_room_311
+scene_00937_room_1955
+scene_00986_room_141
+scene_01009_room_3234
+scene_01009_room_3571
+scene_01021_room_689126
+scene_01034_room_222021
+scene_01036_room_301
+scene_01043_room_2193
+scene_01104_room_875
+scene_01151_room_563
+scene_01165_room_204
+scene_01221_room_26619
+scene_01222_room_273364
+scene_01282_room_1917
+scene_01282_room_24057
+scene_01282_room_2631
+scene_01400_room_10576
+scene_01445_room_3495
+scene_01470_room_1413
+scene_01530_room_577
+scene_01670_room_291
+scene_01745_room_342
+scene_01759_room_3584
+scene_01759_room_3588
+scene_01772_room_897997
+scene_01774_room_143
+scene_01781_room_335
+scene_01781_room_878137
+scene_01786_room_5837
+scene_01916_room_2648
+scene_01993_room_849
+scene_01998_room_54762
+scene_02034_room_921879
+scene_02040_room_311
+scene_02046_room_1014
+scene_02046_room_834
+scene_02047_room_934954
+scene_02101_room_255228
+scene_02172_room_335
+scene_02235_room_799012
+scene_02274_room_4093
+scene_02326_room_836436
+scene_02334_room_869673
+scene_02357_room_118319
+scene_02484_room_43003
+scene_02499_room_1607
+scene_02499_room_977359
+scene_02509_room_687231
+scene_02542_room_671853
+scene_02564_room_702502
+scene_02580_room_724891
+scene_02650_room_877946
+scene_02659_room_577142
+scene_02690_room_586296
+scene_02706_room_823368
+scene_02788_room_815473
+scene_02889_room_848271
+scene_03035_room_631066
+scene_03120_room_830640
+scene_03327_room_315045
+scene_03376_room_800900
+scene_03399_room_337
+scene_03478_room_2193
\ No newline at end of file
diff --git a/s3d_preprocess/metadata/labelids.txt b/s3d_preprocess/metadata/labelids.txt
new file mode 100644
index 0000000000000000000000000000000000000000..54950e84e0c9227e4b518420ed8fed8bdb99a731
--- /dev/null
+++ b/s3d_preprocess/metadata/labelids.txt
@@ -0,0 +1,40 @@
+1 wall
+2 floor
+3 cabinet
+4 bed
+5 chair
+6 sofa
+7 table
+8 door
+9 window
+10 bookshelf
+11 picture
+12 counter
+13 blinds
+14 desk
+15 shelves
+16 curtain
+17 dresser
+18 pillow
+19 mirror
+20 floor mat
+21 clothes
+22 ceiling
+23 books
+24 refrigerator
+25 television
+26 paper
+27 towel
+28 shower curtain
+29 box
+30 whiteboard
+31 person
+32 nightstand
+33 toilet
+34 sink
+35 lamp
+36 bathtub
+37 bag
+38 otherstructure
+39 otherfurniture
+40 otherprop
\ No newline at end of file
diff --git a/s3d_preprocess/metadata/room_types.txt b/s3d_preprocess/metadata/room_types.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c244c5da7fe3f24be2c17ac829232be95b9853b2
--- /dev/null
+++ b/s3d_preprocess/metadata/room_types.txt
@@ -0,0 +1,16 @@
+living room
+kitchen
+bedroom
+bathroom
+balcony
+corridor
+dining room
+study
+studio
+store room
+garden
+laundry room
+office
+basement
+garage
+undefined
diff --git a/s3d_preprocess/misc/__init__.py b/s3d_preprocess/misc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/s3d_preprocess/misc/colors.py b/s3d_preprocess/misc/colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda6bf0e11b8078fc9c4478314bbcb7622392027
--- /dev/null
+++ b/s3d_preprocess/misc/colors.py
@@ -0,0 +1,47 @@
+semantics_cmap = {
+ 'living room': '#e6194b',
+ 'kitchen': '#3cb44b',
+ 'bedroom': '#ffe119',
+ 'bathroom': '#0082c8',
+ 'balcony': '#f58230',
+ 'corridor': '#911eb4',
+ 'dining room': '#46f0f0',
+ 'study': '#f032e6',
+ 'studio': '#d2f53c',
+ 'store room': '#fabebe',
+ 'garden': '#008080',
+ 'laundry room': '#e6beff',
+ 'office': '#aa6e28',
+ 'basement': '#fffac8',
+ 'garage': '#800000',
+ 'undefined': '#aaffc3',
+ 'door': '#808000',
+ 'window': '#ffd7b4',
+ 'outwall': '#000000',
+}
+
+
+colormap_255 = [
+ [230, 25, 75],
+ [ 60, 180, 75],
+ [255, 225, 25],
+ [ 0, 130, 200],
+ [245, 130, 48],
+ [145, 30, 180],
+ [ 70, 240, 240],
+ [240, 50, 230],
+ [210, 245, 60],
+ [250, 190, 190],
+ [ 0, 128, 128],
+ [230, 190, 255],
+ [170, 110, 40],
+ [255, 250, 200],
+ [128, 0, 0],
+ [170, 255, 195],
+ [128, 128, 0],
+ [255, 215, 180],
+ [ 0, 0, 128],
+ [128, 128, 128],
+ [255, 255, 255],
+ [ 0, 0, 0]
+]
diff --git a/s3d_preprocess/misc/figures.py b/s3d_preprocess/misc/figures.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb90a3d0e0d4a098adf1bede66a8f4d3babf600c
--- /dev/null
+++ b/s3d_preprocess/misc/figures.py
@@ -0,0 +1,78 @@
+"""
+Copy from https://github.com/Toblerity/Shapely/blob/master/docs/code/figures.py
+"""
+
+from math import sqrt
+from shapely import affinity
+
+GM = (sqrt(5)-1.0)/2.0
+W = 8.0
+H = W*GM
+SIZE = (W, H)
+
+BLUE = '#6699cc'
+GRAY = '#999999'
+DARKGRAY = '#333333'
+YELLOW = '#ffcc33'
+GREEN = '#339933'
+RED = '#ff3333'
+BLACK = '#000000'
+
+COLOR_ISVALID = {
+ True: BLUE,
+ False: RED,
+}
+
+
+def plot_line(ax, ob, color=GRAY, zorder=1, linewidth=3, alpha=1):
+ x, y = ob.xy
+ ax.plot(x, y, color=color, linewidth=linewidth, solid_capstyle='round', zorder=zorder, alpha=alpha)
+
+
+def plot_coords(ax, ob, color=BLACK, zorder=1, alpha=1):
+ x, y = ob.xy
+ ax.plot(x, y, color=color, zorder=zorder, alpha=alpha)
+
+
+def color_isvalid(ob, valid=BLUE, invalid=RED):
+ if ob.is_valid:
+ return valid
+ else:
+ return invalid
+
+
+def color_issimple(ob, simple=BLUE, complex=YELLOW):
+ if ob.is_simple:
+ return simple
+ else:
+ return complex
+
+
+def plot_line_isvalid(ax, ob, **kwargs):
+ kwargs["color"] = color_isvalid(ob)
+ plot_line(ax, ob, **kwargs)
+
+
+def plot_line_issimple(ax, ob, **kwargs):
+ kwargs["color"] = color_issimple(ob)
+ plot_line(ax, ob, **kwargs)
+
+
+def plot_bounds(ax, ob, zorder=1, alpha=1):
+ x, y = zip(*list((p.x, p.y) for p in ob.boundary))
+ ax.plot(x, y, 'o', color=BLACK, zorder=zorder, alpha=alpha)
+
+
+def add_origin(ax, geom, origin):
+ x, y = xy = affinity.interpret_origin(geom, origin, 2)
+ ax.plot(x, y, 'o', color=GRAY, zorder=1)
+ ax.annotate(str(xy), xy=xy, ha='center',
+ textcoords='offset points', xytext=(0, 8))
+
+
+def set_limits(ax, x0, xN, y0, yN):
+ ax.set_xlim(x0, xN)
+ ax.set_xticks(range(x0, xN+1))
+ ax.set_ylim(y0, yN)
+ ax.set_yticks(range(y0, yN+1))
+ ax.set_aspect("equal")
diff --git a/s3d_preprocess/misc/panorama.py b/s3d_preprocess/misc/panorama.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcbb17d26d5acc1b672656c6ee8a85c151603e6f
--- /dev/null
+++ b/s3d_preprocess/misc/panorama.py
@@ -0,0 +1,243 @@
+"""
+Copy from https://github.com/sunset1995/pytorch-layoutnet/blob/master/pano.py
+"""
+import numpy as np
+import numpy.matlib as matlib
+
+
+def xyz_2_coorxy(xs, ys, zs, H=512, W=1024):
+ us = np.arctan2(xs, ys)
+ vs = -np.arctan(zs / np.sqrt(xs**2 + ys**2))
+ coorx = (us / (2 * np.pi) + 0.5) * W
+ coory = (vs / np.pi + 0.5) * H
+ return coorx, coory
+
+
+def coords2uv(coords, width, height):
+ """
+ Image coordinates (xy) to uv
+ """
+ middleX = width / 2 + 0.5
+ middleY = height / 2 + 0.5
+ uv = np.hstack([
+ (coords[:, [0]] - middleX) / width * 2 * np.pi,
+ -(coords[:, [1]] - middleY) / height * np.pi])
+ return uv
+
+
+def uv2xyzN(uv, planeID=1):
+ ID1 = (int(planeID) - 1 + 0) % 3
+ ID2 = (int(planeID) - 1 + 1) % 3
+ ID3 = (int(planeID) - 1 + 2) % 3
+ xyz = np.zeros((uv.shape[0], 3))
+ xyz[:, ID1] = np.cos(uv[:, 1]) * np.sin(uv[:, 0])
+ xyz[:, ID2] = np.cos(uv[:, 1]) * np.cos(uv[:, 0])
+ xyz[:, ID3] = np.sin(uv[:, 1])
+ return xyz
+
+
+def uv2xyzN_vec(uv, planeID):
+ """
+ vectorization version of uv2xyzN
+ @uv N x 2
+ @planeID N
+ """
+ assert (planeID.astype(int) != planeID).sum() == 0
+ planeID = planeID.astype(int)
+ ID1 = (planeID - 1 + 0) % 3
+ ID2 = (planeID - 1 + 1) % 3
+ ID3 = (planeID - 1 + 2) % 3
+ ID = np.arange(len(uv))
+ xyz = np.zeros((len(uv), 3))
+ xyz[ID, ID1] = np.cos(uv[:, 1]) * np.sin(uv[:, 0])
+ xyz[ID, ID2] = np.cos(uv[:, 1]) * np.cos(uv[:, 0])
+ xyz[ID, ID3] = np.sin(uv[:, 1])
+ return xyz
+
+
+def xyz2uvN(xyz, planeID=1):
+ ID1 = (int(planeID) - 1 + 0) % 3
+ ID2 = (int(planeID) - 1 + 1) % 3
+ ID3 = (int(planeID) - 1 + 2) % 3
+ normXY = np.sqrt(xyz[:, [ID1]] ** 2 + xyz[:, [ID2]] ** 2)
+ normXY[normXY < 0.000001] = 0.000001
+ normXYZ = np.sqrt(xyz[:, [ID1]] ** 2 + xyz[:, [ID2]] ** 2 + xyz[:, [ID3]] ** 2)
+ v = np.arcsin(xyz[:, [ID3]] / normXYZ)
+ u = np.arcsin(xyz[:, [ID1]] / normXY)
+ valid = (xyz[:, [ID2]] < 0) & (u >= 0)
+ u[valid] = np.pi - u[valid]
+ valid = (xyz[:, [ID2]] < 0) & (u <= 0)
+ u[valid] = -np.pi - u[valid]
+ uv = np.hstack([u, v])
+ uv[np.isnan(uv[:, 0]), 0] = 0
+ return uv
+
+
+def computeUVN(n, in_, planeID):
+ """
+ compute v given u and normal.
+ """
+ if planeID == 2:
+ n = np.array([n[1], n[2], n[0]])
+ elif planeID == 3:
+ n = np.array([n[2], n[0], n[1]])
+ bc = n[0] * np.sin(in_) + n[1] * np.cos(in_)
+ bs = n[2]
+ out = np.arctan(-bc / (bs + 1e-9))
+ return out
+
+
+def computeUVN_vec(n, in_, planeID):
+ """
+ vectorization version of computeUVN
+ @n N x 3
+ @in_ MN x 1
+ @planeID N
+ """
+ n = n.copy()
+ if (planeID == 2).sum():
+ n[planeID == 2] = np.roll(n[planeID == 2], 2, axis=1)
+ if (planeID == 3).sum():
+ n[planeID == 3] = np.roll(n[planeID == 3], 1, axis=1)
+ n = np.repeat(n, in_.shape[0] // n.shape[0], axis=0)
+ assert n.shape[0] == in_.shape[0]
+ bc = n[:, [0]] * np.sin(in_) + n[:, [1]] * np.cos(in_)
+ bs = n[:, [2]]
+ out = np.arctan(-bc / (bs + 1e-9))
+ return out
+
+
+def lineFromTwoPoint(pt1, pt2):
+ """
+ Generate line segment based on two points on panorama
+ pt1, pt2: two points on panorama
+ line:
+ 1~3-th dim: normal of the line
+ 4-th dim: the projection dimension ID
+ 5~6-th dim: the u of line segment endpoints in projection plane
+ """
+ numLine = pt1.shape[0]
+ lines = np.zeros((numLine, 6))
+ n = np.cross(pt1, pt2)
+ n = n / (matlib.repmat(np.sqrt(np.sum(n ** 2, 1, keepdims=True)), 1, 3) + 1e-9)
+ lines[:, 0:3] = n
+
+ areaXY = np.abs(np.sum(n * matlib.repmat([0, 0, 1], numLine, 1), 1, keepdims=True))
+ areaYZ = np.abs(np.sum(n * matlib.repmat([1, 0, 0], numLine, 1), 1, keepdims=True))
+ areaZX = np.abs(np.sum(n * matlib.repmat([0, 1, 0], numLine, 1), 1, keepdims=True))
+ planeIDs = np.argmax(np.hstack([areaXY, areaYZ, areaZX]), axis=1) + 1
+ lines[:, 3] = planeIDs
+
+ for i in range(numLine):
+ uv = xyz2uvN(np.vstack([pt1[i, :], pt2[i, :]]), lines[i, 3])
+ umax = uv[:, 0].max() + np.pi
+ umin = uv[:, 0].min() + np.pi
+ if umax - umin > np.pi:
+ lines[i, 4:6] = np.array([umax, umin]) / 2 / np.pi
+ else:
+ lines[i, 4:6] = np.array([umin, umax]) / 2 / np.pi
+
+ return lines
+
+
+def lineIdxFromCors(cor_all, im_w, im_h):
+ assert len(cor_all) % 2 == 0
+ uv = coords2uv(cor_all, im_w, im_h)
+ xyz = uv2xyzN(uv)
+ lines = lineFromTwoPoint(xyz[0::2], xyz[1::2])
+ num_sample = max(im_h, im_w)
+
+ cs, rs = [], []
+ for i in range(lines.shape[0]):
+ n = lines[i, 0:3]
+ sid = lines[i, 4] * 2 * np.pi
+ eid = lines[i, 5] * 2 * np.pi
+ if eid < sid:
+ x = np.linspace(sid, eid + 2 * np.pi, num_sample)
+ x = x % (2 * np.pi)
+ else:
+ x = np.linspace(sid, eid, num_sample)
+
+ u = -np.pi + x.reshape(-1, 1)
+ v = computeUVN(n, u, lines[i, 3])
+ xyz = uv2xyzN(np.hstack([u, v]), lines[i, 3])
+ uv = xyz2uvN(xyz, 1)
+
+ r = np.minimum(np.floor((uv[:, 0] + np.pi) / (2 * np.pi) * im_w) + 1,
+ im_w).astype(np.int32)
+ c = np.minimum(np.floor((np.pi / 2 - uv[:, 1]) / np.pi * im_h) + 1,
+ im_h).astype(np.int32)
+ cs.extend(r - 1)
+ rs.extend(c - 1)
+ return rs, cs
+
+
+def draw_boundary_from_cor_id(cor_id, img_src):
+ im_h, im_w = img_src.shape[:2]
+ cor_all = [cor_id]
+ for i in range(len(cor_id)):
+ cor_all.append(cor_id[i, :])
+ cor_all.append(cor_id[(i+2) % len(cor_id), :])
+ cor_all = np.vstack(cor_all)
+
+ rs, cs = lineIdxFromCors(cor_all, im_w, im_h)
+ rs = np.array(rs)
+ cs = np.array(cs)
+
+ panoEdgeC = img_src.astype(np.uint8)
+ for dx, dy in [[-1, 0], [1, 0], [0, 0], [0, 1], [0, -1]]:
+ panoEdgeC[np.clip(rs + dx, 0, im_h - 1), np.clip(cs + dy, 0, im_w - 1), 0] = 0
+ panoEdgeC[np.clip(rs + dx, 0, im_h - 1), np.clip(cs + dy, 0, im_w - 1), 1] = 0
+ panoEdgeC[np.clip(rs + dx, 0, im_h - 1), np.clip(cs + dy, 0, im_w - 1), 2] = 255
+
+ return panoEdgeC
+
+
+def coorx2u(x, w=1024):
+ return ((x + 0.5) / w - 0.5) * 2 * np.pi
+
+
+def coory2v(y, h=512):
+ return ((y + 0.5) / h - 0.5) * np.pi
+
+
+def u2coorx(u, w=1024):
+ return (u / (2 * np.pi) + 0.5) * w - 0.5
+
+
+def v2coory(v, h=512):
+ return (v / np.pi + 0.5) * h - 0.5
+
+
+def uv2xy(u, v, z=-50):
+ c = z / np.tan(v)
+ x = c * np.cos(u)
+ y = c * np.sin(u)
+ return x, y
+
+
+def pano_connect_points(p1, p2, z=-50, w=1024, h=512):
+ u1 = coorx2u(p1[0], w)
+ v1 = coory2v(p1[1], h)
+ u2 = coorx2u(p2[0], w)
+ v2 = coory2v(p2[1], h)
+
+ x1, y1 = uv2xy(u1, v1, z)
+ x2, y2 = uv2xy(u2, v2, z)
+
+ if abs(p1[0] - p2[0]) < w / 2:
+ pstart = np.ceil(min(p1[0], p2[0]))
+ pend = np.floor(max(p1[0], p2[0]))
+ else:
+ pstart = np.ceil(max(p1[0], p2[0]))
+ pend = np.floor(min(p1[0], p2[0]) + w)
+ coorxs = (np.arange(pstart, pend + 1) % w).astype(np.float64)
+ vx = x2 - x1
+ vy = y2 - y1
+ us = coorx2u(coorxs, w)
+ ps = (np.tan(us) * x1 - y1) / (vy - np.tan(us) * vx)
+ cs = np.sqrt((x1 + ps * vx) ** 2 + (y1 + ps * vy) ** 2)
+ vs = np.arctan2(z, cs)
+ coorys = v2coory(vs)
+
+ return np.stack([coorxs, coorys], axis=-1)
diff --git a/s3d_preprocess/misc/utils.py b/s3d_preprocess/misc/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..14ad31780f17a657990d47f034a9a36dbc5b1744
--- /dev/null
+++ b/s3d_preprocess/misc/utils.py
@@ -0,0 +1,145 @@
+"""
+Adapted from https://github.com/thusiyuan/cooperative_scene_parsing/blob/master/utils/sunrgbd_utils.py
+"""
+import numpy as np
+
+
+def normalize(vector):
+ return vector / np.linalg.norm(vector)
+
+
+def parse_camera_info(camera_info, height, width, inverse=False):
+ """ extract intrinsic and extrinsic matrix
+ """
+ lookat = normalize(camera_info[3:6])
+ up = normalize(camera_info[6:9])
+
+ W = lookat
+ U = np.cross(W, up)
+ V = - np.cross(W, U)
+
+ if inverse:
+ rot = np.linalg.inv(np.vstack((U, -V, W)))
+ trans = camera_info[:3]
+ else:
+ rot = np.vstack((U, V, W))
+ trans = camera_info[:3]
+
+ xfov = camera_info[9]
+ yfov = camera_info[10]
+
+ K = np.diag([1, 1, 1]).astype(np.float32)
+
+ K[0, 2] = (width) / 2.
+ K[1, 2] = (height) / 2.
+
+ K[0, 0] = K[0, 2] / np.tan(xfov)
+ K[1, 1] = K[1, 2] / np.tan(yfov)
+
+# tan_half_fov = window_height_ / (intrinsic.intrinsic_matrix_(1, 1) * 2.0);
+# fov_rad = std::atan(tan_half_fov) * 2.0;
+
+ return rot, trans, K
+
+
+def flip_towards_viewer(normals, points):
+ points = points / np.linalg.norm(points)
+ proj = points.dot(normals[:2, :].T)
+ flip = np.where(proj > 0)
+ normals[flip, :] = -normals[flip, :]
+ return normals
+
+
+def get_corners_of_bb3d(basis, coeffs, centroid):
+ corners = np.zeros((8, 3))
+ # order the basis
+ index = np.argsort(np.abs(basis[:, 0]))[::-1]
+ # the case that two same value appear the same time
+ if index[2] != 2:
+ index[1:] = index[1:][::-1]
+ basis = basis[index, :]
+ coeffs = coeffs[index]
+ # Now, we know the basis vectors are orders X, Y, Z. Next, flip the basis vectors towards the viewer
+ basis = flip_towards_viewer(basis, centroid)
+ coeffs = np.abs(coeffs)
+ corners[0, :] = -basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+ corners[1, :] = basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+ corners[2, :] = basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+ corners[3, :] = -basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+
+ corners[4, :] = -basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners[5, :] = basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners[6, :] = basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners[7, :] = -basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners = corners + np.tile(centroid, (8, 1))
+ return corners
+
+
+def get_corners_of_bb3d_no_index(basis, coeffs, centroid):
+ corners = np.zeros((8, 3))
+ coeffs = np.abs(coeffs)
+ corners[0, :] = -basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+ corners[1, :] = basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+ corners[2, :] = basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+ corners[3, :] = -basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + basis[2, :] * coeffs[2]
+
+ corners[4, :] = -basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners[5, :] = basis[0, :] * coeffs[0] + basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners[6, :] = basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+ corners[7, :] = -basis[0, :] * coeffs[0] + -basis[1, :] * coeffs[1] + -basis[2, :] * coeffs[2]
+
+ corners = corners + np.tile(centroid, (8, 1))
+ return corners
+
+
+def project_3d_points_to_2d(points3d, R_ex, K):
+ """
+ Project 3d points from camera-centered coordinate to 2D image plane
+ Parameters
+ ----------
+ points3d: numpy array
+ 3d location of point
+ R_ex: numpy array
+ extrinsic camera parameter
+ K: numpy array
+ intrinsic camera parameter
+ Returns
+ -------
+ points2d: numpy array
+ 2d location of the point
+ """
+ points3d = R_ex.dot(points3d.T).T
+ x3 = points3d[:, 0]
+ y3 = -points3d[:, 1]
+ z3 = np.abs(points3d[:, 2])
+ xx = x3 * K[0, 0] / z3 + K[0, 2]
+ yy = y3 * K[1, 1] / z3 + K[1, 2]
+ points2d = np.vstack((xx, yy))
+ return points2d
+
+
+def project_struct_bdb_to_2d(basis, coeffs, center, R_ex, K):
+ """
+ Project 3d bounding box to 2d bounding box
+ Parameters
+ ----------
+ basis, coeffs, center, R_ex, K
+ : K is the intrinsic camera parameter matrix
+ : Rtilt is the extrinsic camera parameter matrix in right hand coordinates
+ Returns
+ -------
+ bdb2d: dict
+ Keys: {'x1', 'x2', 'y1', 'y2'}
+ The (x1, y1) position is at the top left corner,
+ the (x2, y2) position is at the bottom right corner
+ """
+ corners3d = get_corners_of_bb3d(basis, coeffs, center)
+ corners = project_3d_points_to_2d(corners3d, R_ex, K)
+ bdb2d = dict()
+ bdb2d['x1'] = int(max(np.min(corners[0, :]), 1)) # x1
+ bdb2d['y1'] = int(max(np.min(corners[1, :]), 1)) # y1
+ bdb2d['x2'] = int(min(np.max(corners[0, :]), 2*K[0, 2])) # x2
+ bdb2d['y2'] = int(min(np.max(corners[1, :]), 2*K[1, 2])) # y2
+ # if not check_bdb(bdb2d, 2*K[0, 2], 2*K[1, 2]):
+ # bdb2d = None
+ return bdb2d
diff --git a/s3d_preprocess/my_visualize_3d.py b/s3d_preprocess/my_visualize_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd9524de20662bb394301563b27132e61e436e59
--- /dev/null
+++ b/s3d_preprocess/my_visualize_3d.py
@@ -0,0 +1,398 @@
+import os
+import json
+import argparse
+
+import open3d
+import pymesh
+import numpy as np
+import matplotlib.pyplot as plt
+from shapely.geometry import Polygon
+from descartes.patch import PolygonPatch
+
+from misc.figures import plot_coords
+from misc.colors import colormap_255, semantics_cmap
+
+
+def visualize_wireframe(annos):
+ """visualize wireframe
+ """
+ colormap = np.array(colormap_255) / 255
+
+ junctions = np.array([item['coordinate'] for item in annos['junctions']])
+ _, junction_pairs = np.where(np.array(annos['lineJunctionMatrix']))
+ junction_pairs = junction_pairs.reshape(-1, 2)
+
+ # extract hole lines
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+ lines_holes = np.unique(lines_holes)
+
+ # extract cuboid lines
+ cuboid_lines = []
+ for cuboid in annos['cuboids']:
+ for planeID in cuboid['planeID']:
+ cuboid_lineID = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ cuboid_lines.extend(cuboid_lineID)
+ cuboid_lines = np.unique(cuboid_lines)
+ cuboid_lines = np.setdiff1d(cuboid_lines, lines_holes)
+
+ # visualize junctions
+ connected_junctions = junctions[np.unique(junction_pairs)]
+ connected_colors = np.repeat(colormap[0].reshape(1, 3), len(connected_junctions), axis=0)
+
+ junction_set = open3d.geometry.PointCloud()
+ junction_set.points = open3d.utility.Vector3dVector(connected_junctions)
+ junction_set.colors = open3d.utility.Vector3dVector(connected_colors)
+
+ # visualize line segments
+ line_colors = np.repeat(colormap[5].reshape(1, 3), len(junction_pairs), axis=0)
+
+ # color holes
+ if len(lines_holes) != 0:
+ line_colors[lines_holes] = colormap[6]
+
+ # color cuboids
+ if len(cuboid_lines) != 0:
+ line_colors[cuboid_lines] = colormap[2]
+
+ line_set = open3d.geometry.LineSet()
+ line_set.points = open3d.utility.Vector3dVector(junctions)
+ line_set.lines = open3d.utility.Vector2iVector(junction_pairs)
+ line_set.colors = open3d.utility.Vector3dVector(line_colors)
+
+ open3d.visualization.draw_geometries([junction_set, line_set])
+
+
+def project(x, meta):
+ """ project 3D to 2D for polygon clipping
+ """
+ proj_axis = max(range(3), key=lambda i: abs(meta['normal'][i]))
+
+ return tuple(c for i, c in enumerate(x) if i != proj_axis)
+
+
+def project_inv(x, meta):
+ """ recover 3D points from 2D
+ """
+ # Returns the vector w in the walls' plane such that project(w) equals x.
+ proj_axis = max(range(3), key=lambda i: abs(meta['normal'][i]))
+
+ w = list(x)
+ w[proj_axis:proj_axis] = [0.0]
+ c = -meta['offset']
+ for i in range(3):
+ c -= w[i] * meta['normal'][i]
+ c /= meta['normal'][proj_axis]
+ w[proj_axis] = c
+ return tuple(w)
+
+
+def triangulate(points):
+ """ triangulate the plane for operation and visualization
+ """
+
+ num_points = len(points)
+ indices = np.arange(num_points, dtype=np.int)
+ segments = np.vstack((indices, np.roll(indices, -1))).T
+
+ tri = pymesh.triangle()
+ tri.points = np.array(points)
+
+ tri.segments = segments
+ tri.verbosity = 0
+ tri.run()
+
+ return tri.mesh
+
+
+def clip_polygon(polygons, vertices_hole, junctions, meta):
+ """ clip polygon the hole
+ """
+ if len(polygons) == 1:
+ junctions = [junctions[vertex] for vertex in polygons[0]]
+ mesh_wall = triangulate(junctions)
+
+ vertices = np.array(mesh_wall.vertices)
+ faces = np.array(mesh_wall.faces)
+
+ return vertices, faces
+
+ else:
+ wall = []
+ holes = []
+ for polygon in polygons:
+ if np.any(np.intersect1d(polygon, vertices_hole)):
+ holes.append(polygon)
+ else:
+ wall.append(polygon)
+
+ # extract junctions on this plane
+ indices = []
+ junctions_wall = []
+ for plane in wall:
+ for vertex in plane:
+ indices.append(vertex)
+ junctions_wall.append(junctions[vertex])
+
+ junctions_holes = []
+ for plane in holes:
+ junctions_hole = []
+ for vertex in plane:
+ indices.append(vertex)
+ junctions_hole.append(junctions[vertex])
+ junctions_holes.append(junctions_hole)
+
+ junctions_wall = [project(x, meta) for x in junctions_wall]
+ junctions_holes = [[project(x, meta) for x in junctions_hole] for junctions_hole in junctions_holes]
+
+ mesh_wall = triangulate(junctions_wall)
+
+ for hole in junctions_holes:
+ mesh_hole = triangulate(hole)
+ mesh_wall = pymesh.boolean(mesh_wall, mesh_hole, 'difference')
+
+ vertices = [project_inv(vertex, meta) for vertex in mesh_wall.vertices]
+
+ return vertices, np.array(mesh_wall.faces)
+
+
+def draw_geometries_with_back_face(geometries):
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ render_option = vis.get_render_option()
+ render_option.mesh_show_back_face = True
+ for geometry in geometries:
+ vis.add_geometry(geometry)
+ vis.run()
+ vis.destroy_window()
+
+
+def convert_lines_to_vertices(lines):
+ """convert line representation to polygon vertices
+ """
+ polygons = []
+ lines = np.array(lines)
+
+ polygon = None
+ while len(lines) != 0:
+ if polygon is None:
+ polygon = lines[0].tolist()
+ lines = np.delete(lines, 0, 0)
+
+ lineID, juncID = np.where(lines == polygon[-1])
+ vertex = lines[lineID[0], 1 - juncID[0]]
+ lines = np.delete(lines, lineID, 0)
+
+ if vertex in polygon:
+ polygons.append(polygon)
+ polygon = None
+ else:
+ polygon.append(vertex)
+
+ return polygons
+
+
+def visualize_plane(annos, args, eps=0.9):
+ """visualize plane
+ """
+ colormap = np.array(colormap_255) / 255
+ junctions = [item['coordinate'] for item in annos['junctions']]
+
+ if args.color == 'manhattan':
+ manhattan = dict()
+ for planes in annos['manhattan']:
+ for planeID in planes['planeID']:
+ manhattan[planeID] = planes['ID']
+
+ # extract hole vertices
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+
+ lines_holes = np.unique(lines_holes)
+ _, vertices_holes = np.where(np.array(annos['lineJunctionMatrix'])[lines_holes])
+ vertices_holes = np.unique(vertices_holes)
+
+ # load polygons
+ polygons = []
+ for semantic in annos['semantics']:
+ for planeID in semantic['planeID']:
+ plane_anno = annos['planes'][planeID]
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ polygon = convert_lines_to_vertices(junction_pairs)
+ vertices, faces = clip_polygon(polygon, vertices_holes, junctions, plane_anno)
+ polygons.append([vertices, faces, planeID, plane_anno['normal'], plane_anno['type'], semantic['type']])
+
+ plane_set = []
+ plane_set_types = []
+ for i, (vertices, faces, planeID, normal, plane_type, semantic_type) in enumerate(polygons):
+ # ignore the room ceiling
+ # if semantic_type not in ['door', 'window']:
+ # if plane_type == 'ceiling' and semantic_type not in ['door', 'window']:
+ if semantic_type in ['door', 'window']:
+ continue
+
+ plane_vis = open3d.geometry.TriangleMesh()
+
+ plane_vis.vertices = open3d.utility.Vector3dVector(vertices)
+ plane_vis.triangles = open3d.utility.Vector3iVector(faces)
+
+ if args.color == 'normal':
+ if np.dot(normal, [1, 0, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[0])
+ elif np.dot(normal, [-1, 0, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[1])
+ elif np.dot(normal, [0, 1, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[2])
+ elif np.dot(normal, [0, -1, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[3])
+ elif np.dot(normal, [0, 0, 1]) > eps:
+ plane_vis.paint_uniform_color(colormap[4])
+ elif np.dot(normal, [0, 0, -1]) > eps:
+ plane_vis.paint_uniform_color(colormap[5])
+ else:
+ plane_vis.paint_uniform_color(colormap[6])
+ elif args.color == 'manhattan':
+ # paint each plane with manhattan world
+ if planeID not in manhattan.keys():
+ plane_vis.paint_uniform_color(colormap[6])
+ else:
+ plane_vis.paint_uniform_color(colormap[manhattan[planeID]])
+
+ plane_set.append(plane_vis)
+ plane_set_types.append(plane_type)
+
+ draw_geometries_with_back_face(plane_set)
+
+ save_path = args.savepath
+ if save_path:
+ import sem_seg_utils as ss
+
+ triangle_pcd_np_list = []
+ triangle_pcd_colors_np_list = []
+ for plane_triangle, plane_type in zip(plane_set, plane_set_types):
+ triangle_pcd = plane_triangle.sample_points_uniformly(number_of_points=100)
+ triangle_pcd_np = np.array(triangle_pcd.points)
+ triangle_pcd_color_np = np.ones_like(triangle_pcd_np) * ss.class_name_to_id[plane_type]
+ triangle_pcd_np_list.append(triangle_pcd_np)
+ triangle_pcd_colors_np_list.append(triangle_pcd_color_np)
+
+ final_pcd_np = np.concatenate(triangle_pcd_np_list, axis=0)
+ final_pcd_colors_np = np.concatenate(triangle_pcd_colors_np_list, axis=0)
+ o3d_final_pcd = open3d.geometry.PointCloud()
+ o3d_final_pcd.points = open3d.utility.Vector3dVector(final_pcd_np)
+ o3d_final_pcd.colors = open3d.utility.Vector3dVector(final_pcd_colors_np / 255.)
+
+ scene_name = "scene_" + "{0:0=5}".format(args.scene)
+
+ file_name = save_path + "/" + scene_name + "/" + scene_name + "_segmented.ply"
+ open3d.io.write_point_cloud(file_name, o3d_final_pcd)
+
+
+
+def plot_floorplan(annos, polygons):
+ """plot floorplan
+ """
+ fig = plt.figure()
+ ax = fig.add_subplot(1, 1, 1)
+
+ junctions = np.array([junc['coordinate'][:2] for junc in annos['junctions']])
+ for (polygon, poly_type) in polygons:
+ polygon = Polygon(junctions[np.array(polygon)])
+ plot_coords(ax, polygon.exterior, alpha=0.5)
+ if poly_type == 'outwall':
+ patch = PolygonPatch(polygon, facecolor=semantics_cmap[poly_type], alpha=0)
+ else:
+ patch = PolygonPatch(polygon, facecolor=semantics_cmap[poly_type], alpha=0.5)
+ ax.add_patch(patch)
+
+ plt.axis('equal')
+ plt.axis('off')
+ plt.show()
+
+
+def visualize_floorplan(annos):
+ """visualize floorplan
+ """
+ # extract the floor in each semantic for floorplan visualization
+ planes = []
+ for semantic in annos['semantics']:
+ for planeID in semantic['planeID']:
+ if annos['planes'][planeID]['type'] == 'floor':
+ planes.append({'planeID': planeID, 'type': semantic['type']})
+
+ if semantic['type'] == 'outwall':
+ outerwall_planes = semantic['planeID']
+
+ # extract hole vertices
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+ lines_holes = np.unique(lines_holes)
+
+ # junctions on the floor
+ junctions = np.array([junc['coordinate'] for junc in annos['junctions']])
+ junction_floor = np.where(np.isclose(junctions[:, -1], 0))[0]
+
+ # construct each polygon
+ polygons = []
+ for plane in planes:
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][plane['planeID']]))[0].tolist()
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ polygon = convert_lines_to_vertices(junction_pairs)
+ polygons.append([polygon[0], plane['type']])
+
+ outerwall_floor = []
+ for planeID in outerwall_planes:
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ lineIDs = np.setdiff1d(lineIDs, lines_holes)
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ for start, end in junction_pairs:
+ if start in junction_floor and end in junction_floor:
+ outerwall_floor.append([start, end])
+
+ outerwall_polygon = convert_lines_to_vertices(outerwall_floor)
+ polygons.append([outerwall_polygon[0], 'outwall'])
+
+ plot_floorplan(annos, polygons)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Structured3D 3D Visualization")
+ parser.add_argument("--path", required=True,
+ help="dataset path", metavar="DIR")
+ parser.add_argument("--scene", required=True,
+ help="scene id", type=int)
+ parser.add_argument("--type", choices=("floorplan", "wireframe", "plane"),
+ default="plane", type=str)
+ parser.add_argument("--color", choices=["normal", "manhattan"],
+ default="normal", type=str)
+ parser.add_argument("--savepath", type=str)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ # load annotations from json
+ with open(os.path.join(args.path, f"scene_{args.scene:05d}", "annotation_3d.json")) as file:
+ annos = json.load(file)
+
+ if args.type == "wireframe":
+ visualize_wireframe(annos)
+ elif args.type == "plane":
+ visualize_plane(annos, args)
+ elif args.type == "floorplan":
+ visualize_floorplan(annos)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/s3d_preprocess/organize_data.py b/s3d_preprocess/organize_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d39bdb76253d394fc42c96b3bbb4164ada6eeb
--- /dev/null
+++ b/s3d_preprocess/organize_data.py
@@ -0,0 +1,46 @@
+import os
+from shutil import copyfile
+
+
+data_base = './montefloor_data/'
+dir_names = list(sorted(os.listdir(data_base)))
+out_path = './s3d_floorplan'
+
+wrong_s3d_annotations_list = [3261, 3271, 3276, 3296, 3342, 3387, 3398, 3466, 3496]
+
+train_list = []
+val_list = []
+test_list = []
+
+for dir_name in dir_names:
+ data_dir = os.path.join(data_base, dir_name)
+ annot_path = os.path.join(data_dir, 'annot.npy')
+ if not os.path.exists(annot_path):
+ continue
+ data_id = int(dir_name[-5:])
+ if data_id in wrong_s3d_annotations_list:
+ continue
+ annot_dst = os.path.join(out_path, 'annot', dir_name[-5:] + '.npy')
+ density_dst = os.path.join(out_path, 'density', dir_name[-5:] + '.png')
+ normal_dst = os.path.join(out_path, 'normals', dir_name[-5:] + '.png')
+ density_src = os.path.join(data_dir, 'density.png')
+ normal_src = os.path.join(data_dir, 'normals.png')
+ copyfile(normal_src, normal_dst)
+ copyfile(density_src, density_dst)
+ copyfile(annot_path, annot_dst)
+ if 0 <= data_id < 3000:
+ train_list.append(dir_name[-5:])
+ elif data_id < 3250:
+ val_list.append(dir_name[-5:])
+ else:
+ test_list.append(dir_name[-5:])
+
+with open(os.path.join(out_path, 'train_list.txt'), 'w') as f:
+ for item in train_list:
+ f.write(item + '\n')
+with open(os.path.join(out_path, 'valid_list.txt'), 'w') as f:
+ for item in val_list:
+ f.write(item + '\n')
+with open(os.path.join(out_path, 'test_list.txt'), 'w') as f:
+ for item in test_list:
+ f.write(item + '\n')
diff --git a/s3d_preprocess/sem_seg_utils.py b/s3d_preprocess/sem_seg_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..689fd4fcfc8f080a9147b2ca64e9d06765125e07
--- /dev/null
+++ b/s3d_preprocess/sem_seg_utils.py
@@ -0,0 +1,103 @@
+import numpy as np
+
+wall_type = 1
+floor_type = 2
+ceil_type = 22
+
+SCANNET_COLORMAP = [
+[0., 0., 0.],
+[174., 199., 232.],
+[152., 223., 138.],
+[31., 119., 180.],
+[255., 187., 120.],
+[188., 189., 34.],
+[140., 86., 75.],
+[255., 152., 150.],
+[214., 39., 40.],
+[197., 176., 213.],
+[148., 103., 189.],
+[196., 156., 148.],
+[23., 190., 207.],
+[247., 182., 210.],
+[66., 188., 102.],
+[219., 219., 141.],
+[140., 57., 197.],
+[202., 185., 52.],
+[51., 176., 203.],
+[200., 54., 131.],
+[92., 193., 61.],
+[78., 71., 183.],
+[172., 114., 82.],
+[255., 127., 14.],
+[91., 163., 138.],
+[153., 98., 156.],
+[140., 153., 101.],
+[158., 218., 229.],
+[100., 125., 154.],
+[178., 127., 135.],
+[146., 111., 194.],
+[44., 160., 44.],
+[112., 128., 144.],
+[96., 207., 209.],
+[227., 119., 194.],
+[213., 92., 176.],
+[94., 106., 211.],
+[82., 84., 163.],
+[100., 85., 144.]]
+
+SCANNET_COLORMAP = np.asarray(SCANNET_COLORMAP) / 255.
+
+class_names = []
+class_name_to_id = {}
+for i, line in enumerate(open("label_names.txt").readlines()):
+ class_id = i # starts with -1
+ class_name = line.strip()
+ class_name_to_id[class_name] = class_id
+ class_names.append(class_name)
+class_names = tuple(class_names)
+
+# color palette for nyu40 labels
+def create_color_palette():
+ return [
+ (0, 0, 0),
+ (174, 199, 232), # wall
+ (152, 223, 138), # floor
+ (31, 119, 180), # cabinet
+ (255, 187, 120), # bed
+ (188, 189, 34), # chair
+ (140, 86, 75), # sofa
+ (255, 152, 150), # table
+ (214, 39, 40), # door
+ (197, 176, 213), # window
+ (148, 103, 189), # bookshelf
+ (196, 156, 148), # picture
+ (23, 190, 207), # counter
+ (178, 76, 76),
+ (247, 182, 210), # desk
+ (66, 188, 102),
+ (219, 219, 141), # curtain
+ (140, 57, 197),
+ (202, 185, 52),
+ (51, 176, 203),
+ (200, 54, 131),
+ (92, 193, 61),
+ (78, 71, 183),
+ (172, 114, 82),
+ (255, 127, 14), # refrigerator
+ (91, 163, 138),
+ (153, 98, 156),
+ (140, 153, 101),
+ (158, 218, 229), # shower curtain
+ (100, 125, 154),
+ (178, 127, 135),
+ (120, 185, 128),
+ (146, 111, 194),
+ (44, 160, 44), # toilet
+ (112, 128, 144), # sink
+ (96, 207, 209),
+ (227, 119, 194), # bathtub
+ (213, 92, 176),
+ (94, 106, 211),
+ (82, 84, 163), # otherfurn
+ (100, 85, 144)
+ ]
\ No newline at end of file
diff --git a/s3d_preprocess/visualize_3d.py b/s3d_preprocess/visualize_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..963d3894461d5d3c2bdcabd8e5d7167709d7a267
--- /dev/null
+++ b/s3d_preprocess/visualize_3d.py
@@ -0,0 +1,396 @@
+import os
+import json
+import argparse
+
+import open3d
+# import pymesh
+import numpy as np
+import matplotlib.pyplot as plt
+from shapely.geometry import Polygon
+from descartes.patch import PolygonPatch
+
+from misc.figures import plot_coords
+from misc.colors import colormap_255, semantics_cmap
+
+
+def visualize_wireframe(annos, vis=True, ret=False):
+ """visualize wireframe
+ """
+ colormap = np.array(colormap_255) / 255
+
+ junctions = np.array([item['coordinate'] for item in annos['junctions']])
+ _, junction_pairs = np.where(np.array(annos['lineJunctionMatrix']))
+ junction_pairs = junction_pairs.reshape(-1, 2)
+
+ # extract hole lines
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+ lines_holes = np.unique(lines_holes)
+
+ # extract cuboid lines
+ cuboid_lines = []
+ for cuboid in annos['cuboids']:
+ for planeID in cuboid['planeID']:
+ cuboid_lineID = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ cuboid_lines.extend(cuboid_lineID)
+ cuboid_lines = np.unique(cuboid_lines)
+ cuboid_lines = np.setdiff1d(cuboid_lines, lines_holes)
+
+ # visualize junctions
+ connected_junctions = junctions[np.unique(junction_pairs)]
+ connected_colors = np.repeat(colormap[0].reshape(1, 3), len(connected_junctions), axis=0)
+
+ junction_set = open3d.geometry.PointCloud()
+ junction_set.points = open3d.utility.Vector3dVector(connected_junctions)
+ junction_set.colors = open3d.utility.Vector3dVector(connected_colors)
+
+ # visualize line segments
+ line_colors = np.repeat(colormap[5].reshape(1, 3), len(junction_pairs), axis=0)
+
+ # color holes
+ if len(lines_holes) != 0:
+ line_colors[lines_holes] = colormap[6]
+
+ # color cuboids
+ if len(cuboid_lines) != 0:
+ line_colors[cuboid_lines] = colormap[2]
+
+ line_set = open3d.geometry.LineSet()
+ line_set.points = open3d.utility.Vector3dVector(junctions)
+ line_set.lines = open3d.utility.Vector2iVector(junction_pairs)
+ line_set.colors = open3d.utility.Vector3dVector(line_colors)
+
+ if vis:
+ open3d.visualization.draw_geometries([junction_set, line_set])
+
+ if ret:
+ return [junction_set, line_set]
+
+def project(x, meta):
+ """ project 3D to 2D for polygon clipping
+ """
+ proj_axis = max(range(3), key=lambda i: abs(meta['normal'][i]))
+
+ return tuple(c for i, c in enumerate(x) if i != proj_axis)
+
+
+def project_inv(x, meta):
+ """ recover 3D points from 2D
+ """
+ # Returns the vector w in the walls' plane such that project(w) equals x.
+ proj_axis = max(range(3), key=lambda i: abs(meta['normal'][i]))
+
+ w = list(x)
+ w[proj_axis:proj_axis] = [0.0]
+ c = -meta['offset']
+ for i in range(3):
+ c -= w[i] * meta['normal'][i]
+ c /= meta['normal'][proj_axis]
+ w[proj_axis] = c
+ return tuple(w)
+
+
+def triangulate(points):
+ """ triangulate the plane for operation and visualization
+ """
+
+ num_points = len(points)
+ indices = np.arange(num_points, dtype=np.int)
+ segments = np.vstack((indices, np.roll(indices, -1))).T
+
+ tri = pymesh.triangle()
+ tri.points = np.array(points)
+
+ tri.segments = segments
+ tri.verbosity = 0
+ tri.run()
+
+ return tri.mesh
+
+
+def clip_polygon(polygons, vertices_hole, junctions, meta):
+ """ clip polygon the hole
+ """
+ if len(polygons) == 1:
+ junctions = [junctions[vertex] for vertex in polygons[0]]
+ mesh_wall = triangulate(junctions)
+
+ vertices = np.array(mesh_wall.vertices)
+ faces = np.array(mesh_wall.faces)
+
+ return vertices, faces
+
+ else:
+ wall = []
+ holes = []
+ for polygon in polygons:
+ if np.any(np.intersect1d(polygon, vertices_hole)):
+ holes.append(polygon)
+ else:
+ wall.append(polygon)
+
+ # extract junctions on this plane
+ indices = []
+ junctions_wall = []
+ for plane in wall:
+ for vertex in plane:
+ indices.append(vertex)
+ junctions_wall.append(junctions[vertex])
+
+ junctions_holes = []
+ for plane in holes:
+ junctions_hole = []
+ for vertex in plane:
+ indices.append(vertex)
+ junctions_hole.append(junctions[vertex])
+ junctions_holes.append(junctions_hole)
+
+ junctions_wall = [project(x, meta) for x in junctions_wall]
+ junctions_holes = [[project(x, meta) for x in junctions_hole] for junctions_hole in junctions_holes]
+
+ mesh_wall = triangulate(junctions_wall)
+
+ for hole in junctions_holes:
+ mesh_hole = triangulate(hole)
+ mesh_wall = pymesh.boolean(mesh_wall, mesh_hole, 'difference')
+
+ vertices = [project_inv(vertex, meta) for vertex in mesh_wall.vertices]
+
+ return vertices, np.array(mesh_wall.faces)
+
+
+def draw_geometries_with_back_face(geometries):
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ render_option = vis.get_render_option()
+ render_option.mesh_show_back_face = True
+ for geometry in geometries:
+ vis.add_geometry(geometry)
+ vis.run()
+ vis.destroy_window()
+
+
+def convert_lines_to_vertices(lines):
+ """convert line representation to polygon vertices
+ """
+ polygons = []
+ lines = np.array(lines)
+
+ polygon = None
+ while len(lines) != 0:
+ if polygon is None:
+ polygon = lines[0].tolist()
+ lines = np.delete(lines, 0, 0)
+
+ lineID, juncID = np.where(lines == polygon[-1])
+ vertex = lines[lineID[0], 1 - juncID[0]]
+ lines = np.delete(lines, lineID, 0)
+
+ if vertex in polygon:
+ polygons.append(polygon)
+ polygon = None
+ else:
+ polygon.append(vertex)
+
+ return polygons
+
+
+def visualize_plane(annos, args, eps=0.9):
+ """visualize plane
+ """
+ colormap = np.array(colormap_255) / 255
+ junctions = [item['coordinate'] for item in annos['junctions']]
+
+ if args.color == 'manhattan':
+ manhattan = dict()
+ for planes in annos['manhattan']:
+ for planeID in planes['planeID']:
+ manhattan[planeID] = planes['ID']
+
+ # extract hole vertices
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+
+ lines_holes = np.unique(lines_holes)
+ _, vertices_holes = np.where(np.array(annos['lineJunctionMatrix'])[lines_holes])
+ vertices_holes = np.unique(vertices_holes)
+
+
+ # load polygons
+ polygons = []
+ for semantic in annos['semantics']:
+ for planeID in semantic['planeID']:
+ plane_anno = annos['planes'][planeID]
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ polygon = convert_lines_to_vertices(junction_pairs)
+ vertices, faces = clip_polygon(polygon, vertices_holes, junctions, plane_anno)
+ polygons.append([vertices, faces, planeID, plane_anno['normal'], plane_anno['type'], semantic['type']])
+
+ plane_set = []
+ for i, (vertices, faces, planeID, normal, plane_type, semantic_type) in enumerate(polygons):
+ # ignore the room ceiling
+ # if plane_type == 'ceiling' and semantic_type not in ['door', 'window']:
+ if plane_type == 'ceiling' or semantic_type in ['door', 'window']:
+ continue
+
+ plane_vis = open3d.geometry.TriangleMesh()
+
+ plane_vis.vertices = open3d.utility.Vector3dVector(vertices)
+ plane_vis.triangles = open3d.utility.Vector3iVector(faces)
+
+ if args.color == 'normal':
+ if np.dot(normal, [1, 0, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[0])
+ elif np.dot(normal, [-1, 0, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[1])
+ elif np.dot(normal, [0, 1, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[2])
+ elif np.dot(normal, [0, -1, 0]) > eps:
+ plane_vis.paint_uniform_color(colormap[3])
+ elif np.dot(normal, [0, 0, 1]) > eps:
+ plane_vis.paint_uniform_color(colormap[4])
+ elif np.dot(normal, [0, 0, -1]) > eps:
+ plane_vis.paint_uniform_color(colormap[5])
+ else:
+ plane_vis.paint_uniform_color(colormap[6])
+ elif args.color == 'manhattan':
+ # paint each plane with manhattan world
+ if planeID not in manhattan.keys():
+ plane_vis.paint_uniform_color(colormap[6])
+ else:
+ plane_vis.paint_uniform_color(colormap[manhattan[planeID]])
+
+ plane_set.append(plane_vis)
+
+ draw_geometries_with_back_face(plane_set)
+
+def plot_floorplan(annos, polygons):
+ """plot floorplan
+ """
+ fig = plt.figure()
+ ax = fig.add_subplot(1, 1, 1)
+
+ junctions = np.array([junc['coordinate'][:2] for junc in annos['junctions']])
+ for (polygon, poly_type) in polygons:
+ polygon = Polygon(junctions[np.array(polygon)])
+ plot_coords(ax, polygon.exterior, alpha=0.5)
+ if poly_type == 'outwall':
+ patch = PolygonPatch(polygon, facecolor=semantics_cmap[poly_type], alpha=0)
+ else:
+ patch = PolygonPatch(polygon, facecolor=semantics_cmap[poly_type], alpha=0.5)
+ ax.add_patch(patch)
+
+ plt.axis('equal')
+ plt.axis('off')
+ plt.show()
+
+
+def visualize_floorplan(annos, vis=True, ret=False):
+ """visualize floorplan
+ """
+ # extract the floor in each semantic for floorplan visualization
+ planes = []
+ for semantic in annos['semantics']:
+ for planeID in semantic['planeID']:
+ if annos['planes'][planeID]['type'] == 'floor':
+ planes.append({'planeID': planeID, 'type': semantic['type']})
+
+ if semantic['type'] == 'outwall':
+ outerwall_planes = semantic['planeID']
+
+ # extract hole vertices
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+ lines_holes = np.unique(lines_holes)
+
+ # junctions on the floor
+ junctions = np.array([junc['coordinate'] for junc in annos['junctions']])
+ junction_floor = np.where(np.isclose(junctions[:, -1], 0))[0]
+
+ # construct each polygon
+ polygons = []
+ for plane in planes:
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][plane['planeID']]))[0].tolist()
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ polygon = convert_lines_to_vertices(junction_pairs)
+ polygons.append([polygon[0], plane['type']])
+
+
+ outerwall_floor = []
+ valid_outer_wall = True
+ for planeID in outerwall_planes:
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist()
+ lineIDs = np.setdiff1d(lineIDs, lines_holes)
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+
+ for jp in junction_pairs:
+ if len(jp) != 2:
+ valid_outer_wall = False
+ break
+
+ if not valid_outer_wall:
+ break
+
+ for jp in junction_pairs:
+ if len(jp) != 2:
+ continue
+ start, end = jp
+ if start in junction_floor and end in junction_floor:
+ outerwall_floor.append([start, end])
+
+ if valid_outer_wall:
+ outerwall_polygon = convert_lines_to_vertices(outerwall_floor)
+ polygons.append([outerwall_polygon[0], 'outwall'])
+ else:
+ polygons = None
+ if ret:
+ return polygons
+
+ if vis:
+ plot_floorplan(annos, polygons)
+
+ if ret:
+ return polygons
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Structured3D 3D Visualization")
+ parser.add_argument("--path", required=True,
+ help="dataset path", metavar="DIR")
+ parser.add_argument("--scene", required=True,
+ help="scene id", type=int)
+ parser.add_argument("--type", choices=("floorplan", "wireframe", "plane"),
+ default="plane", type=str)
+ parser.add_argument("--color", choices=["normal", "manhattan"],
+ default="normal", type=str)
+ parser.add_argument("--savepath", type=str)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ # load annotations from json
+ with open(os.path.join(args.path, f"scene_{args.scene:05d}", "annotation_3d.json")) as file:
+ annos = json.load(file)
+
+ if args.type == "wireframe":
+ visualize_wireframe(annos)
+ elif args.type == "plane":
+ visualize_plane(annos, args)
+ elif args.type == "floorplan":
+ visualize_floorplan(annos)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/s3d_preprocess/visualize_bbox.py b/s3d_preprocess/visualize_bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d73a3d4f1587d37ec77db4a2557a453929e1c93
--- /dev/null
+++ b/s3d_preprocess/visualize_bbox.py
@@ -0,0 +1,87 @@
+import os
+import json
+import argparse
+
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+
+from misc.utils import get_corners_of_bb3d_no_index, project_3d_points_to_2d, parse_camera_info
+
+
+def visualize_bbox(args):
+ with open(os.path.join(args.path, f"scene_{args.scene:05d}", "bbox_3d.json")) as file:
+ annos = json.load(file)
+
+ id2index = dict()
+ for index, object in enumerate(annos):
+ id2index[object.get('ID')] = index
+
+ scene_path = os.path.join(args.path, f"scene_{args.scene:05d}", "2D_rendering")
+
+ for room_id in np.sort(os.listdir(scene_path)):
+ room_path = os.path.join(scene_path, room_id, "perspective", "full")
+
+ if not os.path.exists(room_path):
+ continue
+
+ for position_id in np.sort(os.listdir(room_path)):
+ position_path = os.path.join(room_path, position_id)
+
+ image = cv2.imread(os.path.join(position_path, 'rgb_rawlight.png'))
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ height, width, _ = image.shape
+
+ instance = cv2.imread(os.path.join(position_path, 'instance.png'), cv2.IMREAD_UNCHANGED)
+
+ camera_info = np.loadtxt(os.path.join(position_path, 'camera_pose.txt'))
+
+ rot, trans, K = parse_camera_info(camera_info, height, width)
+
+ plt.figure()
+ plt.imshow(image)
+
+ for index in np.unique(instance)[:-1]:
+ # for each instance in current image
+ bbox = annos[id2index[index]]
+
+ basis = np.array(bbox['basis'])
+ coeffs = np.array(bbox['coeffs'])
+ centroid = np.array(bbox['centroid'])
+
+ corners = get_corners_of_bb3d_no_index(basis, coeffs, centroid)
+ corners = corners - trans
+
+ gt2dcorners = project_3d_points_to_2d(corners, rot, K)
+
+ num_corner = gt2dcorners.shape[1] // 2
+ plt.plot(np.hstack((gt2dcorners[0, :num_corner], gt2dcorners[0, 0])),
+ np.hstack((gt2dcorners[1, :num_corner], gt2dcorners[1, 0])), 'r')
+ plt.plot(np.hstack((gt2dcorners[0, num_corner:], gt2dcorners[0, num_corner])),
+ np.hstack((gt2dcorners[1, num_corner:], gt2dcorners[1, num_corner])), 'b')
+ for i in range(num_corner):
+ plt.plot(gt2dcorners[0, [i, i + num_corner]], gt2dcorners[1, [i, i + num_corner]], 'y')
+
+ plt.axis('off')
+ plt.axis([0, width, height, 0])
+ plt.show()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Structured3D 3D Bounding Box Visualization")
+ parser.add_argument("--path", required=True,
+ help="dataset path", metavar="DIR")
+ parser.add_argument("--scene", required=True,
+ help="scene id", type=int)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ visualize_bbox(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/s3d_preprocess/visualize_layout.py b/s3d_preprocess/visualize_layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a3b785afd159b3b843526e53d0b15773493540
--- /dev/null
+++ b/s3d_preprocess/visualize_layout.py
@@ -0,0 +1,93 @@
+import os
+import json
+import argparse
+
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+from shapely.geometry import Polygon
+from descartes.patch import PolygonPatch
+
+from misc.panorama import draw_boundary_from_cor_id
+from misc.colors import colormap_255
+
+
+def visualize_panorama(args):
+ """visualize panorama layout
+ """
+ scene_path = os.path.join(args.path, f"scene_{args.scene:05d}", "2D_rendering")
+
+ for room_id in np.sort(os.listdir(scene_path)):
+ room_path = os.path.join(scene_path, room_id, "panorama")
+
+ cor_id = np.loadtxt(os.path.join(room_path, "layout.txt"))
+ img_src = cv2.imread(os.path.join(room_path, "full", "rgb_rawlight.png"))
+ img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB)
+ img_viz = draw_boundary_from_cor_id(cor_id, img_src)
+
+ plt.axis('off')
+ plt.imshow(img_viz)
+ plt.show()
+
+
+def visualize_perspective(args):
+ """visualize perspective layout
+ """
+ colors = np.array(colormap_255) / 255
+
+ scene_path = os.path.join(args.path, f"scene_{args.scene:05d}", "2D_rendering")
+
+ for room_id in np.sort(os.listdir(scene_path)):
+ room_path = os.path.join(scene_path, room_id, "perspective", "full")
+
+ if not os.path.exists(room_path):
+ continue
+
+ for position_id in np.sort(os.listdir(room_path)):
+ position_path = os.path.join(room_path, position_id)
+
+ image = cv2.imread(os.path.join(position_path, "rgb_rawlight.png"))
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ with open(os.path.join(position_path, "layout.json")) as f:
+ annos = json.load(f)
+
+ fig = plt.figure()
+ for i, key in enumerate(['amodal_mask', 'visible_mask']):
+ ax = fig.add_subplot(2, 1, i + 1)
+ plt.axis('off')
+ plt.imshow(image)
+
+ for i, planes in enumerate(annos['planes']):
+ if len(planes[key]):
+ for plane in planes[key]:
+ polygon = Polygon([annos['junctions'][id]['coordinate'] for id in plane])
+ patch = PolygonPatch(polygon, facecolor=colors[i], alpha=0.5)
+ ax.add_patch(patch)
+
+ plt.title(key)
+ plt.show()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Structured3D 2D Layout Visualization")
+ parser.add_argument("--path", required=True,
+ help="dataset path", metavar="DIR")
+ parser.add_argument("--scene", required=True,
+ help="scene id", type=int)
+ parser.add_argument("--type", choices=["perspective", "panorama"], required=True,
+ help="type of camera", type=str)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ if args.type == 'panorama':
+ visualize_panorama(args)
+ elif args.type == 'perspective':
+ visualize_perspective(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/s3d_preprocess/visualize_mesh.py b/s3d_preprocess/visualize_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a578e9ca5efa5aa96070189fc402076d51b55b7
--- /dev/null
+++ b/s3d_preprocess/visualize_mesh.py
@@ -0,0 +1,262 @@
+import os
+import json
+import argparse
+
+import cv2
+import open3d
+import numpy as np
+from panda3d.core import Triangulator
+
+from misc.panorama import xyz_2_coorxy
+from visualize_3d import convert_lines_to_vertices
+
+
+def E2P(image, corner_i, corner_j, wall_height, camera, resolution=512, is_wall=True):
+ """convert panorama to persepctive image
+ """
+ corner_i = corner_i - camera
+ corner_j = corner_j - camera
+
+ if is_wall:
+ xs = np.linspace(corner_i[0], corner_j[0], resolution)[None].repeat(resolution, 0)
+ ys = np.linspace(corner_i[1], corner_j[1], resolution)[None].repeat(resolution, 0)
+ zs = np.linspace(-camera[-1], wall_height - camera[-1], resolution)[:, None].repeat(resolution, 1)
+ else:
+ xs = np.linspace(corner_i[0], corner_j[0], resolution)[None].repeat(resolution, 0)
+ ys = np.linspace(corner_i[1], corner_j[1], resolution)[:, None].repeat(resolution, 1)
+ zs = np.zeros_like(xs) + wall_height - camera[-1]
+
+ coorx, coory = xyz_2_coorxy(xs, ys, zs)
+
+ persp = cv2.remap(image, coorx.astype(np.float32), coory.astype(np.float32),
+ cv2.INTER_CUBIC, borderMode=cv2.BORDER_WRAP)
+
+ return persp
+
+
+def create_plane_mesh(vertices, vertices_floor, textures, texture_floor, texture_ceiling,
+ delta_height, ignore_ceiling=False):
+ # create mesh for 3D floorplan visualization
+ triangles = []
+ triangle_uvs = []
+
+ # the number of vertical walls
+ num_walls = len(vertices)
+
+ # 1. vertical wall (always rectangle)
+ num_vertices = 0
+ for i in range(len(vertices)):
+ # hardcode triangles for each vertical wall
+ triangle = np.array([[0, 2, 1], [2, 0, 3]])
+ triangles.append(triangle + num_vertices)
+ num_vertices += 4
+
+ triangle_uv = np.array(
+ [
+ [i / (num_walls + 2), 0],
+ [i / (num_walls + 2), 1],
+ [(i+1) / (num_walls + 2), 1],
+ [(i+1) / (num_walls + 2), 0]
+ ],
+ dtype=np.float32
+ )
+ triangle_uvs.append(triangle_uv)
+
+ # 2. floor and ceiling
+ # Since the floor and ceiling may not be a rectangle, triangulate the polygon first.
+ tri = Triangulator()
+ for i in range(len(vertices_floor)):
+ tri.add_vertex(vertices_floor[i, 0], vertices_floor[i, 1])
+
+ for i in range(len(vertices_floor)):
+ tri.add_polygon_vertex(i)
+
+ tri.triangulate()
+
+ # polygon triangulation
+ triangle = []
+ for i in range(tri.getNumTriangles()):
+ triangle.append([tri.get_triangle_v0(i), tri.get_triangle_v1(i), tri.get_triangle_v2(i)])
+ triangle = np.array(triangle)
+
+ # add triangles for floor and ceiling
+ triangles.append(triangle + num_vertices)
+ num_vertices += len(np.unique(triangle))
+ if not ignore_ceiling:
+ triangles.append(triangle + num_vertices)
+
+ # texture for floor and ceiling
+ vertices_floor_min = np.min(vertices_floor[:, :2], axis=0)
+ vertices_floor_max = np.max(vertices_floor[:, :2], axis=0)
+
+ # normalize to [0, 1]
+ triangle_uv = (vertices_floor[:, :2] - vertices_floor_min) / (vertices_floor_max - vertices_floor_min)
+ triangle_uv[:, 0] = (triangle_uv[:, 0] + num_walls) / (num_walls + 2)
+
+ triangle_uvs.append(triangle_uv)
+
+ # normalize to [0, 1]
+ triangle_uv = (vertices_floor[:, :2] - vertices_floor_min) / (vertices_floor_max - vertices_floor_min)
+ triangle_uv[:, 0] = (triangle_uv[:, 0] + num_walls + 1) / (num_walls + 2)
+
+ triangle_uvs.append(triangle_uv)
+
+ # 3. Merge wall, floor, and ceiling
+ vertices.append(vertices_floor)
+ vertices.append(vertices_floor + delta_height)
+ vertices = np.concatenate(vertices, axis=0)
+
+ triangles = np.concatenate(triangles, axis=0)
+
+ textures.append(texture_floor)
+ textures.append(texture_ceiling)
+ textures = np.concatenate(textures, axis=1)
+
+ triangle_uvs = np.concatenate(triangle_uvs, axis=0)
+
+ mesh = open3d.geometry.TriangleMesh(
+ vertices=open3d.utility.Vector3dVector(vertices),
+ triangles=open3d.utility.Vector3iVector(triangles)
+ )
+ mesh.compute_vertex_normals()
+
+ mesh.texture = open3d.geometry.Image(textures)
+ mesh.triangle_uvs = np.array(triangle_uvs[triangles.reshape(-1), :], dtype=np.float64)
+ return mesh
+
+
+def verify_normal(corner_i, corner_j, delta_height, plane_normal):
+ edge_a = corner_j + delta_height - corner_i
+ edge_b = delta_height
+
+ normal = np.cross(edge_a, edge_b)
+ normal /= np.linalg.norm(normal, ord=2)
+
+ inner_product = normal.dot(plane_normal)
+
+ if inner_product > 1e-8:
+ return False
+ else:
+ return True
+
+
+def visualize_mesh(args):
+ """visualize as water-tight mesh
+ """
+
+ image = cv2.imread(os.path.join(args.path, f"scene_{args.scene:05d}", "2D_rendering",
+ str(args.room), "panorama/full/rgb_rawlight.png"))
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # load room annotations
+ with open(os.path.join(args.path, f"scene_{args.scene:05d}" , "annotation_3d.json")) as f:
+ annos = json.load(f)
+
+ # load camera info
+ camera_center = np.loadtxt(os.path.join(args.path, f"scene_{args.scene:05d}", "2D_rendering",
+ str(args.room), "panorama", "camera_xyz.txt"))
+
+ # parse corners
+ junctions = np.array([item['coordinate'] for item in annos['junctions']])
+ lines_holes = []
+ for semantic in annos['semantics']:
+ if semantic['type'] in ['window', 'door']:
+ for planeID in semantic['planeID']:
+ lines_holes.extend(np.where(np.array(annos['planeLineMatrix'][planeID]))[0].tolist())
+
+ lines_holes = np.unique(lines_holes)
+ _, vertices_holes = np.where(np.array(annos['lineJunctionMatrix'])[lines_holes])
+ vertices_holes = np.unique(vertices_holes)
+
+ # parse annotations
+ walls = dict()
+ walls_normal = dict()
+ for semantic in annos['semantics']:
+ if semantic['ID'] != int(args.room):
+ continue
+
+ # find junctions of ceiling and floor
+ for planeID in semantic['planeID']:
+ plane_anno = annos['planes'][planeID]
+
+ if plane_anno['type'] != 'wall':
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0]
+ lineIDs = np.setdiff1d(lineIDs, lines_holes)
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ wall = convert_lines_to_vertices(junction_pairs)
+ walls[plane_anno['type']] = wall[0]
+
+ # save normal of the vertical walls
+ for planeID in semantic['planeID']:
+ plane_anno = annos['planes'][planeID]
+
+ if plane_anno['type'] == 'wall':
+ lineIDs = np.where(np.array(annos['planeLineMatrix'][planeID]))[0]
+ lineIDs = np.setdiff1d(lineIDs, lines_holes)
+ junction_pairs = [np.where(np.array(annos['lineJunctionMatrix'][lineID]))[0].tolist() for lineID in lineIDs]
+ wall = convert_lines_to_vertices(junction_pairs)
+ walls_normal[tuple(np.intersect1d(wall, walls['floor']))] = plane_anno['normal']
+
+ # we assume that zs of floor equals 0, then the wall height is from the ceiling
+ wall_height = np.mean(junctions[walls['ceiling']], axis=0)[-1]
+ delta_height = np.array([0, 0, wall_height])
+
+ # list of corner index
+ wall_floor = walls['floor']
+
+ corners = [] # 3D coordinate for each wall
+ textures = [] # texture for each wall
+
+ # wall
+ for i, j in zip(wall_floor, np.roll(wall_floor, shift=-1)):
+ corner_i, corner_j = junctions[i], junctions[j]
+
+ flip = verify_normal(corner_i, corner_j, delta_height, walls_normal[tuple(sorted([i, j]))])
+
+ if flip:
+ corner_j, corner_i = corner_i, corner_j
+
+ texture = E2P(image, corner_i, corner_j, wall_height, camera_center)
+
+ corner = np.array([corner_i, corner_i + delta_height, corner_j + delta_height, corner_j])
+
+ corners.append(corner)
+ textures.append(texture)
+
+ # floor and ceiling
+ # the floor/ceiling texture is cropped by the maximum bounding box
+ corner_floor = junctions[wall_floor]
+ corner_min = np.min(corner_floor, axis=0)
+ corner_max = np.max(corner_floor, axis=0)
+ texture_floor = E2P(image, corner_min, corner_max, 0, camera_center, is_wall=False)
+ texture_ceiling = E2P(image, corner_min, corner_max, wall_height, camera_center, is_wall=False)
+
+ # create mesh
+ mesh = create_plane_mesh(corners, corner_floor, textures, texture_floor, texture_ceiling,
+ delta_height, ignore_ceiling=args.ignore_ceiling)
+
+ # visualize mesh
+ open3d.visualization.draw_geometries([mesh])
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Structured3D 2D Layout Visualization")
+ parser.add_argument("--path", required=True,
+ help="dataset path", metavar="DIR")
+ parser.add_argument("--scene", required=True,
+ help="scene id", type=int)
+ parser.add_argument("--room", required=True,
+ help="room id", type=int)
+ parser.add_argument("--ignore_ceiling", action='store_true',
+ help="ignore ceiling for better visualization")
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ visualize_mesh(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..489e52ff85fb48285f8bd822f7d60cb05d2c2240
--- /dev/null
+++ b/train.py
@@ -0,0 +1,291 @@
+import torch
+import torch.nn as nn
+import os
+import time
+import datetime
+import argparse
+from pathlib import Path
+from torch.utils.data import DataLoader
+from arguments import get_args_parser
+from datasets.outdoor_buildings import OutdoorBuildingDataset
+from datasets.s3d_floorplans import S3DFloorplanDataset
+from datasets.data_utils import collate_fn, get_pixel_features
+from models.corner_models import HeatCorner
+from models.edge_models import HeatEdge
+from models.resnet import ResNetBackbone
+from models.loss import CornerCriterion, EdgeCriterion
+from models.corner_to_edge import prepare_edge_data
+import utils.misc as utils
+
+
+def train_one_epoch(image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, data_loader,
+ optimizer,
+ epoch, max_norm, args):
+ backbone.train()
+ corner_model.train()
+ edge_model.train()
+ corner_criterion.train()
+ edge_criterion.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=100, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = args.print_freq
+
+ # get the positional encodings for all pixels
+ pixels, pixel_features = get_pixel_features(image_size)
+ pixel_features = pixel_features.cuda()
+
+ for data in metric_logger.log_every(data_loader, print_freq, header):
+ corner_outputs, corner_loss, corner_recall, s1_logits, s2_logits_hb, s2_logits_rel, s1_losses, s2_losses_hb, \
+ s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel = run_model(
+ data,
+ pixels,
+ pixel_features,
+ backbone,
+ corner_model,
+ edge_model,
+ epoch,
+ corner_criterion,
+ edge_criterion,
+ args)
+
+ loss = s1_losses + s2_losses_hb + s2_losses_rel + corner_loss * args.lambda_corner
+
+ loss_dict = {'loss_e_s1': s1_losses, 'loss_e_s2_hb': s2_losses_hb, 'loss_e_s2_rel': s2_losses_rel,
+ 'edge_acc_s1': s1_acc, 'edge_acc_s2_hb': s2_acc_hb, 'edge_acc_s2_rel': s2_acc_rel,
+ 'loss_c_s1': corner_loss, 'corner_recall': corner_recall}
+ loss_value = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+
+ if max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm)
+ torch.nn.utils.clip_grad_norm_(corner_model.parameters(), max_norm)
+ torch.nn.utils.clip_grad_norm_(edge_model.parameters(), max_norm)
+
+ optimizer.step()
+ metric_logger.update(loss=loss_value, **loss_dict)
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def run_model(data, pixels, pixel_features, backbone, corner_model, edge_model, epoch, corner_criterion, edge_criterion,
+ args):
+ image = data['img'].cuda()
+ annots = data['annot']
+ raw_images = data['raw_img']
+ pixel_labels = data['pixel_labels'].cuda()
+ gauss_labels = data['gauss_labels'].cuda()
+
+ pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
+
+ # get corner preds from corner model
+ image_feats, feat_mask, all_image_feats = backbone(image)
+ preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
+
+ corner_loss_s1, corner_recall = corner_criterion(preds_s1, pixel_labels, gauss_labels, epoch)
+
+ # get edge candidates and corresponding G.T.
+ c_outputs = preds_s1
+ edge_data = prepare_edge_data(c_outputs, annots, raw_images, args.max_corner_num)
+
+ edge_coords = edge_data['edge_coords'].cuda()
+ edge_mask = edge_data['edge_coords_mask'].cuda()
+ edge_lengths = edge_data['edge_coords_lengths'].cuda()
+ edge_labels = edge_data['edge_labels'].cuda()
+ corner_nums = edge_data['processed_corners_lengths']
+
+ # run the edge model
+ max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0)
+ logits_s1, logits_s2_hb, logits_s2_rel, s2_ids, s2_edge_mask, s2_gt_values = edge_model(image_feats, feat_mask,
+ pixel_features,
+ edge_coords, edge_mask,
+ edge_labels,
+ corner_nums,
+ max_candidates)
+
+ s1_losses, s1_acc, s2_losses_hb, s2_acc_hb, s2_losses_rel, s2_acc_rel = edge_criterion(logits_s1, logits_s2_hb,
+ logits_s2_rel, s2_ids,
+ s2_edge_mask,
+ edge_labels, edge_lengths,
+ edge_mask, s2_gt_values)
+
+ return c_outputs, corner_loss_s1, corner_recall, logits_s1, logits_s2_hb, logits_s2_rel, s1_losses, s2_losses_hb, \
+ s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel
+
+
+@torch.no_grad()
+def evaluate(image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, data_loader, epoch,
+ args):
+ backbone.eval()
+ corner_model.eval()
+ edge_model.eval()
+ corner_criterion.eval()
+ edge_criterion.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+
+ pixels, pixel_features = get_pixel_features(image_size)
+ pixel_features = pixel_features.cuda()
+
+ for data in metric_logger.log_every(data_loader, 10, header):
+ c_outputs, corner_loss, corner_recall, s1_logits, \
+ s2_logits_hb, s2_logits_rel, s1_losses, s2_losses_hb, s2_losses_rel, s1_acc, s2_acc_hb, s2_acc_rel = run_model(
+ data,
+ pixels,
+ pixel_features,
+ backbone,
+ corner_model,
+ edge_model,
+ epoch,
+ corner_criterion,
+ edge_criterion,
+ args)
+
+ loss_dict = {'loss_e_s1': s1_losses,
+ 'loss_e_s2_hb': s2_losses_hb,
+ 'loss_e_s2_rel': s2_losses_rel,
+ 'edge_acc_s1': s1_acc,
+ 'edge_acc_s2_hb': s2_acc_hb,
+ 'edge_acc_s2_rel': s2_acc_rel,
+ 'loss_c_s1': corner_loss,
+ 'corner_recall': corner_recall}
+
+ loss = s1_losses + s2_losses_hb + s2_losses_rel + corner_loss * args.lambda_corner
+ loss_value = loss.item()
+ metric_logger.update(loss=loss_value, **loss_dict)
+
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def main():
+ parser = argparse.ArgumentParser('HEAT training', parents=[get_args_parser()])
+ args = parser.parse_args()
+ image_size = args.image_size
+ if args.exp_dataset == 'outdoor':
+ data_path = './data/outdoor/cities_dataset'
+ det_path = './data/outdoor/det_final'
+ train_dataset = OutdoorBuildingDataset(data_path, det_path, phase='train', image_size=image_size, rand_aug=True,
+ inference=False)
+ test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='valid', image_size=image_size, rand_aug=False,
+ inference=False)
+ elif args.exp_dataset == 's3d_floorplan':
+ data_path = './data/s3d_floorplan'
+ train_dataset = S3DFloorplanDataset(data_path, phase='train', rand_aug=True, inference=False)
+ test_dataset = S3DFloorplanDataset(data_path, phase='valid', rand_aug=False, inference=False)
+ else:
+ raise ValueError('Unknown dataset: {}'.format(args.exp_dataset))
+
+ train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
+ collate_fn=collate_fn, drop_last=True)
+ test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=args.num_workers,
+ collate_fn=collate_fn)
+
+ backbone = ResNetBackbone()
+ strides = backbone.strides
+ num_channels = backbone.num_channels
+
+ corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
+ backbone_num_channels=num_channels)
+ backbone = nn.DataParallel(backbone)
+ backbone = backbone.cuda()
+ corner_model = nn.DataParallel(corner_model)
+ corner_model = corner_model.cuda()
+
+ edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
+ backbone_num_channels=num_channels)
+ edge_model = nn.DataParallel(edge_model)
+ edge_model = edge_model.cuda()
+
+ corner_criterion = CornerCriterion(image_size=image_size)
+ edge_criterion = EdgeCriterion()
+
+ backbone_params = [p for p in backbone.parameters()]
+ corner_params = [p for p in corner_model.parameters()]
+ edge_params = [p for p in edge_model.parameters()]
+
+ all_params = corner_params + edge_params + backbone_params
+ optimizer = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=args.weight_decay)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
+ start_epoch = args.start_epoch
+
+ if args.resume:
+ ckpt = torch.load(args.resume)
+ backbone.load_state_dict(ckpt['backbone'])
+ corner_model.load_state_dict(ckpt['corner_model'])
+ edge_model.load_state_dict(ckpt['edge_model'])
+ optimizer.load_state_dict(ckpt['optimizer'])
+ lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
+ lr_scheduler.step_size = args.lr_drop
+
+ print('Resume from ckpt file {}, starting from epoch {}'.format(args.resume, ckpt['epoch']))
+ start_epoch = ckpt['epoch'] + 1
+
+ n_backbone_parameters = sum(p.numel() for p in backbone_params if p.requires_grad)
+ n_corner_parameters = sum(p.numel() for p in corner_params if p.requires_grad)
+ n_edge_parameters = sum(p.numel() for p in edge_params if p.requires_grad)
+ n_all_parameters = sum(p.numel() for p in all_params if p.requires_grad)
+ print('number of trainable backbone params:', n_backbone_parameters)
+ print('number of trainable corner params:', n_corner_parameters)
+ print('number of trainable edge params:', n_edge_parameters)
+ print('number of all trainable params:', n_all_parameters)
+
+ print("Start training")
+ start_time = time.time()
+
+ output_dir = Path(args.output_dir)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ best_acc = 0
+ for epoch in range(start_epoch, args.epochs):
+ train_stats = train_one_epoch(
+ image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, train_dataloader,
+ optimizer,
+ epoch, args.clip_max_norm, args)
+ lr_scheduler.step()
+
+ if args.run_validation:
+ val_stats = evaluate(
+ image_size, backbone, corner_model, edge_model, corner_criterion, edge_criterion, test_dataloader,
+ epoch, args
+ )
+
+ val_acc = (val_stats['edge_acc_s1'] + val_stats['edge_acc_s2_hb']) / 2
+ if val_acc > best_acc:
+ is_best = True
+ best_acc = val_acc
+ else:
+ is_best = False
+ else:
+ val_acc = 0
+ is_best = False
+
+ if args.output_dir:
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
+ if is_best:
+ checkpoint_paths.append(output_dir / 'checkpoint_best.pth')
+
+ for checkpoint_path in checkpoint_paths:
+ torch.save({
+ 'backbone': backbone.state_dict(),
+ 'corner_model': corner_model.state_dict(),
+ 'edge_model': edge_model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'args': args,
+ 'val_acc': val_acc,
+ }, checkpoint_path)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/geometry_utils.py b/utils/geometry_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e691a36e4b21c706a5485810b838b11afd33829a
--- /dev/null
+++ b/utils/geometry_utils.py
@@ -0,0 +1,167 @@
+import torch
+import numpy as np
+import cv2
+
+
+def building_metric(logits, label):
+ preds = torch.argmax(logits, dim=-1)
+ true_ids = torch.where(label==1)
+ num_true = true_ids[0].shape[0]
+ tp = (preds[true_ids] == 1).sum().double()
+ recall = tp / num_true
+ prec = tp / (preds == 1).sum()
+ fscore = 2 * recall * prec / (prec + recall)
+ return recall, prec, fscore
+
+
+def edge_acc(logits, label, lengths, gt_values):
+ """
+ edge f1-score for training/validation logging
+ """
+ all_acc = list()
+ for i in range(logits.shape[0]):
+ length = lengths[i]
+ gt_value = gt_values[i, :length]
+ pred_idx = torch.where(gt_value == 2)
+ if len(pred_idx[0]) == 0:
+ continue
+ else:
+ preds = torch.argmax(logits[i, :, :length][:, pred_idx[0]], dim=0)
+ gts = label[i, :length][pred_idx[0]]
+ pos_ids = torch.where(gts == 1)
+ correct = (preds[pos_ids] == gts[pos_ids]).sum().float()
+ num_pos_gt = len(pos_ids[0])
+ recall = correct / num_pos_gt if num_pos_gt > 0 else torch.tensor(0)
+ num_pos_pred = (preds == 1).sum().float()
+ prec = correct / num_pos_pred if num_pos_pred > 0 else torch.tensor(0)
+ f_score = 2.0 * prec * recall / (recall + prec + 1e-8)
+ f_score = f_score.cpu()
+ all_acc.append(f_score)
+ if len(all_acc) > 1:
+ all_acc = torch.stack(all_acc, 0)
+ avg_acc = all_acc.mean()
+ else:
+ avg_acc = all_acc[0]
+ return avg_acc
+
+
+def corner_eval(targets, outputs):
+ assert isinstance(targets, np.ndarray)
+ assert isinstance(outputs, np.ndarray)
+ output_to_gt = dict()
+ gt_to_output = dict()
+ for target_i, target in enumerate(targets):
+ dist = (outputs - target) ** 2
+ dist = np.sqrt(dist.sum(axis=-1))
+ min_dist = dist.min()
+ min_idx = dist.argmin()
+ if min_dist < 5 and min_idx not in output_to_gt: # a positive match
+ output_to_gt[min_idx] = target_i
+ gt_to_output[target_i] = min_idx
+ tp = len(output_to_gt)
+ prec = tp / len(outputs)
+ recall = tp / len(targets)
+ return prec, recall
+
+
+def rectify_data(image, annot):
+ rows, cols, ch = image.shape
+ bins = [0 for _ in range(180)] # 5 degree per bin
+ # edges vote for directions
+
+ gauss_weights = [0.1, 0.2, 0.5, 1, 0.5, 0.2, 0.1]
+
+ for src, connections in annot.items():
+ for end in connections:
+ edge = [(end[0] - src[0]), -(end[1] - src[1])]
+ edge_len = np.sqrt(edge[0] ** 2 + edge[1] ** 2)
+ if edge_len <= 10: # skip too short edges
+ continue
+ if edge[0] == 0:
+ bin_id = 90
+ else:
+ theta = np.arctan(edge[1] / edge[0]) / np.pi * 180
+ if edge[0] * edge[1] < 0:
+ theta += 180
+ bin_id = int(theta.round())
+ if bin_id == 180:
+ bin_id = 0
+ for offset in range(-3, 4):
+ bin_idx = bin_id + offset
+ if bin_idx >= 180:
+ bin_idx -= 180
+ bins[bin_idx] += np.sqrt(edge[1] ** 2 + edge[0] ** 2) * gauss_weights[offset + 2]
+
+ bins = np.array(bins)
+ sorted_ids = np.argsort(bins)[::-1]
+ bin_1 = sorted_ids[0]
+ remained_ids = [idx for idx in sorted_ids if angle_dist(bin_1, idx) >= 30]
+ bin_2 = remained_ids[0]
+ if bin_1 < bin_2:
+ bin_1, bin_2 = bin_2, bin_1
+
+ dir_1, dir_2 = bin_1, bin_2
+ # compute the affine parameters, and apply affine transform to the image
+ origin = [127, 127]
+ p1_old = [127 + 100 * np.cos(dir_1 / 180 * np.pi), 127 - 100 * np.sin(dir_1 / 180 * np.pi)]
+ p2_old = [127 + 100 * np.cos(dir_2 / 180 * np.pi), 127 - 100 * np.sin(dir_2 / 180 * np.pi)]
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
+ p1_new = [127, 27] # y_axis
+ p2_new = [227, 127] # x_axis
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
+
+ M1 = cv2.getAffineTransform(pts1, pts2)
+
+ all_corners = list(annot.keys())
+ all_corners_ = np.array(all_corners)
+ ones = np.ones([all_corners_.shape[0], 1])
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
+ new_corners = np.matmul(M1, all_corners_.T).T
+
+ M = np.concatenate([M1, np.array([[0, 0, 1]])], axis=0)
+
+ x_max = new_corners[:, 0].max()
+ x_min = new_corners[:, 0].min()
+ y_max = new_corners[:, 1].max()
+ y_min = new_corners[:, 1].min()
+
+ side_x = (x_max - x_min) * 0.1
+ side_y = (y_max - y_min) * 0.1
+ right_border = x_max + side_x
+ left_border = x_min - side_x
+ bot_border = y_max + side_y
+ top_border = y_min - side_y
+ pts1 = np.array([[left_border, top_border], [right_border, top_border], [right_border, bot_border]]).astype(
+ np.float32)
+ pts2 = np.array([[5, 5], [250, 5], [250, 250]]).astype(np.float32)
+ M_scale = cv2.getAffineTransform(pts1, pts2)
+
+ M = np.matmul(np.concatenate([M_scale, np.array([[0, 0, 1]])], axis=0), M)
+
+ new_image = cv2.warpAffine(image, M[:2, :], (cols, rows), borderValue=(255, 255, 255))
+ all_corners_ = np.concatenate([all_corners, ones], axis=-1)
+ new_corners = np.matmul(M[:2, :], all_corners_.T).T
+
+ corner_mapping = dict()
+ for idx, corner in enumerate(all_corners):
+ corner_mapping[corner] = new_corners[idx]
+
+ new_annot = dict()
+ for corner, connections in annot.items():
+ new_corner = corner_mapping[corner]
+ tuple_new_corner = tuple(new_corner)
+ new_annot[tuple_new_corner] = list()
+ for to_corner in connections:
+ new_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
+
+ # do the affine transform
+ return new_image, new_annot, M
+
+
+def angle_dist(a1, a2):
+ if a1 > a2:
+ a1, a2 = a2, a1
+ d1 = a2 - a1
+ d2 = a1 + 180 - a2
+ dist = min(d1, d2)
+ return dist
diff --git a/utils/image_utils.py b/utils/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b78b85867bfa74a063b021b3c54847598ced68
--- /dev/null
+++ b/utils/image_utils.py
@@ -0,0 +1,237 @@
+import math
+from io import BytesIO
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import skimage.draw
+from PIL import Image
+
+
+def get_image_size(filepath):
+ im = Image.open(filepath)
+ return im.size
+
+
+def load_image(image_filepath):
+ image = Image.open(image_filepath)
+ image.load()
+ image_array = np.array(image, dtype=np.uint8)
+ image.close()
+ return image_array
+
+
+def padded_boundingbox(boundingbox, padding):
+ boundingbox_new = np.empty_like(boundingbox)
+ boundingbox_new[0:2] = boundingbox[0:2] + padding
+ boundingbox_new[2:4] = boundingbox[2:4] - padding
+ return boundingbox_new
+
+
+def center_bbox(spatial_shape, output_shape):
+ """
+ Return a bbox centered in spatial_shape with size output_shape
+
+ :param spatial_shape:
+ :param output_shape:
+ :return:
+ """
+ center = (spatial_shape[0] / 2, spatial_shape[1] / 2)
+ half_output_shape = (output_shape[0] / 2, output_shape[1] / 2)
+ bbox = [center[0] - half_output_shape[0], center[1] - half_output_shape[1], center[0] + half_output_shape[0], center[1] + half_output_shape[1]]
+ bbox = bbox_to_int(bbox)
+ return bbox
+
+
+def bbox_add_margin(bbox, margin):
+ bbox_new = bbox.copy()
+ bbox_new[0:2] -= margin
+ bbox_new[2:4] += margin
+ return bbox_new
+
+
+def bbox_to_int(bbox):
+ bbox_new = [
+ int(np.floor(bbox[0])),
+ int(np.floor(bbox[1])),
+ int(np.ceil(bbox[2])),
+ int(np.ceil(bbox[3])),
+ ]
+ return bbox_new
+
+
+def draw_line_aa_in_patch(edge, patch_bounds):
+ rr, cc, prob = skimage.draw.line_aa(edge[0][0], edge[0][1], edge[1][0], edge[1][1])
+ keep_mask = (patch_bounds[0] <= rr) & (rr < patch_bounds[2]) \
+ & (patch_bounds[1] <= cc) & (cc < patch_bounds[3])
+ rr = rr[keep_mask]
+ cc = cc[keep_mask]
+ prob = prob[keep_mask]
+ return rr, cc, prob
+
+
+def convert_array_to_jpg_bytes(image_array, mode=None):
+ img = Image.fromarray(image_array, mode=mode)
+ output = BytesIO()
+ img.save(output, format="JPEG", quality=90)
+ contents = output.getvalue()
+ output.close()
+ return contents
+
+
+def displacement_map_to_transformation_maps(disp_field_map):
+ disp_field_map = disp_field_map.astype(np.float32)
+ i = np.arange(disp_field_map.shape[0], dtype=np.float32)
+ j = np.arange(disp_field_map.shape[1], dtype=np.float32)
+ iv, jv = np.meshgrid(i, j, indexing="ij")
+ reverse_map_i = iv + disp_field_map[:, :, 1]
+ reverse_map_j = jv + disp_field_map[:, :, 0]
+ return reverse_map_i, reverse_map_j
+
+def apply_displacement_field_to_image(image, disp_field_map):
+ trans_map_i, trans_map_j = displacement_map_to_transformation_maps(disp_field_map)
+ misaligned_image = cv2.remap(image, trans_map_j, trans_map_i, cv2.INTER_CUBIC)
+ return misaligned_image
+
+
+def apply_displacement_fields_to_image(image, disp_field_maps):
+ disp_field_map_count = disp_field_maps.shape[0]
+ misaligned_image_list = []
+ for i in range(disp_field_map_count):
+ misaligned_image = apply_displacement_field_to_image(image, disp_field_maps[i, :, :, :])
+ misaligned_image_list.append(misaligned_image)
+ return misaligned_image_list
+
+
+def get_axis_patch_count(length, stride, patch_res):
+ total_double_padding = patch_res - stride
+ patch_count = max(1, int(math.ceil((length - total_double_padding) / stride)))
+ return patch_count
+
+
+def compute_patch_boundingboxes(image_size, stride, patch_res):
+ """
+
+ @param image_size:
+ @param stride:
+ @param patch_res:
+ @return: [[row_start, col_start, row_end, col_end], ...]
+ """
+ im_rows = image_size[0]
+ im_cols = image_size[1]
+
+ row_patch_count = get_axis_patch_count(im_rows, stride, patch_res)
+ col_patch_count = get_axis_patch_count(im_cols, stride, patch_res)
+
+ patch_boundingboxes = []
+ for i in range(0, row_patch_count):
+ if i < row_patch_count - 1:
+ row_slice_begin = i * stride
+ row_slice_end = row_slice_begin + patch_res
+ else:
+ row_slice_end = im_rows
+ row_slice_begin = row_slice_end - patch_res
+ for j in range(0, col_patch_count):
+ if j < col_patch_count - 1:
+ col_slice_begin = j*stride
+ col_slice_end = col_slice_begin + patch_res
+ else:
+ col_slice_end = im_cols
+ col_slice_begin = col_slice_end - patch_res
+
+ patch_boundingbox = np.array([row_slice_begin, col_slice_begin, row_slice_end, col_slice_end], dtype=np.int32)
+ assert row_slice_end - row_slice_begin == col_slice_end - col_slice_begin == patch_res, "ERROR: patch does not have the requested shape"
+ patch_boundingboxes.append(patch_boundingbox)
+
+ return patch_boundingboxes
+
+
+def clip_boundingbox(boundingbox, clip_list):
+ assert len(boundingbox) == len(clip_list), "len(boundingbox) should be equal to len(clip_values)"
+ clipped_boundingbox = []
+ for bb_value, clip in zip(boundingbox[:2], clip_list[:2]):
+ clipped_value = max(clip, bb_value)
+ clipped_boundingbox.append(clipped_value)
+ for bb_value, clip in zip(boundingbox[2:], clip_list[2:]):
+ clipped_value = min(clip, bb_value)
+ clipped_boundingbox.append(clipped_value)
+ return clipped_boundingbox
+
+
+def crop_or_pad_image_with_boundingbox(image, patch_boundingbox):
+ im_rows = image.shape[0]
+ im_cols = image.shape[1]
+
+ row_padding_before = max(0, - patch_boundingbox[0])
+ col_padding_before = max(0, - patch_boundingbox[1])
+ row_padding_after = max(0, patch_boundingbox[2] - im_rows)
+ col_padding_after = max(0, patch_boundingbox[3] - im_cols)
+
+ # Center padding:
+ row_padding = row_padding_before + row_padding_after
+ col_padding = col_padding_before + col_padding_after
+ row_padding_before = row_padding // 2
+ col_padding_before = col_padding // 2
+ row_padding_after = row_padding - row_padding // 2
+ col_padding_after = col_padding - col_padding // 2
+
+ clipped_patch_boundingbox = clip_boundingbox(patch_boundingbox, [0, 0, im_rows, im_cols])
+
+ if len(image.shape) == 2:
+ patch = image[clipped_patch_boundingbox[0]:clipped_patch_boundingbox[2], clipped_patch_boundingbox[1]:clipped_patch_boundingbox[3]]
+ patch = np.pad(patch, [(row_padding_before, row_padding_after), (col_padding_before, col_padding_after)], mode="constant")
+ elif len(image.shape) == 3:
+ patch = image[clipped_patch_boundingbox[0]:clipped_patch_boundingbox[2], clipped_patch_boundingbox[1]:clipped_patch_boundingbox[3], :]
+ patch = np.pad(patch, [(row_padding_before, row_padding_after), (col_padding_before, col_padding_after), (0, 0)], mode="constant")
+ else:
+ print("Image input does not have the right shape/")
+ patch = None
+ return patch
+
+
+def make_grid(images, padding=2, pad_value=0, return_offsets=False):
+ nmaps = images.shape[0]
+ ymaps = int(math.floor(math.sqrt(nmaps)))
+ xmaps = nmaps // ymaps
+ height, width = int(images.shape[1] + padding), int(images.shape[2] + padding)
+ grid = np.zeros((height * ymaps + padding, width * xmaps + padding, images.shape[3])) + pad_value
+ k = 0
+ offsets = []
+ for y in range(ymaps):
+ for x in range(xmaps):
+ if k >= nmaps:
+ break
+ x_offset = x * width + padding
+ y_offset = y * height + padding
+ grid[y * height + padding:(y+1) * height, x * width + padding:(x+1) * width, :] = images[k]
+ offsets.append((x_offset, y_offset))
+ k = k + 1
+ if return_offsets:
+ return grid, offsets
+ else:
+ return grid
+
+
+if __name__ == "__main__":
+ im_rows = 5
+ im_cols = 10
+ stride = 1
+ patch_res = 15
+
+ image = np.random.randint(0, 256, size=(im_rows, im_cols, 3), dtype=np.uint8)
+ image = Image.fromarray(image)
+ image = np.array(image)
+ plt.ion()
+ plt.figure(1)
+ plt.imshow(image)
+ plt.show()
+
+ # Cut patches
+ patch_boundingboxes = compute_patch_boundingboxes(image.shape[0:2], stride, patch_res)
+
+ plt.figure(2)
+
+ for patch_boundingbox in patch_boundingboxes:
+ patch = crop_or_pad_image_with_boundingbox(image, patch_boundingbox)
+ plt.imshow(patch)
+ plt.show()
+ input("Press to finish...")
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d695d362e9c33dfdff71b3cb196e34af1e53c6b5
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,198 @@
+import torch
+import time
+from collections import defaultdict, deque
+import datetime
+from typing import Optional, List
+from torch import Tensor
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=100, fmt=None):
+ if fmt is None:
+ fmt = "{median:.3f} ({global_avg:.3f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ #return self.deque[-1]
+ return self.avg
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, length_total=None):
+ i = 0
+ if length_total is None:
+ length_total = len(iterable)
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(length_total))) + 'd'
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ else:
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == length_total - 1:
+ eta_seconds = iter_time.global_avg * (length_total - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ try:
+ print(log_msg.format(
+ i, length_total, eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ except Exception as e:
+ import pdb; pdb.set_trace()
+ else:
+ print(log_msg.format(
+ i, length_total, eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / length_total))
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device, non_blocking=False):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device, non_blocking=non_blocking)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def record_stream(self, *args, **kwargs):
+ self.tensors.record_stream(*args, **kwargs)
+ if self.mask is not None:
+ self.mask.record_stream(*args, **kwargs)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
diff --git a/utils/nn_utils.py b/utils/nn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa0fe83cb3b6849d170609f21f4607ffacc217e0
--- /dev/null
+++ b/utils/nn_utils.py
@@ -0,0 +1,46 @@
+import torch
+import math
+
+
+def positional_encoding_2d(d_model, height, width):
+ """
+ :param d_model: dimension of the model
+ :param height: height of the positions
+ :param width: width of the positions
+ :return: d_model*height*width position matrix
+ """
+ if d_model % 4 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dimension (got dim={:d})".format(d_model))
+ pe = torch.zeros(d_model, height, width)
+ # Each dimension use half of d_model
+ d_model = int(d_model / 2)
+ div_term = torch.exp(torch.arange(0., d_model, 2) *
+ -(math.log(10000.0) / d_model))
+ pos_w = torch.arange(0., width).unsqueeze(1)
+ pos_h = torch.arange(0., height).unsqueeze(1)
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
+
+ return pe
+
+
+def positional_encoding_1d(d_model, length):
+ """
+ :param d_model: dimension of the model
+ :param length: length of positions
+ :return: length*d_model position matrix
+ """
+ if d_model % 2 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dim (got dim={:d})".format(d_model))
+ pe = torch.zeros(length, d_model)
+ position = torch.arange(0, length).unsqueeze(1)
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
+ -(math.log(10000.0) / d_model)))
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
+
+ return pe