Egrt commited on
Commit
424188c
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +18 -0
  2. HEAT.py +460 -0
  3. LICENSE +674 -0
  4. README.md +21 -0
  5. app.py +33 -0
  6. arguments.py +33 -0
  7. assets/img/pipeline.png +0 -0
  8. assets/img/problem_description.png +0 -0
  9. datasets/__init__.py +0 -0
  10. datasets/corners.py +183 -0
  11. datasets/data_utils.py +57 -0
  12. datasets/outdoor_buildings.py +183 -0
  13. datasets/s3d_floorplans.py +187 -0
  14. images/test.jpg +0 -0
  15. infer.py +455 -0
  16. metrics/get_metric.py +219 -0
  17. metrics/new_utils.py +2100 -0
  18. models/__init__.py +0 -0
  19. models/corner_models.py +275 -0
  20. models/corner_to_edge.py +232 -0
  21. models/deformable_transformer.py +236 -0
  22. models/edge_models.py +314 -0
  23. models/loss.py +63 -0
  24. models/mlp.py +21 -0
  25. models/ops/functions/__init__.py +10 -0
  26. models/ops/functions/ms_deform_attn_func.py +61 -0
  27. models/ops/make.sh +10 -0
  28. models/ops/modules/__init__.py +9 -0
  29. models/ops/modules/ms_deform_attn.py +115 -0
  30. models/ops/setup.py +71 -0
  31. models/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
  32. models/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
  33. models/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
  34. models/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
  35. models/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  36. models/ops/src/ms_deform_attn.h +62 -0
  37. models/ops/src/vision.cpp +16 -0
  38. models/ops/test.py +89 -0
  39. models/resnet.py +167 -0
  40. models/stacked_hg.py +246 -0
  41. predict.py +33 -0
  42. qualitative_outdoor/generate_html.py +64 -0
  43. qualitative_outdoor/plot_utils.py +43 -0
  44. qualitative_outdoor/visualize_gt.py +46 -0
  45. qualitative_outdoor/visualize_npy.py +46 -0
  46. requirements.txt +27 -0
  47. s3d_floorplan_eval/DataRW/DataRW.py +4 -0
  48. s3d_floorplan_eval/DataRW/S3DRW.py +142 -0
  49. s3d_floorplan_eval/DataRW/wrong_annotatios.py +1 -0
  50. s3d_floorplan_eval/Evaluator/Evaluator.py +457 -0
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ./data
2
+ .DS_Store
3
+ viz/*
4
+ *.tar
5
+ *.pdf
6
+ *.zip
7
+ svg_*
8
+ *.html
9
+ models/ops/build
10
+ models/ops/dist
11
+ models/ops/*egg-info
12
+ __pycache__
13
+ results
14
+ montefloor_data
15
+ .idea
16
+ model_data/checkpoints
17
+ model_data/heat_checkpoints
18
+ shpfile/
HEAT.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: [egrt]
3
+ Date: 2022-08-23 11:44:15
4
+ LastEditors: Egrt
5
+ LastEditTime: 2022-11-23 15:25:35
6
+ Description: HEAT的模型加载与预测
7
+ '''
8
+ from turtle import pos
9
+ import torch
10
+ import torch.nn as nn
11
+ from models.resnet import ResNetBackbone
12
+ from models.corner_models import HeatCorner
13
+ from models.edge_models import HeatEdge
14
+ from models.corner_to_edge import get_infer_edge_pairs
15
+ from datasets.data_utils import get_pixel_features
16
+ from huggingface_hub import hf_hub_download
17
+ from PIL import Image
18
+ from utils import image_utils
19
+ from osgeo import gdal, ogr, osr
20
+ from tqdm import tqdm
21
+ import os
22
+ import scipy
23
+ import numpy as np
24
+ import cv2
25
+ import skimage
26
+
27
+ class HEAT(object):
28
+ #-----------------------------------------#
29
+ # 注意修改model_path
30
+ #-----------------------------------------#
31
+ _defaults = {
32
+ #-----------------------------------------------#
33
+ # model_data指向整体网络的地址
34
+ #-----------------------------------------------#
35
+ "model_data" : 'model_data/heat_checkpoints/checkpoints/ckpts_heat_outdoor_256/checkpoint.pth',
36
+ #-----------------------------------------------#
37
+ # image_size模型预测图像的像素大小
38
+ #-----------------------------------------------#
39
+ "image_size" : [256, 256],
40
+ #-----------------------------------------------#
41
+ # patch_size为模型切片的大小
42
+ #-----------------------------------------------#
43
+ "patch_size" : 512,
44
+ #-----------------------------------------------#
45
+ # patch_overlap为切片重叠像素
46
+ #-----------------------------------------------#
47
+ "patch_overlap" : 0,
48
+ #-----------------------------------------------#
49
+ # corner_thresh为预测角点的阈值大小
50
+ #-----------------------------------------------#
51
+ "corner_thresh" : 0.01,
52
+ #-----------------------------------------------#
53
+ # 基于角点候选数的最大边数(不能大于6)
54
+ #-----------------------------------------------#
55
+ "corner_to_edge_multiplier": 3,
56
+ #-----------------------------------------------#
57
+ # 边缘推理筛选的迭代次数
58
+ #-----------------------------------------------#
59
+ "infer_times" : 3,
60
+ #-------------------------------#
61
+ # 是否使用Cuda
62
+ # 没有GPU可以设置成False
63
+ #-------------------------------#
64
+ "cuda" : False,
65
+ }
66
+
67
+ #---------------------------------------------------#
68
+ # 初始化MASKGAN
69
+ #---------------------------------------------------#
70
+ def __init__(self, **kwargs):
71
+ self.__dict__.update(self._defaults)
72
+ for name, value in kwargs.items():
73
+ setattr(self, name, value)
74
+ self.generate()
75
+
76
+ def generate(self):
77
+ # 从Huggingface加载整体网络模型
78
+ filepath = hf_hub_download(repo_id="Egrt/HEAT", filename="checkpoint.pth")
79
+ self.model = torch.load(filepath)
80
+ # 加载Backbone
81
+ self.backbone = ResNetBackbone()
82
+ strides = self.backbone.strides
83
+ num_channels = self.backbone.num_channels
84
+ self.backbone = nn.DataParallel(self.backbone)
85
+ self.backbone = self.backbone.cuda()
86
+ self.backbone.eval()
87
+ # 加载角点检测模型
88
+ self.corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
89
+ backbone_num_channels=num_channels)
90
+ self.corner_model = nn.DataParallel(self.corner_model)
91
+ self.corner_model = self.corner_model.cuda()
92
+ self.corner_model.eval()
93
+ # 加载边缘检测模型
94
+ self.edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
95
+ backbone_num_channels=num_channels)
96
+ self.edge_model = nn.DataParallel(self.edge_model)
97
+ self.edge_model = self.edge_model.cuda()
98
+ self.edge_model.eval()
99
+ # 分别加载模型的地址
100
+ self.backbone.load_state_dict(self.model['backbone'])
101
+ self.corner_model.load_state_dict(self.model['corner_model'])
102
+ self.edge_model.load_state_dict(self.model['edge_model'])
103
+
104
+ def detect_one_image(self, image):
105
+ #---------------------------------------------------------#
106
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
107
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
108
+ #---------------------------------------------------------#
109
+ image = cvtColor(image)
110
+ # 这里判断图片是否需要分成多个patch
111
+ if image.size[0] < self.patch_size or image.size[1] < self.patch_size:
112
+ is_slice = False
113
+ else:
114
+ is_slice = True
115
+ if is_slice:
116
+ # 复制原图
117
+ image = np.array(image, dtype=np.uint8)
118
+ # 复制输入的原图
119
+ viz_image = image.copy()
120
+ height, width = image.shape[0], image.shape[1]
121
+ # 获取缩放比例
122
+ scale = self.patch_size / self.image_size[0]
123
+ # 初始化角点、边缘列表
124
+ pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = [], [], [], [], []
125
+ # 开始切分
126
+ stride = self.patch_size - self.patch_overlap
127
+ patch_boundingboxes = image_utils.compute_patch_boundingboxes((height, width),
128
+ stride=stride,
129
+ patch_res=self.patch_size)
130
+ edge_len = 0
131
+ # 获取切分后的图片
132
+ for bbox in tqdm(patch_boundingboxes, desc="使用切分进行预测", leave=False):
133
+ # 切分图像
134
+ crop_image = image[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
135
+ # np转Image类
136
+ crop_image = Image.fromarray(crop_image)
137
+ try:
138
+ pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, _ = self.predict_no_patching(crop_image)
139
+ except RuntimeError as e:
140
+ print("ERROR: " + str(e))
141
+ print("INFO: 减小patch_size 直到适合内存")
142
+ raise e
143
+ # 拼接角点数组
144
+ pred_corners[:, 0] = pred_corners[:, 0] * scale + bbox[0]
145
+ pred_corners[:, 1] = pred_corners[:, 1] * scale + bbox[1]
146
+ pred_corners_viz = pred_corners
147
+ viz_image = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges,
148
+ edge_confs=edge_confs, shpfile=False)
149
+
150
+ hr_image = Image.fromarray(np.uint8(viz_image))
151
+ else:
152
+ pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image = self.predict_no_patching(image)
153
+ #---------------------------------------------------------#
154
+ # 此处推理结束
155
+ # 开始在原图上根据角点坐标绘制角点与边缘
156
+ #---------------------------------------------------------#
157
+ pred_corners_viz = pred_corners
158
+ image_result = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges,
159
+ edge_confs=edge_confs, shpfile=True)
160
+ hr_image = Image.fromarray(np.uint8(image_result))
161
+ return hr_image
162
+
163
+ #---------------------------------------------------------#
164
+ # 不使用切片预测图像
165
+ # 返回预测后的角点坐标、边缘
166
+ #---------------------------------------------------------#
167
+ def predict_no_patching(self, image):
168
+ image = image.resize(tuple(self.image_size), Image.BICUBIC)
169
+ # 将Image类转换为numpy
170
+ image = np.array(image, dtype=np.uint8)
171
+ # 复制输入的原图
172
+ viz_image = image.copy()
173
+ # preprocess image numpy->tensor
174
+ image = process_image(image)
175
+ # 获取所有像素的位置编码, 默认的图像尺度为256
176
+ pixels, pixel_features = get_pixel_features(image_size=self.image_size[0])
177
+ # 开始模型的预测
178
+ with torch.no_grad():
179
+
180
+ image_feats, feat_mask, all_image_feats = self.backbone(image)
181
+ pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
182
+ preds_s1 = self.corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
183
+
184
+ c_outputs = preds_s1
185
+ # 获取预测出的角点
186
+ c_outputs_np = c_outputs[0].detach().cpu().numpy()
187
+ # 筛选出大于阈值的角点的坐标
188
+ pos_indices = np.where(c_outputs_np >= self.corner_thresh)
189
+ pred_corners = pixels[pos_indices]
190
+ # 获取对应预测角点的置信度
191
+ pred_confs = c_outputs_np[pos_indices]
192
+ # 根据预测角点的置信度进行非极大抑制
193
+ pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
194
+ # 对角点两两排列组合,获取所有的角点对
195
+ pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
196
+ # 获取角点数量
197
+ corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
198
+ max_candidates = torch.stack([corner_nums.max() * self.corner_to_edge_multiplier] * len(corner_nums), dim=0)
199
+ # 无序不重复集合
200
+ all_pos_ids = set()
201
+ # 边缘置信度字典
202
+ all_edge_confs = dict()
203
+ # 推理的迭代次数为3次
204
+ for tt in range(self.infer_times):
205
+ if tt == 0:
206
+ # gt_values和边缘掩膜大小���样且初始值为0
207
+ gt_values = torch.zeros_like(edge_mask).long()
208
+ # 第一二维度的数值设置为2
209
+ gt_values[:, :] = 2
210
+
211
+ # 开始预测边缘
212
+ s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = self.edge_model(image_feats,
213
+ feat_mask,pixel_features,edge_coords, edge_mask,gt_values, corner_nums,max_candidates,True)
214
+ num_total = s1_logits.shape[2]
215
+ num_selected = selected_ids.shape[1]
216
+ num_filtered = num_total - num_selected
217
+ # 将输出值固定为(0,1)之间的概率分布
218
+ s1_preds = s1_logits.squeeze().softmax(0)
219
+ s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
220
+ s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
221
+ s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
222
+ s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
223
+ s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
224
+
225
+ selected_ids = selected_ids.squeeze().detach().cpu().numpy()
226
+ # 进行筛选,将(0.9, 1)之间的设置为T,将(0.01,0.9)之间的设置为U,(0,0.01)之间的设置为F
227
+ if tt != self.infer_times - 1:
228
+ s2_preds_np = s2_preds_hb_np
229
+
230
+ pos_edge_ids = np.where(s2_preds_np >= 0.9)
231
+ neg_edge_ids = np.where(s2_preds_np <= 0.01)
232
+ for pos_id in pos_edge_ids[0]:
233
+ actual_id = selected_ids[pos_id]
234
+ if gt_values[0, actual_id] != 2:
235
+ continue
236
+ all_pos_ids.add(actual_id)
237
+ all_edge_confs[actual_id] = s2_preds_np[pos_id]
238
+ gt_values[0, actual_id] = 1
239
+ for neg_id in neg_edge_ids[0]:
240
+ actual_id = selected_ids[neg_id]
241
+ if gt_values[0, actual_id] != 2:
242
+ continue
243
+ gt_values[0, actual_id] = 0
244
+ num_to_pred = (gt_values == 2).sum()
245
+ if num_to_pred <= num_filtered:
246
+ break
247
+ else:
248
+ s2_preds_np = s2_preds_hb_np
249
+
250
+ pos_edge_ids = np.where(s2_preds_np >= 0.5)
251
+ for pos_id in pos_edge_ids[0]:
252
+ actual_id = selected_ids[pos_id]
253
+ if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
254
+ continue
255
+ all_pos_ids.add(actual_id)
256
+ all_edge_confs[actual_id] = s2_preds_np[pos_id]
257
+ pos_edge_ids = list(all_pos_ids)
258
+ edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
259
+ pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
260
+ edge_confs = np.array(edge_confs)
261
+
262
+ if self.image_size[0] != 256:
263
+ pred_corners = pred_corners / (self.image_size[0] / 256)
264
+
265
+ return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image
266
+ #---------------------------------------------------------#
267
+ # 将图像转换成RGB图像,防止灰度图在预测时报错。
268
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
269
+ #---------------------------------------------------------#
270
+ def cvtColor(image):
271
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
272
+ return image
273
+ else:
274
+ image = image.convert('RGB')
275
+ return image
276
+ #---------------------------------------------------------#
277
+ # 根据角点的置信度排序,并筛选出大于置信度的角点坐标
278
+ #---------------------------------------------------------#
279
+ def corner_nms(preds, confs, image_size):
280
+ data = np.zeros([image_size, image_size])
281
+ neighborhood_size = 5
282
+ threshold = 0
283
+
284
+ for i in range(len(preds)):
285
+ data[preds[i, 1], preds[i, 0]] = confs[i]
286
+
287
+ data_max = scipy.ndimage.filters.maximum_filter(data, neighborhood_size)
288
+ maxima = (data == data_max)
289
+ data_min = scipy.ndimage.filters.minimum_filter(data, neighborhood_size)
290
+ diff = ((data_max - data_min) > threshold)
291
+ maxima[diff == 0] = 0
292
+
293
+ results = np.where(maxima > 0)
294
+ filtered_preds = np.stack([results[1], results[0]], axis=-1)
295
+
296
+ new_confs = list()
297
+ for i, pred in enumerate(filtered_preds):
298
+ new_confs.append(data[pred[1], pred[0]])
299
+ new_confs = np.array(new_confs)
300
+
301
+ return filtered_preds, new_confs
302
+
303
+ def process_image(img):
304
+ mean = [0.485, 0.456, 0.406]
305
+ std = [0.229, 0.224, 0.225]
306
+ img = skimage.img_as_float(img)
307
+ img = img.transpose((2, 0, 1))
308
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
309
+ img = torch.Tensor(img).cuda()
310
+ img = img.unsqueeze(0)
311
+ return img
312
+
313
+ def postprocess_preds(corners, confs, edges):
314
+ corner_degrees = dict()
315
+ for edge_i, edge_pair in enumerate(edges):
316
+ corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
317
+ corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
318
+ good_ids = [i for i in range(len(corners)) if i in corner_degrees]
319
+ if len(good_ids) == len(corners):
320
+ return corners, confs, edges
321
+ else:
322
+ good_corners = corners[good_ids]
323
+ good_confs = confs[good_ids]
324
+ id_mapping = {value: idx for idx, value in enumerate(good_ids)}
325
+ new_edges = list()
326
+ for edge_pair in edges:
327
+ new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
328
+ new_edges.append(new_pair)
329
+ new_edges = np.array(new_edges)
330
+ return good_corners, good_confs, new_edges
331
+
332
+ #---------------------------------------------------------#
333
+ # 将输入图像根据角点坐标进行可视化处理
334
+ # 不同于源代码,我们需要直接返回图像对象而不是保存到指定地址
335
+ #---------------------------------------------------------#
336
+ def visualize_cond_generation(positive_pixels, confs, image, gt_corners=None, prec=None, recall=None,
337
+ image_masks=None, edges=None, edge_confs=None, shpfile=False):
338
+ # 复制原图
339
+ image = image.copy()
340
+ if confs is not None:
341
+ viz_confs = confs
342
+
343
+ if edges is not None:
344
+ preds = positive_pixels.astype(int)
345
+ c_degrees = dict()
346
+ for edge_i, edge_pair in enumerate(edges):
347
+ conf = (edge_confs[edge_i] * 2) - 1
348
+ cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
349
+ c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
350
+ c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
351
+
352
+ for idx, c in enumerate(positive_pixels):
353
+ if edges is not None and idx not in c_degrees:
354
+ continue
355
+ if confs is None:
356
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
357
+ else:
358
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
359
+ # if edges is not None:
360
+ # cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
361
+ # 0.5, (255, 0, 0), 1, cv2.LINE_AA)
362
+
363
+ if gt_corners is not None:
364
+ for c in gt_corners:
365
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
366
+
367
+ if image_masks is not None:
368
+ mask_ids = np.where(image_masks == 1)[0]
369
+ for mask_id in mask_ids:
370
+ y_idx = mask_id // 64
371
+ x_idx = (mask_id - y_idx * 64)
372
+ x_coord = x_idx * 4
373
+ y_coord = y_idx * 4
374
+ cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
375
+
376
+ # if confs is not None:
377
+ # cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
378
+ # 0.5, (255, 255, 0), 1, cv2.LINE_AA)
379
+ if prec is not None:
380
+ if isinstance(prec, tuple):
381
+ cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
382
+ cv2.FONT_HERSHEY_SIMPLEX,
383
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
384
+ cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
385
+ cv2.FONT_HERSHEY_SIMPLEX,
386
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
387
+ else:
388
+ cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
389
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
390
+
391
+ # 是否生成shp文件
392
+ if shpfile:
393
+ preds = positive_pixels.astype(int)
394
+ # 获取点列表
395
+ Polyline = []
396
+ for edge_i, edge_pair in enumerate(edges):
397
+ Polyline.append([preds[edge_pair[0]], preds[edge_pair[1]]])
398
+ Polyline = np.array(Polyline, dtype=np.int32)
399
+ # 写入shp文件
400
+ writeShp(save_file_dir="shpfile", Polyline=Polyline)
401
+
402
+
403
+ return image
404
+
405
+ def writeShp(save_file_dir="shpfile", Polyline=None):
406
+ # 创建文件夹
407
+ if os.path.exists(save_file_dir) is False:
408
+ os.makedirs(save_file_dir)
409
+ # 支持中文路径
410
+ gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
411
+ # 属性表字段支持中文
412
+ gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
413
+ # 注册驱动
414
+ ogr.RegisterAll()
415
+ # 创建shp数据
416
+ strDriverName = "ESRI Shapefile"
417
+ oDriver = ogr.GetDriverByName(strDriverName)
418
+ if oDriver == None:
419
+ return "驱动不可用:"+strDriverName
420
+ # 创建数据源
421
+ file_path = os.path.join(save_file_dir, "result.shp")
422
+ oDS = oDriver.CreateDataSource(file_path)
423
+ if oDS == None:
424
+ return "创建文件失败:result.shp"
425
+ if Polyline is not None:
426
+ # 创建一个多边形图层,指定坐标系为WGS84
427
+ papszLCO = []
428
+ geosrs = osr.SpatialReference()
429
+ geosrs.SetWellKnownGeogCS("WGS84")
430
+ # 线:ogr_type = ogr.wkbLineString
431
+ # 点:ogr_type = ogr.wkbPoint
432
+ ogr_type = ogr.wkbMultiLineString
433
+ # 面的类型为Polygon,线的类型为Polyline,点的类型为Point
434
+ oLayer = oDS.CreateLayer("Polyline", geosrs, ogr_type, papszLCO)
435
+ if oLayer == None:
436
+ return "图层创建失败!"
437
+ # 创建属性表
438
+ # 创建id字段
439
+ oId = ogr.FieldDefn("id", ogr.OFTInteger)
440
+ oLayer.CreateField(oId, 1)
441
+ # 创建name字段
442
+ oName = ogr.FieldDefn("name", ogr.OFTString)
443
+ oLayer.CreateField(oName, 1)
444
+ oDefn = oLayer.GetLayerDefn()
445
+ # 创建要素
446
+ # 数据集
447
+ # wkt_geom id name
448
+ point_str_list = ['({} {},{} {})'.format(row[0, 0], row[0, 1], row[1, 0], row[1, 1]) for row in Polyline]
449
+ Polyline_Wkt = ','.join(point_str_list)
450
+ features = ['Polyline0;MULTILINESTRING({})'.format(Polyline_Wkt)]
451
+ for index, f in enumerate(features):
452
+ oFeaturePolygon = ogr.Feature(oDefn)
453
+ oFeaturePolygon.SetField("id",index)
454
+ oFeaturePolygon.SetField("name",f.split(";")[0])
455
+ geomPolygon = ogr.CreateGeometryFromWkt(f.split(";")[1])
456
+ oFeaturePolygon.SetGeometry(geomPolygon)
457
+ oLayer.CreateFeature(oFeaturePolygon)
458
+ # 创建完成后,关闭进程
459
+ oDS.Destroy()
460
+ return "数据集创建完成!"
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ * @Description:
3
+ * @Author: Egrt
4
+ * @Date: 2022-11-23 15:20:00
5
+ * @LastEditors: Egrt
6
+ * @LastEditTime: 2022-11-23 15:27:41
7
+ -->
8
+ ---
9
+ title: HEAT
10
+ emoji: 📈
11
+ colorFrom: indigo
12
+ colorTo: yellow
13
+ sdk: gradio
14
+ sdk_version: 3.11.0
15
+ app_file: app.py
16
+ pinned: false
17
+ license: apache-2.0
18
+ ---
19
+
20
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
21
+
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Egrt
3
+ Date: 2022-01-13 13:34:10
4
+ LastEditors: [egrt]
5
+ LastEditTime: 2022-08-15 19:40:32
6
+ FilePath: \MaskGAN\app.py
7
+ '''
8
+ from HEAT import HEAT
9
+ import gradio as gr
10
+ import os
11
+ heat = HEAT()
12
+
13
+ # --------模型推理---------- #
14
+ def inference(img):
15
+ image_result = heat.detect_one_image(img)
16
+ return image_result
17
+
18
+ # --------网页信息---------- #
19
+ title = "HEAT"
20
+ description = "HEAT: Holistic Edge Attention Transformer for Structured Reconstruction @Luuuu"
21
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>HEAT: Holistic Edge Attention Transformer for Structured Reconstruction </a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>"
22
+ example_img_dir = 'images/'
23
+ example_img_name = os.listdir(example_img_dir)
24
+ examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith(('.jpg','.jpeg', '.png'))]
25
+ gr.Interface(
26
+ inference,
27
+ [gr.inputs.Image(type="pil", label="Input")],
28
+ gr.outputs.Image(type="pil", label="Output"),
29
+ title=title,
30
+ description=description,
31
+ article=article,
32
+ examples=examples
33
+ ).launch()
arguments.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_args_parser():
5
+ parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
6
+ parser.add_argument('--exp_dataset', default='outdoor',
7
+ help='the dataset for experiments, outdoor/s3d_floorplan')
8
+ parser.add_argument('--lr', default=2e-4, type=float)
9
+ parser.add_argument('--batch_size', default=16, type=int)
10
+ parser.add_argument('--weight_decay', default=1e-5, type=float)
11
+ parser.add_argument('--epochs', default=800, type=int)
12
+ parser.add_argument('--lr_drop', default=600, type=int)
13
+ parser.add_argument('--clip_max_norm', default=0.1, type=float,
14
+ help='gradient clipping max norm')
15
+ parser.add_argument('--print_freq', default=40, type=int)
16
+ parser.add_argument('--output_dir', default='./checkpoints/ckpts_heat_outdoor_256',
17
+ help='path where to save, empty for no saving')
18
+ parser.add_argument('--resume', default='',
19
+ help='resume from checkpoint')
20
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
21
+ help='start epoch')
22
+ parser.add_argument('--num_workers', default=4, type=int)
23
+ parser.add_argument('--image_size', default=256, type=int)
24
+ parser.add_argument('--max_corner_num', default=150, type=int,
25
+ help='the max number of corners allowed in the experiments')
26
+ parser.add_argument('--corner_to_edge_multiplier', default=3, type=int,
27
+ help='the max number of edges based on the number of corner candidates (assuming the '
28
+ 'average degree never greater than 6)')
29
+ parser.add_argument('--lambda_corner', default=0.05, type=float,
30
+ help='the max number of corners allowed in the experiments')
31
+ parser.add_argument('--run_validation', action='store_true',
32
+ help='Whether run validation or not, default: False')
33
+ return parser
assets/img/pipeline.png ADDED
assets/img/problem_description.png ADDED
datasets/__init__.py ADDED
File without changes
datasets/corners.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import Dataset
3
+ from scipy.ndimage import gaussian_filter
4
+ import cv2
5
+
6
+ mean = [0.485, 0.456, 0.406]
7
+ std = [0.229, 0.224, 0.225]
8
+
9
+
10
+ class CornersDataset(Dataset):
11
+ def __init__(self, image_size=256, inference=False):
12
+ super(CornersDataset, self).__init__()
13
+ self.image_size = image_size
14
+ self.inference = inference
15
+ self._data_names = []
16
+
17
+ def __len__(self):
18
+ raise len(self._data_names)
19
+
20
+ def __getitem__(self, idx):
21
+ raise NotImplementedError
22
+
23
+ def process_data(self, data):
24
+ img = data['image']
25
+ corners = data['corners']
26
+ annot = data['annot']
27
+
28
+ # pre-process the image to use ImageNet-pretrained backbones
29
+ img = img.transpose((2, 0, 1))
30
+ raw_img = img.copy()
31
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
32
+ img = img.astype(np.float32)
33
+
34
+ corners = np.array(corners)
35
+
36
+ all_data = {
37
+ "annot": annot,
38
+ "name": data['name'],
39
+ 'img': img,
40
+ 'annot_path': data['annot_path'],
41
+ 'img_path': data['img_path'],
42
+ 'det_path': data['det_path'],
43
+ 'raw_img': raw_img,
44
+ }
45
+
46
+ # corner labels for training
47
+ if not self.inference:
48
+ pixel_labels, gauss_labels = self.get_corner_labels(corners)
49
+ all_data['pixel_labels'] = pixel_labels
50
+ all_data['gauss_labels'] = gauss_labels
51
+
52
+ return all_data
53
+
54
+ def get_corner_labels(self, corners):
55
+ labels = np.zeros((self.image_size, self.image_size))
56
+ corners = corners.round()
57
+ xint, yint = corners[:, 0].astype(np.int), corners[:, 1].astype(np.int)
58
+ labels[yint, xint] = 1
59
+
60
+ gauss_labels = gaussian_filter(labels, sigma=2)
61
+ gauss_labels = gauss_labels / gauss_labels.max()
62
+ return labels, gauss_labels
63
+
64
+ def resize_data(self, image, annot, det_corners):
65
+ new_image = cv2.resize(image, (self.image_size, self.image_size))
66
+ new_annot = {}
67
+ r = self.image_size / 256
68
+ for c, connections in annot.items():
69
+ new_c = tuple(np.array(c) * r)
70
+ new_connections = [other_c * r for other_c in connections]
71
+ new_annot[new_c] = new_connections
72
+ new_dets = det_corners * r
73
+ return new_image, new_annot, new_dets
74
+
75
+ def random_aug_annot(self, img, annot, det_corners=None):
76
+ # do random flipping
77
+ img, annot, det_corners = self.random_flip(img, annot, det_corners)
78
+
79
+ # prepare random augmentation parameters (only do random rotation for now)
80
+ theta = np.random.randint(0, 360) / 360 * np.pi * 2
81
+ r = self.image_size / 256
82
+ origin = [127 * r, 127 * r]
83
+ p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
84
+ p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
85
+ p1_old = [127 * r, 127 * r - 100 * r] # y_axis
86
+ p2_old = [127 * r + 100 * r, 127 * r] # x_axis
87
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
88
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
89
+ M_rot = cv2.getAffineTransform(pts1, pts2)
90
+
91
+ # Combine annotation corners and detection corners
92
+ all_corners = list(annot.keys())
93
+ if det_corners is not None:
94
+ for i in range(det_corners.shape[0]):
95
+ all_corners.append(tuple(det_corners[i]))
96
+ all_corners_ = np.array(all_corners)
97
+
98
+ # Do the corner transform within a big matrix transformation
99
+ corner_mapping = dict()
100
+ ones = np.ones([all_corners_.shape[0], 1])
101
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
102
+ aug_corners = np.matmul(M_rot, all_corners_.T).T
103
+
104
+ for idx, corner in enumerate(all_corners):
105
+ corner_mapping[corner] = aug_corners[idx]
106
+
107
+ # If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
108
+ new_corners = np.array(list(corner_mapping.values()))
109
+ if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
110
+ # return self.random_aug_annot(img, annot, det_corners)
111
+ return img, annot, None, det_corners
112
+
113
+ # build the new annot dict
114
+ aug_annot = dict()
115
+ for corner, connections in annot.items():
116
+ new_corner = corner_mapping[corner]
117
+ tuple_new_corner = tuple(new_corner)
118
+ aug_annot[tuple_new_corner] = list()
119
+ for to_corner in connections:
120
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
121
+
122
+ # Also transform the image correspondingly
123
+ rows, cols, ch = img.shape
124
+ new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
125
+
126
+ y_start = (new_img.shape[0] - self.image_size) // 2
127
+ x_start = (new_img.shape[1] - self.image_size) // 2
128
+ aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
129
+
130
+ if det_corners is None:
131
+ return aug_img, aug_annot, corner_mapping, None
132
+ else:
133
+ aug_det_corners = list()
134
+ for corner in det_corners:
135
+ new_corner = corner_mapping[tuple(corner)]
136
+ aug_det_corners.append(new_corner)
137
+ aug_det_corners = np.array(aug_det_corners)
138
+ return aug_img, aug_annot, corner_mapping, aug_det_corners
139
+
140
+ def random_flip(self, img, annot, det_corners):
141
+ height, width, _ = img.shape
142
+ rand_int = np.random.randint(0, 4)
143
+ if rand_int == 0:
144
+ return img, annot, det_corners
145
+
146
+ all_corners = list(annot.keys())
147
+ if det_corners is not None:
148
+ for i in range(det_corners.shape[0]):
149
+ all_corners.append(tuple(det_corners[i]))
150
+ new_corners = np.array(all_corners)
151
+
152
+ if rand_int == 1:
153
+ img = img[:, ::-1, :]
154
+ new_corners[:, 0] = width - new_corners[:, 0]
155
+ elif rand_int == 2:
156
+ img = img[::-1, :, :]
157
+ new_corners[:, 1] = height - new_corners[:, 1]
158
+ else:
159
+ img = img[::-1, ::-1, :]
160
+ new_corners[:, 0] = width - new_corners[:, 0]
161
+ new_corners[:, 1] = height - new_corners[:, 1]
162
+
163
+ new_corners = np.clip(new_corners, 0, self.image_size - 1) # clip into [0, 255]
164
+ corner_mapping = dict()
165
+ for idx, corner in enumerate(all_corners):
166
+ corner_mapping[corner] = new_corners[idx]
167
+
168
+ aug_annot = dict()
169
+ for corner, connections in annot.items():
170
+ new_corner = corner_mapping[corner]
171
+ tuple_new_corner = tuple(new_corner)
172
+ aug_annot[tuple_new_corner] = list()
173
+ for to_corner in connections:
174
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
175
+
176
+ if det_corners is not None:
177
+ aug_det_corners = list()
178
+ for corner in det_corners:
179
+ new_corner = corner_mapping[tuple(corner)]
180
+ aug_det_corners.append(new_corner)
181
+ det_corners = np.array(aug_det_corners)
182
+
183
+ return img, aug_annot, det_corners
datasets/data_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import ImageFilter
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ from utils.nn_utils import positional_encoding_2d
5
+ from torch.utils.data.dataloader import default_collate
6
+
7
+
8
+ def RandomBlur(radius=2.):
9
+ blur = GaussianBlur(radius=radius)
10
+ full_transform = transforms.RandomApply([blur], p=.3)
11
+ return full_transform
12
+
13
+
14
+ class ImageFilterTransform(object):
15
+
16
+ def __init__(self):
17
+ raise NotImplementedError
18
+
19
+ def __call__(self, img):
20
+ return img.filter(self.filter)
21
+
22
+
23
+ class GaussianBlur(ImageFilterTransform):
24
+
25
+ def __init__(self, radius=2.):
26
+ self.filter = ImageFilter.GaussianBlur(radius=radius)
27
+
28
+
29
+ def collate_fn(data):
30
+ batched_data = {}
31
+ for field in data[0].keys():
32
+ if field in ['annot', 'rec_mat']:
33
+ batch_values = [item[field] for item in data]
34
+ else:
35
+ batch_values = default_collate([d[field] for d in data])
36
+ if field in ['pixel_features', 'pixel_labels', 'gauss_labels']:
37
+ batch_values = batch_values.float()
38
+ batched_data[field] = batch_values
39
+
40
+ return batched_data
41
+
42
+
43
+ def get_pixel_features(image_size, d_pe=128):
44
+ all_pe = positional_encoding_2d(d_pe, image_size, image_size)
45
+ pixels_x = np.arange(0, image_size)
46
+ pixels_y = np.arange(0, image_size)
47
+
48
+ xv, yv = np.meshgrid(pixels_x, pixels_y)
49
+ all_pixels = list()
50
+ for i in range(xv.shape[0]):
51
+ pixs = np.stack([xv[i], yv[i]], axis=-1)
52
+ all_pixels.append(pixs)
53
+ pixels = np.stack(all_pixels, axis=0)
54
+
55
+ pixel_features = all_pe[:, pixels[:, :, 1], pixels[:, :, 0]]
56
+ pixel_features = pixel_features.permute(1, 2, 0)
57
+ return pixels, pixel_features
datasets/outdoor_buildings.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from datasets.corners import CornersDataset
3
+ import os
4
+ import skimage
5
+ import cv2
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ from datasets.data_utils import RandomBlur
9
+
10
+ class OutdoorBuildingDataset(CornersDataset):
11
+ def __init__(self, data_path, det_path, phase='train', image_size=256, rand_aug=True,
12
+ inference=False):
13
+ super(OutdoorBuildingDataset, self).__init__(image_size, inference)
14
+ self.data_path = data_path
15
+ self.det_path = det_path
16
+ self.phase = phase
17
+ self.rand_aug = rand_aug
18
+ self.image_size = image_size
19
+ self.inference = inference
20
+
21
+ blur_transform = RandomBlur()
22
+ self.train_transform = transforms.Compose([
23
+ transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
24
+ transforms.RandomGrayscale(p=0.3),
25
+ blur_transform])
26
+
27
+ if phase == 'train':
28
+ datalistfile = os.path.join(data_path, 'train_list.txt')
29
+ self.training = True
30
+ else:
31
+ datalistfile = os.path.join(data_path, 'valid_list.txt')
32
+ self.training = False
33
+ with open(datalistfile, 'r') as f:
34
+ _data_names = f.readlines()
35
+ if phase == 'train':
36
+ self._data_names = _data_names
37
+ else:
38
+ # based on the data split rule from previous works
39
+ if phase == 'valid':
40
+ self._data_names = _data_names[:50]
41
+ elif phase == 'test':
42
+ self._data_names = _data_names[50:]
43
+ else:
44
+ raise ValueError('Invalid phase {}'.format(phase))
45
+
46
+ def __len__(self):
47
+ return len(self._data_names)
48
+
49
+ def __getitem__(self, idx):
50
+ data_name = self._data_names[idx][:-1]
51
+ annot_path = os.path.join(self.data_path, 'annot', data_name + '.npy')
52
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
53
+ det_path = os.path.join(self.det_path, data_name + '.npy')
54
+ det_corners = np.array(np.load(det_path, allow_pickle=True)) # [N, 2]
55
+ det_corners = det_corners[:, ::-1] # turn into x,y format
56
+
57
+ img_path = os.path.join(self.data_path, 'rgb', data_name + '.jpg')
58
+ rgb = cv2.imread(img_path)
59
+
60
+ if self.image_size != 256:
61
+ rgb, annot, det_corners = self.resize_data(rgb, annot, det_corners)
62
+
63
+ if self.rand_aug:
64
+ image, annot, corner_mapping, det_corners = self.random_aug_annot(rgb, annot, det_corners=det_corners)
65
+ else:
66
+ image = rgb
67
+ rec_mat = None
68
+
69
+ corners = np.array(list(annot.keys()))[:, [1, 0]]
70
+
71
+ if not self.inference and len(corners) > 100:
72
+ new_idx = np.random.randint(0, len(self))
73
+ return self.__getitem__(new_idx)
74
+
75
+ if self.training:
76
+ # Add some randomness for g.t. corners
77
+ corners += np.random.normal(0, 0, size=corners.shape)
78
+ pil_img = Image.fromarray(image)
79
+ image = self.train_transform(pil_img)
80
+ image = np.array(image)
81
+ image = skimage.img_as_float(image)
82
+
83
+ # sort by the second value and then the first value, here the corners are in the format of (y, x)
84
+ sort_idx = np.lexsort(corners.T)
85
+ corners = corners[sort_idx]
86
+
87
+ corner_list = []
88
+ for corner_i in range(corners.shape[0]):
89
+ corner_list.append((corners[corner_i][1], corners[corner_i][0])) # to (x, y) format
90
+
91
+ raw_data = {
92
+ 'name': data_name,
93
+ 'corners': corner_list,
94
+ 'annot': annot,
95
+ 'image': image,
96
+ 'rec_mat': rec_mat,
97
+ 'annot_path': annot_path,
98
+ 'det_path': det_path,
99
+ 'img_path': img_path,
100
+ }
101
+
102
+ return self.process_data(raw_data)
103
+
104
+ def random_aug_annot(self, img, annot, det_corners=None):
105
+ # do random flipping
106
+ img, annot, det_corners = self.random_flip(img, annot, det_corners)
107
+
108
+ # prepare random augmentation parameters (only do random rotation for now)
109
+ theta = np.random.randint(0, 360) / 360 * np.pi * 2
110
+ r = self.image_size / 256
111
+ origin = [127 * r, 127 * r]
112
+ p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
113
+ p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
114
+ p1_old = [127 * r, 127 * r - 100 * r] # y_axis
115
+ p2_old = [127 * r + 100 * r, 127 * r] # x_axis
116
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
117
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
118
+ M_rot = cv2.getAffineTransform(pts1, pts2)
119
+
120
+ # Combine annotation corners and detection corners
121
+ all_corners = list(annot.keys())
122
+ if det_corners is not None:
123
+ for i in range(det_corners.shape[0]):
124
+ all_corners.append(tuple(det_corners[i]))
125
+ all_corners_ = np.array(all_corners)
126
+
127
+ # Do the corner transform within a big matrix transformation
128
+ corner_mapping = dict()
129
+ ones = np.ones([all_corners_.shape[0], 1])
130
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
131
+ aug_corners = np.matmul(M_rot, all_corners_.T).T
132
+
133
+ for idx, corner in enumerate(all_corners):
134
+ corner_mapping[corner] = aug_corners[idx]
135
+
136
+ # If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
137
+ new_corners = np.array(list(corner_mapping.values()))
138
+ if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
139
+ # return self.random_aug_annot(img, annot, det_corners)
140
+ return img, annot, None, det_corners
141
+
142
+ # build the new annot dict
143
+ aug_annot = dict()
144
+ for corner, connections in annot.items():
145
+ new_corner = corner_mapping[corner]
146
+ tuple_new_corner = tuple(new_corner)
147
+ aug_annot[tuple_new_corner] = list()
148
+ for to_corner in connections:
149
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
150
+
151
+ # Also transform the image correspondingly
152
+ rows, cols, ch = img.shape
153
+ new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
154
+
155
+ y_start = (new_img.shape[0] - self.image_size) // 2
156
+ x_start = (new_img.shape[1] - self.image_size) // 2
157
+ aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
158
+
159
+ if det_corners is None:
160
+ return aug_img, aug_annot, corner_mapping, None
161
+ else:
162
+ aug_det_corners = list()
163
+ for corner in det_corners:
164
+ new_corner = corner_mapping[tuple(corner)]
165
+ aug_det_corners.append(new_corner)
166
+ aug_det_corners = np.array(aug_det_corners)
167
+ return aug_img, aug_annot, corner_mapping, aug_det_corners
168
+
169
+
170
+
171
+ if __name__ == '__main__':
172
+ from torch.utils.data import DataLoader
173
+
174
+ DATAPATH = './data/cities_dataset'
175
+ DET_PATH = './data/det_final'
176
+ train_dataset = OutdoorBuildingDataset(DATAPATH, DET_PATH, phase='train')
177
+ train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0,
178
+ collate_fn=collate_fn)
179
+ for i, item in enumerate(train_dataloader):
180
+ import pdb;
181
+
182
+ pdb.set_trace()
183
+ print(item)
datasets/s3d_floorplans.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from datasets.corners import CornersDataset
3
+ import os
4
+ import skimage
5
+ import cv2
6
+ import itertools
7
+
8
+
9
+ mean = [0.485, 0.456, 0.406]
10
+ std = [0.229, 0.224, 0.225]
11
+
12
+ all_combibations = dict()
13
+ for length in range(2, 351):
14
+ ids = np.arange(length)
15
+ combs = np.array(list(itertools.combinations(ids, 2)))
16
+ all_combibations[length] = combs
17
+
18
+
19
+ class S3DFloorplanDataset(CornersDataset):
20
+ def __init__(self, data_path, phase='train', image_size=256, rand_aug=True, inference=False):
21
+ super(S3DFloorplanDataset, self).__init__(image_size, inference)
22
+ self.data_path = data_path
23
+ self.phase = phase
24
+ self.rand_aug = rand_aug
25
+
26
+ if phase == 'train':
27
+ datalistfile = os.path.join(data_path, 'train_list.txt')
28
+ self.training = True
29
+ elif phase == 'valid':
30
+ datalistfile = os.path.join(data_path, 'valid_list.txt')
31
+ self.training = False
32
+ else:
33
+ datalistfile = os.path.join(data_path, 'test_list.txt')
34
+ self.training = False
35
+ with open(datalistfile, 'r') as f:
36
+ self._data_names = f.readlines()
37
+
38
+ def __len__(self):
39
+ return len(self._data_names)
40
+
41
+ def __getitem__(self, idx):
42
+ data_name = self._data_names[idx][:-1]
43
+ annot_path = os.path.join(self.data_path, 'annot', data_name + '.npy')
44
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
45
+
46
+ density_path = os.path.join(self.data_path, 'density', data_name + '.png')
47
+ normal_path = os.path.join(self.data_path, 'normals', data_name + '.png')
48
+
49
+ density = cv2.imread(density_path)
50
+ normal = cv2.imread(normal_path)
51
+ rgb = np.maximum(density, normal)
52
+
53
+ if self.image_size != 256:
54
+ rgb, annot, det_corners = self.resize_data(rgb, annot, None)
55
+
56
+ if self.rand_aug:
57
+ image, annot, _ = self.random_aug_annot(rgb, annot, det_corners=None)
58
+ else:
59
+ image = rgb
60
+ rec_mat = None
61
+
62
+ corners = np.array(list(annot.keys()))[:, [1, 0]]
63
+
64
+ if not self.inference and len(corners) > 150:
65
+ new_idx = np.random.randint(0, len(self))
66
+ return self.__getitem__(new_idx)
67
+
68
+ if self.training:
69
+ # Add some randomness for g.t. corners
70
+ corners += np.random.normal(0, 0, size=corners.shape)
71
+
72
+ image = skimage.img_as_float(image)
73
+
74
+ # sort by the second value and then the first value, here the corners are in the format of (y, x)
75
+ sort_idx = np.lexsort(corners.T)
76
+ corners = corners[sort_idx]
77
+
78
+ corner_list = []
79
+ for corner_i in range(corners.shape[0]):
80
+ corner_list.append((corners[corner_i][1], corners[corner_i][0])) # to (x, y) format
81
+
82
+ raw_data = {
83
+ 'name': data_name,
84
+ 'corners': corner_list,
85
+ 'annot': annot,
86
+ 'image': image,
87
+ 'rec_mat': rec_mat,
88
+ 'annot_path': annot_path,
89
+ 'img_path': density_path,
90
+ }
91
+
92
+ return self.process_data(raw_data)
93
+
94
+ def process_data(self, data):
95
+ img = data['image']
96
+ corners = data['corners']
97
+ annot = data['annot']
98
+
99
+ # pre-process the image to use ImageNet-pretrained backbones
100
+ img = img.transpose((2, 0, 1))
101
+ raw_img = img.copy()
102
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
103
+ img = img.astype(np.float32)
104
+
105
+ corners = np.array(corners)
106
+
107
+ all_data = {
108
+ "annot": annot,
109
+ "name": data['name'],
110
+ 'img': img,
111
+ 'annot_path': data['annot_path'],
112
+ 'img_path': data['img_path'],
113
+ 'raw_img': raw_img,
114
+ }
115
+
116
+ # corner labels
117
+ if not self.inference:
118
+ pixel_labels, gauss_labels = self.get_corner_labels(corners)
119
+ all_data['pixel_labels'] = pixel_labels
120
+ all_data['gauss_labels'] = gauss_labels
121
+
122
+ return all_data
123
+
124
+ def random_aug_annot(self, img, annot, det_corners=None):
125
+ # do random flipping
126
+ img, annot, det_corners = self.random_flip(img, annot, det_corners)
127
+ # return img, annot, None
128
+
129
+ # prepare random augmentation parameters (only do random rotation for now)
130
+ theta = np.random.randint(0, 360) / 360 * np.pi * 2
131
+ r = self.image_size / 256
132
+ origin = [127 * r, 127 * r]
133
+ p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
134
+ p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
135
+ p1_old = [127 * r, 127 * r - 100 * r] # y_axis
136
+ p2_old = [127 * r + 100 * r, 127 * r] # x_axis
137
+ pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
138
+ pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
139
+ M_rot = cv2.getAffineTransform(pts1, pts2)
140
+
141
+ # Combine annotation corners and detection corners
142
+ all_corners = list(annot.keys())
143
+ if det_corners is not None:
144
+ for i in range(det_corners.shape[0]):
145
+ all_corners.append(tuple(det_corners[i]))
146
+ all_corners_ = np.array(all_corners)
147
+
148
+ # Do the per-corner transform
149
+ # Done in a big matrix transformation to save processing time.
150
+ corner_mapping = dict()
151
+ ones = np.ones([all_corners_.shape[0], 1])
152
+ all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
153
+ aug_corners = np.matmul(M_rot, all_corners_.T).T
154
+
155
+ for idx, corner in enumerate(all_corners):
156
+ corner_mapping[corner] = aug_corners[idx]
157
+
158
+ # If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
159
+ new_corners = np.array(list(corner_mapping.values()))
160
+ if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
161
+ # return self.random_aug_annot(img, annot, det_corners)
162
+ return img, annot, None
163
+
164
+ # build the new annot dict
165
+ aug_annot = dict()
166
+ for corner, connections in annot.items():
167
+ new_corner = corner_mapping[corner]
168
+ tuple_new_corner = tuple(new_corner)
169
+ aug_annot[tuple_new_corner] = list()
170
+ for to_corner in connections:
171
+ aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
172
+
173
+ # Also transform the image correspondingly
174
+ rows, cols, ch = img.shape
175
+ new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
176
+
177
+ y_start = (new_img.shape[0] - self.image_size) // 2
178
+ x_start = (new_img.shape[1] - self.image_size) // 2
179
+ aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
180
+
181
+ return aug_img, aug_annot, None
182
+
183
+
184
+
185
+
186
+
187
+
images/test.jpg ADDED
infer.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ from datasets.outdoor_buildings import OutdoorBuildingDataset
5
+ from datasets.s3d_floorplans import S3DFloorplanDataset
6
+ from datasets.data_utils import collate_fn, get_pixel_features
7
+ from models.resnet import ResNetBackbone
8
+ from models.corner_models import HeatCorner
9
+ from models.edge_models import HeatEdge
10
+ from models.corner_to_edge import get_infer_edge_pairs
11
+ from utils.geometry_utils import corner_eval
12
+ import numpy as np
13
+ import cv2
14
+ import os
15
+ import scipy.ndimage.filters as filters
16
+ import matplotlib.pyplot as plt
17
+ from metrics.get_metric import compute_metrics, get_recall_and_precision
18
+ import skimage
19
+ import argparse
20
+
21
+
22
+ def visualize_cond_generation(positive_pixels, confs, image, save_path, gt_corners=None, prec=None, recall=None,
23
+ image_masks=None, edges=None, edge_confs=None):
24
+ image = image.copy() # get a new copy of the original image
25
+ if confs is not None:
26
+ viz_confs = confs
27
+
28
+ if edges is not None:
29
+ preds = positive_pixels.astype(int)
30
+ c_degrees = dict()
31
+ for edge_i, edge_pair in enumerate(edges):
32
+ conf = (edge_confs[edge_i] * 2) - 1
33
+ cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
34
+ c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
35
+ c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
36
+
37
+ for idx, c in enumerate(positive_pixels):
38
+ if edges is not None and idx not in c_degrees:
39
+ continue
40
+ if confs is None:
41
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
42
+ else:
43
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
44
+ # if edges is not None:
45
+ # cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
46
+ # 0.5, (255, 0, 0), 1, cv2.LINE_AA)
47
+
48
+ if gt_corners is not None:
49
+ for c in gt_corners:
50
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
51
+
52
+ if image_masks is not None:
53
+ mask_ids = np.where(image_masks == 1)[0]
54
+ for mask_id in mask_ids:
55
+ y_idx = mask_id // 64
56
+ x_idx = (mask_id - y_idx * 64)
57
+ x_coord = x_idx * 4
58
+ y_coord = y_idx * 4
59
+ cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
60
+
61
+ # if confs is not None:
62
+ # cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
63
+ # 0.5, (255, 255, 0), 1, cv2.LINE_AA)
64
+ if prec is not None:
65
+ if isinstance(prec, tuple):
66
+ cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
67
+ cv2.FONT_HERSHEY_SIMPLEX,
68
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
69
+ cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
70
+ cv2.FONT_HERSHEY_SIMPLEX,
71
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
72
+ else:
73
+ cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
74
+ 0.5, (255, 255, 0), 1, cv2.LINE_AA)
75
+ cv2.imwrite(save_path, image)
76
+
77
+
78
+ def corner_nms(preds, confs, image_size):
79
+ data = np.zeros([image_size, image_size])
80
+ neighborhood_size = 5
81
+ threshold = 0
82
+
83
+ for i in range(len(preds)):
84
+ data[preds[i, 1], preds[i, 0]] = confs[i]
85
+
86
+ data_max = filters.maximum_filter(data, neighborhood_size)
87
+ maxima = (data == data_max)
88
+ data_min = filters.minimum_filter(data, neighborhood_size)
89
+ diff = ((data_max - data_min) > threshold)
90
+ maxima[diff == 0] = 0
91
+
92
+ results = np.where(maxima > 0)
93
+ filtered_preds = np.stack([results[1], results[0]], axis=-1)
94
+
95
+ new_confs = list()
96
+ for i, pred in enumerate(filtered_preds):
97
+ new_confs.append(data[pred[1], pred[0]])
98
+ new_confs = np.array(new_confs)
99
+
100
+ return filtered_preds, new_confs
101
+
102
+
103
+ def main(dataset, ckpt_path, image_size, viz_base, save_base, infer_times):
104
+ ckpt = torch.load(ckpt_path)
105
+ print('Load from ckpts of epoch {}'.format(ckpt['epoch']))
106
+ ckpt_args = ckpt['args']
107
+ if dataset == 'outdoor':
108
+ data_path = './data/outdoor/cities_dataset'
109
+ det_path = './data/outdoor/det_final'
110
+ test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='test', image_size=image_size, rand_aug=False,
111
+ inference=True)
112
+ elif dataset == 's3d_floorplan':
113
+ data_path = './data/s3d_floorplan'
114
+ test_dataset = S3DFloorplanDataset(data_path, phase='test', rand_aug=False, inference=True)
115
+ else:
116
+ raise ValueError('Unknown dataset type: {}'.format(dataset))
117
+
118
+ test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0,
119
+ collate_fn=collate_fn)
120
+
121
+ backbone = ResNetBackbone()
122
+ strides = backbone.strides
123
+ num_channels = backbone.num_channels
124
+ backbone = nn.DataParallel(backbone)
125
+ backbone = backbone.cuda()
126
+ backbone.eval()
127
+ corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
128
+ backbone_num_channels=num_channels)
129
+ corner_model = nn.DataParallel(corner_model)
130
+ corner_model = corner_model.cuda()
131
+ corner_model.eval()
132
+
133
+ edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
134
+ backbone_num_channels=num_channels)
135
+ edge_model = nn.DataParallel(edge_model)
136
+ edge_model = edge_model.cuda()
137
+ edge_model.eval()
138
+
139
+ backbone.load_state_dict(ckpt['backbone'])
140
+ corner_model.load_state_dict(ckpt['corner_model'])
141
+ edge_model.load_state_dict(ckpt['edge_model'])
142
+ print('Loaded saved model from {}'.format(ckpt_path))
143
+
144
+ if not os.path.exists(viz_base):
145
+ os.makedirs(viz_base)
146
+ if not os.path.exists(save_base):
147
+ os.makedirs(save_base)
148
+
149
+ all_prec = list()
150
+ all_recall = list()
151
+
152
+ corner_tp = 0.0
153
+ corner_fp = 0.0
154
+ corner_length = 0.0
155
+ edge_tp = 0.0
156
+ edge_fp = 0.0
157
+ edge_length = 0.0
158
+ region_tp = 0.0
159
+ region_fp = 0.0
160
+ region_length = 0.0
161
+
162
+ # get the positional encodings for all pixels
163
+ pixels, pixel_features = get_pixel_features(image_size=image_size)
164
+
165
+ for data_i, data in enumerate(test_dataloader):
166
+ image = data['img'].cuda()
167
+ img_path = data['img_path'][0]
168
+ annot_path = data['annot_path'][0]
169
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
170
+
171
+ with torch.no_grad():
172
+ pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = get_results(image, annot, backbone,
173
+ corner_model,
174
+ edge_model,
175
+ pixels, pixel_features,
176
+ ckpt_args, infer_times,
177
+ corner_thresh=0.01,
178
+ image_size=image_size)
179
+
180
+ # viz_image = cv2.imread(img_path)
181
+ positive_pixels = np.array(list(annot.keys())).round()
182
+
183
+ viz_image = data['raw_img'][0].cpu().numpy().transpose(1, 2, 0)
184
+ viz_image = (viz_image * 255).astype(np.uint8)
185
+
186
+ # visualize G.T.
187
+ gt_path = os.path.join(viz_base, '{}_gt.png'.format(data_i))
188
+ visualize_cond_generation(positive_pixels, None, viz_image, gt_path, gt_corners=None, image_masks=None)
189
+
190
+ if len(pred_corners) > 0:
191
+ prec, recall = corner_eval(positive_pixels, pred_corners)
192
+ else:
193
+ prec = recall = 0
194
+ all_prec.append(prec)
195
+ all_recall.append(recall)
196
+
197
+ if pred_confs.shape[0] == 0:
198
+ pred_confs = None
199
+
200
+ if image_size != 256:
201
+ pred_corners_viz = pred_corners * (image_size / 256)
202
+ else:
203
+ pred_corners_viz = pred_corners
204
+ recon_path = os.path.join(viz_base, '{}_pred_corner.png'.format(data_i))
205
+ visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=prec,
206
+ recall=recall)
207
+
208
+ pred_corners, pred_confs, pos_edges = postprocess_preds(pred_corners, pred_confs, pos_edges)
209
+
210
+ pred_data = {
211
+ 'corners': pred_corners,
212
+ 'edges': pos_edges,
213
+ }
214
+
215
+ if dataset == 's3d_floorplan':
216
+ save_filename = os.path.basename(annot_path)
217
+ save_npy_path = os.path.join(save_base, save_filename)
218
+ np.save(save_npy_path, pred_data)
219
+ else:
220
+ save_results = {
221
+ 'corners': pred_corners,
222
+ 'edges': pos_edges,
223
+ 'image_path': img_path,
224
+ }
225
+ save_path = os.path.join(save_base, '{}_results.npy'.format(data_i))
226
+ np.save(save_path, save_results)
227
+
228
+ gt_data = convert_annot(annot)
229
+
230
+ score = compute_metrics(gt_data, pred_data)
231
+
232
+ edge_recall, edge_prec = get_recall_and_precision(score['edge_tp'], score['edge_fp'], score['edge_length'])
233
+ region_recall, region_prec = get_recall_and_precision(score['region_tp'], score['region_fp'],
234
+ score['region_length'])
235
+ er_recall = (edge_recall, region_recall)
236
+ er_prec = (edge_prec, region_prec)
237
+
238
+ if image_size != 256:
239
+ pred_corners_viz = pred_corners * (image_size / 256)
240
+ else:
241
+ pred_corners_viz = pred_corners
242
+ recon_path = os.path.join(viz_base, '{}_pred_edge.png'.format(data_i))
243
+ visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=er_prec,
244
+ recall=er_recall, edges=pos_edges, edge_confs=edge_confs)
245
+ corner_tp += score['corner_tp']
246
+ corner_fp += score['corner_fp']
247
+ corner_length += score['corner_length']
248
+ edge_tp += score['edge_tp']
249
+ edge_fp += score['edge_fp']
250
+ edge_length += score['edge_length']
251
+ region_tp += score['region_tp']
252
+ region_fp += score['region_fp']
253
+ region_length += score['region_length']
254
+
255
+ print('Finish inference for sample No.{}'.format(data_i))
256
+ avg_prec = np.array(all_prec).mean()
257
+ avg_recall = np.array(all_recall).mean()
258
+
259
+ recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
260
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
261
+ print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
262
+
263
+ # edge
264
+ recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
265
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
266
+ print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
267
+
268
+ # region
269
+ recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
270
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
271
+ print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
272
+
273
+ print('Avg prec: {}, Avg recall: {}'.format(avg_prec, avg_recall))
274
+
275
+
276
+ def get_results(image, annot, backbone, corner_model, edge_model, pixels, pixel_features,
277
+ args, infer_times, corner_thresh=0.5, image_size=256):
278
+ image_feats, feat_mask, all_image_feats = backbone(image)
279
+ pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
280
+ preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
281
+
282
+ c_outputs = preds_s1
283
+ # get predicted corners
284
+ c_outputs_np = c_outputs[0].detach().cpu().numpy()
285
+ pos_indices = np.where(c_outputs_np >= corner_thresh)
286
+ pred_corners = pixels[pos_indices]
287
+ pred_confs = c_outputs_np[pos_indices]
288
+ pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
289
+
290
+ pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
291
+
292
+ corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
293
+ max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0)
294
+
295
+ all_pos_ids = set()
296
+ all_edge_confs = dict()
297
+
298
+ for tt in range(infer_times):
299
+ if tt == 0:
300
+ gt_values = torch.zeros_like(edge_mask).long()
301
+ gt_values[:, :] = 2
302
+
303
+ # run the edge model
304
+ s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = edge_model(image_feats, feat_mask,
305
+ pixel_features,
306
+ edge_coords, edge_mask,
307
+ gt_values, corner_nums,
308
+ max_candidates,
309
+ True)
310
+ # do_inference=True)
311
+
312
+ num_total = s1_logits.shape[2]
313
+ num_selected = selected_ids.shape[1]
314
+ num_filtered = num_total - num_selected
315
+
316
+ s1_preds = s1_logits.squeeze().softmax(0)
317
+ s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
318
+ s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
319
+ s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
320
+ s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
321
+ s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
322
+
323
+ selected_ids = selected_ids.squeeze().detach().cpu().numpy()
324
+ if tt != infer_times - 1:
325
+ s2_preds_np = s2_preds_hb_np
326
+
327
+ pos_edge_ids = np.where(s2_preds_np >= 0.9)
328
+ neg_edge_ids = np.where(s2_preds_np <= 0.01)
329
+ for pos_id in pos_edge_ids[0]:
330
+ actual_id = selected_ids[pos_id]
331
+ if gt_values[0, actual_id] != 2:
332
+ continue
333
+ all_pos_ids.add(actual_id)
334
+ all_edge_confs[actual_id] = s2_preds_np[pos_id]
335
+ gt_values[0, actual_id] = 1
336
+ for neg_id in neg_edge_ids[0]:
337
+ actual_id = selected_ids[neg_id]
338
+ if gt_values[0, actual_id] != 2:
339
+ continue
340
+ gt_values[0, actual_id] = 0
341
+ num_to_pred = (gt_values == 2).sum()
342
+ if num_to_pred <= num_filtered:
343
+ break
344
+ else:
345
+ s2_preds_np = s2_preds_hb_np
346
+
347
+ pos_edge_ids = np.where(s2_preds_np >= 0.5)
348
+ for pos_id in pos_edge_ids[0]:
349
+ actual_id = selected_ids[pos_id]
350
+ if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
351
+ continue
352
+ all_pos_ids.add(actual_id)
353
+ all_edge_confs[actual_id] = s2_preds_np[pos_id]
354
+
355
+ # print('Inference time {}'.format(tt+1))
356
+ pos_edge_ids = list(all_pos_ids)
357
+ edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
358
+ pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
359
+ edge_confs = np.array(edge_confs)
360
+
361
+ if image_size != 256:
362
+ pred_corners = pred_corners / (image_size / 256)
363
+
364
+ return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np
365
+
366
+
367
+ def postprocess_preds(corners, confs, edges):
368
+ corner_degrees = dict()
369
+ for edge_i, edge_pair in enumerate(edges):
370
+ corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
371
+ corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
372
+ good_ids = [i for i in range(len(corners)) if i in corner_degrees]
373
+ if len(good_ids) == len(corners):
374
+ return corners, confs, edges
375
+ else:
376
+ good_corners = corners[good_ids]
377
+ good_confs = confs[good_ids]
378
+ id_mapping = {value: idx for idx, value in enumerate(good_ids)}
379
+ new_edges = list()
380
+ for edge_pair in edges:
381
+ new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
382
+ new_edges.append(new_pair)
383
+ new_edges = np.array(new_edges)
384
+ return good_corners, good_confs, new_edges
385
+
386
+
387
+ def process_image(img):
388
+ mean = [0.485, 0.456, 0.406]
389
+ std = [0.229, 0.224, 0.225]
390
+ img = skimage.img_as_float(img)
391
+ img = img.transpose((2, 0, 1))
392
+ img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
393
+ img = torch.Tensor(img).cuda()
394
+ img = img.unsqueeze(0)
395
+ return img
396
+
397
+
398
+ def plot_heatmap(results, filename):
399
+ # generate 2 2d grids for the x & y bounds
400
+ # import pdb; pdb.set_trace()
401
+ y, x = np.meshgrid(np.linspace(0, 255, 256), np.linspace(0, 255, 256))
402
+
403
+ z = results[::-1, :]
404
+ # x and y are bounds, so z should be the value *inside* those bounds.
405
+ # Therefore, remove the last value from the z array.
406
+ z = z[:-1, :-1]
407
+
408
+ fig, ax = plt.subplots()
409
+
410
+ c = ax.pcolormesh(y, x, z, cmap='RdBu', vmin=0, vmax=1)
411
+ # set the limits of the plot to the limits of the data
412
+ ax.axis([x.min(), x.max(), y.min(), y.max()])
413
+ fig.colorbar(c, ax=ax)
414
+ fig.savefig(filename)
415
+ plt.close()
416
+
417
+
418
+ def convert_annot(annot):
419
+ corners = np.array(list(annot.keys()))
420
+ corners_mapping = {tuple(c): idx for idx, c in enumerate(corners)}
421
+ edges = set()
422
+ for corner, connections in annot.items():
423
+ idx_c = corners_mapping[tuple(corner)]
424
+ for other_c in connections:
425
+ idx_other_c = corners_mapping[tuple(other_c)]
426
+ if (idx_c, idx_other_c) not in edges and (idx_other_c, idx_c) not in edges:
427
+ edges.add((idx_c, idx_other_c))
428
+ edges = np.array(list(edges))
429
+ gt_data = {
430
+ 'corners': corners,
431
+ 'edges': edges
432
+ }
433
+ return gt_data
434
+
435
+
436
+ def get_args_parser():
437
+ parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
438
+ parser.add_argument('--dataset', default='outdoor',
439
+ help='the dataset for experiments, outdoor/s3d_floorplan')
440
+ parser.add_argument('--checkpoint_path', default='',
441
+ help='path to the checkpoints of the model')
442
+ parser.add_argument('--image_size', default=256, type=int)
443
+ parser.add_argument('--viz_base', default='./results/viz',
444
+ help='path to save the intermediate visualizations')
445
+ parser.add_argument('--save_base', default='./results/npy',
446
+ help='path to save the prediction results in npy files')
447
+ parser.add_argument('--infer_times', default=3, type=int)
448
+ return parser
449
+
450
+
451
+ if __name__ == '__main__':
452
+ parser = argparse.ArgumentParser('HEAT inference', parents=[get_args_parser()])
453
+ args = parser.parse_args()
454
+ main(args.dataset, args.checkpoint_path, args.image_size, args.viz_base, args.save_base,
455
+ infer_times=args.infer_times)
metrics/get_metric.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pickle
4
+ import cv2
5
+ from metrics.new_utils import *
6
+
7
+
8
+ class Metric():
9
+ def calc(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7):
10
+ ### compute corners precision/recall
11
+ gts = gt_data['corners']
12
+ dets = conv_data['corners']
13
+
14
+ per_sample_corner_tp = 0.0
15
+ per_sample_corner_fp = 0.0
16
+ per_sample_corner_length = gts.shape[0]
17
+ found = [False] * gts.shape[0]
18
+ c_det_annot = {}
19
+
20
+
21
+ # for each corner detection
22
+ for i, det in enumerate(dets):
23
+ # get closest gt
24
+ near_gt = [0, 999999.0, (0.0, 0.0)]
25
+ for k, gt in enumerate(gts):
26
+ dist = np.linalg.norm(gt - det)
27
+ if dist < near_gt[1]:
28
+ near_gt = [k, dist, gt]
29
+ if near_gt[1] <= thresh and not found[near_gt[0]]:
30
+ per_sample_corner_tp += 1.0
31
+ found[near_gt[0]] = True
32
+ c_det_annot[i] = near_gt[0]
33
+ else:
34
+ per_sample_corner_fp += 1.0
35
+
36
+ per_corner_score = {
37
+ 'recall': per_sample_corner_tp / gts.shape[0],
38
+ 'precision': per_sample_corner_tp / (per_sample_corner_tp + per_sample_corner_fp + 1e-8)
39
+ }
40
+
41
+ ### compute edges precision/recall
42
+ per_sample_edge_tp = 0.0
43
+ per_sample_edge_fp = 0.0
44
+ edge_corner_annots = gt_data['edges']
45
+ per_sample_edge_length = edge_corner_annots.shape[0]
46
+
47
+ false_edge_ids = []
48
+ match_gt_ids = set()
49
+
50
+ for l, e_det in enumerate(conv_data['edges']):
51
+ c1, c2 = e_det
52
+
53
+ # check if corners are mapped
54
+ if (c1 not in c_det_annot.keys()) or (c2 not in c_det_annot.keys()):
55
+ per_sample_edge_fp += 1.0
56
+ false_edge_ids.append(l)
57
+ continue
58
+ # check hit
59
+ c1_prime = c_det_annot[c1]
60
+ c2_prime = c_det_annot[c2]
61
+ is_hit = False
62
+
63
+ for k, e_annot in enumerate(edge_corner_annots):
64
+ c3, c4 = e_annot
65
+ if ((c1_prime == c3) and (c2_prime == c4)) or ((c1_prime == c4) and (c2_prime == c3)):
66
+ is_hit = True
67
+ match_gt_ids.add(k)
68
+ break
69
+
70
+ # hit
71
+ if is_hit:
72
+ per_sample_edge_tp += 1.0
73
+ else:
74
+ per_sample_edge_fp += 1.0
75
+ false_edge_ids.append(l)
76
+
77
+ per_edge_score = {
78
+ 'recall': per_sample_edge_tp / edge_corner_annots.shape[0],
79
+ 'precision': per_sample_edge_tp / (per_sample_edge_tp + per_sample_edge_fp + 1e-8)
80
+ }
81
+
82
+ # computer regions precision/recall
83
+ conv_mask = render(corners=conv_data['corners'], edges=conv_data['edges'], render_pad=0, edge_linewidth=1)[0]
84
+ conv_mask = 1 - conv_mask
85
+ conv_mask = conv_mask.astype(np.uint8)
86
+ labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
87
+
88
+ #cv2.imwrite('mask-pred.png', region_mask.astype(np.uint8) * 20)
89
+
90
+ background_label = region_mask[0, 0]
91
+ all_conv_masks = []
92
+ for region_i in range(1, labels):
93
+ if region_i == background_label:
94
+ continue
95
+ the_region = region_mask == region_i
96
+ if the_region.sum() < 20:
97
+ continue
98
+ all_conv_masks.append(the_region)
99
+
100
+ gt_mask = render(corners=gt_data['corners'], edges=gt_data['edges'], render_pad=0, edge_linewidth=1)[0]
101
+ gt_mask = 1 - gt_mask
102
+ gt_mask = gt_mask.astype(np.uint8)
103
+ labels, region_mask = cv2.connectedComponents(gt_mask, connectivity=4)
104
+
105
+ #cv2.imwrite('mask-gt.png', region_mask.astype(np.uint8) * 20)
106
+
107
+ background_label = region_mask[0, 0]
108
+ all_gt_masks = []
109
+ for region_i in range(1, labels):
110
+ if region_i == background_label:
111
+ continue
112
+ the_region = region_mask == region_i
113
+ if the_region.sum() < 20:
114
+ continue
115
+ all_gt_masks.append(the_region)
116
+
117
+ per_sample_region_tp = 0.0
118
+ per_sample_region_fp = 0.0
119
+ per_sample_region_length = len(all_gt_masks)
120
+ found = [False] * len(all_gt_masks)
121
+ for i, r_det in enumerate(all_conv_masks):
122
+ # gt closest gt
123
+ near_gt = [0, 0, None]
124
+ for k, r_gt in enumerate(all_gt_masks):
125
+ iou = np.logical_and(r_gt, r_det).sum() / float(np.logical_or(r_gt, r_det).sum())
126
+ if iou > near_gt[1]:
127
+ near_gt = [k, iou, r_gt]
128
+ if near_gt[1] >= iou_thresh and not found[near_gt[0]]:
129
+ per_sample_region_tp += 1.0
130
+ found[near_gt[0]] = True
131
+ else:
132
+ per_sample_region_fp += 1.0
133
+
134
+ per_region_score = {
135
+ 'recall': per_sample_region_tp / len(all_gt_masks),
136
+ 'precision': per_sample_region_tp / (per_sample_region_tp + per_sample_region_fp + 1e-8)
137
+ }
138
+
139
+ return {
140
+ 'corner_tp': per_sample_corner_tp,
141
+ 'corner_fp': per_sample_corner_fp,
142
+ 'corner_length': per_sample_corner_length,
143
+ 'edge_tp': per_sample_edge_tp,
144
+ 'edge_fp': per_sample_edge_fp,
145
+ 'edge_length': per_sample_edge_length,
146
+ 'region_tp': per_sample_region_tp,
147
+ 'region_fp': per_sample_region_fp,
148
+ 'region_length': per_sample_region_length,
149
+ 'corner': per_corner_score,
150
+ 'edge': per_edge_score,
151
+ 'region': per_region_score
152
+ }
153
+
154
+
155
+ def compute_metrics(gt_data, pred_data):
156
+ metric = Metric()
157
+ score = metric.calc(gt_data, pred_data)
158
+ return score
159
+
160
+
161
+ def get_recall_and_precision(tp, fp, length):
162
+ recall = tp / (length + 1e-8)
163
+ precision = tp / (tp + fp + 1e-8)
164
+ return recall, precision
165
+
166
+
167
+ if __name__ == '__main__':
168
+ base_path = './'
169
+ gt_datapath = '../data/cities_dataset/annot'
170
+ metric = Metric()
171
+ corner_tp = 0.0
172
+ corner_fp = 0.0
173
+ corner_length = 0.0
174
+ edge_tp = 0.0
175
+ edge_fp = 0.0
176
+ edge_length = 0.0
177
+ region_tp = 0.0
178
+ region_fp = 0.0
179
+ region_length = 0.0
180
+ for file_name in os.listdir(base_path):
181
+ if len(file_name) < 10:
182
+ continue
183
+ f = open(os.path.join(base_path, file_name), 'rb')
184
+ gt_data = np.load(os.path.join(gt_datapath, file_name + '.npy'), allow_pickle=True).tolist()
185
+ candidate = pickle.load(f)
186
+ conv_corners = candidate.graph.getCornersArray()
187
+ conv_edges = candidate.graph.getEdgesArray()
188
+ conv_data = {'corners': conv_corners, 'edges': conv_edges}
189
+ score = metric.calc(gt_data, conv_data)
190
+ corner_tp += score['corner_tp']
191
+ corner_fp += score['corner_fp']
192
+ corner_length += score['corner_length']
193
+ edge_tp += score['edge_tp']
194
+ edge_fp += score['edge_fp']
195
+ edge_length += score['edge_length']
196
+ region_tp += score['region_tp']
197
+ region_fp += score['region_fp']
198
+ region_length += score['region_length']
199
+
200
+ f = open(os.path.join(base_path, 'score.txt'), 'w')
201
+ # corner
202
+ recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
203
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
204
+ print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
205
+ f.write('corners - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
206
+
207
+ # edge
208
+ recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
209
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
210
+ print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
211
+ f.write('edges - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
212
+
213
+ # region
214
+ recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
215
+ f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
216
+ print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
217
+ f.write('regions - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
218
+
219
+ f.close()
metrics/new_utils.py ADDED
@@ -0,0 +1,2100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import cv2
4
+ import threading
5
+ import os
6
+ import skimage
7
+ import random
8
+ import time
9
+
10
+ TWO_CORNER_MINIMUM_DISTANCE = 5
11
+ SAFE_NUM = 3
12
+ score_weights = (1., 2., 100.)
13
+
14
+
15
+ #########################################################################################
16
+ ################################# General Functions #####################################
17
+ #########################################################################################
18
+ def swap_two_corner_place(corners, edges, id1, id2):
19
+ for edge_i in range(edges.shape[0]):
20
+ if edges[edge_i, 0] == id1:
21
+ edges[edge_i, 0] = id2
22
+ elif edges[edge_i, 0] == id2:
23
+ edges[edge_i, 0] = id1
24
+ if edges[edge_i, 1] == id1:
25
+ edges[edge_i, 1] = id2
26
+ elif edges[edge_i, 1] == id2:
27
+ edges[edge_i, 1] = id1
28
+ temp = corners[id1].copy()
29
+ corners[id1] = corners[id2]
30
+ corners[id2] = temp
31
+ return corners, edges
32
+
33
+
34
+ def get_neighbor_corner_id(corner_id, edges):
35
+ where = np.where(edges == corner_id)
36
+ return edges[where[0], 1 - where[1]]
37
+
38
+
39
+ def swap_two_edge_place(edges, id1, id2):
40
+ temp = edges[id1].copy()
41
+ edges[id1] = edges[id2]
42
+ edges[id2] = temp
43
+ return edges
44
+
45
+
46
+ def degree_of_three_corners(cornerA, cornerB, cornerM):
47
+ # cornerM is middle corner
48
+ AM_length = l2_distance(cornerA, cornerM)
49
+ BM_length = l2_distance(cornerB, cornerM)
50
+ dot = np.dot((cornerA[0] - cornerM[0], cornerA[1] - cornerM[1]),
51
+ (cornerB[0] - cornerM[0], cornerB[1] - cornerM[1]))
52
+ cos = dot / (AM_length + 1e-8) / (BM_length + 1e-8)
53
+ cos = min(1, max(-1, cos))
54
+ degree = np.arccos(cos)
55
+ return degree / np.pi * 180
56
+
57
+
58
+ def sort_graph(corners, edges):
59
+ corners = corners.copy()
60
+ edges = edges.copy()
61
+ for corner_i in range(corners.shape[0]):
62
+ min_id = -1
63
+ min_pos = corners[corner_i]
64
+ for corner_j in range(corner_i + 1, corners.shape[0]):
65
+ if (corners[corner_j, 0] < min_pos[0]) or \
66
+ (corners[corner_j, 0] == min_pos[0] and corners[corner_j, 1] < min_pos[1]):
67
+ min_pos = corners[corner_j]
68
+ min_id = corner_j
69
+ if min_id != -1:
70
+ corners, edges = swap_two_corner_place(corners, edges, corner_i, min_id)
71
+
72
+ for edge_i in range(edges.shape[0]):
73
+ if edges[edge_i, 0] > edges[edge_i, 1]:
74
+ temp = edges[edge_i, 0]
75
+ edges[edge_i, 0] = edges[edge_i, 1]
76
+ edges[edge_i, 1] = temp
77
+
78
+ for edge_i in range(edges.shape[0]):
79
+ min_id = -1
80
+ min_pos = edges[edge_i]
81
+ for edge_j in range(edge_i + 1, edges.shape[0]):
82
+ if (edges[edge_j, 0] < min_pos[0]) or \
83
+ (edges[edge_j, 0] == min_pos[0] and edges[edge_j, 1] < min_pos[1]):
84
+ min_pos = edges[edge_j]
85
+ min_id = edge_j
86
+ if min_id != -1:
87
+ edges = swap_two_edge_place(edges, edge_i, min_id)
88
+
89
+ return corners, edges
90
+
91
+
92
+ def IOU(maskA, maskB):
93
+ return np.logical_and(maskA, maskB).sum() / np.logical_or(maskA, maskB).sum()
94
+
95
+
96
+ def render(corners, edges, render_pad=0, edge_linewidth=2, corner_size=3, scale=1.):
97
+ size = int(256 * scale)
98
+ mask = np.ones((2, size, size)) * render_pad
99
+
100
+ corners = np.round(corners.copy() * scale).astype(np.int)
101
+ for edge_i in range(edges.shape[0]):
102
+ a = edges[edge_i, 0]
103
+ b = edges[edge_i, 1]
104
+ mask[0] = cv2.line(mask[0], (int(corners[a, 1]), int(corners[a, 0])),
105
+ (int(corners[b, 1]), int(corners[b, 0])), 1.0, thickness=edge_linewidth)
106
+ for corner_i in range(corners.shape[0]):
107
+ mask[1] = cv2.circle(mask[1], (int(corners[corner_i, 1]), int(corners[corner_i, 0])), corner_size, 1.0, -1)
108
+
109
+ return mask
110
+
111
+
112
+ def patch_samples(edge_num, batch_size):
113
+ num = edge_num // batch_size
114
+ patchs = []
115
+ for i in range(num):
116
+ patchs.append([i * batch_size + j for j in range(batch_size)])
117
+
118
+ if edge_num % batch_size != 0:
119
+ patchs.append([j for j in range(batch_size * num, edge_num)])
120
+
121
+ return patchs
122
+
123
+
124
+ def l2_distance(x1, x2):
125
+ return np.sqrt((x1[0] - x2[0]) ** 2 + (x1[1] - x2[1]) ** 2)
126
+
127
+
128
+ def triangle_region(A, B, C):
129
+ l1 = np.linalg.norm(np.array(A) - np.array(B))
130
+ l2 = np.linalg.norm(np.array(A) - np.array(C))
131
+ l3 = np.linalg.norm(np.array(B) - np.array(C))
132
+ p = (l1 + l2 + l3) / 2
133
+ area = np.sqrt(np.abs(p * (p - l1) * (p - l2) * (p - l3)))
134
+ return area
135
+
136
+
137
+ def remove_intersection_and_duplicate(corners, edges, name):
138
+ over_all_flag = False
139
+ ori_corners = corners.copy()
140
+ ori_edges = edges.copy()
141
+ while True:
142
+ flag = False
143
+ for edge_i in range(edges.shape[0]):
144
+ for edge_j in range(edge_i + 1, edges.shape[0]):
145
+ corner11 = corners[edges[edge_i, 0]]
146
+ corner12 = corners[edges[edge_i, 1]]
147
+ corner21 = corners[edges[edge_j, 0]]
148
+ corner22 = corners[edges[edge_j, 1]]
149
+
150
+ y1 = corner11[0]
151
+ x1 = corner11[1]
152
+ y2 = corner12[0]
153
+ x2 = corner12[1]
154
+ a1 = y1 - y2
155
+ b1 = x2 - x1
156
+ c1 = x1 * y2 - x2 * y1
157
+ flag1 = (a1 * corner21[1] + b1 * corner21[0] + c1) * (a1 * corner22[1] + b1 * corner22[0] + c1)
158
+
159
+ y1 = corner21[0]
160
+ x1 = corner21[1]
161
+ y2 = corner22[0]
162
+ x2 = corner22[1]
163
+ a2 = y1 - y2
164
+ b2 = x2 - x1
165
+ c2 = x1 * y2 - x2 * y1
166
+ flag2 = (a2 * corner11[1] + b2 * corner11[0] + c2) * (a2 * corner12[1] + b2 * corner12[0] + c2)
167
+
168
+ if flag1 < -1e-5 and flag2 < -1e-5:
169
+ # intersection!
170
+ over_all_flag = True
171
+ flag = True
172
+
173
+ new_x = (c2 * b1 - c1 * b2) / (a1 * b2 - a2 * b1)
174
+ new_y = (a2 * c1 - a1 * c2) / (a1 * b2 - a2 * b1)
175
+
176
+ temp_d = 3
177
+ temp_id = -1
178
+ if l2_distance((new_y, new_x), corner11) < temp_d:
179
+ temp_id = edges[edge_i, 0]
180
+ temp_d = l2_distance((new_y, new_x), corner11)
181
+ if l2_distance((new_y, new_x), corner12) < temp_d:
182
+ temp_id = edges[edge_i, 1]
183
+ temp_d = l2_distance((new_y, new_x), corner12)
184
+ if l2_distance((new_y, new_x), corner21) < temp_d:
185
+ temp_id = edges[edge_j, 0]
186
+ temp_d = l2_distance((new_y, new_x), corner21)
187
+ if l2_distance((new_y, new_x), corner22) < temp_d:
188
+ temp_id = edges[edge_j, 1]
189
+ temp_d = l2_distance((new_y, new_x), corner22)
190
+ if temp_id != -1:
191
+ if edges[edge_i, 0] != temp_id and edges[edge_i, 1] != temp_id:
192
+ tt = edges[edge_i, 0]
193
+ edges[edge_i, 0] = temp_id
194
+ edges = np.append(edges, np.array([(temp_id, tt)]), 0)
195
+ if edges[edge_j, 0] != temp_id and edges[edge_j, 1] != temp_id:
196
+ tt = edges[edge_j, 0]
197
+ edges[edge_j, 0] = temp_id
198
+ edges = np.append(edges, np.array([(temp_id, tt)]), 0)
199
+ else:
200
+ corners = np.append(corners, np.array([(new_y, new_x)]), 0)
201
+ edge_id1 = edges[edge_i, 1]
202
+ edge_id2 = edges[edge_j, 1]
203
+ edges[edge_i, 1] = corners.shape[0] - 1
204
+ edges[edge_j, 1] = corners.shape[0] - 1
205
+ edges = np.append(edges, np.array([(edge_id1, corners.shape[0] - 1)]), 0)
206
+ edges = np.append(edges, np.array([(edge_id2, corners.shape[0] - 1)]), 0)
207
+ break
208
+ if flag:
209
+ break
210
+ if flag:
211
+ continue
212
+ break
213
+
214
+ # remove duplicate and zero degree
215
+ graph = Graph(np.round(corners), edges)
216
+ for corner_i in reversed(range(len(graph.getCorners()))):
217
+ corner_ele1 = graph.getCorners()[corner_i]
218
+ for corner_j in reversed(range(corner_i)):
219
+ corner_ele2 = graph.getCorners()[corner_j]
220
+ if l2_distance(corner_ele1.x, corner_ele2.x) < 3:
221
+ connected_edge = graph.getEdgeConnected(corner_ele1)
222
+ for edge_ele in connected_edge:
223
+ if edge_ele.x[0] == corner_ele1:
224
+ another = edge_ele.x[1]
225
+ else:
226
+ another = edge_ele.x[0]
227
+ if another == corner_ele2:
228
+ graph.remove(edge_ele)
229
+ edge_ele.x = (another, corner_ele2)
230
+ graph.remove(corner_ele1)
231
+ for corner_ele in graph.getCorners():
232
+ if graph.getCornerDegree(corner_ele) == 0:
233
+ graph.remove(corner_ele)
234
+
235
+ corners = graph.getCornersArray()
236
+ edges = graph.getEdgesArray()
237
+ # if over_all_flag:
238
+ # plt.subplot(121)
239
+ # ori = render(ori_corners, ori_edges, edge_linewidth=1, corner_size=1)
240
+ # temp = np.concatenate((ori.transpose((1,2,0)), np.zeros((ori.shape[1],ori.shape[2],1))),2)
241
+ # plt.imshow(temp)
242
+ # plt.subplot(122)
243
+ # new_ = render(corners, edges, edge_linewidth=1, corner_size=1)
244
+ # temp = np.concatenate((new_.transpose((1,2,0)), np.zeros((new_.shape[1],new_.shape[2],1))),2)
245
+ # plt.imshow(temp)
246
+ # plt.show()
247
+
248
+ return corners, edges
249
+
250
+
251
+ def get_two_edge_intersection_location(corner11, corner12, corner21, corner22):
252
+ y1 = corner11[0]
253
+ x1 = corner11[1]
254
+ y2 = corner12[0]
255
+ x2 = corner12[1]
256
+ a1 = y1 - y2
257
+ b1 = x2 - x1
258
+ c1 = x1 * y2 - x2 * y1
259
+
260
+ y1 = corner21[0]
261
+ x1 = corner21[1]
262
+ y2 = corner22[0]
263
+ x2 = corner22[1]
264
+ a2 = y1 - y2
265
+ b2 = x2 - x1
266
+ c2 = x1 * y2 - x2 * y1
267
+
268
+ l = a1 * b2 - a2 * b1
269
+ if l == 0:
270
+ l = 1e-5
271
+
272
+ new_x = (c2 * b1 - c1 * b2) / l
273
+ new_y = (a2 * c1 - a1 * c2) / l
274
+
275
+ return round(new_y), round(new_x)
276
+
277
+
278
+ def get_distance_of_corner_and_edge(corner1, corner2, corner):
279
+ x = corner[0]
280
+ y = corner[1]
281
+ x1 = corner1[0]
282
+ y1 = corner1[1]
283
+ x2 = corner2[0]
284
+ y2 = corner2[1]
285
+
286
+ cross = (x2 - x1) * (x - x1) + (y2 - y1) * (y - y1)
287
+ if cross <= 0:
288
+ # dist to corner1
289
+ return np.linalg.norm((x - x1, y - y1))
290
+
291
+ d2 = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
292
+ if cross >= d2:
293
+ # dist to corner2
294
+ return np.linalg.norm((x - x2, y - y2))
295
+
296
+ r = cross / d2
297
+ px = x1 + (x2 - x1) * r
298
+ py = y1 + (y2 - y1) * r
299
+ return np.linalg.norm((x - px, y - py))
300
+
301
+
302
+ #########################################################################################
303
+ ################################# Dataset Functions #####################################
304
+ #########################################################################################
305
+ def EuclideanDistance(A, B):
306
+ BT = B.transpose()
307
+ vecProd = np.dot(A, BT)
308
+
309
+ SqA = A ** 2
310
+ sumSqA = np.matrix(np.sum(SqA, axis=1))
311
+ sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
312
+
313
+ SqB = B ** 2
314
+ sumSqB = np.sum(SqB, axis=1)
315
+ sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
316
+ SqED = sumSqBEx + sumSqAEx - 2 * vecProd
317
+ SqED[SqED < 0] = 0.0
318
+ ED = np.sqrt(SqED)
319
+ return ED
320
+
321
+
322
+ def samedirection(conv_corner_id, gt_corner_id, conv_corners, gt_corners, conv_edges, gt_edges):
323
+ # degree
324
+ if np.where(conv_edges == conv_corner_id)[0].shape[0] != np.where(gt_edges == gt_corner_id)[0].shape[0]:
325
+ return False
326
+
327
+ # direction
328
+ place = np.where(conv_edges == conv_corner_id)
329
+ neighbor_id = conv_edges[place[0], 1 - place[1]]
330
+
331
+ distance = conv_corners[conv_corner_id] - conv_corners[neighbor_id]
332
+ direction = np.arctan2(distance[:, 0], distance[:, 1]) * 180 / np.pi / 15
333
+ direction = (direction + 24) % 24
334
+
335
+ conv_dir = np.sort(direction)
336
+
337
+ place = np.where(gt_edges == gt_corner_id)
338
+ neighbor_id = gt_edges[place[0], 1 - place[1]]
339
+
340
+ distance = gt_corners[gt_corner_id] - gt_corners[neighbor_id]
341
+ direction = np.arctan2(distance[:, 0], distance[:, 1]) * 180 / np.pi / 15
342
+ direction = (direction + 24) % 24
343
+
344
+ gt_dir = np.sort(direction)
345
+
346
+ conv_dir = list(conv_dir)
347
+ gt_dir = list(gt_dir)
348
+ for angle in gt_dir:
349
+ temp = sorted(conv_dir, key=lambda x: min(np.abs(x - angle), 24 - np.abs(x - angle)))
350
+ if min(np.abs(temp[0] - angle), 24 - np.abs(temp[0] - angle)) <= 1.3:
351
+ conv_dir.remove(temp[0])
352
+ else:
353
+ return False
354
+ return True
355
+
356
+
357
+ def simplify_gt(gt_match_location, gt_corner, gt_edge):
358
+ graph = Graph(np.round(gt_corner), gt_edge)
359
+ for idx, corner in enumerate(graph.getCorners()):
360
+ # use score to store the matching info
361
+ corner.store_score(gt_match_location[idx])
362
+
363
+ for idx, corner in enumerate(graph.getCorners()):
364
+ if corner.get_score() is None:
365
+ connected_edges = graph.getEdgeConnected(corner)
366
+ neighbor_corners = []
367
+ for edge in connected_edges:
368
+ if edge.x[0] != corner:
369
+ neighbor_corners.append(edge.x[0])
370
+ continue
371
+ if edge.x[1] != corner:
372
+ neighbor_corners.append(edge.x[1])
373
+ continue
374
+ raise BaseException()
375
+ neighbor_corners = sorted(neighbor_corners, key=lambda ele: l2_distance(ele.x, corner.x))
376
+ for neighbor_ele in neighbor_corners:
377
+ if l2_distance(neighbor_ele.x, corner.x) > 8:
378
+ break
379
+ if neighbor_ele.get_score() is None:
380
+ continue
381
+ # find the suitable neighbor that replace corner
382
+ for ele in neighbor_corners:
383
+ if ele == neighbor_ele:
384
+ continue
385
+ graph.add_edge(ele, neighbor_ele)
386
+ neighbor_ele.x = (0.7 * neighbor_ele.x[0] + 0.3 * corner.x[0],
387
+ 0.7 * neighbor_ele.x[1] + 0.3 * corner.x[1])
388
+ graph.remove(corner)
389
+ break
390
+ return graph.getCornersArray(), graph.getEdgesArray()
391
+
392
+
393
+ def get_wrong_corners(corners, gt_corners, edges, gt_edges):
394
+ corners = corners.copy()
395
+ gt_corners = gt_corners.copy()
396
+ edges = edges.copy()
397
+ gt_edges = gt_edges.copy()
398
+ dist_matrix = EuclideanDistance(gt_corners, corners)
399
+ assigned_id = set()
400
+ gt_match_same_degree = []
401
+ gt_match_location = []
402
+ for gt_i in range(gt_corners.shape[0]):
403
+ sort_id = np.argsort(dist_matrix[gt_i]).__array__()[0]
404
+ flag = True
405
+ for id_ in sort_id:
406
+ if dist_matrix[gt_i, id_] > 7:
407
+ break
408
+ temete = samedirection(id_, gt_i, corners, gt_corners, edges, gt_edges)
409
+ if temete == False:
410
+ break
411
+ elif id_ not in assigned_id:
412
+ assigned_id.add(id_)
413
+ gt_match_same_degree.append(id_)
414
+ flag = False
415
+ break
416
+ if flag:
417
+ gt_match_same_degree.append(None)
418
+
419
+ matched = []
420
+ gt_match_location = [None for _ in range(gt_corners.shape[0])]
421
+ for gt_i in sorted(list(range(gt_corners.shape[0])), key=lambda i: np.min(dist_matrix[i])):
422
+ sort_id = np.argsort(dist_matrix[gt_i]).__array__()[0]
423
+ if dist_matrix[gt_i, sort_id[0]] > 7:
424
+ gt_match_location[gt_i] = None
425
+ else:
426
+ for c_i in sort_id:
427
+ if c_i in matched:
428
+ continue
429
+ if dist_matrix[gt_i, c_i] > 7:
430
+ gt_match_location[gt_i] = None
431
+ break
432
+ else:
433
+ gt_match_location[gt_i] = c_i
434
+ matched.append(c_i)
435
+ break
436
+
437
+ return set(range(corners.shape[0])) - assigned_id, gt_match_same_degree, gt_match_location
438
+
439
+
440
+ def get_wrong_edges(corners, gt_corners, edges, gt_edges, gt_match):
441
+ edges = edges.copy()
442
+ gt_edges = gt_edges.copy()
443
+
444
+ all_possible_good_edges = []
445
+ for edge_i in range(gt_edges.shape[0]):
446
+ if gt_match[gt_edges[edge_i, 0]] is None or gt_match[gt_edges[edge_i, 1]] is None:
447
+ continue
448
+ all_possible_good_edges.append((gt_match[gt_edges[edge_i, 0]], gt_match[gt_edges[edge_i, 1]]))
449
+ false_edge_id = []
450
+ for edge_i in range(edges.shape[0]):
451
+ id1 = edges[edge_i][0]
452
+ id2 = edges[edge_i][1]
453
+ if (id1, id2) not in all_possible_good_edges and (id2, id1) not in all_possible_good_edges:
454
+ false_edge_id.append(edge_i)
455
+ continue
456
+
457
+ return false_edge_id
458
+
459
+
460
+ def get_corner_bin_map(corners, corner_list_for_each_bin, bin_size=10):
461
+ bin_map = np.zeros((bin_size, 256, 256))
462
+ for bin_i in range(bin_size):
463
+ bin_map[bin_i] = render(corners[corner_list_for_each_bin[bin_i]], np.array([]), render_pad=0)[1]
464
+ return bin_map
465
+
466
+
467
+ #########################################################################################
468
+ ################################ Searching Functions ####################################
469
+ #########################################################################################
470
+ def visualization(candidate, show=True):
471
+ corners = candidate.graph.getCornersArray()
472
+ edges = candidate.graph.getEdgesArray()
473
+ mask = render(corners, edges)
474
+ mask = np.transpose(np.concatenate((mask, np.zeros((1, 256, 256))), 0), (1, 2, 0))
475
+ plt.imshow(mask)
476
+ if show:
477
+ plt.show()
478
+
479
+
480
+ def check_intersection(edge1, edge2):
481
+ corner11 = edge1.x[0].x
482
+ corner12 = edge1.x[1].x
483
+ corner21 = edge2.x[0].x
484
+ corner22 = edge2.x[1].x
485
+
486
+ y1 = corner11[0]
487
+ x1 = corner11[1]
488
+ y2 = corner12[0]
489
+ x2 = corner12[1]
490
+ a = y1 - y2
491
+ b = x2 - x1
492
+ c = x1 * y2 - x2 * y1
493
+ flag1 = (a * corner21[1] + b * corner21[0] + c) * (a * corner22[1] + b * corner22[0] + c)
494
+
495
+ y1 = corner21[0]
496
+ x1 = corner21[1]
497
+ y2 = corner22[0]
498
+ x2 = corner22[1]
499
+ a = y1 - y2
500
+ b = x2 - x1
501
+ c = x1 * y2 - x2 * y1
502
+ flag2 = (a * corner11[1] + b * corner11[0] + c) * (a * corner12[1] + b * corner12[0] + c)
503
+
504
+ if flag1 < -1e-6 and flag2 < -1e-6:
505
+ return True
506
+
507
+ return False
508
+
509
+
510
+ def adding_a_corner_by_triangle_operation(candidate):
511
+ new_candidates = []
512
+ name = candidate.name
513
+ gt_mask = region_cache.get_region(name)
514
+ gt_mask = gt_mask > 0.4
515
+ gt_mask_grow = cv2.dilate(gt_mask.astype(np.float64), np.ones((3, 3), np.uint8), iterations=6) > 0
516
+
517
+ # get the current candidate region mask
518
+ conv_mask = render(corners=candidate.graph.getCornersArray(), edges=candidate.graph.getEdgesArray(),
519
+ render_pad=0, edge_linewidth=1)[0]
520
+ conv_mask = 1 - conv_mask
521
+ conv_mask = conv_mask.astype(np.uint8)
522
+ labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
523
+
524
+ background_label = region_mask[0, 0]
525
+ all_masks = []
526
+ for region_i in range(1, labels):
527
+ if region_i == background_label:
528
+ continue
529
+ the_region = region_mask == region_i
530
+ if the_region.sum() < 20:
531
+ continue
532
+ all_masks.append(the_region)
533
+
534
+ candidate_mask = (np.sum(all_masks, 0) + (1 - conv_mask)) > 0
535
+
536
+ final_mask = np.logical_xor(gt_mask_grow, np.logical_and(candidate_mask, gt_mask_grow))
537
+
538
+ for corner_i in range(random.randint(0, 16), 256, 16):
539
+ for corner_j in range(random.randint(0, 16), 256, 16):
540
+ if candidate.addable((corner_i, corner_j)):
541
+ if final_mask[corner_i, corner_j] == True: # inside the region
542
+ new_corner = Element((corner_i, corner_j))
543
+ new_candidate = candidate.generate_new_candidate_add_a_corner(new_corner)
544
+ new_graph = new_candidate.graph
545
+ corners = new_graph.getCorners()
546
+
547
+ # find two suitable existed corners to make into a triangle (no intersection and no colinear)
548
+ for id_A in range(len(corners)):
549
+ ele_A = corners[id_A]
550
+ if ele_A == new_corner:
551
+ continue
552
+ for id_B in range(id_A + 1, len(corners)):
553
+ ele_B = corners[id_B]
554
+ if ele_B == new_corner:
555
+ continue
556
+ if new_graph.has_edge(new_corner, ele_A) is not None:
557
+ raise BaseException('should not have edge in this case')
558
+ if new_graph.has_edge(new_corner, ele_B) is not None:
559
+ raise BaseException('should not have edge in this case')
560
+ temp_edge1 = Element((new_corner, ele_A))
561
+ temp_edge2 = Element((new_corner, ele_B))
562
+
563
+ # check if addable
564
+ if new_candidate.addable(temp_edge1) is False:
565
+ continue
566
+ if new_candidate.addable(temp_edge2) is False:
567
+ continue
568
+
569
+ # avoid intersection
570
+ if new_graph.checkIntersectionEdge(temp_edge1):
571
+ continue
572
+ if new_graph.checkIntersectionEdge(temp_edge2):
573
+ continue
574
+
575
+ # avoid too small triangle
576
+ if triangle_region(new_corner.x, ele_A.x, ele_B.x) < 20:
577
+ continue
578
+
579
+ ### avoid colinear edge (only when fold case)
580
+ # for edge1
581
+ neighbor_edges = new_graph.getEdgeConnected(temp_edge1)
582
+ flag_ = True
583
+ for neighbor in neighbor_edges:
584
+ if new_corner in neighbor.x:
585
+ raise BaseException('new corner should not in any edge')
586
+ elif ele_A in neighbor.x:
587
+ shared_corner = ele_A
588
+ else:
589
+ raise BaseException('error.')
590
+ two_neighbor = {neighbor.x[0], neighbor.x[1], ele_A, new_corner}
591
+ two_neighbor.remove(shared_corner)
592
+ assert len(two_neighbor) == 2
593
+ two_neighbor = tuple(two_neighbor)
594
+
595
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
596
+ line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
597
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
598
+ cos = min(1, max(-1, cos))
599
+ if np.arccos(cos) < np.pi / 9: # 20 degree
600
+ flag_ = False
601
+ break
602
+ if flag_ is False:
603
+ continue
604
+ # for edge2
605
+ neighbor_edges = new_graph.getEdgeConnected(temp_edge2)
606
+ flag_ = True
607
+ for neighbor in neighbor_edges:
608
+ if new_corner in neighbor.x:
609
+ raise BaseException('new corner should not in any edge')
610
+ elif ele_B in neighbor.x:
611
+ shared_corner = ele_B
612
+ else:
613
+ raise BaseException('error.')
614
+ two_neighbor = {neighbor.x[0], neighbor.x[1], ele_B, new_corner}
615
+ two_neighbor.remove(shared_corner)
616
+ assert len(two_neighbor) == 2
617
+ two_neighbor = tuple(two_neighbor)
618
+
619
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
620
+ line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
621
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
622
+ cos = min(1, max(-1, cos))
623
+ if np.arccos(cos) < np.pi / 9: # 20 degree
624
+ flag_ = False
625
+ break
626
+ if flag_ is False:
627
+ continue
628
+
629
+ # make new candidate
630
+ try:
631
+ new_ = new_candidate.generate_new_candidate_add_an_edge(new_corner, ele_A)
632
+ new_ = new_.generate_new_candidate_add_an_edge(new_corner, ele_B)
633
+ new_candidates.append(new_)
634
+ except:
635
+ continue
636
+ # plt.subplot(151)
637
+ # visualization(candidate, show=False)
638
+ # plt.subplot(152)
639
+ # plt.imshow(final_mask)
640
+ # plt.subplot(153)
641
+ # plt.imshow(candidate_mask)
642
+ # plt.subplot(154)
643
+ # plt.imshow(gt_mask_grow)
644
+ # plt.subplot(155)
645
+ # visualization(new_, show=False)
646
+ # plt.show()
647
+
648
+ return new_candidates
649
+
650
+
651
+ def adding_an_edge_from_new_corner_operation(candidate):
652
+ new_candidates = []
653
+ name = candidate.name
654
+ gt_mask = region_cache.get_region(name)
655
+ gt_mask = gt_mask > 0.4
656
+ gt_mask_grow = cv2.dilate(gt_mask.astype(np.float64), np.ones((3, 3), np.uint8), iterations=6) > 0
657
+
658
+ # get the current candidate region mask
659
+ conv_mask = render(corners=candidate.graph.getCornersArray(), edges=candidate.graph.getEdgesArray(),
660
+ render_pad=0, edge_linewidth=1)[0]
661
+ conv_mask = 1 - conv_mask
662
+ conv_mask = conv_mask.astype(np.uint8)
663
+ labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
664
+ background_label = region_mask[0, 0]
665
+ all_masks = []
666
+ for region_i in range(1, labels):
667
+ if region_i == background_label:
668
+ continue
669
+ the_region = region_mask == region_i
670
+ if the_region.sum() < 20:
671
+ continue
672
+ all_masks.append(the_region)
673
+ candidate_mask = (np.sum(all_masks, 0) + (1 - conv_mask)) > 0
674
+
675
+ final_mask = np.logical_xor(gt_mask_grow, np.logical_and(candidate_mask, gt_mask_grow))
676
+ for corner_i in range(random.randint(0, 16), 256, 16):
677
+ for corner_j in range(random.randint(0, 16), 256, 16):
678
+ if candidate.addable((corner_i, corner_j)):
679
+ if final_mask[corner_i, corner_j] == True:
680
+ # inside the region
681
+ new_corner = Element((corner_i, corner_j))
682
+ new_candidate = candidate.generate_new_candidate_add_a_corner(new_corner)
683
+ new_graph = new_candidate.graph
684
+ corners = new_graph.getCorners()
685
+
686
+ # find a suitable existed corner that can make
687
+ # a new edge with new_corner (no intersection and colinear)
688
+ for corner_ele in corners:
689
+ if corner_ele == new_corner:
690
+ continue
691
+ if new_graph.has_edge(new_corner, corner_ele) is not None:
692
+ raise BaseException('should not have edge in this case')
693
+ temp_edge = Element((new_corner, corner_ele))
694
+
695
+ # check if addable
696
+ if new_candidate.addable(temp_edge) is False:
697
+ continue
698
+
699
+ # avoid intersection
700
+ if new_graph.checkIntersectionEdge(temp_edge):
701
+ continue
702
+
703
+ # avoid colinear edge
704
+ neighbor_edges = new_graph.getEdgeConnected(temp_edge)
705
+ flag_ = True
706
+ for neighbor in neighbor_edges:
707
+ if new_corner in neighbor.x:
708
+ raise BaseException('new corner should not in any edge')
709
+ elif corner_ele in neighbor.x:
710
+ shared_corner = corner_ele
711
+ else:
712
+ raise BaseException('error.')
713
+ two_neighbor = {neighbor.x[0], neighbor.x[1], corner_ele, new_corner}
714
+ two_neighbor.remove(shared_corner)
715
+ assert len(two_neighbor) == 2
716
+ two_neighbor = tuple(two_neighbor)
717
+
718
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
719
+ line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
720
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
721
+ cos = min(1, max(-1, cos))
722
+ if np.arccos(cos) < np.pi / 9: # 20 degree
723
+ flag_ = False
724
+ break
725
+ if flag_ is False:
726
+ continue
727
+
728
+ # make new candidate
729
+ try:
730
+ new_ = new_candidate.generate_new_candidate_add_an_edge(new_corner, corner_ele)
731
+ new_candidates.append(new_)
732
+ except:
733
+ continue
734
+
735
+ return new_candidates
736
+
737
+
738
+ def removing_a_corner_operation(candidate):
739
+ new_candidates = []
740
+ graph = candidate.graph
741
+ corners = graph.getCorners()
742
+ for the_corner in corners:
743
+ if candidate.removable(the_corner):
744
+ try:
745
+ new_ = candidate.generate_new_candidate_remove_a_corner(the_corner)
746
+ new_candidates.append(new_)
747
+ except:
748
+ continue
749
+
750
+ return new_candidates
751
+
752
+
753
+ def removing_a_colinear_corner_operation(candidate):
754
+ new_candidates = []
755
+ graph = candidate.graph
756
+ corners = graph.getCorners()
757
+ for the_corner in corners:
758
+ if candidate.removable(the_corner): # NO NEED TO CHECK IF COLINEAR and graph.checkColinearCorner(the_corner):
759
+ try:
760
+ new_ = candidate.generate_new_candidate_remove_a_colinear_corner(the_corner)
761
+
762
+ if new_.graph.checkIntersectionEdge():
763
+ continue
764
+ new_candidates.append(new_)
765
+ except:
766
+ continue
767
+
768
+ return new_candidates
769
+
770
+
771
+ def adding_an_edge_operation(candidate):
772
+ new_candidates = []
773
+ graph = candidate.graph
774
+ corners = graph.getCorners()
775
+ for corner_i in range(len(corners)):
776
+ cornerA = corners[corner_i]
777
+ for corner_j in range(corner_i + 1, len(corners)):
778
+ cornerB = corners[corner_j]
779
+ if graph.has_edge(cornerA, cornerB) is not None:
780
+ continue
781
+
782
+ temp_edge = Element((cornerA, cornerB))
783
+ # check if addable (not in existed_before dict)
784
+ if candidate.addable(temp_edge) is False:
785
+ continue
786
+
787
+ if graph.checkIntersectionEdge(temp_edge):
788
+ continue
789
+
790
+ # avoid adding a colinear edge
791
+ neighbor_edges = graph.getEdgeConnected(temp_edge)
792
+ flag_ = True
793
+ for neighbor in neighbor_edges:
794
+ if cornerA in neighbor.x:
795
+ shared_corner = cornerA
796
+ elif cornerB in neighbor.x:
797
+ shared_corner = cornerB
798
+ else:
799
+ raise BaseException('error.')
800
+ two_neighbor = {neighbor.x[0], neighbor.x[1], cornerA, cornerB}
801
+ two_neighbor.remove(shared_corner)
802
+ assert len(two_neighbor) == 2
803
+ two_neighbor = tuple(two_neighbor)
804
+
805
+ line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
806
+ line2 = np.array(two_neighbor[1].x) - np.array(shared_corner.x)
807
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
808
+ cos = min(1, max(-1, cos))
809
+ if np.arccos(cos) < np.pi / 18 or np.arccos(cos) > np.pi - np.pi / 18: # 10 degree
810
+ flag_ = False
811
+ break
812
+ if flag_ is False:
813
+ continue
814
+
815
+ # make new candidate
816
+ try:
817
+ new_ = candidate.generate_new_candidate_add_an_edge(cornerA, cornerB)
818
+ new_candidates.append(new_)
819
+ except:
820
+ continue
821
+
822
+ return new_candidates
823
+
824
+
825
+ def removing_an_edge_operation(candidate):
826
+ new_candidates = []
827
+ graph = candidate.graph
828
+ edges = graph.getEdges()
829
+ for edge_ele in edges:
830
+ if candidate.removable(edge_ele):
831
+ try:
832
+ new_ = candidate.generate_new_candidate_remove_an_edge(edge_ele)
833
+ new_candidates.append(new_)
834
+ except:
835
+ continue
836
+
837
+ return new_candidates
838
+
839
+
840
+ def adding_an_edge_from_gt(candidate, gt_data):
841
+ new_candidates = []
842
+ corners_array = candidate.graph.getCornersArray()
843
+ edges_array = candidate.graph.getEdgesArray()
844
+
845
+ gt_corners = gt_data['corners'].copy()
846
+ gt_edges = gt_data['edges'].copy()
847
+
848
+ _, _, map_same_location = get_wrong_corners(
849
+ corners_array, gt_corners, edges_array, gt_edges)
850
+
851
+ gt_corners, gt_edges = simplify_gt(map_same_location, gt_corners, gt_edges)
852
+
853
+ _, _, map_same_location = get_wrong_corners(
854
+ corners_array, gt_corners, edges_array, gt_edges)
855
+
856
+ for corner_i in range(gt_corners.shape[0]):
857
+ if map_same_location[corner_i] is None:
858
+ # doesn't exist in candidate
859
+ neighbor_id = get_neighbor_corner_id(corner_i, gt_edges)
860
+ for corner_j in neighbor_id:
861
+ if map_same_location[corner_j] is not None:
862
+ # exist corner in candidate that maps neighbor corner
863
+ new_candidate = candidate.copy()
864
+ new_corner = Element(
865
+ (
866
+ int(np.round(gt_corners[corner_i, 0])), int(np.round(gt_corners[corner_i, 1]))
867
+ )
868
+ )
869
+ if new_candidate.addable(new_corner) is False:
870
+ continue
871
+ # new corner can be too close to an edge
872
+ flag = False
873
+ for edge_ele in new_candidate.graph.getEdges():
874
+ if get_distance_of_corner_and_edge(edge_ele.x[0].x, edge_ele.x[1].x, new_corner.x) < 7:
875
+ flag = True
876
+ break
877
+ if flag:
878
+ continue
879
+
880
+ new_corner = new_candidate.addCorner(new_corner)
881
+ neighbor_index = map_same_location[corner_j]
882
+ neighbor_corner = new_candidate.graph.getCorners()[neighbor_index]
883
+ new_edge = new_candidate.addEdge(new_corner, neighbor_corner)
884
+ if new_candidate.graph.checkIntersectionEdge(new_edge):
885
+ continue
886
+ new_candidates.append(new_candidate)
887
+
888
+ return new_candidates
889
+
890
+
891
+ def adding_a_corner_from_two_edges_extension(candidate):
892
+ new_candidates = []
893
+ graph = candidate.graph
894
+ edges = candidate.graph.getEdges()
895
+ for edge_i in range(len(edges)):
896
+ for edge_j in range(edge_i + 1, len(edges)):
897
+ edgeA = edges[edge_i]
898
+ edgeB = edges[edge_j]
899
+ if graph.isNeighbor(edgeA, edgeB):
900
+ continue
901
+ intersection_loc = get_two_edge_intersection_location(edgeA.x[0].x, edgeA.x[1].x, edgeB.x[0].x,
902
+ edgeB.x[1].x)
903
+ if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
904
+ intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
905
+ continue
906
+ # intersection point can not be too close to an edge
907
+ flag = False
908
+ for edge_ele in graph.getEdges():
909
+ if get_distance_of_corner_and_edge(edge_ele.x[0].x, edge_ele.x[1].x, intersection_loc) < 7:
910
+ flag = True
911
+ break
912
+ if flag:
913
+ continue
914
+ new_candidate = candidate.copy()
915
+ new_graph = new_candidate.graph
916
+ new_edgeA = new_graph.getRealElement(edgeA)
917
+ new_edgeB = new_graph.getRealElement(edgeB)
918
+ new_corner = Element(intersection_loc)
919
+ if new_candidate.addable(new_corner) is False:
920
+ continue
921
+ new_corner = new_candidate.addCorner_v2(new_corner)
922
+ # get cornerA and cornerB from edgeA, edgeB
923
+ if l2_distance(new_corner.x, new_edgeA.x[0].x) < l2_distance(new_corner.x, new_edgeA.x[1].x):
924
+ cornerA = new_edgeA.x[0]
925
+ else:
926
+ cornerA = new_edgeA.x[1]
927
+ if l2_distance(new_corner.x, new_edgeB.x[0].x) < l2_distance(new_corner.x, new_edgeB.x[1].x):
928
+ cornerB = new_edgeB.x[0]
929
+ else:
930
+ cornerB = new_edgeB.x[1]
931
+
932
+ # new edge can not be too short
933
+ if l2_distance(cornerA.x, new_corner.x) < 7:
934
+ continue
935
+ if l2_distance(cornerB.x, new_corner.x) < 7:
936
+ continue
937
+
938
+ # new intersection cannot be too flat
939
+ if degree_of_three_corners(cornerA.x, cornerB.x, new_corner.x) > 165:
940
+ continue
941
+
942
+ flag = False
943
+ for edge_ele in new_graph.getEdges():
944
+ if new_corner in edge_ele.x and cornerA in edge_ele.x:
945
+ flag = True
946
+ break
947
+ if edge_ele.x[0] not in (new_corner, cornerA):
948
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[0].x)
949
+ if l <= 7:
950
+ flag = True
951
+ break
952
+ if edge_ele.x[1] not in (new_corner, cornerA):
953
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[1].x)
954
+ if l <= 7:
955
+ flag = True
956
+ break
957
+ if flag:
958
+ continue
959
+ add_edgeA = new_candidate.addEdge(new_corner, cornerA)
960
+ if new_graph.checkIntersectionEdge(add_edgeA):
961
+ continue
962
+
963
+ flag = False
964
+ for edge_ele in new_graph.getEdges():
965
+ if new_corner in edge_ele.x and cornerB in edge_ele.x:
966
+ flag = True
967
+ break
968
+ if edge_ele.x[0] not in (new_corner, cornerB):
969
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[0].x)
970
+ if l <= 7:
971
+ flag = True
972
+ break
973
+ if edge_ele.x[1] not in (new_corner, cornerB):
974
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[1].x)
975
+ if l <= 7:
976
+ flag = True
977
+ break
978
+ if flag:
979
+ continue
980
+ add_edgeB = new_candidate.addEdge(new_corner, cornerB)
981
+ if new_graph.checkIntersectionEdge(add_edgeB):
982
+ continue
983
+
984
+ # make real new candidate
985
+ # new_candidate = candidate.copy()
986
+ # new_graph = new_candidate.graph
987
+ # new_corner = Element(intersection_loc)
988
+ # new_corner = new_graph.add_corner_v2(new_corner)
989
+ # new_candidate = new_candidate.generate_new_candidate_add_an_edge(new_corner, cornerA)
990
+ # new_candidate = new_candidate.generate_new_candidate_add_an_edge(new_corner, cornerB)
991
+
992
+ new_candidates.append(new_candidate)
993
+ return new_candidates
994
+
995
+
996
+ def adding_a_corner_from_parallel(candidate):
997
+ new_candidates = []
998
+ graph = candidate.graph
999
+ edges = candidate.graph.getEdges()
1000
+ for edge_i in range(len(edges)):
1001
+ for edge_j in range(edge_i + 1, len(edges)):
1002
+ edgeA = edges[edge_i]
1003
+ edgeB = edges[edge_j]
1004
+ # get intersection loc
1005
+ if graph.isNeighbor(edgeA, edgeB):
1006
+ shared_corner = edgeA.x[0] if edgeA.x[0] in edgeB.x else edgeA.x[1]
1007
+ intersection_loc = shared_corner.x
1008
+ else:
1009
+ intersection_loc = get_two_edge_intersection_location(
1010
+ edgeA.x[0].x, edgeA.x[1].x, edgeB.x[0].x, edgeB.x[1].x)
1011
+ if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
1012
+ intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
1013
+ continue
1014
+
1015
+ # get another two loc
1016
+ locA = edgeA.x[1].x if \
1017
+ l2_distance(edgeA.x[0].x, intersection_loc) < l2_distance(edgeA.x[1].x, intersection_loc) else \
1018
+ edgeA.x[0].x
1019
+ locB = edgeB.x[1].x if \
1020
+ l2_distance(edgeB.x[0].x, intersection_loc) < l2_distance(edgeB.x[1].x, intersection_loc) else \
1021
+ edgeB.x[0].x
1022
+
1023
+ # get new loc
1024
+ new_loc = (locA[0] + locB[0] - intersection_loc[0], locA[1] + locB[1] - intersection_loc[1])
1025
+ if new_loc[0] >= 255 or new_loc[1] >= 255 or \
1026
+ new_loc[0] <= 0 or new_loc[1] <= 0:
1027
+ continue
1028
+
1029
+ new_corner = Element(new_loc)
1030
+ new_candidate = candidate.copy()
1031
+ new_graph = new_candidate.graph
1032
+ edgeA = new_graph.getRealElement(edgeA)
1033
+ edgeB = new_graph.getRealElement(edgeB)
1034
+ if new_candidate.addable(new_corner) is False:
1035
+ continue
1036
+ new_corner = new_candidate.addCorner_v2(new_corner)
1037
+ # get cornerA and cornerB from edgeA, edgeB
1038
+ cornerA = edgeA.x[1] if l2_distance(edgeA.x[0].x, intersection_loc) < l2_distance(edgeA.x[1].x,
1039
+ intersection_loc) \
1040
+ else edgeA.x[0]
1041
+ cornerB = edgeB.x[1] if l2_distance(edgeB.x[0].x, intersection_loc) < l2_distance(edgeB.x[1].x,
1042
+ intersection_loc) \
1043
+ else edgeB.x[0]
1044
+
1045
+ # new edge can not be too short
1046
+ if l2_distance(cornerA.x, new_corner.x) < 12:
1047
+ continue
1048
+ if l2_distance(cornerB.x, new_corner.x) < 12:
1049
+ continue
1050
+
1051
+ flag = False
1052
+ for edge_ele in new_graph.getEdges():
1053
+ if new_corner in edge_ele.x and cornerA in edge_ele.x:
1054
+ flag = True
1055
+ break
1056
+ if edge_ele.x[0] not in (new_corner, cornerA):
1057
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[0].x)
1058
+ if l <= 7:
1059
+ flag = True
1060
+ break
1061
+ if edge_ele.x[1] not in (new_corner, cornerA):
1062
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[1].x)
1063
+ if l <= 7:
1064
+ flag = True
1065
+ break
1066
+ if flag:
1067
+ continue
1068
+ add_edgeA = new_candidate.addEdge(new_corner, cornerA)
1069
+ if new_graph.checkIntersectionEdge(add_edgeA):
1070
+ continue
1071
+
1072
+ flag = False
1073
+ for edge_ele in new_graph.getEdges():
1074
+ if new_corner in edge_ele.x and cornerB in edge_ele.x:
1075
+ flag = True
1076
+ break
1077
+ if edge_ele.x[0] not in (new_corner, cornerB):
1078
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[0].x)
1079
+ if l <= 7:
1080
+ flag = True
1081
+ break
1082
+ if edge_ele.x[1] not in (new_corner, cornerB):
1083
+ l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[1].x)
1084
+ if l <= 7:
1085
+ flag = True
1086
+ break
1087
+ if flag:
1088
+ continue
1089
+ add_edgeB = new_candidate.addEdge(new_corner, cornerB)
1090
+ if new_graph.checkIntersectionEdge(add_edgeB):
1091
+ continue
1092
+
1093
+ new_candidates.append(new_candidate)
1094
+ return new_candidates
1095
+
1096
+
1097
+ def adding_a_orthogonal_edge(candidate):
1098
+ new_candidates = []
1099
+ graph = candidate.graph
1100
+ edges = candidate.graph.getEdges()
1101
+ for edge in edges:
1102
+ cornerA = edge.x[0]
1103
+ cornerB = edge.x[1]
1104
+
1105
+ # get orthogonal direction
1106
+ dir_ = (cornerA.x[1] - cornerB.x[1], cornerB.x[0] - cornerA.x[0])
1107
+
1108
+ for the_corner in edge.x:
1109
+ temp_orth_loc = (the_corner.x[0] - dir_[0], the_corner.x[1] - dir_[1])
1110
+ for inter_edge in edges:
1111
+ if inter_edge == edge:
1112
+ continue
1113
+ if the_corner in inter_edge.x:
1114
+ continue
1115
+ intersection_loc = get_two_edge_intersection_location(
1116
+ the_corner.x, temp_orth_loc, inter_edge.x[0].x, inter_edge.x[1].x
1117
+ )
1118
+ if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
1119
+ intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
1120
+ continue
1121
+ if np.dot((inter_edge.x[0].x[0] - intersection_loc[0], inter_edge.x[0].x[1] - intersection_loc[1]),
1122
+ (inter_edge.x[1].x[0] - intersection_loc[0], inter_edge.x[1].x[1] - intersection_loc[1])) > 0:
1123
+ # which means the intersection is not inside inter_edge but at the edge extension
1124
+ continue
1125
+ if l2_distance(intersection_loc, inter_edge.x[0].x) < 5 or \
1126
+ l2_distance(intersection_loc, inter_edge.x[1].x) < 5:
1127
+ continue
1128
+
1129
+ # no thin degree with neighbor edge
1130
+ flag = False
1131
+ neighbor_corners = graph.getNeighborCorner(the_corner)
1132
+ for corner_ele in neighbor_corners:
1133
+ if corner_ele in edge.x:
1134
+ continue
1135
+ if degree_of_three_corners(corner_ele.x, intersection_loc, the_corner.x) < 15:
1136
+ flag = True
1137
+ break
1138
+ if degree_of_three_corners(corner_ele.x, intersection_loc, the_corner.x) > 165:
1139
+ flag = True
1140
+ break
1141
+ if flag:
1142
+ continue
1143
+
1144
+ new_candidate = candidate.copy()
1145
+ new_graph = new_candidate.graph
1146
+ new_corner = Element(intersection_loc)
1147
+ if new_candidate.addable(new_corner) is False:
1148
+ continue
1149
+ new_corner = new_candidate.addCorner_v2(new_corner)
1150
+
1151
+ # new edge can not be too short
1152
+ if l2_distance(new_corner.x, the_corner.x) < 7:
1153
+ continue
1154
+
1155
+ add_edge = new_candidate.addEdge(new_corner, new_graph.getRealElement(the_corner))
1156
+ if new_graph.checkIntersectionEdge(add_edge):
1157
+ continue
1158
+
1159
+ new_candidates.append(new_candidate)
1160
+ return new_candidates
1161
+
1162
+
1163
+ class _thread(threading.Thread):
1164
+ def __init__(self, threadID, name, candidate, lock, result_list, func):
1165
+ threading.Thread.__init__(self)
1166
+ self.threadID = threadID
1167
+ self.name = name
1168
+ self.candidate = candidate
1169
+ self.lock = lock
1170
+ self.result_list = result_list
1171
+ self.func = func
1172
+
1173
+ def run(self):
1174
+ print('running id: ', self.name)
1175
+ start_time = time.time()
1176
+ candidates = self.func(self.candidate)
1177
+ print('test: =================================', self.name, len(candidates))
1178
+ self.lock.acquire()
1179
+ self.result_list.extend(candidates)
1180
+ self.lock.release()
1181
+ print(self.name, "spend time: {}s".format(time.time() - start_time))
1182
+
1183
+
1184
+ def candidate_enumerate_training(candidate, gt):
1185
+ new_candidates = []
1186
+ # remove a corner
1187
+ try:
1188
+ new_ = removing_a_corner_operation(candidate)
1189
+ if len(new_) > 0:
1190
+ new_candidates.append(random.choice(new_))
1191
+ except:
1192
+ print('something wrong with remove a corner !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
1193
+
1194
+ # remove a colinear corner
1195
+ try:
1196
+ new_ = removing_a_colinear_corner_operation(candidate)
1197
+ if len(new_) > 0:
1198
+ new_candidates.append(random.choice(new_))
1199
+ except:
1200
+ print('something wrong with remove a colinear corner !!!!!!!!!!!!!!!!!!!!!!!!!!!')
1201
+
1202
+ # remove an edge
1203
+ try:
1204
+ new_ = removing_an_edge_operation(candidate)
1205
+ if len(new_) > 0:
1206
+ new_candidates.append(random.choice(new_))
1207
+ except:
1208
+ print('something wrong with remove an edge !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
1209
+
1210
+ # add an edge from existed corner
1211
+ try:
1212
+ new_ = adding_an_edge_operation(candidate)
1213
+ if len(new_) > 0:
1214
+ new_candidates.append(random.choice(new_))
1215
+ except:
1216
+ print('something wrong with add an edge from existed corner !!!!!!!!!!!!!!!!!!!!')
1217
+
1218
+ # add a corner from two edges
1219
+ try:
1220
+ new_ = adding_a_corner_from_two_edges_extension(candidate)
1221
+ if len(new_) > 0:
1222
+ new_candidates.append(random.choice(new_))
1223
+ except:
1224
+ print('something wrong with add a corner from two edges !!!!!!!!!!!!!!!!!!!!!!!!')
1225
+
1226
+ try:
1227
+ new_ = adding_a_corner_from_parallel(candidate)
1228
+ if len(new_) > 0:
1229
+ new_candidates.append(random.choice(new_))
1230
+ except:
1231
+ print('something wrong with add a corner from parallel !!!!!!!!!!!!!!!!!!!!!!!!')
1232
+
1233
+ # add an edge from gt
1234
+ try:
1235
+ new_ = adding_an_edge_from_gt(candidate, gt)
1236
+ if len(new_) > 0:
1237
+ new_candidates.append(random.choice(new_))
1238
+ except:
1239
+ print('something wrong with add an edge from gt !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
1240
+
1241
+ # add a orthogonal edge
1242
+ try:
1243
+ new_ = adding_a_orthogonal_edge(candidate)
1244
+ if len(new_) > 0:
1245
+ new_candidates.append(random.choice(new_))
1246
+ except:
1247
+ print('something wrong with add a orthogonal edge !!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
1248
+ return new_candidates
1249
+
1250
+
1251
+ def candidate_enumerate(candidate):
1252
+ new_candidates = []
1253
+ new_candidates.extend(removing_a_corner_operation(candidate))
1254
+ new_candidates.extend(removing_a_colinear_corner_operation(candidate))
1255
+ new_candidates.extend(removing_an_edge_operation(candidate))
1256
+ new_candidates.extend(adding_an_edge_operation(candidate))
1257
+ new_candidates.extend(adding_a_corner_from_two_edges_extension(candidate))
1258
+ new_candidates.extend(adding_a_corner_from_parallel(candidate))
1259
+ new_candidates.extend(adding_a_orthogonal_edge(candidate))
1260
+
1261
+ return new_candidates
1262
+
1263
+
1264
+ def candidate_enumerate_thread(candidate):
1265
+ new_candidates = []
1266
+ lock = threading.Lock()
1267
+
1268
+ thread1 = _thread(1, 'remove_a_corner', candidate, lock, new_candidates, removing_a_corner_operation)
1269
+ thread2 = _thread(2, 'remove_a_colinear_corner', candidate, lock, new_candidates,
1270
+ removing_a_colinear_corner_operation)
1271
+ thread3 = _thread(3, 'add_an_edge', candidate, lock, new_candidates, adding_an_edge_operation)
1272
+ thread4 = _thread(4, 'remove_an_edge', candidate, lock, new_candidates, removing_an_edge_operation)
1273
+
1274
+ thread1.start()
1275
+ thread2.start()
1276
+ thread3.start()
1277
+ thread4.start()
1278
+
1279
+ threads = []
1280
+ threads.append(thread1)
1281
+ threads.append(thread2)
1282
+ threads.append(thread3)
1283
+ threads.append(thread4)
1284
+
1285
+ for t in threads:
1286
+ t.join()
1287
+
1288
+ return new_candidates
1289
+
1290
+
1291
+ def reduce_duplicate_candidate(candidates):
1292
+ i = 0
1293
+ while i < len(candidates):
1294
+ for j in reversed(range(i + 1, len(candidates))):
1295
+ if candidates[i].equal(candidates[j]):
1296
+ del candidates[j]
1297
+ i = i + 1
1298
+ return candidates
1299
+
1300
+
1301
+ def save_candidate_image(candidate, base_path, base_name):
1302
+ corners = candidate.graph.getCornersArray()
1303
+ edges = candidate.graph.getEdgesArray()
1304
+ # graph svg
1305
+ svg = svg_generate(corners, edges, base_name, samecolor=True)
1306
+ svg.saveas(os.path.join(base_path, base_name + '.svg'))
1307
+ # corner image
1308
+ temp_mask = np.zeros((256, 256))
1309
+ for ele in candidate.graph.getCorners():
1310
+ if ele.get_score() < 0:
1311
+ temp_mask = cv2.circle(temp_mask, ele.x[::-1], 3, 1, -1)
1312
+ fig = plt.figure(frameon=False)
1313
+ fig.set_size_inches(1, 1)
1314
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
1315
+ ax.set_axis_off()
1316
+ fig.add_axes(ax)
1317
+ ax.imshow(temp_mask, aspect='auto')
1318
+ fig.savefig(os.path.join(base_path, base_name + '_corner.png'), dpi=256)
1319
+ # edges image
1320
+ temp_mask = np.zeros((256, 256))
1321
+ for ele in candidate.graph.getEdges():
1322
+ if ele.get_score() < 0:
1323
+ A = ele.x[0]
1324
+ B = ele.x[1]
1325
+ temp_mask = cv2.line(temp_mask, A.x[::-1], B.x[::-1], 1, thickness=1)
1326
+ ax.imshow(temp_mask, aspect='auto')
1327
+ fig.savefig(os.path.join(base_path, base_name + '_edge.png'), dpi=256)
1328
+ # region no need fig
1329
+ plt.close()
1330
+
1331
+
1332
+ #########################################################################################
1333
+ ###################################### Class ############################################
1334
+ #########################################################################################
1335
+
1336
+ class Element:
1337
+ def __init__(self, x, safe_count=0):
1338
+ assert type(x) is tuple
1339
+ assert type(x[0]) == int or type(x[0]) == Element
1340
+ assert type(x[1]) == int or type(x[1]) == Element
1341
+ self.x = x
1342
+ self.__score = None
1343
+ self.safe_count = safe_count
1344
+
1345
+ def store_score(self, score):
1346
+ self.__score = score
1347
+
1348
+ def get_score(self):
1349
+ return self.__score
1350
+
1351
+ def equal(self, ele):
1352
+ if type(self.x[0]) != type(ele.x[0]):
1353
+ return False
1354
+ if type(self.x[0]) == int:
1355
+ # corner
1356
+ return True if self.x[0] == ele.x[0] and self.x[1] == ele.x[1] else False
1357
+ if type(self.x[0]) == Element:
1358
+ # edge
1359
+ if self.x[0].equal(ele.x[0]) and self.x[1].equal(ele.x[1]):
1360
+ return True
1361
+ if self.x[1].equal(ele.x[0]) and self.x[0].equal(ele.x[1]):
1362
+ return True
1363
+ return False
1364
+ raise BaseException('no implement type')
1365
+
1366
+
1367
+ class regionCache():
1368
+ def __init__(self, datapath):
1369
+ self.cache = {}
1370
+ self.datapath = datapath
1371
+
1372
+ def get_region(self, name):
1373
+ if name in self.cache.keys():
1374
+ return self.cache[name]
1375
+ gt_mask = np.load(os.path.join(self.datapath, name + '.npy'))
1376
+ if len(self.cache) == 5:
1377
+ self.cache.pop(list(self.cache.keys())[0])
1378
+ self.cache[name] = gt_mask
1379
+ return gt_mask
1380
+
1381
+
1382
+ class imgCache():
1383
+ def __init__(self, datapath):
1384
+ self.cache = {}
1385
+ self.datapath = datapath
1386
+
1387
+ def get_image(self, name):
1388
+ if name in self.cache.keys():
1389
+ return self.cache[name]
1390
+ img = skimage.img_as_float(plt.imread(os.path.join(self.datapath, 'rgb', name + '.jpg')))
1391
+ if len(self.cache) == 5:
1392
+ self.cache.pop(list(self.cache.keys())[0])
1393
+ self.cache[name] = img
1394
+ return img
1395
+
1396
+
1397
+ class Graph:
1398
+ def __init__(self, corners, edges):
1399
+ corners, edges = sort_graph(corners, edges)
1400
+
1401
+ self.__corners = []
1402
+ for corner_i in range(corners.shape[0]):
1403
+ self.__corners.append(
1404
+ Element(
1405
+ tuple(
1406
+ (int(corners[corner_i, 0]), int(corners[corner_i, 1]))
1407
+ )
1408
+ )
1409
+ )
1410
+ self.__edges = []
1411
+ for edge_i in range(edges.shape[0]):
1412
+ self.__edges.append(Element((self.__corners[edges[edge_i, 0]], self.__corners[edges[edge_i, 1]])))
1413
+ self.__regions = []
1414
+ self.__regions.append(Element((0, 0))) # we use entire region here
1415
+
1416
+ @classmethod
1417
+ def initialFromTuple(cls, corners, edges):
1418
+ edge_index = []
1419
+ for item in edges:
1420
+ a = corners.index(item[0])
1421
+ b = corners.index(item[1])
1422
+ edge_index.append((a, b))
1423
+ edge_index = np.array(edge_index)
1424
+ corners = np.array(corners)
1425
+ return cls(corners, edge_index)
1426
+
1427
+ def store_score(self, corner_score=None, edge_score=None, region_score=None):
1428
+ '''
1429
+ :param corner_score: np array size: len(corners)
1430
+ :param edge_score: np array size: len(edges)
1431
+ :param region_score: np.array size: len(regions)
1432
+ :return:
1433
+ '''
1434
+ if corner_score is not None:
1435
+ for idx, element in enumerate(self.__corners):
1436
+ element.store_score(corner_score[idx])
1437
+ if edge_score is not None:
1438
+ for idx, element in enumerate(self.__edges):
1439
+ element.store_score(edge_score[idx])
1440
+ if region_score is not None:
1441
+ for idx, element in enumerate(self.__regions):
1442
+ element.store_score(region_score[idx])
1443
+ return
1444
+
1445
+ def getCornersArray(self):
1446
+ c = []
1447
+ for ele in self.__corners:
1448
+ c.append(ele.x)
1449
+ return np.array(c)
1450
+
1451
+ def getEdgesArray(self):
1452
+ c = []
1453
+ for ele in self.__edges:
1454
+ corner1 = ele.x[0]
1455
+ corner2 = ele.x[1]
1456
+ idx1 = self.__corners.index(corner1)
1457
+ idx2 = self.__corners.index(corner2)
1458
+ c.append([idx1, idx2])
1459
+ return np.array(c)
1460
+
1461
+ def getCorners(self):
1462
+ return self.__corners
1463
+
1464
+ def getRegions(self):
1465
+ return self.__regions
1466
+
1467
+ def getEdges(self):
1468
+ return self.__edges
1469
+
1470
+ def graph_score(self):
1471
+ corner_score = 0
1472
+ for ele in self.__corners:
1473
+ corner_score += ele.get_score()
1474
+ edge_score = 0
1475
+ for ele in self.__edges:
1476
+ edge_score += ele.get_score()
1477
+ region_score = 0
1478
+ for ele in self.__regions:
1479
+ region_score += ele.get_score()
1480
+ return score_weights[0] * corner_score + score_weights[1] * edge_score + score_weights[2] * region_score
1481
+
1482
+ def corner_score(self):
1483
+ corner_score = 0
1484
+ for ele in self.__corners:
1485
+ corner_score += ele.get_score()
1486
+ return corner_score
1487
+
1488
+ def edge_score(self):
1489
+ edge_score = 0
1490
+ for ele in self.__edges:
1491
+ edge_score += ele.get_score()
1492
+ return edge_score
1493
+
1494
+ def region_score(self):
1495
+ region_score = 0
1496
+ for ele in self.__regions:
1497
+ region_score += ele.get_score()
1498
+ return region_score
1499
+
1500
+ def remove(self, ele):
1501
+ '''
1502
+ :param ele: remove eles as well as some other related elements
1503
+ :return: set() of removed elements
1504
+ '''
1505
+ # corner
1506
+ removed = set()
1507
+ if ele in self.__corners:
1508
+ self.__corners.remove(ele)
1509
+ removed.add(ele)
1510
+ # remove edge that has the corner
1511
+ for idx in reversed(range(len(self.__edges))):
1512
+ edge_ele = self.__edges[idx]
1513
+ if ele in edge_ele.x:
1514
+ removed = removed.union(self.remove(edge_ele))
1515
+ # edge
1516
+ elif ele in self.__edges:
1517
+ self.__edges.remove(ele)
1518
+ removed.add(ele)
1519
+ corner1 = ele.x[0]
1520
+ corner2 = ele.x[1]
1521
+ if corner1.safe_count == 0:
1522
+ # can be delete
1523
+ _count = 0
1524
+ for edge_ele in self.__edges:
1525
+ if corner1 in edge_ele.x:
1526
+ _count += 1
1527
+ if _count == 0:
1528
+ removed = removed.union(self.remove(corner1))
1529
+ if corner2.safe_count == 0:
1530
+ # can be delete
1531
+ _count = 0
1532
+ for edge_ele in self.__edges:
1533
+ if corner2 in edge_ele.x:
1534
+ _count += 1
1535
+ if _count == 0:
1536
+ removed = removed.union(self.remove(corner2))
1537
+ return removed
1538
+
1539
+ def has_edge(self, ele1, ele2):
1540
+ """
1541
+ :param ele1: corner1
1542
+ :param ele2: corner2
1543
+ :return: edge or none
1544
+ """
1545
+ for edge_ele in self.__edges:
1546
+ if ele1 in edge_ele.x and ele2 in edge_ele.x:
1547
+ return edge_ele
1548
+ return None
1549
+
1550
+ def add_edge(self, ele1, ele2):
1551
+ temp = self.has_edge(ele1, ele2)
1552
+ if temp is not None:
1553
+ temp.safe_count = SAFE_NUM
1554
+ return temp
1555
+ new_ele = Element((ele1, ele2), safe_count=SAFE_NUM)
1556
+ self.__edges.append(new_ele)
1557
+ return new_ele
1558
+
1559
+ def add_corner(self, ele):
1560
+ for corner in self.__corners:
1561
+ if corner.x == ele.x:
1562
+ corner.safe_count = SAFE_NUM
1563
+ return corner
1564
+ ele.safe_count = SAFE_NUM
1565
+ self.__corners.append(ele)
1566
+ return ele
1567
+
1568
+ def add_corner_v2(self, ele):
1569
+ # if new corner is near a existed corner, return the existed corner
1570
+ # if new corner is on an edge, split edge
1571
+ for corner in self.__corners:
1572
+ if l2_distance(corner.x, ele.x) < 5:
1573
+ corner.safe_count = SAFE_NUM
1574
+ return corner
1575
+ min_d = 256
1576
+ the_edge = None
1577
+ for edge in self.__edges:
1578
+ temp = get_distance_of_corner_and_edge(edge.x[0].x, edge.x[1].x, ele.x)
1579
+ if temp < min_d:
1580
+ min_d = temp
1581
+ the_edge = edge
1582
+ if min_d < 3:
1583
+ # split edge
1584
+ corner1 = the_edge.x[0]
1585
+ corner2 = the_edge.x[1]
1586
+ new_ele = Element((corner1, ele), safe_count=the_edge.safe_count)
1587
+ self.__edges.append(new_ele)
1588
+ new_ele = Element((corner2, ele), safe_count=the_edge.safe_count)
1589
+ self.__edges.append(new_ele)
1590
+ self.__edges.remove(the_edge)
1591
+ ele.safe_count = SAFE_NUM
1592
+ self.__corners.append(ele)
1593
+ return ele
1594
+
1595
+ def checkColinearCorner(self, ele):
1596
+ if self.getCornerDegree(ele) != 2:
1597
+ return False
1598
+ edge_in = []
1599
+ for edge_ele in self.__edges:
1600
+ if ele in edge_ele.x:
1601
+ edge_in.append(edge_ele)
1602
+ if len(edge_in) == 2:
1603
+ break
1604
+ two_neighbor = {edge_in[0].x[0], edge_in[0].x[1], edge_in[1].x[0], edge_in[1].x[1]}
1605
+ two_neighbor.remove(ele)
1606
+ two_neighbor = tuple(two_neighbor)
1607
+ if self.has_edge(two_neighbor[0], two_neighbor[1]) is not None:
1608
+ return False
1609
+
1610
+ line1 = np.array(ele.x) - np.array(two_neighbor[0].x)
1611
+ line2 = np.array(two_neighbor[1].x) - np.array(ele.x)
1612
+ cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
1613
+ cos = min(1, max(-1, cos))
1614
+ if np.arccos(cos) < np.pi / 9: # 20 degree
1615
+ return True
1616
+ return False
1617
+
1618
+ def checkIntersectionEdge(self, ele=None):
1619
+ if ele is None:
1620
+ for edge_i in range(len(self.__edges)):
1621
+ for edge_j in range(edge_i + 1, len(self.__edges)):
1622
+ if check_intersection(self.__edges[edge_i], self.__edges[edge_j]):
1623
+ return True
1624
+ return False
1625
+ for edge_ele in self.__edges:
1626
+ if ele == edge_ele:
1627
+ continue
1628
+ if check_intersection(edge_ele, ele):
1629
+ return True
1630
+ return False
1631
+
1632
+ def getCornerDegree(self, ele):
1633
+ degree = 0
1634
+ for edge_ele in self.__edges:
1635
+ if ele in edge_ele.x:
1636
+ degree += 1
1637
+ return degree
1638
+
1639
+ def getEdgeConnected(self, ele):
1640
+ out_ = set()
1641
+ if type(ele.x[0]) == int:
1642
+ # corner
1643
+ for edge_ele in self.__edges:
1644
+ if ele in edge_ele.x:
1645
+ out_.add(edge_ele)
1646
+ return out_
1647
+ if type(ele.x[0]) == Element:
1648
+ # Edge
1649
+ out_ = out_.union(self.getEdgeConnected(ele.x[0]))
1650
+ out_ = out_.union(self.getEdgeConnected(ele.x[1]))
1651
+ if ele in out_:
1652
+ out_.remove(ele)
1653
+ return out_
1654
+
1655
+ def getNeighborCorner(self, ele):
1656
+ out_ = set()
1657
+ for edge_ele in self.__edges:
1658
+ if ele == edge_ele.x[0]:
1659
+ out_.add(edge_ele.x[1])
1660
+ if ele == edge_ele.x[1]:
1661
+ out_.add(edge_ele.x[0])
1662
+ return out_
1663
+
1664
+ def getRealElement(self, ele):
1665
+ # edge
1666
+ if type(ele.x[0]) == Element:
1667
+ for e in self.__edges:
1668
+ if (e.x[0].x == ele.x[0].x and e.x[1].x == ele.x[1].x) or \
1669
+ (e.x[1].x == ele.x[0].x and e.x[0].x == ele.x[1].x):
1670
+ return e
1671
+ raise BaseException("no same edge exists.")
1672
+ # corner
1673
+ elif type(ele.x[0]) == int:
1674
+ for c in self.__corners:
1675
+ if c.x == ele.x:
1676
+ return c
1677
+ raise BaseException("no same corner exists.")
1678
+
1679
+ def copy(self):
1680
+ corners = self.getCornersArray()
1681
+ edges = self.getEdgesArray()
1682
+ new_graph = Graph(corners, edges)
1683
+ for idx, ele in enumerate(self.__corners):
1684
+ new_graph.__corners[idx].store_score(self.__corners[idx].get_score())
1685
+ for idx, ele in enumerate(self.__edges):
1686
+ new_graph.__edges[idx].store_score(self.__edges[idx].get_score())
1687
+ for idx, ele in enumerate(self.__regions):
1688
+ new_graph.__regions[idx].store_score(self.__regions[idx].get_score)
1689
+ return new_graph
1690
+
1691
+ def update_safe_count(self):
1692
+ for ele in self.__corners:
1693
+ if ele.safe_count > 0:
1694
+ ele.safe_count -= 1
1695
+ for ele in self.__edges:
1696
+ if ele.safe_count > 0:
1697
+ ele.safe_count -= 1
1698
+
1699
+ def isNeighbor(self, element1, element2):
1700
+ '''
1701
+ :param element1:
1702
+ :param element2:
1703
+ :return: True / False
1704
+ '''
1705
+ if element1 == element2:
1706
+ return False
1707
+ if type(element1.x[0]) != type(element2.x[0]):
1708
+ # corner and edge
1709
+ return False
1710
+ if type(element1.x[0]) == int:
1711
+ # both are corner type
1712
+ for edge_ele in self.__edges:
1713
+ if edge_ele.x[0] == element1 and edge_ele.x[1] == element2:
1714
+ return True
1715
+ if edge_ele.x[0] == element2 and edge_ele.x[1] == element1:
1716
+ return True
1717
+ return False
1718
+ if type(element1.x[0]) == Element:
1719
+ # both are edge type
1720
+ if len({element1.x[0], element1.x[1], element2.x[0], element2.x[1]}) < 4:
1721
+ return True
1722
+ return False
1723
+
1724
+ def equal(self, graph):
1725
+ if len(self.__corners) != len(graph.__corners) or \
1726
+ len(self.__edges) != len(graph.__edges):
1727
+ return False
1728
+ for corner_i in range(len(self.__corners)):
1729
+ if self.__corners[corner_i].equal(graph.__corners[corner_i]) is False:
1730
+ return False
1731
+ for edge_i in range(len(self.__edges)):
1732
+ if self.__edges[edge_i].equal(graph.__edges[edge_i]) is False:
1733
+ return False
1734
+
1735
+ return True
1736
+
1737
+
1738
+ class Candidate:
1739
+ def __init__(self, graph, name, corner_existed_before, edge_existed_before):
1740
+ '''
1741
+ :param graph: Class graph
1742
+ :param name: string, data name
1743
+ :param corner_existed_before: dict {(x_i,y_i):c_1 ...} indicates counts for corresponding corners, after one search,
1744
+ counts -= 1, if count == 0, remove from the set.
1745
+ :param edge_existed_before: dict {((x_i1,y_i1),(x_i2,y_i2)):ci}
1746
+ '''
1747
+ self.graph = graph
1748
+ self.name = name
1749
+ self.corner_existed_before = corner_existed_before
1750
+ self.edge_existed_before = edge_existed_before
1751
+
1752
+ @classmethod
1753
+ def initial(cls, graph, name):
1754
+ return cls(graph, name, {}, {})
1755
+
1756
+ def update(self):
1757
+ # all the existed before elements count - 1
1758
+ for key in self.corner_existed_before.keys():
1759
+ self.corner_existed_before[key] -= 1
1760
+ for key in self.edge_existed_before.keys():
1761
+ self.edge_existed_before[key] -= 1
1762
+
1763
+ # check if some need to remove from existed before set
1764
+ for key in list(self.corner_existed_before.keys()):
1765
+ if self.corner_existed_before[key] == 0:
1766
+ self.corner_existed_before.pop(key)
1767
+
1768
+ for key in list(self.edge_existed_before.keys()):
1769
+ if self.edge_existed_before[key] == 0:
1770
+ self.edge_existed_before.pop(key)
1771
+
1772
+ # update graph
1773
+ self.graph.update_safe_count()
1774
+
1775
+ def copy(self):
1776
+ corner_existed_before = self.corner_existed_before.copy()
1777
+ edge_existed_before = self.edge_existed_before.copy()
1778
+ new_graph = self.graph.copy()
1779
+ return Candidate(new_graph, self.name, corner_existed_before, edge_existed_before)
1780
+
1781
+ def removable(self, ele):
1782
+ '''
1783
+ :param x: input is element
1784
+ :return:
1785
+ '''
1786
+ assert type(ele) == Element
1787
+ # edge
1788
+ return True if ele.safe_count == 0 else False
1789
+
1790
+ def addable(self, ele):
1791
+ if type(ele) == Element:
1792
+ if type(ele.x[0]) == Element:
1793
+ # edge
1794
+ for edge in self.graph.getEdges():
1795
+ c1 = edge.x[0]
1796
+ c2 = edge.x[1]
1797
+ if (ele.x[0].x == c1.x and ele.x[1].x == c2.x) or \
1798
+ (ele.x[1].x == c1.x and ele.x[0].x == c2.x):
1799
+ # already existed
1800
+ return False
1801
+ corner1_loc = ele.x[0].x
1802
+ corner2_loc = ele.x[1].x
1803
+ if (corner1_loc, corner2_loc) in self.edge_existed_before.keys() or \
1804
+ (corner2_loc, corner1_loc) in self.edge_existed_before.keys():
1805
+ return False
1806
+ return True
1807
+ else:
1808
+ # corner
1809
+ for corner in self.graph.getCorners():
1810
+ if l2_distance(ele.x, corner.x) < TWO_CORNER_MINIMUM_DISTANCE:
1811
+ # already existed
1812
+ return False
1813
+ if ele.x in self.corner_existed_before.keys():
1814
+ return False
1815
+ return True
1816
+ else: # (x,y) or ((x1,y1),(x2,y2))
1817
+ if type(ele[0]) == tuple:
1818
+ # edge
1819
+ corner1_loc = ele[0]
1820
+ corner2_loc = ele[1]
1821
+ for edge in self.graph.getEdges():
1822
+ c1 = edge.x[0]
1823
+ c2 = edge.x[1]
1824
+ if (corner1_loc == c1.x and corner2_loc == c2.x) or \
1825
+ (corner2_loc == c1.x and corner1_loc == c2.x):
1826
+ # already existed
1827
+ return False
1828
+ if (corner1_loc, corner2_loc) in self.edge_existed_before.keys() or \
1829
+ (corner2_loc, corner1_loc) in self.edge_existed_before.keys():
1830
+ return False
1831
+ return True
1832
+ else:
1833
+ # corner
1834
+ for corner in self.graph.getCorners():
1835
+ if l2_distance(ele, corner.x) < TWO_CORNER_MINIMUM_DISTANCE:
1836
+ # already existed
1837
+ return False
1838
+ if ele in self.corner_existed_before.keys():
1839
+ return False
1840
+ return True
1841
+
1842
+ def addCorner(self, ele):
1843
+ if ele.x in self.corner_existed_before.keys():
1844
+ raise BaseException('cannot add the corner')
1845
+ new_ele = self.graph.add_corner(ele) # possible changed
1846
+ return new_ele
1847
+
1848
+ def addCorner_v2(self, ele):
1849
+ if ele.x in self.corner_existed_before.keys():
1850
+ raise BaseException('cannot add the corner')
1851
+ new_ele = self.graph.add_corner_v2(ele)
1852
+ return new_ele
1853
+
1854
+ def addEdge(self, ele1, ele2):
1855
+ corner1 = ele1
1856
+ corner2 = ele2
1857
+ assert corner1 in self.graph.getCorners()
1858
+ assert corner2 in self.graph.getCorners()
1859
+ if (corner1.x, corner2.x) in self.edge_existed_before.keys() or \
1860
+ (corner2.x, corner1.x) in self.edge_existed_before.keys():
1861
+ raise BaseException('cannot add the edge')
1862
+ new_ele = self.graph.add_edge(corner1, corner2)
1863
+ return new_ele
1864
+
1865
+ def removeCorner(self, ele):
1866
+ if ele.x in self.corner_existed_before.keys():
1867
+ raise BaseException('already existed.')
1868
+ self.corner_existed_before[ele.x] = SAFE_NUM
1869
+
1870
+ def removeEdge(self, ele):
1871
+ corner1 = ele.x[0]
1872
+ corner2 = ele.x[1]
1873
+ loc1 = corner1.x
1874
+ loc2 = corner2.x
1875
+ if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
1876
+ loc1 = corner2.x
1877
+ loc2 = corner1.x
1878
+ if (loc1, loc2) in self.edge_existed_before.keys():
1879
+ raise BaseException('already existed.')
1880
+ self.edge_existed_before[(loc1, loc2)] = SAFE_NUM
1881
+
1882
+ def generate_new_candidate_remove_a_colinear_corner(self, ele):
1883
+ # need to check if ele is a colinear corner before
1884
+ new_candidate = self.copy()
1885
+ new_graph = new_candidate.graph
1886
+ ele = new_graph.getRealElement(ele)
1887
+
1888
+ # find two neighbor corners
1889
+ temp = set()
1890
+ for element in new_graph.getEdgeConnected(ele):
1891
+ # edge
1892
+ if type(element.x[0]) == Element:
1893
+ temp.add(element.x[0])
1894
+ temp.add(element.x[1])
1895
+ temp.remove(ele)
1896
+ temp = tuple(temp)
1897
+ assert len(temp) == 2
1898
+
1899
+ # add edge to two neighbor corners
1900
+ # (add before remove, in case the neighbor corners will be removed by zero degree)
1901
+ # special case no need to check existed_before, instead remove if in existed_before dict
1902
+ added = new_graph.add_edge(temp[0], temp[1])
1903
+ if (temp[0].x, temp[1].x) in self.edge_existed_before.keys():
1904
+ self.edge_existed_before.pop((temp[0].x, temp[1].x))
1905
+ if (temp[1].x, temp[0].x) in self.edge_existed_before.keys():
1906
+ self.edge_existed_before.pop((temp[1].x, temp[0].x))
1907
+
1908
+ # remove
1909
+ removed = new_graph.remove(ele)
1910
+
1911
+ # add removed elements into existed before
1912
+ for element in removed:
1913
+ # edge
1914
+ if type(element.x[0]) == Element:
1915
+ new_candidate.removeEdge(element)
1916
+ # corner
1917
+ elif type(element.x[0]) == int:
1918
+ new_candidate.removeCorner(element)
1919
+ else:
1920
+ raise BaseException('wrong type.')
1921
+
1922
+ # modify scores that need to be recounted
1923
+ # all corners are recounted
1924
+ for element in new_graph.getCorners():
1925
+ element.store_score(None)
1926
+
1927
+ # edges that are neighbors to the removed edges OR new edges will be recounted
1928
+ for element in new_graph.getEdges():
1929
+ for modified_ele in removed.union({added}):
1930
+ if new_graph.isNeighbor(element, modified_ele):
1931
+ element.store_score(None)
1932
+ break
1933
+
1934
+ # all regions are recounted
1935
+ for element in new_graph.getRegions():
1936
+ element.store_score(None)
1937
+
1938
+ return new_candidate
1939
+
1940
+ def generate_new_candidate_remove_a_corner(self, ele):
1941
+ # need to check if ele is removable before call this method
1942
+ new_candidate = self.copy()
1943
+ new_graph = new_candidate.graph
1944
+ ele = new_graph.getRealElement(ele)
1945
+ removed = new_graph.remove(ele)
1946
+
1947
+ # add removed elements into existed before
1948
+ for element in removed:
1949
+ # edge
1950
+ if type(element.x[0]) == Element:
1951
+ corner1 = element.x[0]
1952
+ corner2 = element.x[1]
1953
+ loc1 = corner1.x
1954
+ loc2 = corner2.x
1955
+ if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
1956
+ loc1 = corner2.x
1957
+ loc2 = corner1.x
1958
+ if (loc1, loc2) in self.edge_existed_before.keys():
1959
+ raise BaseException('already existed.')
1960
+ new_candidate.edge_existed_before[(loc1, loc2)] = SAFE_NUM
1961
+ # corner
1962
+ elif type(element.x[0]) == int:
1963
+ if element.x in self.corner_existed_before.keys():
1964
+ raise BaseException('already existed.')
1965
+ new_candidate.corner_existed_before[element.x] = SAFE_NUM
1966
+ else:
1967
+ raise BaseException('wrong type.')
1968
+
1969
+ # modify scores that need to be recounted
1970
+ # all corners are recounted
1971
+ for element in new_graph.getCorners():
1972
+ element.store_score(None)
1973
+
1974
+ # edges that are neighbors to the removed edges will be recounted
1975
+ for element in new_graph.getEdges():
1976
+ for removed_ele in removed:
1977
+ if new_graph.isNeighbor(element, removed_ele):
1978
+ element.store_score(None)
1979
+ break
1980
+
1981
+ # all regions are recounted
1982
+ for element in new_graph.getRegions():
1983
+ element.store_score(None)
1984
+
1985
+ return new_candidate
1986
+
1987
+ def generate_new_candidate_add_an_edge(self, ele1, ele2):
1988
+ # need to check addable before call this method
1989
+ new_candidate = self.copy()
1990
+ new_graph = new_candidate.graph
1991
+ ele1 = new_graph.getRealElement(ele1)
1992
+ ele2 = new_graph.getRealElement(ele2)
1993
+
1994
+ # add edge
1995
+ new_ele = new_candidate.addEdge(ele1, ele2)
1996
+
1997
+ # modify scores that need to be recounted
1998
+ # all corners are recounted
1999
+ for element in new_graph.getCorners():
2000
+ element.store_score(None)
2001
+
2002
+ # edges that are neighbors to the added edges will be recounted
2003
+ for element in new_graph.getEdges():
2004
+ if new_graph.isNeighbor(element, new_ele):
2005
+ element.store_score(None)
2006
+
2007
+ # all regions are recounted
2008
+ for element in new_graph.getRegions():
2009
+ element.store_score(None)
2010
+
2011
+ return new_candidate
2012
+
2013
+ def generate_new_candidate_remove_an_edge(self, ele):
2014
+ # need to check if ele is removable before call this method
2015
+ new_candidate = self.copy()
2016
+ new_graph = new_candidate.graph
2017
+ ele = new_graph.getRealElement(ele)
2018
+ removed = new_graph.remove(ele)
2019
+
2020
+ # add removed elements into existed before
2021
+ for element in removed:
2022
+ # edge
2023
+ if type(element.x[0]) == Element:
2024
+ corner1 = element.x[0]
2025
+ corner2 = element.x[1]
2026
+ loc1 = corner1.x
2027
+ loc2 = corner2.x
2028
+ if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
2029
+ loc1 = corner2.x
2030
+ loc2 = corner1.x
2031
+ if (loc1, loc2) in self.edge_existed_before.keys():
2032
+ raise BaseException('already existed.')
2033
+ new_candidate.edge_existed_before[(loc1, loc2)] = SAFE_NUM
2034
+ # corner
2035
+ elif type(element.x[0]) == int:
2036
+ if element.x in self.corner_existed_before.keys():
2037
+ raise BaseException('already existed.')
2038
+ new_candidate.corner_existed_before[element.x] = SAFE_NUM
2039
+ else:
2040
+ raise BaseException('wrong type.')
2041
+
2042
+ # modify scores that need to be recounted
2043
+ # all corners are recounted
2044
+ for element in new_graph.getCorners():
2045
+ element.store_score(None)
2046
+
2047
+ # edges that are neighbors to the removed edges will be recounted
2048
+ for element in new_graph.getEdges():
2049
+ for removed_ele in removed:
2050
+ if new_graph.isNeighbor(element, removed_ele):
2051
+ element.store_score(None)
2052
+ break
2053
+
2054
+ # all regions are recounted
2055
+ for element in new_graph.getRegions():
2056
+ element.store_score(None)
2057
+
2058
+ return new_candidate
2059
+
2060
+ def generate_new_candidate_add_a_new_triangle(self, ele_new, ele1, ele2):
2061
+ # this method is to add a new corner as well as two new edges into the graph
2062
+ # need to check addable of ele_new before call this method
2063
+ new_candidate = self.copy()
2064
+ new_graph = new_candidate.graph
2065
+ ele1 = new_graph.getRealElement(ele1)
2066
+ ele2 = new_graph.getRealElement(ele2)
2067
+
2068
+ # add corner
2069
+ ele_new = new_candidate.addCorner(ele_new) # ele_new possible change
2070
+
2071
+ # no score need to be recounted in current situation
2072
+
2073
+ # add two_new edge (ele1, ele_new) and (ele2, ele_new)
2074
+ new_candidate = new_candidate.generate_new_candidate_add_an_edge(ele_new, ele1)
2075
+ new_candidate = new_candidate.generate_new_candidate_add_an_edge(ele_new, ele2)
2076
+
2077
+ return new_candidate
2078
+
2079
+ def generate_new_candidate_add_a_corner(self, ele):
2080
+ # need to check addable of ele before call this method
2081
+ new_candidate = self.copy()
2082
+ new_graph = new_candidate.graph
2083
+
2084
+ # add corner
2085
+ ele = new_candidate.addCorner(ele)
2086
+
2087
+ # modify scores that need to be recounted
2088
+ # all corners are recounted
2089
+ for element in new_graph.getCorners():
2090
+ element.store_score(None)
2091
+
2092
+ # no edge need to be recounted
2093
+ # all regions are recounted
2094
+ for element in new_graph.getRegions():
2095
+ element.store_score(None)
2096
+
2097
+ return new_candidate
2098
+
2099
+ def equal(self, candidate):
2100
+ return self.graph.equal(candidate.graph)
models/__init__.py ADDED
File without changes
models/corner_models.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+ from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
7
+ DeformableTransformerDecoder, DeformableAttnDecoderLayer
8
+ from models.ops.modules import MSDeformAttn
9
+ from models.resnet import convrelu
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ from einops.layers.torch import Rearrange
12
+ from utils.misc import NestedTensor
13
+
14
+
15
+ class HeatCorner(nn.Module):
16
+ """
17
+ The corner model of HEAT is the edge model till the edge-filtering part. So only per-candidate prediction w/o
18
+ relational modeling.
19
+ """
20
+ def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
21
+ super(HeatCorner, self).__init__()
22
+ self.input_dim = input_dim
23
+ self.hidden_dim = hidden_dim
24
+ self.num_feature_levels = num_feature_levels
25
+
26
+ if num_feature_levels > 1:
27
+ num_backbone_outs = len(backbone_strides)
28
+ input_proj_list = []
29
+ for _ in range(num_backbone_outs):
30
+ in_channels = backbone_num_channels[_]
31
+ input_proj_list.append(nn.Sequential(
32
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
33
+ nn.GroupNorm(32, hidden_dim),
34
+ ))
35
+ for _ in range(num_feature_levels - num_backbone_outs):
36
+ input_proj_list.append(nn.Sequential(
37
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
38
+ nn.GroupNorm(32, hidden_dim),
39
+ ))
40
+ in_channels = hidden_dim
41
+ self.input_proj = nn.ModuleList(input_proj_list)
42
+ else:
43
+ self.input_proj = nn.ModuleList([
44
+ nn.Sequential(
45
+ nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
46
+ nn.GroupNorm(32, hidden_dim),
47
+ )])
48
+
49
+ self.patch_size = 4
50
+ patch_dim = (self.patch_size ** 2) * input_dim
51
+ self.to_patch_embedding = nn.Sequential(
52
+ Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size),
53
+ nn.Linear(patch_dim, input_dim),
54
+ nn.Linear(input_dim, hidden_dim),
55
+ )
56
+
57
+ self.pixel_pe_fc = nn.Linear(input_dim, hidden_dim)
58
+ self.transformer = CornerTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
59
+ dim_feedforward=1024, dropout=0.1)
60
+
61
+ self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
62
+
63
+ @staticmethod
64
+ def get_ms_feat(xs, img_mask):
65
+ out: Dict[str, NestedTensor] = {}
66
+ for name, x in sorted(xs.items()):
67
+ m = img_mask
68
+ assert m is not None
69
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
70
+ out[name] = NestedTensor(x, mask)
71
+ return out
72
+
73
+ @staticmethod
74
+ def get_decoder_reference_points(height, width, device):
75
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
76
+ torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device))
77
+ ref_y = ref_y.reshape(-1)[None] / height
78
+ ref_x = ref_x.reshape(-1)[None] / width
79
+ ref = torch.stack((ref_x, ref_y), -1)
80
+ return ref
81
+
82
+ def forward(self, image_feats, feat_mask, pixels_feat, pixels, all_image_feats):
83
+ # process image features
84
+ features = self.get_ms_feat(image_feats, feat_mask)
85
+
86
+ srcs = []
87
+ masks = []
88
+ all_pos = []
89
+
90
+ new_features = list()
91
+ for name, x in sorted(features.items()):
92
+ new_features.append(x)
93
+ features = new_features
94
+
95
+ for l, feat in enumerate(features):
96
+ src, mask = feat.decompose()
97
+ mask = mask.to(src.device)
98
+ srcs.append(self.input_proj[l](src))
99
+ pos = self.img_pos(src).to(src.dtype)
100
+ all_pos.append(pos)
101
+ masks.append(mask)
102
+ assert mask is not None
103
+
104
+ if self.num_feature_levels > len(srcs):
105
+ _len_srcs = len(srcs)
106
+ for l in range(_len_srcs, self.num_feature_levels):
107
+ if l == _len_srcs:
108
+ src = self.input_proj[l](features[-1].tensors)
109
+ else:
110
+ src = self.input_proj[l](srcs[-1])
111
+ m = feat_mask
112
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
113
+ pos_l = self.img_pos(src).to(src.dtype)
114
+ srcs.append(src)
115
+ masks.append(mask)
116
+ all_pos.append(pos_l)
117
+
118
+ sp_inputs = self.to_patch_embedding(pixels_feat)
119
+
120
+ # compute the reference points
121
+ H_tgt = W_tgt = int(np.sqrt(sp_inputs.shape[1]))
122
+ reference_points_s1 = self.get_decoder_reference_points(H_tgt, W_tgt, sp_inputs.device)
123
+
124
+ corner_logits = self.transformer(srcs, masks, all_pos, sp_inputs, reference_points_s1, all_image_feats)
125
+ return corner_logits
126
+
127
+
128
+ class PositionEmbeddingSine(nn.Module):
129
+ """
130
+ This is a more standard version of the position embedding, very similar to the one
131
+ used by the Attention is all you need paper, generalized to work on images.
132
+ """
133
+
134
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
135
+ super().__init__()
136
+ self.num_pos_feats = num_pos_feats
137
+ self.temperature = temperature
138
+ self.normalize = normalize
139
+ if scale is not None and normalize is False:
140
+ raise ValueError("normalize should be True if scale is passed")
141
+ if scale is None:
142
+ scale = 2 * math.pi
143
+ self.scale = scale
144
+
145
+ def forward(self, x):
146
+ mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]]).bool().to(x.device)
147
+ not_mask = ~mask
148
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
149
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
150
+ if self.normalize:
151
+ eps = 1e-6
152
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
153
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
154
+
155
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
156
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
157
+
158
+ pos_x = x_embed[:, :, :, None] / dim_t
159
+ pos_y = y_embed[:, :, :, None] / dim_t
160
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
161
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
162
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
163
+ return pos
164
+
165
+
166
+ class CornerTransformer(nn.Module):
167
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
168
+ dim_feedforward=1024, dropout=0.1,
169
+ activation="relu", return_intermediate_dec=False,
170
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
171
+ ):
172
+ super(CornerTransformer, self).__init__()
173
+
174
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
175
+ dropout, activation,
176
+ num_feature_levels, nhead, enc_n_points)
177
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
178
+
179
+ decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
180
+ dropout, activation,
181
+ num_feature_levels, nhead, dec_n_points)
182
+ self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
183
+
184
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
185
+
186
+ # upconv layers
187
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
188
+ self.conv_up1 = convrelu(256 + 256, 256, 3, 1)
189
+ self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
190
+ self.conv_original_size2 = convrelu(64 + 128, d_model, 3, 1)
191
+ self.output_fc_1 = nn.Linear(d_model, 1)
192
+ self.output_fc_2 = nn.Linear(d_model, 1)
193
+
194
+ self._reset_parameters()
195
+
196
+ def _reset_parameters(self):
197
+ for p in self.parameters():
198
+ if p.dim() > 1:
199
+ nn.init.xavier_uniform_(p)
200
+ for m in self.modules():
201
+ if isinstance(m, MSDeformAttn):
202
+ m._reset_parameters()
203
+ normal_(self.level_embed)
204
+
205
+ def get_valid_ratio(self, mask):
206
+ _, H, W = mask.shape
207
+ valid_H = torch.sum(~mask[:, :, 0], 1)
208
+ valid_W = torch.sum(~mask[:, 0, :], 1)
209
+ valid_ratio_h = valid_H.float() / H
210
+ valid_ratio_w = valid_W.float() / W
211
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
212
+ return valid_ratio
213
+
214
+ def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, all_image_feats):
215
+ # prepare input for encoder
216
+ src_flatten = []
217
+ mask_flatten = []
218
+ lvl_pos_embed_flatten = []
219
+ spatial_shapes = []
220
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
221
+ bs, c, h, w = src.shape
222
+ spatial_shape = (h, w)
223
+ spatial_shapes.append(spatial_shape)
224
+ src = src.flatten(2).transpose(1, 2)
225
+ mask = mask.flatten(1)
226
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
227
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
228
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
229
+ src_flatten.append(src)
230
+ mask_flatten.append(mask)
231
+ src_flatten = torch.cat(src_flatten, 1)
232
+ mask_flatten = torch.cat(mask_flatten, 1)
233
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
234
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
235
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
236
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
237
+
238
+ # encoder
239
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
240
+ mask_flatten)
241
+
242
+ # prepare input for decoder
243
+ bs, _, c = memory.shape
244
+
245
+ tgt = query_embed
246
+
247
+ # relational decoder
248
+ hs_pixels_s1, _ = self.per_edge_decoder(tgt, reference_points, memory,
249
+ spatial_shapes, level_start_index, valid_ratios, query_embed,
250
+ mask_flatten)
251
+
252
+ feats_s1, preds_s1 = self.generate_corner_preds(hs_pixels_s1, all_image_feats)
253
+
254
+ return preds_s1
255
+
256
+ def generate_corner_preds(self, outputs, conv_outputs):
257
+ B, L, C = outputs.shape
258
+ side = int(np.sqrt(L))
259
+ outputs = outputs.view(B, side, side, C)
260
+ outputs = outputs.permute(0, 3, 1, 2)
261
+ outputs = torch.cat([outputs, conv_outputs['layer1']], dim=1)
262
+ x = self.conv_up1(outputs)
263
+
264
+ x = self.upsample(x)
265
+ x = torch.cat([x, conv_outputs['layer0']], dim=1)
266
+ x = self.conv_up0(x)
267
+
268
+ x = self.upsample(x)
269
+ x = torch.cat([x, conv_outputs['x_original']], dim=1)
270
+ x = self.conv_original_size2(x)
271
+
272
+ logits = x.permute(0, 2, 3, 1)
273
+ preds = self.output_fc_1(logits)
274
+ preds = preds.squeeze(-1).sigmoid()
275
+ return logits, preds
models/corner_to_edge.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import scipy.ndimage.filters as filters
4
+ import cv2
5
+ import itertools
6
+
7
+ NEIGHBOUR_SIZE = 5
8
+ MATCH_THRESH = 5
9
+ LOCAL_MAX_THRESH = 0.01
10
+ viz_count = 0
11
+
12
+ # pre-compute all combinations to generate edge candidates faster
13
+ all_combibations = dict()
14
+ for length in range(2, 351):
15
+ ids = np.arange(length)
16
+ combs = np.array(list(itertools.combinations(ids, 2)))
17
+ all_combibations[length] = combs
18
+
19
+
20
+ def prepare_edge_data(c_outputs, annots, images, max_corner_num):
21
+ bs = c_outputs.shape[0]
22
+ # prepares parameters for each sample of the batch
23
+ all_results = list()
24
+
25
+ for b_i in range(bs):
26
+ annot = annots[b_i]
27
+ output = c_outputs[b_i]
28
+ results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]}, max_corner_num)
29
+ all_results.append(results)
30
+
31
+ processed_corners = [item['corners'] for item in all_results]
32
+ edge_coords = [item['edges'] for item in all_results]
33
+ edge_labels = [item['labels'] for item in all_results]
34
+
35
+ edge_info = {
36
+ 'edge_coords': edge_coords,
37
+ 'edge_labels': edge_labels,
38
+ 'processed_corners': processed_corners
39
+ }
40
+
41
+ edge_data = collate_edge_info(edge_info)
42
+ return edge_data
43
+
44
+
45
+ def process_annot(annot, do_round=True):
46
+ corners = np.array(list(annot.keys()))
47
+ ind = np.lexsort(corners.T) # sort the g.t. corners to fix the order for the matching later
48
+ corners = corners[ind] # sorted by y, then x
49
+ corner_mapping = {tuple(k): v for v, k in enumerate(corners)}
50
+
51
+ edges = list()
52
+ for c, connections in annot.items():
53
+ for other_c in connections:
54
+ edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)])
55
+ edges.append(edge_pair)
56
+ corner_degrees = [len(annot[tuple(c)]) for c in corners]
57
+ if do_round:
58
+ corners = corners.round()
59
+ return corners, edges, corner_degrees
60
+
61
+
62
+ def process_each_sample(data, max_corner_num):
63
+ annot = data['annot']
64
+ output = data['output']
65
+
66
+ preds = output.detach().cpu().numpy()
67
+
68
+ data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE)
69
+ maxima = (preds == data_max)
70
+ data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE)
71
+ diff = ((data_max - data_min) > 0)
72
+ maxima[diff == 0] = 0
73
+ local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH))
74
+ pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]] # to (x, y format)
75
+
76
+ # produce edge labels labels from pred corners here
77
+
78
+ processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot, max_corner_num)
79
+ # global viz_count
80
+ # viz_img = data['viz_img']
81
+ #output_path = './viz_training/{}_example_gt.png'.format(viz_count)
82
+ #_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
83
+ #viz_count += 1
84
+
85
+ results = {
86
+ 'corners': processed_corners,
87
+ 'edges': edges,
88
+ 'labels': labels,
89
+ }
90
+ return results
91
+
92
+
93
+ def get_edge_label_mix_gt(pred_corners, annot, max_corner_num):
94
+ ind = np.lexsort(pred_corners.T) # sort the pred corners to fix the order for matching
95
+ pred_corners = pred_corners[ind] # sorted by y, then x
96
+ gt_corners, edge_pairs, corner_degrees = process_annot(annot)
97
+
98
+ output_to_gt = dict()
99
+ gt_to_output = dict()
100
+ diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1))
101
+ diff = diff.T
102
+
103
+ if len(pred_corners) > 0:
104
+ for target_i, target in enumerate(gt_corners):
105
+ dist = diff[target_i]
106
+ if len(output_to_gt) > 0:
107
+ dist[list(output_to_gt.keys())] = 1000 # ignore already matched pred corners
108
+ min_dist = dist.min()
109
+ min_idx = dist.argmin()
110
+ if min_dist < MATCH_THRESH and min_idx not in output_to_gt: # a positive match
111
+ output_to_gt[min_idx] = (target_i, min_dist)
112
+ gt_to_output[target_i] = min_idx
113
+
114
+ all_corners = gt_corners.copy()
115
+
116
+ # replace matched g.t. corners with pred corners
117
+ for gt_i in range(len(gt_corners)):
118
+ if gt_i in gt_to_output:
119
+ all_corners[gt_i] = pred_corners[gt_to_output[gt_i]]
120
+
121
+ nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt]
122
+ nm_pred_ids = np.random.permutation(nm_pred_ids)
123
+ if len(nm_pred_ids) > 0:
124
+ nm_pred_corners = pred_corners[nm_pred_ids]
125
+ #if len(nm_pred_ids) + len(all_corners) <= 150:
126
+ if len(nm_pred_ids) + len(all_corners) <= max_corner_num:
127
+ all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0)
128
+ else:
129
+ #all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0)
130
+ all_corners = np.concatenate([all_corners, nm_pred_corners[:(max_corner_num - len(gt_corners)), :]], axis=0)
131
+
132
+ processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs)
133
+
134
+ return processed_corners, edges, labels
135
+
136
+
137
+ def _get_edges(corners, edge_pairs):
138
+ ind = np.lexsort(corners.T)
139
+ corners = corners[ind] # sorted by y, then x
140
+ corners = corners.round()
141
+ id_mapping = {old: new for new, old in enumerate(ind)}
142
+
143
+ all_ids = all_combibations[len(corners)]
144
+ edges = corners[all_ids]
145
+ labels = np.zeros(edges.shape[0])
146
+
147
+ N = len(corners)
148
+ edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs]
149
+ edge_pairs = [p for p in edge_pairs if p[0] < p[1]]
150
+ pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs]
151
+ labels[pos_ids] = 1
152
+
153
+ edge_ids = np.array(all_ids)
154
+ return corners, edges, edge_ids, labels
155
+
156
+
157
+ def collate_edge_info(data):
158
+ batched_data = {}
159
+ lengths_info = {}
160
+ for field in data.keys():
161
+ batch_values = data[field]
162
+ all_lens = [len(value) for value in batch_values]
163
+ max_len = max(all_lens)
164
+ pad_value = 0
165
+ batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values]
166
+ batch_values = np.stack(batch_values, axis=0)
167
+
168
+ if field in ['edge_coords', 'edge_labels', 'gt_values']:
169
+ batch_values = torch.Tensor(batch_values).long()
170
+ if field in ['processed_corners', 'edge_coords']:
171
+ lengths_info[field] = all_lens
172
+ batched_data[field] = batch_values
173
+
174
+ # Add length and mask into the data, the mask if for Transformers' input format, True means padding
175
+ for field, lengths in lengths_info.items():
176
+ lengths_str = field + '_lengths'
177
+ batched_data[lengths_str] = torch.Tensor(lengths).long()
178
+ mask = torch.arange(max(lengths))
179
+ mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1)
180
+ mask = mask >= batched_data[lengths_str].unsqueeze(-1)
181
+ mask_str = field + '_mask'
182
+ batched_data[mask_str] = mask
183
+
184
+ return batched_data
185
+
186
+
187
+ def pad_sequence(seq, length, pad_value=0):
188
+ if len(seq) == length:
189
+ return seq
190
+ else:
191
+ pad_len = length - len(seq)
192
+ if len(seq.shape) == 1:
193
+ if pad_value == 0:
194
+ paddings = np.zeros([pad_len, ])
195
+ else:
196
+ paddings = np.ones([pad_len, ]) * pad_value
197
+ else:
198
+ if pad_value == 0:
199
+ paddings = np.zeros([pad_len, ] + list(seq.shape[1:]))
200
+ else:
201
+ paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value
202
+ padded_seq = np.concatenate([seq, paddings], axis=0)
203
+ return padded_seq
204
+
205
+
206
+ def get_infer_edge_pairs(corners, confs):
207
+ ind = np.lexsort(corners.T)
208
+ corners = corners[ind] # sorted by y, then x
209
+ confs = confs[ind]
210
+
211
+ edge_ids = all_combibations[len(corners)]
212
+ edge_coords = corners[edge_ids]
213
+
214
+ edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long()
215
+ mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool()
216
+ edge_ids = torch.tensor(np.array(edge_ids))
217
+ return corners, confs, edge_coords, mask, edge_ids
218
+
219
+
220
+ def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path):
221
+ image = image.transpose([1, 2, 0])
222
+ image = (image * 255).astype(np.uint8)
223
+ image = np.ascontiguousarray(image)
224
+
225
+ for edge, label in zip(edges, edge_labels):
226
+ if label == 1:
227
+ cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2)
228
+
229
+ for c in corners:
230
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
231
+
232
+ cv2.imwrite(save_path, image)
models/deformable_transformer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from torch import nn, Tensor
4
+ from models.ops.modules import MSDeformAttn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DeformableTransformerEncoderLayer(nn.Module):
9
+ def __init__(self,
10
+ d_model=256, d_ffn=1024,
11
+ dropout=0.1, activation="relu",
12
+ n_levels=4, n_heads=8, n_points=4):
13
+ super().__init__()
14
+
15
+ # self attention
16
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
17
+ self.dropout1 = nn.Dropout(dropout)
18
+ self.norm1 = nn.LayerNorm(d_model)
19
+
20
+ # ffn
21
+ self.linear1 = nn.Linear(d_model, d_ffn)
22
+ self.activation = _get_activation_fn(activation)
23
+ self.dropout2 = nn.Dropout(dropout)
24
+ self.linear2 = nn.Linear(d_ffn, d_model)
25
+ self.dropout3 = nn.Dropout(dropout)
26
+ self.norm2 = nn.LayerNorm(d_model)
27
+
28
+ @staticmethod
29
+ def with_pos_embed(tensor, pos):
30
+ return tensor if pos is None else tensor + pos
31
+
32
+ def forward_ffn(self, src):
33
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
34
+ src = src + self.dropout3(src2)
35
+ src = self.norm2(src)
36
+ return src
37
+
38
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
39
+ # self attention
40
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index,
41
+ padding_mask)
42
+ src = src + self.dropout1(src2)
43
+ src = self.norm1(src)
44
+
45
+ # ffn
46
+ src = self.forward_ffn(src)
47
+
48
+ return src
49
+
50
+
51
+ class DeformableTransformerEncoder(nn.Module):
52
+ def __init__(self, encoder_layer, num_layers):
53
+ super().__init__()
54
+ self.layers = _get_clones(encoder_layer, num_layers)
55
+ self.num_layers = num_layers
56
+
57
+ @staticmethod
58
+ def get_reference_points(spatial_shapes, valid_ratios, device):
59
+ reference_points_list = []
60
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
61
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
62
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
63
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
64
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
65
+ ref = torch.stack((ref_x, ref_y), -1)
66
+ reference_points_list.append(ref)
67
+ reference_points = torch.cat(reference_points_list, 1)
68
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
69
+ return reference_points
70
+
71
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
72
+ output = src
73
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
74
+ for _, layer in enumerate(self.layers):
75
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
76
+
77
+ return output
78
+
79
+
80
+ class DeformableAttnDecoderLayer(nn.Module):
81
+ def __init__(self, d_model=256, d_ffn=1024,
82
+ dropout=0.1, activation="relu",
83
+ n_levels=4, n_heads=8, n_points=4):
84
+ super().__init__()
85
+ # cross attention
86
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
87
+ self.dropout1 = nn.Dropout(dropout)
88
+ self.norm1 = nn.LayerNorm(d_model)
89
+
90
+ # ffn
91
+ self.linear1 = nn.Linear(d_model, d_ffn)
92
+ self.activation = _get_activation_fn(activation)
93
+ self.dropout3 = nn.Dropout(dropout)
94
+ self.linear2 = nn.Linear(d_ffn, d_model)
95
+ self.dropout4 = nn.Dropout(dropout)
96
+ self.norm3 = nn.LayerNorm(d_model)
97
+
98
+ @staticmethod
99
+ def with_pos_embed(tensor, pos):
100
+ return tensor if pos is None else tensor + pos
101
+
102
+ def forward_ffn(self, tgt):
103
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
104
+ tgt = tgt + self.dropout4(tgt2)
105
+ tgt = self.norm3(tgt)
106
+ return tgt
107
+
108
+ def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index,
109
+ src_padding_mask=None,
110
+ key_padding_mask=None):
111
+ # cross attention
112
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
113
+ reference_points,
114
+ src, src_spatial_shapes, level_start_index, src_padding_mask)
115
+ tgt = tgt + self.dropout1(tgt2)
116
+ tgt = self.norm1(tgt)
117
+
118
+ # ffn
119
+ tgt = self.forward_ffn(tgt)
120
+
121
+ return tgt
122
+
123
+
124
+
125
+ class DeformableTransformerDecoderLayer(nn.Module):
126
+ def __init__(self, d_model=256, d_ffn=1024,
127
+ dropout=0.1, activation="relu",
128
+ n_levels=4, n_heads=8, n_points=4):
129
+ super().__init__()
130
+ # cross attention
131
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
132
+ self.dropout1 = nn.Dropout(dropout)
133
+ self.norm1 = nn.LayerNorm(d_model)
134
+
135
+ # self attention
136
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
137
+ self.dropout2 = nn.Dropout(dropout)
138
+ self.norm2 = nn.LayerNorm(d_model)
139
+
140
+ # ffn
141
+ self.linear1 = nn.Linear(d_model, d_ffn)
142
+ self.activation = _get_activation_fn(activation)
143
+ self.dropout3 = nn.Dropout(dropout)
144
+ self.linear2 = nn.Linear(d_ffn, d_model)
145
+ self.dropout4 = nn.Dropout(dropout)
146
+ self.norm3 = nn.LayerNorm(d_model)
147
+
148
+ @staticmethod
149
+ def with_pos_embed(tensor, pos):
150
+ return tensor if pos is None else tensor + pos
151
+
152
+ def forward_ffn(self, tgt):
153
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
154
+ tgt = tgt + self.dropout4(tgt2)
155
+ tgt = self.norm3(tgt)
156
+ return tgt
157
+
158
+ def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index,
159
+ src_padding_mask=None,
160
+ key_padding_mask=None,
161
+ get_image_feat=True):
162
+ # self attention
163
+ q = k = self.with_pos_embed(tgt, query_pos)
164
+ tgt2 = \
165
+ self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), key_padding_mask=key_padding_mask)[
166
+ 0].transpose(0, 1)
167
+ tgt = tgt + self.dropout2(tgt2)
168
+ tgt = self.norm2(tgt)
169
+
170
+ if get_image_feat:
171
+ # cross attention
172
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
173
+ reference_points,
174
+ src, src_spatial_shapes, level_start_index, src_padding_mask)
175
+ tgt = tgt + self.dropout1(tgt2)
176
+ tgt = self.norm1(tgt)
177
+
178
+ # ffn
179
+ tgt = self.forward_ffn(tgt)
180
+
181
+ return tgt
182
+
183
+
184
+ class DeformableTransformerDecoder(nn.Module):
185
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False, with_sa=True):
186
+ super().__init__()
187
+ self.layers = _get_clones(decoder_layer, num_layers)
188
+ self.num_layers = num_layers
189
+ self.return_intermediate = return_intermediate
190
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
191
+ self.with_sa = with_sa
192
+
193
+ def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
194
+ query_pos=None, src_padding_mask=None, key_padding_mask=None, get_image_feat=True):
195
+ output = tgt
196
+
197
+ intermediate = []
198
+ intermediate_reference_points = []
199
+ for lid, layer in enumerate(self.layers):
200
+ if reference_points.shape[-1] == 4:
201
+ reference_points_input = reference_points[:, :, None] \
202
+ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
203
+ else:
204
+ assert reference_points.shape[-1] == 2
205
+ reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
206
+ if self.with_sa:
207
+ output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index,
208
+ src_padding_mask, key_padding_mask, get_image_feat)
209
+ else:
210
+ output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes,
211
+ src_level_start_index,
212
+ src_padding_mask, key_padding_mask)
213
+
214
+ if self.return_intermediate:
215
+ intermediate.append(output)
216
+ intermediate_reference_points.append(reference_points)
217
+
218
+ if self.return_intermediate:
219
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
220
+
221
+ return output, reference_points
222
+
223
+
224
+ def _get_clones(module, N):
225
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
226
+
227
+
228
+ def _get_activation_fn(activation):
229
+ """Return an activation function given a string"""
230
+ if activation == "relu":
231
+ return F.relu
232
+ if activation == "gelu":
233
+ return F.gelu
234
+ if activation == "glu":
235
+ return F.glu
236
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
models/edge_models.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from models.mlp import MLP
6
+ from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
7
+ DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer
8
+ from models.ops.modules import MSDeformAttn
9
+ from models.corner_models import PositionEmbeddingSine
10
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
11
+ import torch.nn.functional as F
12
+ from utils.misc import NestedTensor
13
+
14
+
15
+ class HeatEdge(nn.Module):
16
+ def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
17
+ super(HeatEdge, self).__init__()
18
+ self.input_dim = input_dim
19
+ self.hidden_dim = hidden_dim
20
+ self.num_feature_levels = num_feature_levels
21
+
22
+ if num_feature_levels > 1:
23
+ num_backbone_outs = len(backbone_strides)
24
+ input_proj_list = []
25
+ for _ in range(num_backbone_outs):
26
+ in_channels = backbone_num_channels[_]
27
+ input_proj_list.append(nn.Sequential(
28
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
29
+ nn.GroupNorm(32, hidden_dim),
30
+ ))
31
+ for _ in range(num_feature_levels - num_backbone_outs):
32
+ input_proj_list.append(nn.Sequential(
33
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
34
+ nn.GroupNorm(32, hidden_dim),
35
+ ))
36
+ in_channels = hidden_dim
37
+ self.input_proj = nn.ModuleList(input_proj_list)
38
+ else:
39
+ self.input_proj = nn.ModuleList([
40
+ nn.Sequential(
41
+ nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
42
+ nn.GroupNorm(32, hidden_dim),
43
+ )])
44
+
45
+ self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
46
+
47
+ self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim)
48
+ self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2)
49
+
50
+ self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
51
+ num_decoder_layers=6, dim_feedforward=1024, dropout=0.1)
52
+
53
+ @staticmethod
54
+ def get_ms_feat(xs, img_mask):
55
+ out: Dict[str, NestedTensor] = {}
56
+ for name, x in sorted(xs.items()):
57
+ m = img_mask
58
+ assert m is not None
59
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
60
+ out[name] = NestedTensor(x, mask)
61
+ return out
62
+
63
+ def forward(self, image_feats, feat_mask, corner_outputs, edge_coords, edge_masks, gt_values, corner_nums,
64
+ max_candidates, do_inference=False):
65
+ # Prepare ConvNet features
66
+ features = self.get_ms_feat(image_feats, feat_mask)
67
+
68
+ srcs = []
69
+ masks = []
70
+ all_pos = []
71
+
72
+ new_features = list()
73
+ for name, x in sorted(features.items()):
74
+ new_features.append(x)
75
+ features = new_features
76
+
77
+ for l, feat in enumerate(features):
78
+ src, mask = feat.decompose()
79
+ mask = mask.to(src.device)
80
+ srcs.append(self.input_proj[l](src))
81
+ pos = self.img_pos(src).to(src.dtype)
82
+ all_pos.append(pos)
83
+ masks.append(mask)
84
+ assert mask is not None
85
+
86
+ if self.num_feature_levels > len(srcs):
87
+ _len_srcs = len(srcs)
88
+ for l in range(_len_srcs, self.num_feature_levels):
89
+ if l == _len_srcs:
90
+ src = self.input_proj[l](features[-1].tensors)
91
+ else:
92
+ src = self.input_proj[l](srcs[-1])
93
+ m = feat_mask
94
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
95
+ pos_l = self.img_pos(src).to(src.dtype)
96
+ srcs.append(src)
97
+ masks.append(mask)
98
+ all_pos.append(pos_l)
99
+
100
+ bs = edge_masks.size(0)
101
+ num_edges = edge_masks.size(1)
102
+
103
+ corner_feats = corner_outputs
104
+ edge_feats = list()
105
+ for b_i in range(bs):
106
+ feats = corner_feats[b_i, edge_coords[b_i, :, :, 1], edge_coords[b_i, :, :, 0], :]
107
+ edge_feats.append(feats)
108
+ edge_feats = torch.stack(edge_feats, dim=0)
109
+ edge_feats = edge_feats.view(bs, num_edges, -1)
110
+
111
+ edge_inputs = self.edge_input_fc(edge_feats.view(bs * num_edges, -1))
112
+ edge_inputs = edge_inputs.view(bs, num_edges, -1)
113
+
114
+ edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2
115
+ edge_center = edge_center / feat_mask.shape[1]
116
+
117
+ logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values = self.transformer(srcs,
118
+ masks,
119
+ all_pos,
120
+ edge_inputs,
121
+ edge_center,
122
+ gt_values,
123
+ edge_masks,
124
+ corner_nums,
125
+ max_candidates,
126
+ do_inference)
127
+
128
+ return logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values
129
+
130
+
131
+ class EdgeTransformer(nn.Module):
132
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
133
+ num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
134
+ activation="relu", return_intermediate_dec=False,
135
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
136
+ ):
137
+ super(EdgeTransformer, self).__init__()
138
+
139
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
140
+ dropout, activation,
141
+ num_feature_levels, nhead, enc_n_points)
142
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
143
+
144
+ decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
145
+ dropout, activation,
146
+ num_feature_levels, nhead, dec_n_points)
147
+ # one-layer decoder, without self-attention layers
148
+ self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
149
+
150
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
151
+ dropout, activation,
152
+ num_feature_levels, nhead, dec_n_points)
153
+
154
+ # edge decoder w/ self-attention layers (image-aware decoder and geom-only decoder)
155
+ self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers,
156
+ return_intermediate_dec, with_sa=True)
157
+
158
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
159
+
160
+ self.gt_label_embed = nn.Embedding(3, d_model)
161
+
162
+ self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
163
+ self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
164
+
165
+ self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
166
+ self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
167
+ self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
168
+ self._reset_parameters()
169
+
170
+ def _reset_parameters(self):
171
+ for p in self.parameters():
172
+ if p.dim() > 1:
173
+ nn.init.xavier_uniform_(p)
174
+ for m in self.modules():
175
+ if isinstance(m, MSDeformAttn):
176
+ m._reset_parameters()
177
+ normal_(self.level_embed)
178
+
179
+ def get_valid_ratio(self, mask):
180
+ _, H, W = mask.shape
181
+ valid_H = torch.sum(~mask[:, :, 0], 1)
182
+ valid_W = torch.sum(~mask[:, 0, :], 1)
183
+ valid_ratio_h = valid_H.float() / H
184
+ valid_ratio_w = valid_W.float() / W
185
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
186
+ return valid_ratio
187
+
188
+ def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, labels, key_padding_mask, corner_nums,
189
+ max_candidates, do_inference=False):
190
+ # prepare input for encoder
191
+ src_flatten = []
192
+ mask_flatten = []
193
+ lvl_pos_embed_flatten = []
194
+ spatial_shapes = []
195
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
196
+ bs, c, h, w = src.shape
197
+ spatial_shape = (h, w)
198
+ spatial_shapes.append(spatial_shape)
199
+ src = src.flatten(2).transpose(1, 2)
200
+ mask = mask.flatten(1)
201
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
202
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
203
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
204
+ src_flatten.append(src)
205
+ mask_flatten.append(mask)
206
+ src_flatten = torch.cat(src_flatten, 1)
207
+ mask_flatten = torch.cat(mask_flatten, 1)
208
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
209
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
210
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
211
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
212
+
213
+ # encoder
214
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
215
+ mask_flatten)
216
+
217
+ # prepare input for decoder
218
+ bs, _, c = memory.shape
219
+
220
+ tgt = query_embed
221
+
222
+ # per-edge filtering with single-layer decoder (no self-attn)
223
+ hs_per_edge, _ = self.per_edge_decoder(tgt, reference_points, memory,
224
+ spatial_shapes, level_start_index, valid_ratios, query_embed,
225
+ mask_flatten)
226
+ logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1)
227
+ filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids = self.candidate_filtering(
228
+ logits_per_edge,
229
+ hs_per_edge, query_embed, reference_points,
230
+ labels,
231
+ key_padding_mask, corner_nums, max_candidates)
232
+
233
+ # generate the info for masked training
234
+ if not do_inference:
235
+ filtered_gt_values = self.generate_gt_masking(filtered_labels, filtered_mask)
236
+ else:
237
+ filtered_gt_values = filtered_labels
238
+ gt_info = self.gt_label_embed(filtered_gt_values)
239
+
240
+ # relational decoder with image feature (image-aware decoder)
241
+ hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1))
242
+
243
+ hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory,
244
+ spatial_shapes, level_start_index, valid_ratios, filtered_query,
245
+ mask_flatten,
246
+ key_padding_mask=filtered_mask, get_image_feat=True)
247
+
248
+ logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1)
249
+
250
+ # relational decoder without image feature (geom-only decoder)
251
+ rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1))
252
+
253
+ hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory,
254
+ spatial_shapes, level_start_index, valid_ratios, filtered_query,
255
+ mask_flatten,
256
+ key_padding_mask=filtered_mask, get_image_feat=False)
257
+
258
+ logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1)
259
+
260
+ return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values
261
+
262
+ @staticmethod
263
+ def candidate_filtering(logits, hs, query, rp, labels, key_padding_mask, corner_nums, max_candidates):
264
+ """
265
+ Filter out the easy-negatives from the edge candidates, and update the edge information correspondingly
266
+ """
267
+ B, L, _ = hs.shape
268
+ preds = logits.detach().softmax(1)[:, 1, :] # BxL
269
+ preds[key_padding_mask == True] = -1 # ignore the masking parts
270
+ sorted_ids = torch.argsort(preds, dim=-1, descending=True)
271
+ filtered_hs = list()
272
+ filtered_mask = list()
273
+ filtered_query = list()
274
+ filtered_rp = list()
275
+ filtered_labels = list()
276
+ selected_ids = list()
277
+ for b_i in range(B):
278
+ num_candidates = corner_nums[b_i] * 3
279
+ ids = sorted_ids[b_i, :max_candidates[b_i]]
280
+ filtered_hs.append(hs[b_i][ids])
281
+ new_mask = key_padding_mask[b_i][ids]
282
+ new_mask[num_candidates:] = True
283
+ filtered_mask.append(new_mask)
284
+ filtered_query.append(query[b_i][ids])
285
+ filtered_rp.append(rp[b_i][ids])
286
+ filtered_labels.append(labels[b_i][ids])
287
+ selected_ids.append(ids)
288
+ filtered_hs = torch.stack(filtered_hs, dim=0)
289
+ filtered_mask = torch.stack(filtered_mask, dim=0)
290
+ filtered_query = torch.stack(filtered_query, dim=0)
291
+ filtered_rp = torch.stack(filtered_rp, dim=0)
292
+ filtered_labels = torch.stack(filtered_labels, dim=0)
293
+ selected_ids = torch.stack(selected_ids, dim=0)
294
+
295
+ return filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids
296
+
297
+ @staticmethod
298
+ def generate_gt_masking(labels, mask):
299
+ """
300
+ Generate the info for masked training on-the-fly with ratio=0.5
301
+ """
302
+ bs = labels.shape[0]
303
+ gt_values = torch.zeros_like(mask).long()
304
+ for b_i in range(bs):
305
+ edge_length = (mask[b_i] == 0).sum()
306
+ rand_ratio = np.random.rand() * 0.5 + 0.5
307
+ gt_rand = torch.rand(edge_length)
308
+ gt_flag = torch.zeros(edge_length)
309
+ gt_flag[torch.where(gt_rand >= rand_ratio)] = 1
310
+ gt_idx = torch.where(gt_flag == 1)
311
+ pred_idx = torch.where(gt_flag == 0)
312
+ gt_values[b_i, gt_idx[0]] = labels[b_i, gt_idx[0]]
313
+ gt_values[b_i, pred_idx[0]] = 2 # use 2 to represent unknown value, need to predict
314
+ return gt_values
models/loss.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from utils.geometry_utils import edge_acc
5
+
6
+
7
+ class CornerCriterion(nn.Module):
8
+ def __init__(self, image_size):
9
+ super().__init__()
10
+ self.loss_rate = 9
11
+
12
+ def forward(self, outputs_s1, targets, gauss_targets, epoch=0):
13
+ # Compute the acc first, use the acc to guide the setup of loss weight
14
+ preds_s1 = (outputs_s1 >= 0.5).float()
15
+ pos_target_ids = torch.where(targets == 1)
16
+ correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum()
17
+ recall_s1 = correct / len(pos_target_ids[0])
18
+
19
+ rate = self.loss_rate
20
+
21
+ loss_weight = (gauss_targets > 0.5).float() * rate + 1
22
+ loss_s1 = F.binary_cross_entropy(outputs_s1, gauss_targets, weight=loss_weight, reduction='none')
23
+ loss_s1 = loss_s1.sum(-1).sum(-1).mean()
24
+
25
+ return loss_s1, recall_s1
26
+
27
+
28
+ class EdgeCriterion(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.edge_loss = nn.CrossEntropyLoss(weight=torch.tensor([0.33, 1.0]).cuda(), reduction='none')
32
+
33
+ def forward(self, logits_s1, logits_s2_hybrid, logits_s2_rel, s2_ids, s2_edge_mask, edge_labels, edge_lengths,
34
+ edge_mask, s2_gt_values):
35
+ # loss for edge filtering
36
+ s1_losses = self.edge_loss(logits_s1, edge_labels)
37
+ s1_losses[torch.where(edge_mask == True)] = 0
38
+ s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0]
39
+ gt_values = torch.ones_like(edge_mask).long() * 2
40
+ s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values)
41
+
42
+ # loss for stage-2
43
+ s2_labels = torch.gather(edge_labels, 1, s2_ids)
44
+
45
+ # the image-aware decoder
46
+ s2_losses_hybrid = self.edge_loss(logits_s2_hybrid, s2_labels)
47
+ s2_losses_hybrid[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
48
+ # aggregate the loss into the final scalar
49
+ s2_losses_hybrid = s2_losses_hybrid[torch.where(s2_losses_hybrid > 0)].sum() / s2_edge_mask.shape[0]
50
+ s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
51
+ # compute edge-level acc
52
+ s2_acc_hybrid = edge_acc(logits_s2_hybrid, s2_labels, s2_edge_lengths, s2_gt_values)
53
+
54
+ # the geom-only decoder
55
+ s2_losses_rel = self.edge_loss(logits_s2_rel, s2_labels)
56
+ s2_losses_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
57
+ # aggregate the loss into the final scalar
58
+ s2_losses_rel = s2_losses_rel[torch.where(s2_losses_rel > 0)].sum() / s2_edge_mask.shape[0]
59
+ s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
60
+ # compute edge-level f1-score
61
+ s2_acc_rel = edge_acc(logits_s2_rel, s2_labels, s2_edge_lengths, s2_gt_values)
62
+
63
+ return s1_losses, s1_acc, s2_losses_hybrid, s2_acc_hybrid, s2_losses_rel, s2_acc_rel
models/mlp.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class MLP(nn.Module):
6
+ """ Very simple multi-layer perceptron (also called FFN)"""
7
+
8
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
9
+ super(MLP, self).__init__()
10
+ self.output_dim = output_dim
11
+ self.num_layers = num_layers
12
+ h = [hidden_dim] * (num_layers - 1)
13
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
14
+
15
+ def forward(self, x):
16
+ B, N, D = x.size()
17
+ x = x.reshape(B*N, D)
18
+ for i, layer in enumerate(self.layers):
19
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
20
+ x = x.view(B, N, self.output_dim)
21
+ return x
models/ops/functions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn_func import MSDeformAttnFunction
10
+
models/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.autograd import Function
16
+ from torch.autograd.function import once_differentiable
17
+
18
+ import MultiScaleDeformableAttention as MSDA
19
+
20
+
21
+ class MSDeformAttnFunction(Function):
22
+ @staticmethod
23
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24
+ ctx.im2col_step = im2col_step
25
+ output = MSDA.ms_deform_attn_forward(
26
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28
+ return output
29
+
30
+ @staticmethod
31
+ @once_differentiable
32
+ def backward(ctx, grad_output):
33
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34
+ grad_value, grad_sampling_loc, grad_attn_weight = \
35
+ MSDA.ms_deform_attn_backward(
36
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37
+
38
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39
+
40
+
41
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42
+ # for debug and test only,
43
+ # need to use cuda version instead
44
+ N_, S_, M_, D_ = value.shape
45
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47
+ sampling_grids = 2 * sampling_locations - 1
48
+ sampling_value_list = []
49
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54
+ # N_*M_, D_, Lq_, P_
55
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56
+ mode='bilinear', padding_mode='zeros', align_corners=False)
57
+ sampling_value_list.append(sampling_value_l_)
58
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61
+ return output.transpose(1, 2).contiguous()
models/ops/make.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ # ------------------------------------------------------------------------------------------------
9
+
10
+ python3 setup.py build install --user
models/ops/modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn import MSDeformAttn
models/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import warnings
14
+ import math
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.init import xavier_uniform_, constant_
20
+
21
+ from ..functions import MSDeformAttnFunction
22
+
23
+
24
+ def _is_power_of_2(n):
25
+ if (not isinstance(n, int)) or (n < 0):
26
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
27
+ return (n & (n-1) == 0) and n != 0
28
+
29
+
30
+ class MSDeformAttn(nn.Module):
31
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
32
+ """
33
+ Multi-Scale Deformable Attention Module
34
+ :param d_model hidden dimension
35
+ :param n_levels number of feature levels
36
+ :param n_heads number of attention heads
37
+ :param n_points number of sampling points per attention head per feature level
38
+ """
39
+ super().__init__()
40
+ if d_model % n_heads != 0:
41
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
42
+ _d_per_head = d_model // n_heads
43
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
44
+ if not _is_power_of_2(_d_per_head):
45
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
46
+ "which is more efficient in our CUDA implementation.")
47
+
48
+ self.im2col_step = 64
49
+
50
+ self.d_model = d_model
51
+ self.n_levels = n_levels
52
+ self.n_heads = n_heads
53
+ self.n_points = n_points
54
+
55
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
56
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
57
+ self.value_proj = nn.Linear(d_model, d_model)
58
+ self.output_proj = nn.Linear(d_model, d_model)
59
+
60
+ self._reset_parameters()
61
+
62
+ def _reset_parameters(self):
63
+ constant_(self.sampling_offsets.weight.data, 0.)
64
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
65
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
66
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
67
+ for i in range(self.n_points):
68
+ grid_init[:, :, i, :] *= i + 1
69
+ with torch.no_grad():
70
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
71
+ constant_(self.attention_weights.weight.data, 0.)
72
+ constant_(self.attention_weights.bias.data, 0.)
73
+ xavier_uniform_(self.value_proj.weight.data)
74
+ constant_(self.value_proj.bias.data, 0.)
75
+ xavier_uniform_(self.output_proj.weight.data)
76
+ constant_(self.output_proj.bias.data, 0.)
77
+
78
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
79
+ """
80
+ :param query (N, Length_{query}, C)
81
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
82
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
83
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
84
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
85
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
86
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
87
+
88
+ :return output (N, Length_{query}, C)
89
+ """
90
+ N, Len_q, _ = query.shape
91
+ N, Len_in, _ = input_flatten.shape
92
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
93
+
94
+ value = self.value_proj(input_flatten)
95
+ if input_padding_mask is not None:
96
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
97
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
98
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
99
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
100
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
101
+ # N, Len_q, n_heads, n_levels, n_points, 2
102
+ if reference_points.shape[-1] == 2:
103
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
104
+ sampling_locations = reference_points[:, :, None, :, None, :] \
105
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
106
+ elif reference_points.shape[-1] == 4:
107
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
108
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
109
+ else:
110
+ raise ValueError(
111
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
112
+ output = MSDeformAttnFunction.apply(
113
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
114
+ output = self.output_proj(output)
115
+ return output
models/ops/setup.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ import os
10
+ import glob
11
+
12
+ import torch
13
+
14
+ from torch.utils.cpp_extension import CUDA_HOME
15
+ from torch.utils.cpp_extension import CppExtension
16
+ from torch.utils.cpp_extension import CUDAExtension
17
+
18
+ from setuptools import find_packages
19
+ from setuptools import setup
20
+
21
+ requirements = ["torch", "torchvision"]
22
+
23
+ def get_extensions():
24
+ this_dir = os.path.dirname(os.path.abspath(__file__))
25
+ extensions_dir = os.path.join(this_dir, "src")
26
+
27
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30
+
31
+ sources = main_file + source_cpu
32
+ extension = CppExtension
33
+ extra_compile_args = {"cxx": []}
34
+ define_macros = []
35
+
36
+ if torch.cuda.is_available() and CUDA_HOME is not None:
37
+ extension = CUDAExtension
38
+ sources += source_cuda
39
+ define_macros += [("WITH_CUDA", None)]
40
+ extra_compile_args["nvcc"] = [
41
+ "-DCUDA_HAS_FP16=1",
42
+ "-D__CUDA_NO_HALF_OPERATORS__",
43
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
44
+ "-D__CUDA_NO_HALF2_OPERATORS__",
45
+ ]
46
+ else:
47
+ raise NotImplementedError('Cuda is not availabel')
48
+
49
+ sources = [os.path.join(extensions_dir, s) for s in sources]
50
+ include_dirs = [extensions_dir]
51
+ ext_modules = [
52
+ extension(
53
+ "MultiScaleDeformableAttention",
54
+ sources,
55
+ include_dirs=include_dirs,
56
+ define_macros=define_macros,
57
+ extra_compile_args=extra_compile_args,
58
+ )
59
+ ]
60
+ return ext_modules
61
+
62
+ setup(
63
+ name="MultiScaleDeformableAttention",
64
+ version="1.0",
65
+ author="Weijie Su",
66
+ url="https://github.com/fundamentalvision/Deformable-DETR",
67
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
68
+ packages=find_packages(exclude=("configs", "tests",)),
69
+ ext_modules=get_extensions(),
70
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71
+ )
models/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
41
+
models/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+
models/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "cuda/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+
20
+ at::Tensor ms_deform_attn_cuda_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step)
27
+ {
28
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33
+
34
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39
+
40
+ const int batch = value.size(0);
41
+ const int spatial_size = value.size(1);
42
+ const int num_heads = value.size(2);
43
+ const int channels = value.size(3);
44
+
45
+ const int num_levels = spatial_shapes.size(0);
46
+
47
+ const int num_query = sampling_loc.size(1);
48
+ const int num_point = sampling_loc.size(4);
49
+
50
+ const int im2col_step_ = std::min(batch, im2col_step);
51
+
52
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53
+
54
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55
+
56
+ const int batch_n = im2col_step_;
57
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58
+ auto per_value_size = spatial_size * num_heads * channels;
59
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61
+ for (int n = 0; n < batch/im2col_step_; ++n)
62
+ {
63
+ auto columns = output_n.select(0, n);
64
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
67
+ spatial_shapes.data<int64_t>(),
68
+ level_start_index.data<int64_t>(),
69
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
70
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
71
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72
+ columns.data<scalar_t>());
73
+
74
+ }));
75
+ }
76
+
77
+ output = output.view({batch, num_query, num_heads*channels});
78
+
79
+ return output;
80
+ }
81
+
82
+
83
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
84
+ const at::Tensor &value,
85
+ const at::Tensor &spatial_shapes,
86
+ const at::Tensor &level_start_index,
87
+ const at::Tensor &sampling_loc,
88
+ const at::Tensor &attn_weight,
89
+ const at::Tensor &grad_output,
90
+ const int im2col_step)
91
+ {
92
+
93
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99
+
100
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106
+
107
+ const int batch = value.size(0);
108
+ const int spatial_size = value.size(1);
109
+ const int num_heads = value.size(2);
110
+ const int channels = value.size(3);
111
+
112
+ const int num_levels = spatial_shapes.size(0);
113
+
114
+ const int num_query = sampling_loc.size(1);
115
+ const int num_point = sampling_loc.size(4);
116
+
117
+ const int im2col_step_ = std::min(batch, im2col_step);
118
+
119
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120
+
121
+ auto grad_value = at::zeros_like(value);
122
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
123
+ auto grad_attn_weight = at::zeros_like(attn_weight);
124
+
125
+ const int batch_n = im2col_step_;
126
+ auto per_value_size = spatial_size * num_heads * channels;
127
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130
+
131
+ for (int n = 0; n < batch/im2col_step_; ++n)
132
+ {
133
+ auto grad_output_g = grad_output_n.select(0, n);
134
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136
+ grad_output_g.data<scalar_t>(),
137
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
138
+ spatial_shapes.data<int64_t>(),
139
+ level_start_index.data<int64_t>(),
140
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
141
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
142
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
144
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
145
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
146
+
147
+ }));
148
+ }
149
+
150
+ return {
151
+ grad_value, grad_sampling_loc, grad_attn_weight
152
+ };
153
+ }
models/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const at::Tensor &grad_output,
29
+ const int im2col_step);
30
+
models/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
models/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.type().is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.type().is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
62
+
models/ops/src/vision.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include "ms_deform_attn.h"
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16
+ }
models/ops/test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import time
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.autograd import gradcheck
17
+
18
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19
+
20
+
21
+ N, M, D = 1, 2, 2
22
+ Lq, L, P = 2, 2, 2
23
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25
+ S = sum([(H*W).item() for H, W in shapes])
26
+
27
+
28
+ torch.manual_seed(3)
29
+
30
+
31
+ @torch.no_grad()
32
+ def check_forward_equal_with_pytorch_double():
33
+ value = torch.rand(N, S, M, D).cuda() * 0.01
34
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37
+ im2col_step = 2
38
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40
+ fwdok = torch.allclose(output_cuda, output_pytorch)
41
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
42
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43
+
44
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45
+
46
+
47
+ @torch.no_grad()
48
+ def check_forward_equal_with_pytorch_float():
49
+ value = torch.rand(N, S, M, D).cuda() * 0.01
50
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53
+ im2col_step = 2
54
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
58
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59
+
60
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61
+
62
+
63
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64
+
65
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
66
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69
+ im2col_step = 2
70
+ func = MSDeformAttnFunction.apply
71
+
72
+ value.requires_grad = grad_value
73
+ sampling_locations.requires_grad = grad_sampling_loc
74
+ attention_weights.requires_grad = grad_attn_weight
75
+
76
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77
+
78
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
79
+
80
+
81
+ if __name__ == '__main__':
82
+ check_forward_equal_with_pytorch_double()
83
+ check_forward_equal_with_pytorch_float()
84
+
85
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86
+ check_gradient_numerical(channels, True, True, True)
87
+
88
+
89
+
models/resnet.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+
6
+ def convrelu(in_channels, out_channels, kernel, padding):
7
+ return nn.Sequential(
8
+ nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
9
+ nn.ReLU(inplace=True),
10
+ )
11
+
12
+
13
+ class ResNetBackbone(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.base_model = models.resnet50(pretrained=False)
17
+ self.base_layers = list(self.base_model.children())
18
+
19
+ self.conv_original_size0 = convrelu(3, 64, 3, 1)
20
+ self.conv_original_size1 = convrelu(64, 64, 3, 1)
21
+ self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
22
+ self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
23
+ self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
24
+ self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
25
+ self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
26
+
27
+ self.strides = [8, 16, 32]
28
+ self.num_channels = [512, 1024, 2048]
29
+
30
+ def forward(self, inputs):
31
+ x_original = self.conv_original_size0(inputs)
32
+ x_original = self.conv_original_size1(x_original)
33
+ layer0 = self.layer0(inputs)
34
+ layer1 = self.layer1(layer0)
35
+ layer2 = self.layer2(layer1)
36
+ layer3 = self.layer3(layer2)
37
+ layer4 = self.layer4(layer3)
38
+
39
+ xs = {"0": layer2, "1": layer3, "2": layer4}
40
+ all_feats = {'layer0': layer0, 'layer1': layer1, 'layer2': layer2,
41
+ 'layer3': layer3, 'layer4': layer4, 'x_original': x_original}
42
+
43
+ mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
44
+ return xs, mask, all_feats
45
+
46
+ def train(self, mode=True):
47
+ # Override train so that the training mode is set as we want
48
+ nn.Module.train(self, mode)
49
+ if mode:
50
+ # fix all bn layers
51
+ def set_bn_eval(m):
52
+ classname = m.__class__.__name__
53
+ if classname.find('BatchNorm') != -1:
54
+ m.eval()
55
+
56
+ self.apply(set_bn_eval)
57
+
58
+
59
+ class ResNetUNet(nn.Module):
60
+ def __init__(self, n_class, out_dim=None, ms_feat=False):
61
+ super().__init__()
62
+
63
+ self.return_ms_feat = ms_feat
64
+ self.out_dim = out_dim
65
+
66
+ self.base_model = models.resnet50(pretrained=True)
67
+ self.base_layers = list(self.base_model.children())
68
+
69
+ self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
70
+ # self.layer0_1x1 = convrelu(64, 64, 1, 0)
71
+ self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
72
+ # self.layer1_1x1 = convrelu(256, 256, 1, 0)
73
+ self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
74
+ # self.layer2_1x1 = convrelu(512, 512, 1, 0)
75
+ self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
76
+ # self.layer3_1x1 = convrelu(1024, 1024, 1, 0)
77
+ self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
78
+ # self.layer4_1x1 = convrelu(2048, 2048, 1, 0)
79
+
80
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
81
+
82
+ self.conv_up3 = convrelu(1024 + 2048, 1024, 3, 1)
83
+ self.conv_up2 = convrelu(512 + 1024, 512, 3, 1)
84
+ self.conv_up1 = convrelu(256 + 512, 256, 3, 1)
85
+ self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
86
+ # self.conv_up1 = convrelu(512, 256, 3, 1)
87
+ # self.conv_up0 = convrelu(256, 128, 3, 1)
88
+
89
+ self.conv_original_size0 = convrelu(3, 64, 3, 1)
90
+ self.conv_original_size1 = convrelu(64, 64, 3, 1)
91
+ self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
92
+ # self.conv_last = nn.Conv2d(128, n_class, 1)
93
+ self.conv_last = nn.Conv2d(64, n_class, 1)
94
+ if out_dim:
95
+ self.conv_out = nn.Conv2d(64, out_dim, 1)
96
+ # self.conv_out = nn.Conv2d(128, out_dim, 1)
97
+
98
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
99
+ self.strides = [8, 16, 32]
100
+ self.num_channels = [512, 1024, 2048]
101
+
102
+ def forward(self, inputs):
103
+ x_original = self.conv_original_size0(inputs)
104
+ x_original = self.conv_original_size1(x_original)
105
+
106
+ layer0 = self.layer0(inputs)
107
+ layer1 = self.layer1(layer0)
108
+ layer2 = self.layer2(layer1)
109
+ layer3 = self.layer3(layer2)
110
+ layer4 = self.layer4(layer3)
111
+
112
+ # layer4 = self.layer4_1x1(layer4)
113
+ x = self.upsample(layer4)
114
+ # layer3 = self.layer3_1x1(layer3)
115
+ x = torch.cat([x, layer3], dim=1)
116
+ x = self.conv_up3(x)
117
+ layer3_up = x
118
+
119
+ x = self.upsample(x)
120
+ # layer2 = self.layer2_1x1(layer2)
121
+ x = torch.cat([x, layer2], dim=1)
122
+ x = self.conv_up2(x)
123
+ layer2_up = x
124
+
125
+ x = self.upsample(x)
126
+ # layer1 = self.layer1_1x1(layer1)
127
+ x = torch.cat([x, layer1], dim=1)
128
+ x = self.conv_up1(x)
129
+
130
+ x = self.upsample(x)
131
+ # layer0 = self.layer0_1x1(layer0)
132
+ x = torch.cat([x, layer0], dim=1)
133
+ x = self.conv_up0(x)
134
+
135
+ x = self.upsample(x)
136
+ x = torch.cat([x, x_original], dim=1)
137
+ x = self.conv_original_size2(x)
138
+
139
+ out = self.conv_last(x)
140
+ out = out.sigmoid().squeeze(1)
141
+
142
+ # xs = {"0": layer2, "1": layer3, "2": layer4}
143
+ xs = {"0": layer2_up, "1": layer3_up, "2": layer4}
144
+ mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
145
+ # ms_feats = self.ms_feat(xs, mask)
146
+
147
+ if self.return_ms_feat:
148
+ if self.out_dim:
149
+ out_feat = self.conv_out(x)
150
+ out_feat = out_feat.permute(0, 2, 3, 1)
151
+ return xs, mask, out, out_feat
152
+ else:
153
+ return xs, mask, out
154
+ else:
155
+ return out
156
+
157
+ def train(self, mode=True):
158
+ # Override train so that the training mode is set as we want
159
+ nn.Module.train(self, mode)
160
+ if mode:
161
+ # fix all bn layers
162
+ def set_bn_eval(m):
163
+ classname = m.__class__.__name__
164
+ if classname.find('BatchNorm') != -1:
165
+ m.eval()
166
+
167
+ self.apply(set_bn_eval)
models/stacked_hg.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hourglass network inserted in the pre-activated Resnet
3
+ Use lr=0.01 for current version
4
+ (c) Nan Xue (HAWP)
5
+ (c) Yichao Zhou (LCNN)
6
+ (c) YANG, Wei
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ __all__ = ["HourglassNet", "hg"]
13
+
14
+
15
+ class Bottleneck2D(nn.Module):
16
+ expansion = 2
17
+
18
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
19
+ super(Bottleneck2D, self).__init__()
20
+
21
+ self.bn1 = nn.BatchNorm2d(inplanes)
22
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
23
+ self.bn2 = nn.BatchNorm2d(planes)
24
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
25
+ self.bn3 = nn.BatchNorm2d(planes)
26
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.bn1(x)
35
+ out = self.relu(out)
36
+ out = self.conv1(out)
37
+
38
+ out = self.bn2(out)
39
+ out = self.relu(out)
40
+ out = self.conv2(out)
41
+
42
+ out = self.bn3(out)
43
+ out = self.relu(out)
44
+ out = self.conv3(out)
45
+
46
+ if self.downsample is not None:
47
+ residual = self.downsample(x)
48
+
49
+ out += residual
50
+
51
+ return out
52
+
53
+
54
+ class Hourglass(nn.Module):
55
+ def __init__(self, block, num_blocks, planes, depth):
56
+ super(Hourglass, self).__init__()
57
+ self.depth = depth
58
+ self.block = block
59
+ self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
60
+
61
+ def _make_residual(self, block, num_blocks, planes):
62
+ layers = []
63
+ for i in range(0, num_blocks):
64
+ layers.append(block(planes * block.expansion, planes))
65
+ return nn.Sequential(*layers)
66
+
67
+ def _make_hour_glass(self, block, num_blocks, planes, depth):
68
+ hg = []
69
+ for i in range(depth):
70
+ res = []
71
+ for j in range(3):
72
+ res.append(self._make_residual(block, num_blocks, planes))
73
+ if i == 0:
74
+ res.append(self._make_residual(block, num_blocks, planes))
75
+ hg.append(nn.ModuleList(res))
76
+ return nn.ModuleList(hg)
77
+
78
+ def _hour_glass_forward(self, n, x):
79
+ up1 = self.hg[n - 1][0](x)
80
+ low1 = F.max_pool2d(x, 2, stride=2)
81
+ low1 = self.hg[n - 1][1](low1)
82
+
83
+ if n > 1:
84
+ low2 = self._hour_glass_forward(n - 1, low1)
85
+ else:
86
+ low2 = self.hg[n - 1][3](low1)
87
+ low3 = self.hg[n - 1][2](low2)
88
+ up2 = F.interpolate(low3, scale_factor=2)
89
+ out = up1 + up2
90
+ return out
91
+
92
+ def forward(self, x):
93
+ return self._hour_glass_forward(self.depth, x)
94
+
95
+
96
+ class HourglassNet(nn.Module):
97
+ """Hourglass model from Newell et al ECCV 2016"""
98
+
99
+ def __init__(self, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes):
100
+ super(HourglassNet, self).__init__()
101
+
102
+ self.inplanes = inplanes
103
+ self.num_feats = num_feats
104
+ self.num_stacks = num_stacks
105
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
106
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
107
+ self.relu = nn.ReLU(inplace=True)
108
+ self.layer1 = self._make_residual(block, self.inplanes, 1)
109
+ self.layer2 = self._make_residual(block, self.inplanes, 1)
110
+ self.layer3 = self._make_residual(block, self.num_feats, 1)
111
+ self.maxpool = nn.MaxPool2d(2, stride=2)
112
+
113
+ # build hourglass modules
114
+ ch = self.num_feats * block.expansion
115
+ # vpts = []
116
+ hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
117
+ for i in range(num_stacks):
118
+ hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
119
+ res.append(self._make_residual(block, self.num_feats, num_blocks))
120
+ fc.append(self._make_fc(ch, ch))
121
+ score.append(head(ch, num_classes))
122
+ # vpts.append(VptsHead(ch))
123
+ # vpts.append(nn.Linear(ch, 9))
124
+ # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
125
+ # score[i].bias.data[0] += 4.6
126
+ # score[i].bias.data[2] += 4.6
127
+ if i < num_stacks - 1:
128
+ fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
129
+ score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
130
+ self.hg = nn.ModuleList(hg)
131
+ self.res = nn.ModuleList(res)
132
+ self.fc = nn.ModuleList(fc)
133
+ self.score = nn.ModuleList(score)
134
+ # self.vpts = nn.ModuleList(vpts)
135
+ self.fc_ = nn.ModuleList(fc_)
136
+ self.score_ = nn.ModuleList(score_)
137
+
138
+ def _make_residual(self, block, planes, blocks, stride=1):
139
+ downsample = None
140
+ if stride != 1 or self.inplanes != planes * block.expansion:
141
+ downsample = nn.Sequential(
142
+ nn.Conv2d(
143
+ self.inplanes,
144
+ planes * block.expansion,
145
+ kernel_size=1,
146
+ stride=stride,
147
+ )
148
+ )
149
+
150
+ layers = []
151
+ layers.append(block(self.inplanes, planes, stride, downsample))
152
+ self.inplanes = planes * block.expansion
153
+ for i in range(1, blocks):
154
+ layers.append(block(self.inplanes, planes))
155
+
156
+ return nn.Sequential(*layers)
157
+
158
+ def _make_fc(self, inplanes, outplanes):
159
+ bn = nn.BatchNorm2d(inplanes)
160
+ conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
161
+ return nn.Sequential(conv, bn, self.relu)
162
+
163
+ def forward(self, x):
164
+ out = []
165
+ x = self.conv1(x)
166
+ x = self.bn1(x)
167
+ x = self.relu(x)
168
+
169
+ x = self.layer1(x)
170
+ x = self.maxpool(x)
171
+ x = self.layer2(x)
172
+ x = self.layer3(x)
173
+
174
+ for i in range(self.num_stacks):
175
+ y = self.hg[i](x)
176
+ y = self.res[i](y)
177
+ y = self.fc[i](y)
178
+ score = self.score[i](y)
179
+ out.append(score)
180
+
181
+ if i < self.num_stacks - 1:
182
+ fc_ = self.fc_[i](y)
183
+ score_ = self.score_[i](score)
184
+ x = x + fc_ + score_
185
+
186
+ return out[::-1], y
187
+
188
+ def train(self, mode=True):
189
+ # Override train so that the training mode is set as we want
190
+ nn.Module.train(self, mode)
191
+ if mode:
192
+ # fix all bn layers
193
+ def set_bn_eval(m):
194
+ classname = m.__class__.__name__
195
+ if classname.find('BatchNorm') != -1:
196
+ m.eval()
197
+
198
+ self.apply(set_bn_eval)
199
+
200
+
201
+ class MultitaskHead(nn.Module):
202
+ def __init__(self, input_channels, num_class, head_size):
203
+ super(MultitaskHead, self).__init__()
204
+
205
+ m = int(input_channels / 4)
206
+ heads = []
207
+ for output_channels in sum(head_size, []):
208
+ heads.append(
209
+ nn.Sequential(
210
+ nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
211
+ nn.ReLU(inplace=True),
212
+ nn.Conv2d(m, output_channels, kernel_size=1),
213
+ )
214
+ )
215
+ self.heads = nn.ModuleList(heads)
216
+ assert num_class == sum(sum(head_size, []))
217
+
218
+ def forward(self, x):
219
+ return torch.cat([head(x) for head in self.heads], dim=1)
220
+
221
+
222
+ def build_hg():
223
+ inplanes = 64
224
+ num_feats = 256 //2
225
+ depth = 4
226
+ num_stacks = 2
227
+ num_blocks = 1
228
+ head_size = [[2], [2]]
229
+
230
+ out_feature_channels = 256
231
+
232
+ num_class = sum(sum(head_size, []))
233
+ model = HourglassNet(
234
+ block=Bottleneck2D,
235
+ inplanes = inplanes,
236
+ num_feats= num_feats,
237
+ depth=depth,
238
+ head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
239
+ num_stacks = num_stacks,
240
+ num_blocks = num_blocks,
241
+ num_classes = num_class)
242
+
243
+ model.out_feature_channels = out_feature_channels
244
+
245
+ return model
246
+
predict.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: [egrt]
3
+ Date: 2022-08-23 13:21:27
4
+ LastEditors: [egrt]
5
+ LastEditTime: 2022-08-23 13:45:21
6
+ Description:
7
+ '''
8
+ #--------------------------------------------------------------#
9
+ # 对单张图片进行预测,运行结果保存在根目录
10
+ # 默认保存文件为results/predict_out/predict_srgan.png
11
+ #--------------------------------------------------------------#
12
+ from PIL import Image
13
+
14
+ from HEAT import HEAT
15
+
16
+ if __name__ == "__main__":
17
+ heat = HEAT()
18
+ #----------------------------#
19
+ # 单张图片的保存路径
20
+ #----------------------------#
21
+ save_path = "assets/test_out.jpg"
22
+
23
+ while True:
24
+ img = input('Input image filename:')
25
+ try:
26
+ image = Image.open(img)
27
+ except:
28
+ print('Open Error! Try again!')
29
+ continue
30
+ else:
31
+ r_image = heat.detect_one_image(image)
32
+ r_image.save(save_path)
33
+ r_image.show()
qualitative_outdoor/generate_html.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import numpy as np
4
+
5
+
6
+ head = '''
7
+ <html>
8
+ <head>
9
+ <style>
10
+ td {text-align: center;}
11
+ </style>
12
+ </head>
13
+ <p>
14
+ </p>
15
+ <br>
16
+ <table border="1">
17
+ '''
18
+
19
+ end = '''
20
+ </table>
21
+ <br>`
22
+ </html>
23
+ '''
24
+
25
+ def writeHTML(out_path, results_dirs):
26
+ f = open(out_path, 'w')
27
+ f.write(head + '\n')
28
+ f.write('<tr>'
29
+ '<td style="background-color:#FFFFFF"> ID </td> '
30
+ '<td style="background-color:#FFFFFF"> Input </td> '
31
+ '<td style="background-color:#FFFFFF"> ConvMPN </td> '
32
+ '<td style="background-color:#FFFFFF"> Exp-cls </td> '
33
+ '<td style="background-color:#FFFFFF"> HAWP </td> '
34
+ '<td style="background-color:#FFFFFF"> LETR </td> '
35
+ '<td style="background-color:#FFFFFF"> HEAT (Ours) </td> '
36
+ '<td style="background-color:#FFFFFF"> G.T. </td> '
37
+ '</tr>')
38
+
39
+ fileids_path = '../data/cities_dataset/valid_list.txt'
40
+ img_base = '../data/cities_dataset/rgb'
41
+ with open(fileids_path) as ff:
42
+ file_ids = ff.readlines()
43
+ file_ids = file_ids[50:]
44
+ file_ids = [file_id.strip() for file_id in file_ids]
45
+ permuted_ids = np.random.permutation(file_ids)
46
+ file_ids = permuted_ids[:100]
47
+
48
+ for file_id in file_ids:
49
+ row_str = '<tr>'
50
+ row_str += '<td> {} </td>'.format(file_id)
51
+ row_str += '<td> <img src="{}" width="180"> </td>'.format(os.path.join(img_base, file_id + '.jpg'))
52
+ for dir_idx, result_dir in enumerate(results_dirs):
53
+ pred_filepath = osp.join(result_dir, '{}.png'.format(file_id))
54
+ row_str += '<td> <img src="{}" width="180"> </td>'.format(pred_filepath)
55
+ row_str += '</tr>'
56
+ f.write(row_str + '\n')
57
+
58
+ f.write(end + '\n')
59
+
60
+
61
+ if __name__ == '__main__':
62
+ results_dirs = ['svg_images_256/convmpn', 'svg_images_256/exp_cls', 'svg_images_256/hawp', 'svg_images_256/letr', 'svg_images_256/heat', 'svg_images_256/gt']
63
+
64
+ writeHTML(out_path='./outdoor_qual.html', results_dirs=results_dirs)
qualitative_outdoor/plot_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import svgwrite
3
+ import colorsys
4
+
5
+
6
+ def plot_preds(image, corners, edges):
7
+ for line in edges:
8
+ cv2.line(image, tuple(line[:2]), tuple(line[2:]), (255, 255, 0), 2)
9
+ for c in corners:
10
+ cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
11
+ return image
12
+
13
+
14
+ def random_colors(N, bright=True, same=False, colors=None):
15
+ brightness = 1.0 if bright else 0.7
16
+ if colors is None or same:
17
+ if same:
18
+ hsv = [(0, 1, brightness) for i in range(N)]
19
+ else:
20
+ hsv = [(i / N, 1, brightness) for i in range(N)]
21
+ else:
22
+ hsv = [(colors[i], 1, brightness) for i in range(N)]
23
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
24
+ return colors
25
+
26
+
27
+ def svg_generate(image_link, corners, edges, name, size=512):
28
+ dwg = svgwrite.Drawing(name + '.svg', size=('{}'.format(size), '{}'.format(size)))
29
+ shapes = dwg.add(dwg.g(id='shape', fill='black'))
30
+ # colors = random_colors(len(edges), same=True)
31
+ shapes.add(dwg.image(href=image_link, size=(size, size)))
32
+
33
+ scale = size / 256
34
+ for i, edge in enumerate(edges):
35
+ x = edge[:2] * scale
36
+ y = edge[2:] * scale
37
+ shapes.add(dwg.line((int(x[0]), int(x[1])), (int(y[0]), int(y[1])),
38
+ stroke="#EE6507", stroke_width=3*scale, opacity=0.7))
39
+
40
+ for i, corner in enumerate(corners):
41
+ shapes.add(dwg.circle((int(corners[i][0] * scale), int(corners[i][1]) * scale), r=4*scale,
42
+ stroke='green', fill='white', stroke_width=2*scale, opacity=0.8))
43
+ return dwg
qualitative_outdoor/visualize_gt.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import cv2
4
+ import numpy as np
5
+ from plot_utils import plot_preds, svg_generate
6
+ import cairosvg
7
+
8
+ image_base = '../data/cities_dataset/rgb/'
9
+ annot_base = '../data/cities_dataset/annot/'
10
+ data_filename = '../data/cities_dataset/valid_list.txt'
11
+ with open(data_filename) as f:
12
+ filenames = f.readlines()
13
+
14
+ filenames = filenames[50:]
15
+ filenames = [filename.strip() for filename in filenames]
16
+
17
+
18
+ for filename in filenames:
19
+ image_path = os.path.join(image_base, filename + '.jpg')
20
+ # image = cv2.imread(image_path)
21
+ annot_path = os.path.join(annot_base, filename + '.npy')
22
+
23
+ annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
24
+ corners = np.array(list(annot.keys())).astype(np.int)
25
+
26
+ edges = set()
27
+ for c, others in annot.items():
28
+ for other_c in others:
29
+ edge = (c[0], c[1], other_c[0], other_c[1])
30
+ edge_2 = (other_c[0], other_c[1], c[0], c[1])
31
+ if edge not in edges and edge_2 not in edges:
32
+ edges.add(edge)
33
+
34
+ edges = np.array(list(edges)).astype(np.int)
35
+
36
+ # image = plot_preds(image, corners, edges)
37
+ # out_path = os.path.join(out_base, filename + '.png')
38
+ # cv2.imwrite(out_path, image)
39
+
40
+ svg = svg_generate(image_path, corners, edges, name='temp', size=256)
41
+ svg_path = './svg_results/' + 'tmp.svg'
42
+ svg.saveas(svg_path)
43
+ svg_img_path = './svg_images_256/gt/' + '{}.png'.format(filename)
44
+ cairosvg.svg2png(url=svg_path, write_to=svg_img_path)
45
+
46
+
qualitative_outdoor/visualize_npy.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import cv2
4
+ import numpy as np
5
+ import cairosvg
6
+ from plot_utils import plot_preds, svg_generate
7
+
8
+ image_base = '../../data/outdoor/cities_dataset/rgb/'
9
+ svg_base = './svg_results'
10
+
11
+ if not os.path.exists(svg_base):
12
+ os.makedirs(svg_base)
13
+
14
+ data_filename = '../data/outdoor/cities_dataset/valid_list.txt'
15
+ with open(data_filename) as f:
16
+ filenames = f.readlines()
17
+
18
+ filenames = filenames[50:] # according to previous works, the testing samples are the last 350 samples of the val split
19
+ filenames = [filename.strip() for filename in filenames]
20
+ idx_to_filename = {idx: filename for idx, filename in enumerate(filenames)}
21
+
22
+ method_name = 'heat'
23
+ results_base = '../results/npy_outdoor_test_256/'
24
+
25
+ svg_method_base = os.path.join(svg_base, method_name)
26
+ if not os.path.exists(svg_method_base):
27
+ os.makedirs(svg_method_base)
28
+
29
+ for result_filename in sorted(os.listdir(results_base)):
30
+ file_idx = int(result_filename[:-12])
31
+ filename = idx_to_filename[file_idx]
32
+
33
+ image_path = os.path.join(image_base, filename + '.jpg')
34
+
35
+ results_path = os.path.join(results_base, result_filename)
36
+ results = np.load(results_path, allow_pickle=True).tolist()
37
+ corners = results['corners'].astype(np.int)
38
+ edge_ids = results['edges']
39
+ edges = corners[edge_ids].reshape(edge_ids.shape[0], -1)
40
+
41
+ svg = svg_generate(image_path, corners, edges, name='temp', size=256)
42
+ svg_path = os.path.join(svg_base, 'tmp.svg')
43
+ svg.saveas(svg_path) # save the svg file temporarily
44
+
45
+ svg_img_path = os.path.join(svg_method_base, '{}.png'.format(filename))
46
+ cairosvg.svg2png(url=svg_path, write_to=svg_img_path)
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython==0.29.22
2
+ defusedxml==0.6.0
3
+ einops==0.4.1
4
+ future==0.18.2
5
+ imageio==2.16.1
6
+ matplotlib==3.3.4
7
+ MultiScaleDeformableAttention==1.0
8
+ numpy==1.20.1
9
+ opencv-python==4.4.0.44
10
+ packaging==20.9
11
+ Pillow==9.0.1
12
+ prometheus-client==0.9.0
13
+ prompt-toolkit==3.0.16
14
+ ptyprocess==0.7.0
15
+ pycparser==2.20
16
+ Pygments==2.8.0
17
+ python-dateutil==2.8.1
18
+ scikit-image==0.19.2
19
+ scikit-learn==1.0
20
+ scipy==1.6.1
21
+ six==1.15.0
22
+ torch==1.5.1
23
+ torchvision==0.6.1
24
+ cairosvg==2.5.2
25
+ svgwrite==1.4.2
26
+ shapely==1.8.2
27
+ gradio==2.5.3
s3d_floorplan_eval/DataRW/DataRW.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ class DataRW:
3
+ def __init__(self, options):
4
+ pass
s3d_floorplan_eval/DataRW/S3DRW.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ import os
5
+ import time
6
+
7
+ from DataRW.DataRW import DataRW
8
+ from S3DLoader.S3DLoader import S3DLoader
9
+
10
+ class S3DRW(DataRW):
11
+ def __init__(self, options):
12
+ """
13
+ Class for accessing FloorNet dataset related data
14
+
15
+ :param options:
16
+ """
17
+ # initialize the base class variables
18
+ super(DataRW, self).__init__()
19
+
20
+ self.options = options
21
+
22
+ self.dataset_path = options.dataset_path
23
+ self.scene_id = options.scene_id
24
+
25
+ self.mcts_path = options.mcts_path
26
+ self.creation_time = int(time.time())
27
+
28
+ self.device = torch.device("cpu")
29
+
30
+ # mode = "train"
31
+ # mode = "online_eval"
32
+ mode = "test"
33
+ # For validation only
34
+ # self.loader = S3DLoader(options, 'online_eval').dataset
35
+ self.loader = S3DLoader(options, mode).dataset
36
+
37
+ # gt_sample = iter(floornet_loader.dataset[int(self.scene_id)])
38
+ # self.gt_sample = floornet_loader.load_sample(list(iter(floornet_loader.dataset))[int(self.scene_id)])
39
+
40
+ if mode == "online_eval":
41
+ scene_ind = int(self.scene_id[6:]) - 3000
42
+ elif mode == "test":
43
+ scene_ind = int(self.scene_id[6:]) - 3250
44
+ elif mode == "train":
45
+ scene_ind = int(self.scene_id[6:])
46
+ else:
47
+ assert False
48
+
49
+ # print(len(list(iter(self.s3d_loader.data))))
50
+ self.gt_sample = gt_sample = self.loader[scene_ind]
51
+ self.gt_sample["density_map"] = torch.tensor(self.gt_sample["density_map"][None], device=self.device)
52
+ self.gt_sample["room_map"] = torch.tensor(self.gt_sample["room_map"][None,:,:,None], device=self.device)
53
+ self.gt_sample["wall_map"] = torch.tensor(self.gt_sample["wall_map"][None,:,:,None], device=self.device)
54
+
55
+
56
+ self.density_map = self.gt_sample['density_map'][:,:,:,None]
57
+
58
+ self.h, self.w = self.density_map.shape[1], self.density_map.shape[2]
59
+
60
+ self.generate_input_map_from_props = self.generate_input_dict_from_room_props
61
+
62
+ def get_gt_solution(self):
63
+ """
64
+ Read top-view density map of the scene
65
+
66
+ :return:
67
+ """
68
+ img_path = os.path.join(self.dataset_path, str(self.scene_id) + "_density.png")
69
+ density_map = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)[:,:, 0][None,:,:,None]
70
+
71
+ density_map = torch.from_numpy(density_map).to(self.device)
72
+
73
+ dm_min = torch.min(density_map)
74
+ dm_max = torch.max(density_map)
75
+
76
+ density_map = (density_map - dm_min) / (dm_max - dm_min)
77
+
78
+ return density_map.type(torch.cuda.FloatTensor)
79
+
80
+ def polygonize_mask(self, pm, return_mask=True):
81
+ pm_np = pm.cpu().numpy()
82
+
83
+ room_mask = 255 * (pm_np == 1)
84
+ room_mask = room_mask.astype(np.uint8)
85
+ room_mask_inv = 255 - room_mask
86
+
87
+ ret, thresh = cv2.threshold(room_mask_inv, 250, 255, cv2.THRESH_BINARY_INV)
88
+
89
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
90
+
91
+ cnt = contours[0]
92
+ max_area = cv2.contourArea(cnt)
93
+
94
+ for cont in contours:
95
+ if cv2.contourArea(cont) > max_area:
96
+ cnt = cont
97
+ max_area = cv2.contourArea(cont)
98
+
99
+ # define main island contour approx. and hull
100
+ perimeter = cv2.arcLength(cnt, True)
101
+ epsilon = 0.01 * cv2.arcLength(cnt, True)
102
+ approx = cv2.approxPolyDP(cnt, epsilon, True)
103
+
104
+ # approx = np.concatenate([approx, approx[0][None]], axis=0)
105
+ approx = approx.astype(np.int32).reshape((1, -1, 2))
106
+
107
+ if return_mask:
108
+ room_filled_map = np.zeros((self.h, self.w))
109
+ cv2.fillPoly(room_filled_map, approx, color=1.)
110
+
111
+ room_filled_map = torch.tensor(room_filled_map[:,:], dtype=torch.float32, device=self.device)
112
+
113
+ return room_filled_map
114
+ else:
115
+ approx_tensor = torch.tensor(approx, device=self.device)
116
+ return approx_tensor
117
+
118
+ def generate_input_dict_from_room_props(self, room_prop_list, score_function, use_thresh=False):
119
+ """
120
+
121
+ :param room_prop_list:
122
+ :type room_prop_list: list of FloorPlanRoomProp
123
+ :param score_function:
124
+ :return:
125
+ """
126
+
127
+ if score_function == "room_maskrcnn_iou":
128
+ inputs = self.generate_input_dict_for_room_maskrcnn_iou(room_prop_list)
129
+ elif score_function == "room_iou":
130
+ inputs = self.generate_input_dict_for_room_iou(room_prop_list, use_thresh=use_thresh)
131
+ else:
132
+ assert "generate_input_dict_from_room_props for %s not implemented" % score_function
133
+
134
+ return inputs
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
s3d_floorplan_eval/DataRW/wrong_annotatios.py ADDED
@@ -0,0 +1 @@
 
 
1
+ wrong_s3d_annotations_list = [3261, 3271, 3276, 3296, 3342, 3387, 3398, 3466, 3496]
s3d_floorplan_eval/Evaluator/Evaluator.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import numpy as np
6
+ from scipy.spatial import Delaunay
7
+ import os
8
+ import shapely
9
+ from shapely.geometry import Polygon, MultiPolygon, LineString, MultiLineString
10
+
11
+ corner_metric_thresh = 10
12
+ angle_metric_thresh = 5
13
+
14
+
15
+
16
+ # colormap_255 = [[i, i, i] for i in range(40)]
17
+
18
+ class Evaluator():
19
+ def __init__(self, data_rw, options):
20
+ self.data_rw = data_rw
21
+ self.options = options
22
+
23
+ self.device = torch.device("cuda")
24
+
25
+ def polygonize_mask(self, mask, degree, return_mask=True):
26
+ h, w = mask.shape[0], mask.shape[1]
27
+ mask = mask
28
+
29
+ room_mask = 255 * (mask == 1)
30
+ room_mask = room_mask.astype(np.uint8)
31
+ room_mask_inv = 255 - room_mask
32
+
33
+ ret, thresh = cv2.threshold(room_mask_inv, 250, 255, cv2.THRESH_BINARY_INV)
34
+
35
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
36
+
37
+ cnt = contours[0]
38
+ max_area = cv2.contourArea(cnt)
39
+
40
+ for cont in contours:
41
+ if cv2.contourArea(cont) > max_area:
42
+ cnt = cont
43
+ max_area = cv2.contourArea(cont)
44
+
45
+ perimeter = cv2.arcLength(cnt, True)
46
+ # epsilon = 0.01 * cv2.arcLength(cnt, True)
47
+ epsilon = degree * cv2.arcLength(cnt, True)
48
+ approx = cv2.approxPolyDP(cnt, epsilon, True)
49
+
50
+ # approx = np.concatenate([approx, approx[0][None]], axis=0)
51
+ approx = approx.astype(np.int32).reshape((-1, 2))
52
+
53
+ # approx_tensor = torch.tensor(approx, device=self.device)
54
+
55
+ # return approx_tensor
56
+ if return_mask:
57
+ room_filled_map = np.zeros((h, w))
58
+ cv2.fillPoly(room_filled_map, [approx], color=1.)
59
+
60
+ return approx, room_filled_map
61
+ else:
62
+ return approx
63
+
64
+ def print_res_str_for_latex(self, quant_result_dict):
65
+
66
+ str_fields = ""
67
+ str_values = ""
68
+
69
+ avg_value_prec = 0
70
+ avg_value_rec = 0
71
+ for k_ind, k in enumerate(quant_result_dict.keys()):
72
+ str_fields += " & " + k
73
+ str_values += " & %.2f " % quant_result_dict[k]
74
+
75
+ if k_ind % 2 == 0:
76
+ avg_value_prec += quant_result_dict[k] / 3
77
+ else:
78
+ avg_value_rec += quant_result_dict[k] / 3
79
+
80
+ str_fields += "tm_prec & tm_rec"
81
+
82
+ str_values += " & %.2f " % avg_value_prec
83
+ str_values += " & %.2f " % avg_value_rec
84
+
85
+ str_fields += " \\\\"
86
+ str_values += " \\\\"
87
+
88
+ print(str_fields)
89
+ print(str_values)
90
+
91
+ def calc_gradient(self, room_map):
92
+ grad_x = np.abs(room_map[:, 1:] - room_map[:, :-1])
93
+ grad_y = np.abs(room_map[1:] - room_map[:-1])
94
+
95
+ grad_xy = np.zeros_like(room_map)
96
+ grad_xy[1:] = grad_y
97
+ grad_xy[:, 1:] = np.maximum(grad_x, grad_xy[:,1:])
98
+
99
+ plt.figure()
100
+ plt.axis("off")
101
+ plt.imshow(grad_xy, cmap="gray")
102
+ # plt.show()
103
+ plt.savefig("grad.png", bbox_inches='tight')
104
+
105
+ plt.figure()
106
+ plt.axis("off")
107
+ plt.imshow(room_map, cmap="gray")
108
+ # plt.show()
109
+ plt.savefig("joint_mask.png", bbox_inches='tight')
110
+ assert False
111
+
112
+ def evaluate_scene(self, room_polys, show=False, name="ours", dataset_type="s3d"):
113
+
114
+ with torch.no_grad():
115
+ joint_room_map = np.zeros((self.options.height, self.options.width))
116
+
117
+ edge_map = np.zeros_like(joint_room_map)
118
+ room_filled_map = np.ones([joint_room_map.shape[0], joint_room_map.shape[1], 3])
119
+
120
+ density_map = self.data_rw.density_map.cpu().numpy()[0]
121
+ img_size = (density_map.shape[0], density_map.shape[0])
122
+
123
+ for room_ind, poly in enumerate(room_polys):
124
+ cv2.polylines(edge_map, [poly], isClosed=True, color=1.)
125
+ cv2.fillPoly(joint_room_map, [poly], color=1.)
126
+
127
+ joint_room_map_vis = np.ones([joint_room_map.shape[0], joint_room_map.shape[1], 3])
128
+
129
+ # Ground Truth
130
+
131
+ gt_polys_list = self.data_rw.gt_sample["polygons_list"]
132
+ gt_polys_list = [np.concatenate([poly, poly[None, 0]]) for poly in gt_polys_list]
133
+
134
+ ignore_mask_region = self.data_rw.gt_sample["wall_map"].cpu().numpy()[0, :, :, 0]
135
+
136
+ img_size = (joint_room_map.shape[0], joint_room_map.shape[1])
137
+ quant_result_dict = self.get_quantitative(gt_polys_list, ignore_mask_region, room_polys, img_size, dataset_type=dataset_type)
138
+
139
+ return quant_result_dict
140
+
141
+ def get_quantitative(self, gt_polys, ignore_mask_region, pred_polys=None, masks_list=None, img_size=(256, 256), dataset_type="s3d"):
142
+ def get_room_metric():
143
+ pred_overlaps = [False] * len(pred_room_map_list)
144
+
145
+ for pred_ind1 in range(len(pred_room_map_list) - 1):
146
+ pred_map1 = pred_room_map_list[pred_ind1]
147
+
148
+ for pred_ind2 in range(pred_ind1 + 1, len(pred_room_map_list)):
149
+ pred_map2 = pred_room_map_list[pred_ind2]
150
+
151
+ if dataset_type == "s3d":
152
+ kernel = np.ones((5, 5), np.uint8)
153
+ else:
154
+ kernel = np.ones((3, 3), np.uint8)
155
+
156
+ # todo: for our method, the rooms share corners and edges, need to check here
157
+ pred_map1_er = cv2.erode(pred_map1, kernel)
158
+ pred_map2_er = cv2.erode(pred_map2, kernel)
159
+
160
+ intersection = (pred_map1_er + pred_map2_er) == 2
161
+ # intersection = (pred_map1 + pred_map2) == 2
162
+
163
+ intersection_area = np.sum(intersection)
164
+
165
+ if intersection_area >= 1:
166
+ pred_overlaps[pred_ind1] = True
167
+ pred_overlaps[pred_ind2] = True
168
+
169
+ # import pdb; pdb.set_trace()
170
+ room_metric = [np.bool((1 - pred_overlaps[ind]) * pred2gt_exists[ind]) for ind in range(len(pred_polys))]
171
+
172
+ return room_metric
173
+
174
+ def get_corner_metric():
175
+
176
+ room_corners_metric = []
177
+ for pred_poly_ind, gt_poly_ind in enumerate(pred2gt_indices):
178
+ p_poly = pred_polys[pred_poly_ind][:-1] # Last vertex = First vertex
179
+
180
+ p_poly_corner_metrics = [False] * p_poly.shape[0]
181
+ if not room_metric[pred_poly_ind]:
182
+ room_corners_metric += p_poly_corner_metrics
183
+ continue
184
+
185
+ gt_poly = gt_polys[gt_poly_ind][:-1]
186
+
187
+ # for v in p_poly:
188
+ # v_dists = np.linalg.norm(v[None,:] - gt_poly, axis=1, ord=2)
189
+ # v_min_dist = np.min(v_dists)
190
+ #
191
+ # v_tp = v_min_dist <= 10
192
+ # room_corners_metric.append(v_tp)
193
+
194
+ for v in gt_poly:
195
+ v_dists = np.linalg.norm(v[None,:] - p_poly, axis=1, ord=2)
196
+ v_min_dist_ind = np.argmin(v_dists)
197
+ v_min_dist = v_dists[v_min_dist_ind]
198
+
199
+ if not p_poly_corner_metrics[v_min_dist_ind]:
200
+ v_tp = v_min_dist <= corner_metric_thresh
201
+ p_poly_corner_metrics[v_min_dist_ind] = v_tp
202
+
203
+ room_corners_metric += p_poly_corner_metrics
204
+
205
+ return room_corners_metric
206
+
207
+ def get_angle_metric():
208
+
209
+ def get_line_vector(p1, p2):
210
+ p1 = np.concatenate((p1, np.array([1])))
211
+ p2 = np.concatenate((p2, np.array([1])))
212
+
213
+ line_vector = -np.cross(p1, p2)
214
+
215
+ return line_vector
216
+
217
+ def get_poly_orientation(my_poly):
218
+ angles_sum = 0
219
+ for v_ind, _ in enumerate(my_poly):
220
+ if v_ind < len(my_poly) - 1:
221
+ v_sides = my_poly[[v_ind - 1, v_ind, v_ind, v_ind + 1], :]
222
+ else:
223
+ v_sides = my_poly[[v_ind - 1, v_ind, v_ind, 0], :]
224
+
225
+ v1_vector = get_line_vector(v_sides[0], v_sides[1])
226
+ v1_vector = v1_vector / (np.linalg.norm(v1_vector, ord=2) + 1e-4)
227
+ v2_vector = get_line_vector(v_sides[2], v_sides[3])
228
+ v2_vector = v2_vector / (np.linalg.norm(v2_vector, ord=2) + 1e-4)
229
+
230
+ orientation = (v_sides[1, 1] - v_sides[0, 1]) * (v_sides[3, 0] - v_sides[1, 0]) - (
231
+ v_sides[3, 1] - v_sides[1, 1]) * (
232
+ v_sides[1, 0] - v_sides[0, 0])
233
+
234
+ v1_vector_2d = v1_vector[:2] / (v1_vector[2] + 1e-4)
235
+ v2_vector_2d = v2_vector[:2] / (v2_vector[2] + 1e-4)
236
+
237
+ v1_vector_2d = v1_vector_2d / (np.linalg.norm(v1_vector_2d, ord=2) + 1e-4)
238
+ v2_vector_2d = v2_vector_2d / (np.linalg.norm(v2_vector_2d, ord=2) + 1e-4)
239
+
240
+ angle_cos = v1_vector_2d.dot(v2_vector_2d)
241
+ angle_cos = np.clip(angle_cos, -1, 1)
242
+
243
+ # G.T. has clockwise orientation, remove minus in the equation
244
+
245
+ angle = np.sign(orientation) * np.abs(np.arccos(angle_cos))
246
+ angle_degree = angle * 180 / np.pi
247
+
248
+ angles_sum += angle_degree
249
+
250
+ return np.sign(angles_sum)
251
+
252
+ def get_angle_v_sides(inp_v_sides, poly_orient):
253
+ v1_vector = get_line_vector(inp_v_sides[0], inp_v_sides[1])
254
+ v1_vector = v1_vector / (np.linalg.norm(v1_vector, ord=2) + 1e-4)
255
+ v2_vector = get_line_vector(inp_v_sides[2], inp_v_sides[3])
256
+ v2_vector = v2_vector / (np.linalg.norm(v2_vector, ord=2) + 1e-4)
257
+
258
+ orientation = (inp_v_sides[1, 1] - inp_v_sides[0, 1]) * (inp_v_sides[3, 0] - inp_v_sides[1, 0]) - (
259
+ inp_v_sides[3, 1] - inp_v_sides[1, 1]) * (
260
+ inp_v_sides[1, 0] - inp_v_sides[0, 0])
261
+
262
+ v1_vector_2d = v1_vector[:2] / (v1_vector[2]+ 1e-4)
263
+ v2_vector_2d = v2_vector[:2] / (v2_vector[2]+ 1e-4)
264
+
265
+ v1_vector_2d = v1_vector_2d / (np.linalg.norm(v1_vector_2d, ord=2) + 1e-4)
266
+ v2_vector_2d = v2_vector_2d / (np.linalg.norm(v2_vector_2d, ord=2) + 1e-4)
267
+
268
+ angle_cos = v1_vector_2d.dot(v2_vector_2d)
269
+ angle_cos = np.clip(angle_cos, -1, 1)
270
+
271
+ angle = poly_orient * np.sign(orientation) * np.arccos(angle_cos)
272
+ angle_degree = angle * 180 / np.pi
273
+
274
+ return angle_degree
275
+
276
+ room_angles_metric = []
277
+ for pred_poly_ind, gt_poly_ind in enumerate(pred2gt_indices):
278
+ p_poly = pred_polys[pred_poly_ind][:-1] # Last vertex = First vertex
279
+
280
+ p_poly_angle_metrics = [False] * p_poly.shape[0]
281
+ if not room_metric[pred_poly_ind]:
282
+ room_angles_metric += p_poly_angle_metrics
283
+ continue
284
+
285
+ gt_poly = gt_polys[gt_poly_ind][:-1]
286
+
287
+ # for v in p_poly:
288
+ # v_dists = np.linalg.norm(v[None,:] - gt_poly, axis=1, ord=2)
289
+ # v_min_dist = np.min(v_dists)
290
+ #
291
+ # v_tp = v_min_dist <= 10
292
+ # room_corners_metric.append(v_tp)
293
+
294
+ gt_poly_orient = get_poly_orientation(gt_poly)
295
+ p_poly_orient = get_poly_orientation(p_poly)
296
+
297
+ for v_gt_ind, v in enumerate(gt_poly):
298
+ v_dists = np.linalg.norm(v[None,:] - p_poly, axis=1, ord=2)
299
+ v_ind = np.argmin(v_dists)
300
+ v_min_dist = v_dists[v_ind]
301
+
302
+ if v_min_dist > corner_metric_thresh:
303
+ # room_angles_metric.append(False)
304
+ continue
305
+
306
+ if v_ind < len(p_poly) - 1:
307
+ v_sides = p_poly[[v_ind - 1, v_ind, v_ind, v_ind + 1], :]
308
+ else:
309
+ v_sides = p_poly[[v_ind - 1, v_ind, v_ind, 0], :]
310
+
311
+ v_sides = v_sides.reshape((4,2))
312
+ pred_angle_degree = get_angle_v_sides(v_sides, p_poly_orient)
313
+
314
+ # Note: replacing some variables with values from the g.t. poly
315
+
316
+ if v_gt_ind < len(gt_poly) - 1:
317
+ v_sides = gt_poly[[v_gt_ind - 1, v_gt_ind, v_gt_ind, v_gt_ind + 1], :]
318
+ else:
319
+ v_sides = gt_poly[[v_gt_ind - 1, v_gt_ind, v_gt_ind, 0], :]
320
+
321
+ v_sides = v_sides.reshape((4, 2))
322
+ gt_angle_degree = get_angle_v_sides(v_sides, gt_poly_orient)
323
+
324
+ angle_metric = np.abs(pred_angle_degree - gt_angle_degree)
325
+
326
+ # room_angles_metric.append(angle_metric < 5)
327
+ p_poly_angle_metrics[v_ind] = angle_metric <= angle_metric_thresh
328
+
329
+ # if angle_metric > 5:
330
+ # print(v_gt_ind, angle_metric)
331
+ # print(pred_angle_degree, gt_angle_degree)
332
+ # input("?")
333
+
334
+
335
+ room_angles_metric += p_poly_angle_metrics
336
+
337
+ for am, cm in zip(room_angles_metric, corner_metric):
338
+ assert not (cm == False and am == True), "cm: %d am: %d" %(cm, am)
339
+
340
+ return room_angles_metric
341
+
342
+ def poly_map_sort_key(x):
343
+ return np.sum(x[1])
344
+
345
+ h, w = img_size
346
+
347
+ gt_room_map_list = []
348
+ for room_ind, poly in enumerate(gt_polys):
349
+ room_map = np.zeros((h, w))
350
+ cv2.fillPoly(room_map, [poly], color=1.)
351
+
352
+ gt_room_map_list.append(room_map)
353
+
354
+ gt_polys_sorted_indcs = [i[0] for i in sorted(enumerate(gt_room_map_list), key=poly_map_sort_key, reverse=True)]
355
+
356
+ gt_polys = [gt_polys[ind] for ind in gt_polys_sorted_indcs]
357
+ gt_room_map_list = [gt_room_map_list[ind] for ind in gt_polys_sorted_indcs]
358
+
359
+ if pred_polys is not None:
360
+ pred_room_map_list = []
361
+ for room_ind, poly in enumerate(pred_polys):
362
+ room_map = np.zeros((h, w))
363
+ cv2.fillPoly(room_map, [poly], color=1.)
364
+
365
+ pred_room_map_list.append(room_map)
366
+ else:
367
+ pred_room_map_list = masks_list
368
+
369
+ gt2pred_indices = [-1] * len(gt_polys)
370
+ gt2pred_exists = [False] * len(gt_polys)
371
+
372
+ for gt_ind, gt_map in enumerate(gt_room_map_list):
373
+
374
+ best_iou = 0.
375
+ best_ind = -1
376
+ for pred_ind, pred_map in enumerate(pred_room_map_list):
377
+
378
+ intersection = (1 - ignore_mask_region) * ((pred_map + gt_map) == 2)
379
+ union = (1 - ignore_mask_region) * ((pred_map + gt_map) >= 1)
380
+
381
+ iou = np.sum(intersection) / (np.sum(union) + 1)
382
+
383
+ if iou > best_iou and iou > 0.5:
384
+ best_iou = iou
385
+ best_ind = pred_ind
386
+
387
+ # plt.figure()
388
+ # plt.subplot(121)
389
+ # plt.imshow(pred_map)
390
+ # plt.subplot(122)
391
+ # plt.imshow(gt_map)
392
+ # plt.show()
393
+ # if best_ind == -1:
394
+ # plt.figure()
395
+ # plt.imshow(gt_map)
396
+ # plt.show()
397
+
398
+ gt2pred_indices[gt_ind] = best_ind
399
+ gt2pred_exists[gt_ind] = best_ind != -1
400
+
401
+ # if best_ind == -1:
402
+ # plt.figure()
403
+ # plt.imshow(gt_map)
404
+ # plt.show()
405
+
406
+ pred2gt_exists = [True if pred_ind in gt2pred_indices else False for pred_ind, _ in enumerate(pred_polys)]
407
+ pred2gt_indices = [gt2pred_indices.index(pred_ind) if pred_ind in gt2pred_indices else -1 for pred_ind, _ in enumerate(pred_polys)]
408
+
409
+ # print(gt2pred_indices)
410
+ # print(pred2gt_indices)
411
+ # assert False
412
+
413
+ # import pdb; pdb.set_trace()
414
+ room_metric = get_room_metric()
415
+ if len(pred_polys) == 0:
416
+ room_metric_prec = 0
417
+ else:
418
+ room_metric_prec = sum(room_metric) / float(len(pred_polys))
419
+ room_metric_rec = sum(room_metric) / float(len(gt_polys))
420
+
421
+
422
+ corner_metric = get_corner_metric()
423
+ pred_corners_n = sum([poly.shape[0] - 1 for poly in pred_polys])
424
+ gt_corners_n = sum([poly.shape[0] - 1 for poly in gt_polys])
425
+
426
+ if pred_corners_n > 0:
427
+ corner_metric_prec = sum(corner_metric) / float(pred_corners_n)
428
+ else:
429
+ corner_metric_prec = 0
430
+ corner_metric_rec = sum(corner_metric) / float(gt_corners_n)
431
+
432
+
433
+ angles_metric = get_angle_metric()
434
+
435
+ if pred_corners_n > 0:
436
+ angles_metric_prec = sum(angles_metric) / float(pred_corners_n)
437
+ else:
438
+ angles_metric_prec = 0
439
+ angles_metric_rec = sum(angles_metric) / float(gt_corners_n)
440
+
441
+ assert room_metric_prec <= 1
442
+ assert room_metric_rec <= 1
443
+ assert corner_metric_prec <= 1
444
+ assert corner_metric_rec <= 1
445
+ assert angles_metric_prec <= 1
446
+ assert angles_metric_rec <= 1
447
+
448
+ result_dict = {
449
+ 'room_prec': room_metric_prec,
450
+ 'room_rec': room_metric_rec,
451
+ 'corner_prec': corner_metric_prec,
452
+ 'corner_rec': corner_metric_rec,
453
+ 'angles_prec': angles_metric_prec,
454
+ 'angles_rec': angles_metric_rec,
455
+ }
456
+
457
+ return result_dict