Commit
·
424188c
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +18 -0
- HEAT.py +460 -0
- LICENSE +674 -0
- README.md +21 -0
- app.py +33 -0
- arguments.py +33 -0
- assets/img/pipeline.png +0 -0
- assets/img/problem_description.png +0 -0
- datasets/__init__.py +0 -0
- datasets/corners.py +183 -0
- datasets/data_utils.py +57 -0
- datasets/outdoor_buildings.py +183 -0
- datasets/s3d_floorplans.py +187 -0
- images/test.jpg +0 -0
- infer.py +455 -0
- metrics/get_metric.py +219 -0
- metrics/new_utils.py +2100 -0
- models/__init__.py +0 -0
- models/corner_models.py +275 -0
- models/corner_to_edge.py +232 -0
- models/deformable_transformer.py +236 -0
- models/edge_models.py +314 -0
- models/loss.py +63 -0
- models/mlp.py +21 -0
- models/ops/functions/__init__.py +10 -0
- models/ops/functions/ms_deform_attn_func.py +61 -0
- models/ops/make.sh +10 -0
- models/ops/modules/__init__.py +9 -0
- models/ops/modules/ms_deform_attn.py +115 -0
- models/ops/setup.py +71 -0
- models/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
- models/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
- models/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
- models/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
- models/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
- models/ops/src/ms_deform_attn.h +62 -0
- models/ops/src/vision.cpp +16 -0
- models/ops/test.py +89 -0
- models/resnet.py +167 -0
- models/stacked_hg.py +246 -0
- predict.py +33 -0
- qualitative_outdoor/generate_html.py +64 -0
- qualitative_outdoor/plot_utils.py +43 -0
- qualitative_outdoor/visualize_gt.py +46 -0
- qualitative_outdoor/visualize_npy.py +46 -0
- requirements.txt +27 -0
- s3d_floorplan_eval/DataRW/DataRW.py +4 -0
- s3d_floorplan_eval/DataRW/S3DRW.py +142 -0
- s3d_floorplan_eval/DataRW/wrong_annotatios.py +1 -0
- s3d_floorplan_eval/Evaluator/Evaluator.py +457 -0
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
./data
|
2 |
+
.DS_Store
|
3 |
+
viz/*
|
4 |
+
*.tar
|
5 |
+
*.pdf
|
6 |
+
*.zip
|
7 |
+
svg_*
|
8 |
+
*.html
|
9 |
+
models/ops/build
|
10 |
+
models/ops/dist
|
11 |
+
models/ops/*egg-info
|
12 |
+
__pycache__
|
13 |
+
results
|
14 |
+
montefloor_data
|
15 |
+
.idea
|
16 |
+
model_data/checkpoints
|
17 |
+
model_data/heat_checkpoints
|
18 |
+
shpfile/
|
HEAT.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: [egrt]
|
3 |
+
Date: 2022-08-23 11:44:15
|
4 |
+
LastEditors: Egrt
|
5 |
+
LastEditTime: 2022-11-23 15:25:35
|
6 |
+
Description: HEAT的模型加载与预测
|
7 |
+
'''
|
8 |
+
from turtle import pos
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from models.resnet import ResNetBackbone
|
12 |
+
from models.corner_models import HeatCorner
|
13 |
+
from models.edge_models import HeatEdge
|
14 |
+
from models.corner_to_edge import get_infer_edge_pairs
|
15 |
+
from datasets.data_utils import get_pixel_features
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
from PIL import Image
|
18 |
+
from utils import image_utils
|
19 |
+
from osgeo import gdal, ogr, osr
|
20 |
+
from tqdm import tqdm
|
21 |
+
import os
|
22 |
+
import scipy
|
23 |
+
import numpy as np
|
24 |
+
import cv2
|
25 |
+
import skimage
|
26 |
+
|
27 |
+
class HEAT(object):
|
28 |
+
#-----------------------------------------#
|
29 |
+
# 注意修改model_path
|
30 |
+
#-----------------------------------------#
|
31 |
+
_defaults = {
|
32 |
+
#-----------------------------------------------#
|
33 |
+
# model_data指向整体网络的地址
|
34 |
+
#-----------------------------------------------#
|
35 |
+
"model_data" : 'model_data/heat_checkpoints/checkpoints/ckpts_heat_outdoor_256/checkpoint.pth',
|
36 |
+
#-----------------------------------------------#
|
37 |
+
# image_size模型预测图像的像素大小
|
38 |
+
#-----------------------------------------------#
|
39 |
+
"image_size" : [256, 256],
|
40 |
+
#-----------------------------------------------#
|
41 |
+
# patch_size为模型切片的大小
|
42 |
+
#-----------------------------------------------#
|
43 |
+
"patch_size" : 512,
|
44 |
+
#-----------------------------------------------#
|
45 |
+
# patch_overlap为切片重叠像素
|
46 |
+
#-----------------------------------------------#
|
47 |
+
"patch_overlap" : 0,
|
48 |
+
#-----------------------------------------------#
|
49 |
+
# corner_thresh为预测角点的阈值大小
|
50 |
+
#-----------------------------------------------#
|
51 |
+
"corner_thresh" : 0.01,
|
52 |
+
#-----------------------------------------------#
|
53 |
+
# 基于角点候选数的最大边数(不能大于6)
|
54 |
+
#-----------------------------------------------#
|
55 |
+
"corner_to_edge_multiplier": 3,
|
56 |
+
#-----------------------------------------------#
|
57 |
+
# 边缘推理筛选的迭代次数
|
58 |
+
#-----------------------------------------------#
|
59 |
+
"infer_times" : 3,
|
60 |
+
#-------------------------------#
|
61 |
+
# 是否使用Cuda
|
62 |
+
# 没有GPU可以设置成False
|
63 |
+
#-------------------------------#
|
64 |
+
"cuda" : False,
|
65 |
+
}
|
66 |
+
|
67 |
+
#---------------------------------------------------#
|
68 |
+
# 初始化MASKGAN
|
69 |
+
#---------------------------------------------------#
|
70 |
+
def __init__(self, **kwargs):
|
71 |
+
self.__dict__.update(self._defaults)
|
72 |
+
for name, value in kwargs.items():
|
73 |
+
setattr(self, name, value)
|
74 |
+
self.generate()
|
75 |
+
|
76 |
+
def generate(self):
|
77 |
+
# 从Huggingface加载整体网络模型
|
78 |
+
filepath = hf_hub_download(repo_id="Egrt/HEAT", filename="checkpoint.pth")
|
79 |
+
self.model = torch.load(filepath)
|
80 |
+
# 加载Backbone
|
81 |
+
self.backbone = ResNetBackbone()
|
82 |
+
strides = self.backbone.strides
|
83 |
+
num_channels = self.backbone.num_channels
|
84 |
+
self.backbone = nn.DataParallel(self.backbone)
|
85 |
+
self.backbone = self.backbone.cuda()
|
86 |
+
self.backbone.eval()
|
87 |
+
# 加载角点检测模型
|
88 |
+
self.corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
|
89 |
+
backbone_num_channels=num_channels)
|
90 |
+
self.corner_model = nn.DataParallel(self.corner_model)
|
91 |
+
self.corner_model = self.corner_model.cuda()
|
92 |
+
self.corner_model.eval()
|
93 |
+
# 加载边缘检测模型
|
94 |
+
self.edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
|
95 |
+
backbone_num_channels=num_channels)
|
96 |
+
self.edge_model = nn.DataParallel(self.edge_model)
|
97 |
+
self.edge_model = self.edge_model.cuda()
|
98 |
+
self.edge_model.eval()
|
99 |
+
# 分别加载模型的地址
|
100 |
+
self.backbone.load_state_dict(self.model['backbone'])
|
101 |
+
self.corner_model.load_state_dict(self.model['corner_model'])
|
102 |
+
self.edge_model.load_state_dict(self.model['edge_model'])
|
103 |
+
|
104 |
+
def detect_one_image(self, image):
|
105 |
+
#---------------------------------------------------------#
|
106 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
107 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
108 |
+
#---------------------------------------------------------#
|
109 |
+
image = cvtColor(image)
|
110 |
+
# 这里判断图片是否需要分成多个patch
|
111 |
+
if image.size[0] < self.patch_size or image.size[1] < self.patch_size:
|
112 |
+
is_slice = False
|
113 |
+
else:
|
114 |
+
is_slice = True
|
115 |
+
if is_slice:
|
116 |
+
# 复制原图
|
117 |
+
image = np.array(image, dtype=np.uint8)
|
118 |
+
# 复制输入的原图
|
119 |
+
viz_image = image.copy()
|
120 |
+
height, width = image.shape[0], image.shape[1]
|
121 |
+
# 获取缩放比例
|
122 |
+
scale = self.patch_size / self.image_size[0]
|
123 |
+
# 初始化角点、边缘列表
|
124 |
+
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = [], [], [], [], []
|
125 |
+
# 开始切分
|
126 |
+
stride = self.patch_size - self.patch_overlap
|
127 |
+
patch_boundingboxes = image_utils.compute_patch_boundingboxes((height, width),
|
128 |
+
stride=stride,
|
129 |
+
patch_res=self.patch_size)
|
130 |
+
edge_len = 0
|
131 |
+
# 获取切分后的图片
|
132 |
+
for bbox in tqdm(patch_boundingboxes, desc="使用切分进行预测", leave=False):
|
133 |
+
# 切分图像
|
134 |
+
crop_image = image[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
135 |
+
# np转Image类
|
136 |
+
crop_image = Image.fromarray(crop_image)
|
137 |
+
try:
|
138 |
+
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, _ = self.predict_no_patching(crop_image)
|
139 |
+
except RuntimeError as e:
|
140 |
+
print("ERROR: " + str(e))
|
141 |
+
print("INFO: 减小patch_size 直到适合内存")
|
142 |
+
raise e
|
143 |
+
# 拼接角点数组
|
144 |
+
pred_corners[:, 0] = pred_corners[:, 0] * scale + bbox[0]
|
145 |
+
pred_corners[:, 1] = pred_corners[:, 1] * scale + bbox[1]
|
146 |
+
pred_corners_viz = pred_corners
|
147 |
+
viz_image = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges,
|
148 |
+
edge_confs=edge_confs, shpfile=False)
|
149 |
+
|
150 |
+
hr_image = Image.fromarray(np.uint8(viz_image))
|
151 |
+
else:
|
152 |
+
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image = self.predict_no_patching(image)
|
153 |
+
#---------------------------------------------------------#
|
154 |
+
# 此处推理结束
|
155 |
+
# 开始在原图上根据角点坐标绘制角点与边缘
|
156 |
+
#---------------------------------------------------------#
|
157 |
+
pred_corners_viz = pred_corners
|
158 |
+
image_result = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges,
|
159 |
+
edge_confs=edge_confs, shpfile=True)
|
160 |
+
hr_image = Image.fromarray(np.uint8(image_result))
|
161 |
+
return hr_image
|
162 |
+
|
163 |
+
#---------------------------------------------------------#
|
164 |
+
# 不使用切片预测图像
|
165 |
+
# 返回预测后的角点坐标、边缘
|
166 |
+
#---------------------------------------------------------#
|
167 |
+
def predict_no_patching(self, image):
|
168 |
+
image = image.resize(tuple(self.image_size), Image.BICUBIC)
|
169 |
+
# 将Image类转换为numpy
|
170 |
+
image = np.array(image, dtype=np.uint8)
|
171 |
+
# 复制输入的原图
|
172 |
+
viz_image = image.copy()
|
173 |
+
# preprocess image numpy->tensor
|
174 |
+
image = process_image(image)
|
175 |
+
# 获取所有像素的位置编码, 默认的图像尺度为256
|
176 |
+
pixels, pixel_features = get_pixel_features(image_size=self.image_size[0])
|
177 |
+
# 开始模型的预测
|
178 |
+
with torch.no_grad():
|
179 |
+
|
180 |
+
image_feats, feat_mask, all_image_feats = self.backbone(image)
|
181 |
+
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
|
182 |
+
preds_s1 = self.corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
|
183 |
+
|
184 |
+
c_outputs = preds_s1
|
185 |
+
# 获取预测出的角点
|
186 |
+
c_outputs_np = c_outputs[0].detach().cpu().numpy()
|
187 |
+
# 筛选出大于阈值的角点的坐标
|
188 |
+
pos_indices = np.where(c_outputs_np >= self.corner_thresh)
|
189 |
+
pred_corners = pixels[pos_indices]
|
190 |
+
# 获取对应预测角点的置信度
|
191 |
+
pred_confs = c_outputs_np[pos_indices]
|
192 |
+
# 根据预测角点的置信度进行非极大抑制
|
193 |
+
pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
|
194 |
+
# 对角点两两排列组合,获取所有的角点对
|
195 |
+
pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
|
196 |
+
# 获取角点数量
|
197 |
+
corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
|
198 |
+
max_candidates = torch.stack([corner_nums.max() * self.corner_to_edge_multiplier] * len(corner_nums), dim=0)
|
199 |
+
# 无序不重复集合
|
200 |
+
all_pos_ids = set()
|
201 |
+
# 边缘置信度字典
|
202 |
+
all_edge_confs = dict()
|
203 |
+
# 推理的迭代次数为3次
|
204 |
+
for tt in range(self.infer_times):
|
205 |
+
if tt == 0:
|
206 |
+
# gt_values和边缘掩膜大小���样且初始值为0
|
207 |
+
gt_values = torch.zeros_like(edge_mask).long()
|
208 |
+
# 第一二维度的数值设置为2
|
209 |
+
gt_values[:, :] = 2
|
210 |
+
|
211 |
+
# 开始预测边缘
|
212 |
+
s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = self.edge_model(image_feats,
|
213 |
+
feat_mask,pixel_features,edge_coords, edge_mask,gt_values, corner_nums,max_candidates,True)
|
214 |
+
num_total = s1_logits.shape[2]
|
215 |
+
num_selected = selected_ids.shape[1]
|
216 |
+
num_filtered = num_total - num_selected
|
217 |
+
# 将输出值固定为(0,1)之间的概率分布
|
218 |
+
s1_preds = s1_logits.squeeze().softmax(0)
|
219 |
+
s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
|
220 |
+
s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
|
221 |
+
s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
|
222 |
+
s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
|
223 |
+
s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
|
224 |
+
|
225 |
+
selected_ids = selected_ids.squeeze().detach().cpu().numpy()
|
226 |
+
# 进行筛选,将(0.9, 1)之间的设置为T,将(0.01,0.9)之间的设置为U,(0,0.01)之间的设置为F
|
227 |
+
if tt != self.infer_times - 1:
|
228 |
+
s2_preds_np = s2_preds_hb_np
|
229 |
+
|
230 |
+
pos_edge_ids = np.where(s2_preds_np >= 0.9)
|
231 |
+
neg_edge_ids = np.where(s2_preds_np <= 0.01)
|
232 |
+
for pos_id in pos_edge_ids[0]:
|
233 |
+
actual_id = selected_ids[pos_id]
|
234 |
+
if gt_values[0, actual_id] != 2:
|
235 |
+
continue
|
236 |
+
all_pos_ids.add(actual_id)
|
237 |
+
all_edge_confs[actual_id] = s2_preds_np[pos_id]
|
238 |
+
gt_values[0, actual_id] = 1
|
239 |
+
for neg_id in neg_edge_ids[0]:
|
240 |
+
actual_id = selected_ids[neg_id]
|
241 |
+
if gt_values[0, actual_id] != 2:
|
242 |
+
continue
|
243 |
+
gt_values[0, actual_id] = 0
|
244 |
+
num_to_pred = (gt_values == 2).sum()
|
245 |
+
if num_to_pred <= num_filtered:
|
246 |
+
break
|
247 |
+
else:
|
248 |
+
s2_preds_np = s2_preds_hb_np
|
249 |
+
|
250 |
+
pos_edge_ids = np.where(s2_preds_np >= 0.5)
|
251 |
+
for pos_id in pos_edge_ids[0]:
|
252 |
+
actual_id = selected_ids[pos_id]
|
253 |
+
if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
|
254 |
+
continue
|
255 |
+
all_pos_ids.add(actual_id)
|
256 |
+
all_edge_confs[actual_id] = s2_preds_np[pos_id]
|
257 |
+
pos_edge_ids = list(all_pos_ids)
|
258 |
+
edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
|
259 |
+
pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
|
260 |
+
edge_confs = np.array(edge_confs)
|
261 |
+
|
262 |
+
if self.image_size[0] != 256:
|
263 |
+
pred_corners = pred_corners / (self.image_size[0] / 256)
|
264 |
+
|
265 |
+
return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image
|
266 |
+
#---------------------------------------------------------#
|
267 |
+
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
268 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
269 |
+
#---------------------------------------------------------#
|
270 |
+
def cvtColor(image):
|
271 |
+
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
272 |
+
return image
|
273 |
+
else:
|
274 |
+
image = image.convert('RGB')
|
275 |
+
return image
|
276 |
+
#---------------------------------------------------------#
|
277 |
+
# 根据角点的置信度排序,并筛选出大于置信度的角点坐标
|
278 |
+
#---------------------------------------------------------#
|
279 |
+
def corner_nms(preds, confs, image_size):
|
280 |
+
data = np.zeros([image_size, image_size])
|
281 |
+
neighborhood_size = 5
|
282 |
+
threshold = 0
|
283 |
+
|
284 |
+
for i in range(len(preds)):
|
285 |
+
data[preds[i, 1], preds[i, 0]] = confs[i]
|
286 |
+
|
287 |
+
data_max = scipy.ndimage.filters.maximum_filter(data, neighborhood_size)
|
288 |
+
maxima = (data == data_max)
|
289 |
+
data_min = scipy.ndimage.filters.minimum_filter(data, neighborhood_size)
|
290 |
+
diff = ((data_max - data_min) > threshold)
|
291 |
+
maxima[diff == 0] = 0
|
292 |
+
|
293 |
+
results = np.where(maxima > 0)
|
294 |
+
filtered_preds = np.stack([results[1], results[0]], axis=-1)
|
295 |
+
|
296 |
+
new_confs = list()
|
297 |
+
for i, pred in enumerate(filtered_preds):
|
298 |
+
new_confs.append(data[pred[1], pred[0]])
|
299 |
+
new_confs = np.array(new_confs)
|
300 |
+
|
301 |
+
return filtered_preds, new_confs
|
302 |
+
|
303 |
+
def process_image(img):
|
304 |
+
mean = [0.485, 0.456, 0.406]
|
305 |
+
std = [0.229, 0.224, 0.225]
|
306 |
+
img = skimage.img_as_float(img)
|
307 |
+
img = img.transpose((2, 0, 1))
|
308 |
+
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
|
309 |
+
img = torch.Tensor(img).cuda()
|
310 |
+
img = img.unsqueeze(0)
|
311 |
+
return img
|
312 |
+
|
313 |
+
def postprocess_preds(corners, confs, edges):
|
314 |
+
corner_degrees = dict()
|
315 |
+
for edge_i, edge_pair in enumerate(edges):
|
316 |
+
corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
|
317 |
+
corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
|
318 |
+
good_ids = [i for i in range(len(corners)) if i in corner_degrees]
|
319 |
+
if len(good_ids) == len(corners):
|
320 |
+
return corners, confs, edges
|
321 |
+
else:
|
322 |
+
good_corners = corners[good_ids]
|
323 |
+
good_confs = confs[good_ids]
|
324 |
+
id_mapping = {value: idx for idx, value in enumerate(good_ids)}
|
325 |
+
new_edges = list()
|
326 |
+
for edge_pair in edges:
|
327 |
+
new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
|
328 |
+
new_edges.append(new_pair)
|
329 |
+
new_edges = np.array(new_edges)
|
330 |
+
return good_corners, good_confs, new_edges
|
331 |
+
|
332 |
+
#---------------------------------------------------------#
|
333 |
+
# 将输入图像根据角点坐标进行可视化处理
|
334 |
+
# 不同于源代码,我们需要直接返回图像对象而不是保存到指定地址
|
335 |
+
#---------------------------------------------------------#
|
336 |
+
def visualize_cond_generation(positive_pixels, confs, image, gt_corners=None, prec=None, recall=None,
|
337 |
+
image_masks=None, edges=None, edge_confs=None, shpfile=False):
|
338 |
+
# 复制原图
|
339 |
+
image = image.copy()
|
340 |
+
if confs is not None:
|
341 |
+
viz_confs = confs
|
342 |
+
|
343 |
+
if edges is not None:
|
344 |
+
preds = positive_pixels.astype(int)
|
345 |
+
c_degrees = dict()
|
346 |
+
for edge_i, edge_pair in enumerate(edges):
|
347 |
+
conf = (edge_confs[edge_i] * 2) - 1
|
348 |
+
cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
|
349 |
+
c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
|
350 |
+
c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
|
351 |
+
|
352 |
+
for idx, c in enumerate(positive_pixels):
|
353 |
+
if edges is not None and idx not in c_degrees:
|
354 |
+
continue
|
355 |
+
if confs is None:
|
356 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
|
357 |
+
else:
|
358 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
|
359 |
+
# if edges is not None:
|
360 |
+
# cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
|
361 |
+
# 0.5, (255, 0, 0), 1, cv2.LINE_AA)
|
362 |
+
|
363 |
+
if gt_corners is not None:
|
364 |
+
for c in gt_corners:
|
365 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
|
366 |
+
|
367 |
+
if image_masks is not None:
|
368 |
+
mask_ids = np.where(image_masks == 1)[0]
|
369 |
+
for mask_id in mask_ids:
|
370 |
+
y_idx = mask_id // 64
|
371 |
+
x_idx = (mask_id - y_idx * 64)
|
372 |
+
x_coord = x_idx * 4
|
373 |
+
y_coord = y_idx * 4
|
374 |
+
cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
|
375 |
+
|
376 |
+
# if confs is not None:
|
377 |
+
# cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
|
378 |
+
# 0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
379 |
+
if prec is not None:
|
380 |
+
if isinstance(prec, tuple):
|
381 |
+
cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
|
382 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
383 |
+
0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
384 |
+
cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
|
385 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
386 |
+
0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
387 |
+
else:
|
388 |
+
cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
|
389 |
+
0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
390 |
+
|
391 |
+
# 是否生成shp文件
|
392 |
+
if shpfile:
|
393 |
+
preds = positive_pixels.astype(int)
|
394 |
+
# 获取点列表
|
395 |
+
Polyline = []
|
396 |
+
for edge_i, edge_pair in enumerate(edges):
|
397 |
+
Polyline.append([preds[edge_pair[0]], preds[edge_pair[1]]])
|
398 |
+
Polyline = np.array(Polyline, dtype=np.int32)
|
399 |
+
# 写入shp文件
|
400 |
+
writeShp(save_file_dir="shpfile", Polyline=Polyline)
|
401 |
+
|
402 |
+
|
403 |
+
return image
|
404 |
+
|
405 |
+
def writeShp(save_file_dir="shpfile", Polyline=None):
|
406 |
+
# 创建文件夹
|
407 |
+
if os.path.exists(save_file_dir) is False:
|
408 |
+
os.makedirs(save_file_dir)
|
409 |
+
# 支持中文路径
|
410 |
+
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
|
411 |
+
# 属性表字段支持中文
|
412 |
+
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
|
413 |
+
# 注册驱动
|
414 |
+
ogr.RegisterAll()
|
415 |
+
# 创建shp数据
|
416 |
+
strDriverName = "ESRI Shapefile"
|
417 |
+
oDriver = ogr.GetDriverByName(strDriverName)
|
418 |
+
if oDriver == None:
|
419 |
+
return "驱动不可用:"+strDriverName
|
420 |
+
# 创建数据源
|
421 |
+
file_path = os.path.join(save_file_dir, "result.shp")
|
422 |
+
oDS = oDriver.CreateDataSource(file_path)
|
423 |
+
if oDS == None:
|
424 |
+
return "创建文件失败:result.shp"
|
425 |
+
if Polyline is not None:
|
426 |
+
# 创建一个多边形图层,指定坐标系为WGS84
|
427 |
+
papszLCO = []
|
428 |
+
geosrs = osr.SpatialReference()
|
429 |
+
geosrs.SetWellKnownGeogCS("WGS84")
|
430 |
+
# 线:ogr_type = ogr.wkbLineString
|
431 |
+
# 点:ogr_type = ogr.wkbPoint
|
432 |
+
ogr_type = ogr.wkbMultiLineString
|
433 |
+
# 面的类型为Polygon,线的类型为Polyline,点的类型为Point
|
434 |
+
oLayer = oDS.CreateLayer("Polyline", geosrs, ogr_type, papszLCO)
|
435 |
+
if oLayer == None:
|
436 |
+
return "图层创建失败!"
|
437 |
+
# 创建属性表
|
438 |
+
# 创建id字段
|
439 |
+
oId = ogr.FieldDefn("id", ogr.OFTInteger)
|
440 |
+
oLayer.CreateField(oId, 1)
|
441 |
+
# 创建name字段
|
442 |
+
oName = ogr.FieldDefn("name", ogr.OFTString)
|
443 |
+
oLayer.CreateField(oName, 1)
|
444 |
+
oDefn = oLayer.GetLayerDefn()
|
445 |
+
# 创建要素
|
446 |
+
# 数据集
|
447 |
+
# wkt_geom id name
|
448 |
+
point_str_list = ['({} {},{} {})'.format(row[0, 0], row[0, 1], row[1, 0], row[1, 1]) for row in Polyline]
|
449 |
+
Polyline_Wkt = ','.join(point_str_list)
|
450 |
+
features = ['Polyline0;MULTILINESTRING({})'.format(Polyline_Wkt)]
|
451 |
+
for index, f in enumerate(features):
|
452 |
+
oFeaturePolygon = ogr.Feature(oDefn)
|
453 |
+
oFeaturePolygon.SetField("id",index)
|
454 |
+
oFeaturePolygon.SetField("name",f.split(";")[0])
|
455 |
+
geomPolygon = ogr.CreateGeometryFromWkt(f.split(";")[1])
|
456 |
+
oFeaturePolygon.SetGeometry(geomPolygon)
|
457 |
+
oLayer.CreateFeature(oFeaturePolygon)
|
458 |
+
# 创建完成后,关闭进程
|
459 |
+
oDS.Destroy()
|
460 |
+
return "数据集创建完成!"
|
LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
README.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--
|
2 |
+
* @Description:
|
3 |
+
* @Author: Egrt
|
4 |
+
* @Date: 2022-11-23 15:20:00
|
5 |
+
* @LastEditors: Egrt
|
6 |
+
* @LastEditTime: 2022-11-23 15:27:41
|
7 |
+
-->
|
8 |
+
---
|
9 |
+
title: HEAT
|
10 |
+
emoji: 📈
|
11 |
+
colorFrom: indigo
|
12 |
+
colorTo: yellow
|
13 |
+
sdk: gradio
|
14 |
+
sdk_version: 3.11.0
|
15 |
+
app_file: app.py
|
16 |
+
pinned: false
|
17 |
+
license: apache-2.0
|
18 |
+
---
|
19 |
+
|
20 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
21 |
+
|
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Egrt
|
3 |
+
Date: 2022-01-13 13:34:10
|
4 |
+
LastEditors: [egrt]
|
5 |
+
LastEditTime: 2022-08-15 19:40:32
|
6 |
+
FilePath: \MaskGAN\app.py
|
7 |
+
'''
|
8 |
+
from HEAT import HEAT
|
9 |
+
import gradio as gr
|
10 |
+
import os
|
11 |
+
heat = HEAT()
|
12 |
+
|
13 |
+
# --------模型推理---------- #
|
14 |
+
def inference(img):
|
15 |
+
image_result = heat.detect_one_image(img)
|
16 |
+
return image_result
|
17 |
+
|
18 |
+
# --------网页信息---------- #
|
19 |
+
title = "HEAT"
|
20 |
+
description = "HEAT: Holistic Edge Attention Transformer for Structured Reconstruction @Luuuu"
|
21 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>HEAT: Holistic Edge Attention Transformer for Structured Reconstruction </a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>"
|
22 |
+
example_img_dir = 'images/'
|
23 |
+
example_img_name = os.listdir(example_img_dir)
|
24 |
+
examples=[[os.path.join(example_img_dir, image_path)] for image_path in example_img_name if image_path.endswith(('.jpg','.jpeg', '.png'))]
|
25 |
+
gr.Interface(
|
26 |
+
inference,
|
27 |
+
[gr.inputs.Image(type="pil", label="Input")],
|
28 |
+
gr.outputs.Image(type="pil", label="Output"),
|
29 |
+
title=title,
|
30 |
+
description=description,
|
31 |
+
article=article,
|
32 |
+
examples=examples
|
33 |
+
).launch()
|
arguments.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def get_args_parser():
|
5 |
+
parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
|
6 |
+
parser.add_argument('--exp_dataset', default='outdoor',
|
7 |
+
help='the dataset for experiments, outdoor/s3d_floorplan')
|
8 |
+
parser.add_argument('--lr', default=2e-4, type=float)
|
9 |
+
parser.add_argument('--batch_size', default=16, type=int)
|
10 |
+
parser.add_argument('--weight_decay', default=1e-5, type=float)
|
11 |
+
parser.add_argument('--epochs', default=800, type=int)
|
12 |
+
parser.add_argument('--lr_drop', default=600, type=int)
|
13 |
+
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
14 |
+
help='gradient clipping max norm')
|
15 |
+
parser.add_argument('--print_freq', default=40, type=int)
|
16 |
+
parser.add_argument('--output_dir', default='./checkpoints/ckpts_heat_outdoor_256',
|
17 |
+
help='path where to save, empty for no saving')
|
18 |
+
parser.add_argument('--resume', default='',
|
19 |
+
help='resume from checkpoint')
|
20 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
21 |
+
help='start epoch')
|
22 |
+
parser.add_argument('--num_workers', default=4, type=int)
|
23 |
+
parser.add_argument('--image_size', default=256, type=int)
|
24 |
+
parser.add_argument('--max_corner_num', default=150, type=int,
|
25 |
+
help='the max number of corners allowed in the experiments')
|
26 |
+
parser.add_argument('--corner_to_edge_multiplier', default=3, type=int,
|
27 |
+
help='the max number of edges based on the number of corner candidates (assuming the '
|
28 |
+
'average degree never greater than 6)')
|
29 |
+
parser.add_argument('--lambda_corner', default=0.05, type=float,
|
30 |
+
help='the max number of corners allowed in the experiments')
|
31 |
+
parser.add_argument('--run_validation', action='store_true',
|
32 |
+
help='Whether run validation or not, default: False')
|
33 |
+
return parser
|
assets/img/pipeline.png
ADDED
![]() |
assets/img/problem_description.png
ADDED
![]() |
datasets/__init__.py
ADDED
File without changes
|
datasets/corners.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from scipy.ndimage import gaussian_filter
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
mean = [0.485, 0.456, 0.406]
|
7 |
+
std = [0.229, 0.224, 0.225]
|
8 |
+
|
9 |
+
|
10 |
+
class CornersDataset(Dataset):
|
11 |
+
def __init__(self, image_size=256, inference=False):
|
12 |
+
super(CornersDataset, self).__init__()
|
13 |
+
self.image_size = image_size
|
14 |
+
self.inference = inference
|
15 |
+
self._data_names = []
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
raise len(self._data_names)
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def process_data(self, data):
|
24 |
+
img = data['image']
|
25 |
+
corners = data['corners']
|
26 |
+
annot = data['annot']
|
27 |
+
|
28 |
+
# pre-process the image to use ImageNet-pretrained backbones
|
29 |
+
img = img.transpose((2, 0, 1))
|
30 |
+
raw_img = img.copy()
|
31 |
+
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
|
32 |
+
img = img.astype(np.float32)
|
33 |
+
|
34 |
+
corners = np.array(corners)
|
35 |
+
|
36 |
+
all_data = {
|
37 |
+
"annot": annot,
|
38 |
+
"name": data['name'],
|
39 |
+
'img': img,
|
40 |
+
'annot_path': data['annot_path'],
|
41 |
+
'img_path': data['img_path'],
|
42 |
+
'det_path': data['det_path'],
|
43 |
+
'raw_img': raw_img,
|
44 |
+
}
|
45 |
+
|
46 |
+
# corner labels for training
|
47 |
+
if not self.inference:
|
48 |
+
pixel_labels, gauss_labels = self.get_corner_labels(corners)
|
49 |
+
all_data['pixel_labels'] = pixel_labels
|
50 |
+
all_data['gauss_labels'] = gauss_labels
|
51 |
+
|
52 |
+
return all_data
|
53 |
+
|
54 |
+
def get_corner_labels(self, corners):
|
55 |
+
labels = np.zeros((self.image_size, self.image_size))
|
56 |
+
corners = corners.round()
|
57 |
+
xint, yint = corners[:, 0].astype(np.int), corners[:, 1].astype(np.int)
|
58 |
+
labels[yint, xint] = 1
|
59 |
+
|
60 |
+
gauss_labels = gaussian_filter(labels, sigma=2)
|
61 |
+
gauss_labels = gauss_labels / gauss_labels.max()
|
62 |
+
return labels, gauss_labels
|
63 |
+
|
64 |
+
def resize_data(self, image, annot, det_corners):
|
65 |
+
new_image = cv2.resize(image, (self.image_size, self.image_size))
|
66 |
+
new_annot = {}
|
67 |
+
r = self.image_size / 256
|
68 |
+
for c, connections in annot.items():
|
69 |
+
new_c = tuple(np.array(c) * r)
|
70 |
+
new_connections = [other_c * r for other_c in connections]
|
71 |
+
new_annot[new_c] = new_connections
|
72 |
+
new_dets = det_corners * r
|
73 |
+
return new_image, new_annot, new_dets
|
74 |
+
|
75 |
+
def random_aug_annot(self, img, annot, det_corners=None):
|
76 |
+
# do random flipping
|
77 |
+
img, annot, det_corners = self.random_flip(img, annot, det_corners)
|
78 |
+
|
79 |
+
# prepare random augmentation parameters (only do random rotation for now)
|
80 |
+
theta = np.random.randint(0, 360) / 360 * np.pi * 2
|
81 |
+
r = self.image_size / 256
|
82 |
+
origin = [127 * r, 127 * r]
|
83 |
+
p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
|
84 |
+
p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
|
85 |
+
p1_old = [127 * r, 127 * r - 100 * r] # y_axis
|
86 |
+
p2_old = [127 * r + 100 * r, 127 * r] # x_axis
|
87 |
+
pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
|
88 |
+
pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
|
89 |
+
M_rot = cv2.getAffineTransform(pts1, pts2)
|
90 |
+
|
91 |
+
# Combine annotation corners and detection corners
|
92 |
+
all_corners = list(annot.keys())
|
93 |
+
if det_corners is not None:
|
94 |
+
for i in range(det_corners.shape[0]):
|
95 |
+
all_corners.append(tuple(det_corners[i]))
|
96 |
+
all_corners_ = np.array(all_corners)
|
97 |
+
|
98 |
+
# Do the corner transform within a big matrix transformation
|
99 |
+
corner_mapping = dict()
|
100 |
+
ones = np.ones([all_corners_.shape[0], 1])
|
101 |
+
all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
|
102 |
+
aug_corners = np.matmul(M_rot, all_corners_.T).T
|
103 |
+
|
104 |
+
for idx, corner in enumerate(all_corners):
|
105 |
+
corner_mapping[corner] = aug_corners[idx]
|
106 |
+
|
107 |
+
# If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
|
108 |
+
new_corners = np.array(list(corner_mapping.values()))
|
109 |
+
if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
|
110 |
+
# return self.random_aug_annot(img, annot, det_corners)
|
111 |
+
return img, annot, None, det_corners
|
112 |
+
|
113 |
+
# build the new annot dict
|
114 |
+
aug_annot = dict()
|
115 |
+
for corner, connections in annot.items():
|
116 |
+
new_corner = corner_mapping[corner]
|
117 |
+
tuple_new_corner = tuple(new_corner)
|
118 |
+
aug_annot[tuple_new_corner] = list()
|
119 |
+
for to_corner in connections:
|
120 |
+
aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
|
121 |
+
|
122 |
+
# Also transform the image correspondingly
|
123 |
+
rows, cols, ch = img.shape
|
124 |
+
new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
|
125 |
+
|
126 |
+
y_start = (new_img.shape[0] - self.image_size) // 2
|
127 |
+
x_start = (new_img.shape[1] - self.image_size) // 2
|
128 |
+
aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
|
129 |
+
|
130 |
+
if det_corners is None:
|
131 |
+
return aug_img, aug_annot, corner_mapping, None
|
132 |
+
else:
|
133 |
+
aug_det_corners = list()
|
134 |
+
for corner in det_corners:
|
135 |
+
new_corner = corner_mapping[tuple(corner)]
|
136 |
+
aug_det_corners.append(new_corner)
|
137 |
+
aug_det_corners = np.array(aug_det_corners)
|
138 |
+
return aug_img, aug_annot, corner_mapping, aug_det_corners
|
139 |
+
|
140 |
+
def random_flip(self, img, annot, det_corners):
|
141 |
+
height, width, _ = img.shape
|
142 |
+
rand_int = np.random.randint(0, 4)
|
143 |
+
if rand_int == 0:
|
144 |
+
return img, annot, det_corners
|
145 |
+
|
146 |
+
all_corners = list(annot.keys())
|
147 |
+
if det_corners is not None:
|
148 |
+
for i in range(det_corners.shape[0]):
|
149 |
+
all_corners.append(tuple(det_corners[i]))
|
150 |
+
new_corners = np.array(all_corners)
|
151 |
+
|
152 |
+
if rand_int == 1:
|
153 |
+
img = img[:, ::-1, :]
|
154 |
+
new_corners[:, 0] = width - new_corners[:, 0]
|
155 |
+
elif rand_int == 2:
|
156 |
+
img = img[::-1, :, :]
|
157 |
+
new_corners[:, 1] = height - new_corners[:, 1]
|
158 |
+
else:
|
159 |
+
img = img[::-1, ::-1, :]
|
160 |
+
new_corners[:, 0] = width - new_corners[:, 0]
|
161 |
+
new_corners[:, 1] = height - new_corners[:, 1]
|
162 |
+
|
163 |
+
new_corners = np.clip(new_corners, 0, self.image_size - 1) # clip into [0, 255]
|
164 |
+
corner_mapping = dict()
|
165 |
+
for idx, corner in enumerate(all_corners):
|
166 |
+
corner_mapping[corner] = new_corners[idx]
|
167 |
+
|
168 |
+
aug_annot = dict()
|
169 |
+
for corner, connections in annot.items():
|
170 |
+
new_corner = corner_mapping[corner]
|
171 |
+
tuple_new_corner = tuple(new_corner)
|
172 |
+
aug_annot[tuple_new_corner] = list()
|
173 |
+
for to_corner in connections:
|
174 |
+
aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
|
175 |
+
|
176 |
+
if det_corners is not None:
|
177 |
+
aug_det_corners = list()
|
178 |
+
for corner in det_corners:
|
179 |
+
new_corner = corner_mapping[tuple(corner)]
|
180 |
+
aug_det_corners.append(new_corner)
|
181 |
+
det_corners = np.array(aug_det_corners)
|
182 |
+
|
183 |
+
return img, aug_annot, det_corners
|
datasets/data_utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import ImageFilter
|
2 |
+
from torchvision import transforms
|
3 |
+
import numpy as np
|
4 |
+
from utils.nn_utils import positional_encoding_2d
|
5 |
+
from torch.utils.data.dataloader import default_collate
|
6 |
+
|
7 |
+
|
8 |
+
def RandomBlur(radius=2.):
|
9 |
+
blur = GaussianBlur(radius=radius)
|
10 |
+
full_transform = transforms.RandomApply([blur], p=.3)
|
11 |
+
return full_transform
|
12 |
+
|
13 |
+
|
14 |
+
class ImageFilterTransform(object):
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __call__(self, img):
|
20 |
+
return img.filter(self.filter)
|
21 |
+
|
22 |
+
|
23 |
+
class GaussianBlur(ImageFilterTransform):
|
24 |
+
|
25 |
+
def __init__(self, radius=2.):
|
26 |
+
self.filter = ImageFilter.GaussianBlur(radius=radius)
|
27 |
+
|
28 |
+
|
29 |
+
def collate_fn(data):
|
30 |
+
batched_data = {}
|
31 |
+
for field in data[0].keys():
|
32 |
+
if field in ['annot', 'rec_mat']:
|
33 |
+
batch_values = [item[field] for item in data]
|
34 |
+
else:
|
35 |
+
batch_values = default_collate([d[field] for d in data])
|
36 |
+
if field in ['pixel_features', 'pixel_labels', 'gauss_labels']:
|
37 |
+
batch_values = batch_values.float()
|
38 |
+
batched_data[field] = batch_values
|
39 |
+
|
40 |
+
return batched_data
|
41 |
+
|
42 |
+
|
43 |
+
def get_pixel_features(image_size, d_pe=128):
|
44 |
+
all_pe = positional_encoding_2d(d_pe, image_size, image_size)
|
45 |
+
pixels_x = np.arange(0, image_size)
|
46 |
+
pixels_y = np.arange(0, image_size)
|
47 |
+
|
48 |
+
xv, yv = np.meshgrid(pixels_x, pixels_y)
|
49 |
+
all_pixels = list()
|
50 |
+
for i in range(xv.shape[0]):
|
51 |
+
pixs = np.stack([xv[i], yv[i]], axis=-1)
|
52 |
+
all_pixels.append(pixs)
|
53 |
+
pixels = np.stack(all_pixels, axis=0)
|
54 |
+
|
55 |
+
pixel_features = all_pe[:, pixels[:, :, 1], pixels[:, :, 0]]
|
56 |
+
pixel_features = pixel_features.permute(1, 2, 0)
|
57 |
+
return pixels, pixel_features
|
datasets/outdoor_buildings.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from datasets.corners import CornersDataset
|
3 |
+
import os
|
4 |
+
import skimage
|
5 |
+
import cv2
|
6 |
+
from torchvision import transforms
|
7 |
+
from PIL import Image
|
8 |
+
from datasets.data_utils import RandomBlur
|
9 |
+
|
10 |
+
class OutdoorBuildingDataset(CornersDataset):
|
11 |
+
def __init__(self, data_path, det_path, phase='train', image_size=256, rand_aug=True,
|
12 |
+
inference=False):
|
13 |
+
super(OutdoorBuildingDataset, self).__init__(image_size, inference)
|
14 |
+
self.data_path = data_path
|
15 |
+
self.det_path = det_path
|
16 |
+
self.phase = phase
|
17 |
+
self.rand_aug = rand_aug
|
18 |
+
self.image_size = image_size
|
19 |
+
self.inference = inference
|
20 |
+
|
21 |
+
blur_transform = RandomBlur()
|
22 |
+
self.train_transform = transforms.Compose([
|
23 |
+
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
24 |
+
transforms.RandomGrayscale(p=0.3),
|
25 |
+
blur_transform])
|
26 |
+
|
27 |
+
if phase == 'train':
|
28 |
+
datalistfile = os.path.join(data_path, 'train_list.txt')
|
29 |
+
self.training = True
|
30 |
+
else:
|
31 |
+
datalistfile = os.path.join(data_path, 'valid_list.txt')
|
32 |
+
self.training = False
|
33 |
+
with open(datalistfile, 'r') as f:
|
34 |
+
_data_names = f.readlines()
|
35 |
+
if phase == 'train':
|
36 |
+
self._data_names = _data_names
|
37 |
+
else:
|
38 |
+
# based on the data split rule from previous works
|
39 |
+
if phase == 'valid':
|
40 |
+
self._data_names = _data_names[:50]
|
41 |
+
elif phase == 'test':
|
42 |
+
self._data_names = _data_names[50:]
|
43 |
+
else:
|
44 |
+
raise ValueError('Invalid phase {}'.format(phase))
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self._data_names)
|
48 |
+
|
49 |
+
def __getitem__(self, idx):
|
50 |
+
data_name = self._data_names[idx][:-1]
|
51 |
+
annot_path = os.path.join(self.data_path, 'annot', data_name + '.npy')
|
52 |
+
annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
|
53 |
+
det_path = os.path.join(self.det_path, data_name + '.npy')
|
54 |
+
det_corners = np.array(np.load(det_path, allow_pickle=True)) # [N, 2]
|
55 |
+
det_corners = det_corners[:, ::-1] # turn into x,y format
|
56 |
+
|
57 |
+
img_path = os.path.join(self.data_path, 'rgb', data_name + '.jpg')
|
58 |
+
rgb = cv2.imread(img_path)
|
59 |
+
|
60 |
+
if self.image_size != 256:
|
61 |
+
rgb, annot, det_corners = self.resize_data(rgb, annot, det_corners)
|
62 |
+
|
63 |
+
if self.rand_aug:
|
64 |
+
image, annot, corner_mapping, det_corners = self.random_aug_annot(rgb, annot, det_corners=det_corners)
|
65 |
+
else:
|
66 |
+
image = rgb
|
67 |
+
rec_mat = None
|
68 |
+
|
69 |
+
corners = np.array(list(annot.keys()))[:, [1, 0]]
|
70 |
+
|
71 |
+
if not self.inference and len(corners) > 100:
|
72 |
+
new_idx = np.random.randint(0, len(self))
|
73 |
+
return self.__getitem__(new_idx)
|
74 |
+
|
75 |
+
if self.training:
|
76 |
+
# Add some randomness for g.t. corners
|
77 |
+
corners += np.random.normal(0, 0, size=corners.shape)
|
78 |
+
pil_img = Image.fromarray(image)
|
79 |
+
image = self.train_transform(pil_img)
|
80 |
+
image = np.array(image)
|
81 |
+
image = skimage.img_as_float(image)
|
82 |
+
|
83 |
+
# sort by the second value and then the first value, here the corners are in the format of (y, x)
|
84 |
+
sort_idx = np.lexsort(corners.T)
|
85 |
+
corners = corners[sort_idx]
|
86 |
+
|
87 |
+
corner_list = []
|
88 |
+
for corner_i in range(corners.shape[0]):
|
89 |
+
corner_list.append((corners[corner_i][1], corners[corner_i][0])) # to (x, y) format
|
90 |
+
|
91 |
+
raw_data = {
|
92 |
+
'name': data_name,
|
93 |
+
'corners': corner_list,
|
94 |
+
'annot': annot,
|
95 |
+
'image': image,
|
96 |
+
'rec_mat': rec_mat,
|
97 |
+
'annot_path': annot_path,
|
98 |
+
'det_path': det_path,
|
99 |
+
'img_path': img_path,
|
100 |
+
}
|
101 |
+
|
102 |
+
return self.process_data(raw_data)
|
103 |
+
|
104 |
+
def random_aug_annot(self, img, annot, det_corners=None):
|
105 |
+
# do random flipping
|
106 |
+
img, annot, det_corners = self.random_flip(img, annot, det_corners)
|
107 |
+
|
108 |
+
# prepare random augmentation parameters (only do random rotation for now)
|
109 |
+
theta = np.random.randint(0, 360) / 360 * np.pi * 2
|
110 |
+
r = self.image_size / 256
|
111 |
+
origin = [127 * r, 127 * r]
|
112 |
+
p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
|
113 |
+
p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
|
114 |
+
p1_old = [127 * r, 127 * r - 100 * r] # y_axis
|
115 |
+
p2_old = [127 * r + 100 * r, 127 * r] # x_axis
|
116 |
+
pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
|
117 |
+
pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
|
118 |
+
M_rot = cv2.getAffineTransform(pts1, pts2)
|
119 |
+
|
120 |
+
# Combine annotation corners and detection corners
|
121 |
+
all_corners = list(annot.keys())
|
122 |
+
if det_corners is not None:
|
123 |
+
for i in range(det_corners.shape[0]):
|
124 |
+
all_corners.append(tuple(det_corners[i]))
|
125 |
+
all_corners_ = np.array(all_corners)
|
126 |
+
|
127 |
+
# Do the corner transform within a big matrix transformation
|
128 |
+
corner_mapping = dict()
|
129 |
+
ones = np.ones([all_corners_.shape[0], 1])
|
130 |
+
all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
|
131 |
+
aug_corners = np.matmul(M_rot, all_corners_.T).T
|
132 |
+
|
133 |
+
for idx, corner in enumerate(all_corners):
|
134 |
+
corner_mapping[corner] = aug_corners[idx]
|
135 |
+
|
136 |
+
# If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
|
137 |
+
new_corners = np.array(list(corner_mapping.values()))
|
138 |
+
if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
|
139 |
+
# return self.random_aug_annot(img, annot, det_corners)
|
140 |
+
return img, annot, None, det_corners
|
141 |
+
|
142 |
+
# build the new annot dict
|
143 |
+
aug_annot = dict()
|
144 |
+
for corner, connections in annot.items():
|
145 |
+
new_corner = corner_mapping[corner]
|
146 |
+
tuple_new_corner = tuple(new_corner)
|
147 |
+
aug_annot[tuple_new_corner] = list()
|
148 |
+
for to_corner in connections:
|
149 |
+
aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
|
150 |
+
|
151 |
+
# Also transform the image correspondingly
|
152 |
+
rows, cols, ch = img.shape
|
153 |
+
new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
|
154 |
+
|
155 |
+
y_start = (new_img.shape[0] - self.image_size) // 2
|
156 |
+
x_start = (new_img.shape[1] - self.image_size) // 2
|
157 |
+
aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
|
158 |
+
|
159 |
+
if det_corners is None:
|
160 |
+
return aug_img, aug_annot, corner_mapping, None
|
161 |
+
else:
|
162 |
+
aug_det_corners = list()
|
163 |
+
for corner in det_corners:
|
164 |
+
new_corner = corner_mapping[tuple(corner)]
|
165 |
+
aug_det_corners.append(new_corner)
|
166 |
+
aug_det_corners = np.array(aug_det_corners)
|
167 |
+
return aug_img, aug_annot, corner_mapping, aug_det_corners
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == '__main__':
|
172 |
+
from torch.utils.data import DataLoader
|
173 |
+
|
174 |
+
DATAPATH = './data/cities_dataset'
|
175 |
+
DET_PATH = './data/det_final'
|
176 |
+
train_dataset = OutdoorBuildingDataset(DATAPATH, DET_PATH, phase='train')
|
177 |
+
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0,
|
178 |
+
collate_fn=collate_fn)
|
179 |
+
for i, item in enumerate(train_dataloader):
|
180 |
+
import pdb;
|
181 |
+
|
182 |
+
pdb.set_trace()
|
183 |
+
print(item)
|
datasets/s3d_floorplans.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from datasets.corners import CornersDataset
|
3 |
+
import os
|
4 |
+
import skimage
|
5 |
+
import cv2
|
6 |
+
import itertools
|
7 |
+
|
8 |
+
|
9 |
+
mean = [0.485, 0.456, 0.406]
|
10 |
+
std = [0.229, 0.224, 0.225]
|
11 |
+
|
12 |
+
all_combibations = dict()
|
13 |
+
for length in range(2, 351):
|
14 |
+
ids = np.arange(length)
|
15 |
+
combs = np.array(list(itertools.combinations(ids, 2)))
|
16 |
+
all_combibations[length] = combs
|
17 |
+
|
18 |
+
|
19 |
+
class S3DFloorplanDataset(CornersDataset):
|
20 |
+
def __init__(self, data_path, phase='train', image_size=256, rand_aug=True, inference=False):
|
21 |
+
super(S3DFloorplanDataset, self).__init__(image_size, inference)
|
22 |
+
self.data_path = data_path
|
23 |
+
self.phase = phase
|
24 |
+
self.rand_aug = rand_aug
|
25 |
+
|
26 |
+
if phase == 'train':
|
27 |
+
datalistfile = os.path.join(data_path, 'train_list.txt')
|
28 |
+
self.training = True
|
29 |
+
elif phase == 'valid':
|
30 |
+
datalistfile = os.path.join(data_path, 'valid_list.txt')
|
31 |
+
self.training = False
|
32 |
+
else:
|
33 |
+
datalistfile = os.path.join(data_path, 'test_list.txt')
|
34 |
+
self.training = False
|
35 |
+
with open(datalistfile, 'r') as f:
|
36 |
+
self._data_names = f.readlines()
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self._data_names)
|
40 |
+
|
41 |
+
def __getitem__(self, idx):
|
42 |
+
data_name = self._data_names[idx][:-1]
|
43 |
+
annot_path = os.path.join(self.data_path, 'annot', data_name + '.npy')
|
44 |
+
annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
|
45 |
+
|
46 |
+
density_path = os.path.join(self.data_path, 'density', data_name + '.png')
|
47 |
+
normal_path = os.path.join(self.data_path, 'normals', data_name + '.png')
|
48 |
+
|
49 |
+
density = cv2.imread(density_path)
|
50 |
+
normal = cv2.imread(normal_path)
|
51 |
+
rgb = np.maximum(density, normal)
|
52 |
+
|
53 |
+
if self.image_size != 256:
|
54 |
+
rgb, annot, det_corners = self.resize_data(rgb, annot, None)
|
55 |
+
|
56 |
+
if self.rand_aug:
|
57 |
+
image, annot, _ = self.random_aug_annot(rgb, annot, det_corners=None)
|
58 |
+
else:
|
59 |
+
image = rgb
|
60 |
+
rec_mat = None
|
61 |
+
|
62 |
+
corners = np.array(list(annot.keys()))[:, [1, 0]]
|
63 |
+
|
64 |
+
if not self.inference and len(corners) > 150:
|
65 |
+
new_idx = np.random.randint(0, len(self))
|
66 |
+
return self.__getitem__(new_idx)
|
67 |
+
|
68 |
+
if self.training:
|
69 |
+
# Add some randomness for g.t. corners
|
70 |
+
corners += np.random.normal(0, 0, size=corners.shape)
|
71 |
+
|
72 |
+
image = skimage.img_as_float(image)
|
73 |
+
|
74 |
+
# sort by the second value and then the first value, here the corners are in the format of (y, x)
|
75 |
+
sort_idx = np.lexsort(corners.T)
|
76 |
+
corners = corners[sort_idx]
|
77 |
+
|
78 |
+
corner_list = []
|
79 |
+
for corner_i in range(corners.shape[0]):
|
80 |
+
corner_list.append((corners[corner_i][1], corners[corner_i][0])) # to (x, y) format
|
81 |
+
|
82 |
+
raw_data = {
|
83 |
+
'name': data_name,
|
84 |
+
'corners': corner_list,
|
85 |
+
'annot': annot,
|
86 |
+
'image': image,
|
87 |
+
'rec_mat': rec_mat,
|
88 |
+
'annot_path': annot_path,
|
89 |
+
'img_path': density_path,
|
90 |
+
}
|
91 |
+
|
92 |
+
return self.process_data(raw_data)
|
93 |
+
|
94 |
+
def process_data(self, data):
|
95 |
+
img = data['image']
|
96 |
+
corners = data['corners']
|
97 |
+
annot = data['annot']
|
98 |
+
|
99 |
+
# pre-process the image to use ImageNet-pretrained backbones
|
100 |
+
img = img.transpose((2, 0, 1))
|
101 |
+
raw_img = img.copy()
|
102 |
+
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
|
103 |
+
img = img.astype(np.float32)
|
104 |
+
|
105 |
+
corners = np.array(corners)
|
106 |
+
|
107 |
+
all_data = {
|
108 |
+
"annot": annot,
|
109 |
+
"name": data['name'],
|
110 |
+
'img': img,
|
111 |
+
'annot_path': data['annot_path'],
|
112 |
+
'img_path': data['img_path'],
|
113 |
+
'raw_img': raw_img,
|
114 |
+
}
|
115 |
+
|
116 |
+
# corner labels
|
117 |
+
if not self.inference:
|
118 |
+
pixel_labels, gauss_labels = self.get_corner_labels(corners)
|
119 |
+
all_data['pixel_labels'] = pixel_labels
|
120 |
+
all_data['gauss_labels'] = gauss_labels
|
121 |
+
|
122 |
+
return all_data
|
123 |
+
|
124 |
+
def random_aug_annot(self, img, annot, det_corners=None):
|
125 |
+
# do random flipping
|
126 |
+
img, annot, det_corners = self.random_flip(img, annot, det_corners)
|
127 |
+
# return img, annot, None
|
128 |
+
|
129 |
+
# prepare random augmentation parameters (only do random rotation for now)
|
130 |
+
theta = np.random.randint(0, 360) / 360 * np.pi * 2
|
131 |
+
r = self.image_size / 256
|
132 |
+
origin = [127 * r, 127 * r]
|
133 |
+
p1_new = [127 * r + 100 * np.sin(theta) * r, 127 * r - 100 * np.cos(theta) * r]
|
134 |
+
p2_new = [127 * r + 100 * np.cos(theta) * r, 127 * r + 100 * np.sin(theta) * r]
|
135 |
+
p1_old = [127 * r, 127 * r - 100 * r] # y_axis
|
136 |
+
p2_old = [127 * r + 100 * r, 127 * r] # x_axis
|
137 |
+
pts1 = np.array([origin, p1_old, p2_old]).astype(np.float32)
|
138 |
+
pts2 = np.array([origin, p1_new, p2_new]).astype(np.float32)
|
139 |
+
M_rot = cv2.getAffineTransform(pts1, pts2)
|
140 |
+
|
141 |
+
# Combine annotation corners and detection corners
|
142 |
+
all_corners = list(annot.keys())
|
143 |
+
if det_corners is not None:
|
144 |
+
for i in range(det_corners.shape[0]):
|
145 |
+
all_corners.append(tuple(det_corners[i]))
|
146 |
+
all_corners_ = np.array(all_corners)
|
147 |
+
|
148 |
+
# Do the per-corner transform
|
149 |
+
# Done in a big matrix transformation to save processing time.
|
150 |
+
corner_mapping = dict()
|
151 |
+
ones = np.ones([all_corners_.shape[0], 1])
|
152 |
+
all_corners_ = np.concatenate([all_corners_, ones], axis=-1)
|
153 |
+
aug_corners = np.matmul(M_rot, all_corners_.T).T
|
154 |
+
|
155 |
+
for idx, corner in enumerate(all_corners):
|
156 |
+
corner_mapping[corner] = aug_corners[idx]
|
157 |
+
|
158 |
+
# If the transformed geometry goes beyond image boundary, we simply re-do the augmentation
|
159 |
+
new_corners = np.array(list(corner_mapping.values()))
|
160 |
+
if new_corners.min() <= 0 or new_corners.max() >= (self.image_size - 1):
|
161 |
+
# return self.random_aug_annot(img, annot, det_corners)
|
162 |
+
return img, annot, None
|
163 |
+
|
164 |
+
# build the new annot dict
|
165 |
+
aug_annot = dict()
|
166 |
+
for corner, connections in annot.items():
|
167 |
+
new_corner = corner_mapping[corner]
|
168 |
+
tuple_new_corner = tuple(new_corner)
|
169 |
+
aug_annot[tuple_new_corner] = list()
|
170 |
+
for to_corner in connections:
|
171 |
+
aug_annot[tuple_new_corner].append(corner_mapping[tuple(to_corner)])
|
172 |
+
|
173 |
+
# Also transform the image correspondingly
|
174 |
+
rows, cols, ch = img.shape
|
175 |
+
new_img = cv2.warpAffine(img, M_rot, (cols, rows), borderValue=(255, 255, 255))
|
176 |
+
|
177 |
+
y_start = (new_img.shape[0] - self.image_size) // 2
|
178 |
+
x_start = (new_img.shape[1] - self.image_size) // 2
|
179 |
+
aug_img = new_img[y_start:y_start + self.image_size, x_start:x_start + self.image_size, :]
|
180 |
+
|
181 |
+
return aug_img, aug_annot, None
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
|
images/test.jpg
ADDED
![]() |
infer.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from datasets.outdoor_buildings import OutdoorBuildingDataset
|
5 |
+
from datasets.s3d_floorplans import S3DFloorplanDataset
|
6 |
+
from datasets.data_utils import collate_fn, get_pixel_features
|
7 |
+
from models.resnet import ResNetBackbone
|
8 |
+
from models.corner_models import HeatCorner
|
9 |
+
from models.edge_models import HeatEdge
|
10 |
+
from models.corner_to_edge import get_infer_edge_pairs
|
11 |
+
from utils.geometry_utils import corner_eval
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
import os
|
15 |
+
import scipy.ndimage.filters as filters
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from metrics.get_metric import compute_metrics, get_recall_and_precision
|
18 |
+
import skimage
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
|
22 |
+
def visualize_cond_generation(positive_pixels, confs, image, save_path, gt_corners=None, prec=None, recall=None,
|
23 |
+
image_masks=None, edges=None, edge_confs=None):
|
24 |
+
image = image.copy() # get a new copy of the original image
|
25 |
+
if confs is not None:
|
26 |
+
viz_confs = confs
|
27 |
+
|
28 |
+
if edges is not None:
|
29 |
+
preds = positive_pixels.astype(int)
|
30 |
+
c_degrees = dict()
|
31 |
+
for edge_i, edge_pair in enumerate(edges):
|
32 |
+
conf = (edge_confs[edge_i] * 2) - 1
|
33 |
+
cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
|
34 |
+
c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
|
35 |
+
c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
|
36 |
+
|
37 |
+
for idx, c in enumerate(positive_pixels):
|
38 |
+
if edges is not None and idx not in c_degrees:
|
39 |
+
continue
|
40 |
+
if confs is None:
|
41 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
|
42 |
+
else:
|
43 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
|
44 |
+
# if edges is not None:
|
45 |
+
# cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
|
46 |
+
# 0.5, (255, 0, 0), 1, cv2.LINE_AA)
|
47 |
+
|
48 |
+
if gt_corners is not None:
|
49 |
+
for c in gt_corners:
|
50 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
|
51 |
+
|
52 |
+
if image_masks is not None:
|
53 |
+
mask_ids = np.where(image_masks == 1)[0]
|
54 |
+
for mask_id in mask_ids:
|
55 |
+
y_idx = mask_id // 64
|
56 |
+
x_idx = (mask_id - y_idx * 64)
|
57 |
+
x_coord = x_idx * 4
|
58 |
+
y_coord = y_idx * 4
|
59 |
+
cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
|
60 |
+
|
61 |
+
# if confs is not None:
|
62 |
+
# cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
|
63 |
+
# 0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
64 |
+
if prec is not None:
|
65 |
+
if isinstance(prec, tuple):
|
66 |
+
cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
|
67 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
68 |
+
0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
69 |
+
cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
|
70 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
71 |
+
0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
72 |
+
else:
|
73 |
+
cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
|
74 |
+
0.5, (255, 255, 0), 1, cv2.LINE_AA)
|
75 |
+
cv2.imwrite(save_path, image)
|
76 |
+
|
77 |
+
|
78 |
+
def corner_nms(preds, confs, image_size):
|
79 |
+
data = np.zeros([image_size, image_size])
|
80 |
+
neighborhood_size = 5
|
81 |
+
threshold = 0
|
82 |
+
|
83 |
+
for i in range(len(preds)):
|
84 |
+
data[preds[i, 1], preds[i, 0]] = confs[i]
|
85 |
+
|
86 |
+
data_max = filters.maximum_filter(data, neighborhood_size)
|
87 |
+
maxima = (data == data_max)
|
88 |
+
data_min = filters.minimum_filter(data, neighborhood_size)
|
89 |
+
diff = ((data_max - data_min) > threshold)
|
90 |
+
maxima[diff == 0] = 0
|
91 |
+
|
92 |
+
results = np.where(maxima > 0)
|
93 |
+
filtered_preds = np.stack([results[1], results[0]], axis=-1)
|
94 |
+
|
95 |
+
new_confs = list()
|
96 |
+
for i, pred in enumerate(filtered_preds):
|
97 |
+
new_confs.append(data[pred[1], pred[0]])
|
98 |
+
new_confs = np.array(new_confs)
|
99 |
+
|
100 |
+
return filtered_preds, new_confs
|
101 |
+
|
102 |
+
|
103 |
+
def main(dataset, ckpt_path, image_size, viz_base, save_base, infer_times):
|
104 |
+
ckpt = torch.load(ckpt_path)
|
105 |
+
print('Load from ckpts of epoch {}'.format(ckpt['epoch']))
|
106 |
+
ckpt_args = ckpt['args']
|
107 |
+
if dataset == 'outdoor':
|
108 |
+
data_path = './data/outdoor/cities_dataset'
|
109 |
+
det_path = './data/outdoor/det_final'
|
110 |
+
test_dataset = OutdoorBuildingDataset(data_path, det_path, phase='test', image_size=image_size, rand_aug=False,
|
111 |
+
inference=True)
|
112 |
+
elif dataset == 's3d_floorplan':
|
113 |
+
data_path = './data/s3d_floorplan'
|
114 |
+
test_dataset = S3DFloorplanDataset(data_path, phase='test', rand_aug=False, inference=True)
|
115 |
+
else:
|
116 |
+
raise ValueError('Unknown dataset type: {}'.format(dataset))
|
117 |
+
|
118 |
+
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0,
|
119 |
+
collate_fn=collate_fn)
|
120 |
+
|
121 |
+
backbone = ResNetBackbone()
|
122 |
+
strides = backbone.strides
|
123 |
+
num_channels = backbone.num_channels
|
124 |
+
backbone = nn.DataParallel(backbone)
|
125 |
+
backbone = backbone.cuda()
|
126 |
+
backbone.eval()
|
127 |
+
corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
|
128 |
+
backbone_num_channels=num_channels)
|
129 |
+
corner_model = nn.DataParallel(corner_model)
|
130 |
+
corner_model = corner_model.cuda()
|
131 |
+
corner_model.eval()
|
132 |
+
|
133 |
+
edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
|
134 |
+
backbone_num_channels=num_channels)
|
135 |
+
edge_model = nn.DataParallel(edge_model)
|
136 |
+
edge_model = edge_model.cuda()
|
137 |
+
edge_model.eval()
|
138 |
+
|
139 |
+
backbone.load_state_dict(ckpt['backbone'])
|
140 |
+
corner_model.load_state_dict(ckpt['corner_model'])
|
141 |
+
edge_model.load_state_dict(ckpt['edge_model'])
|
142 |
+
print('Loaded saved model from {}'.format(ckpt_path))
|
143 |
+
|
144 |
+
if not os.path.exists(viz_base):
|
145 |
+
os.makedirs(viz_base)
|
146 |
+
if not os.path.exists(save_base):
|
147 |
+
os.makedirs(save_base)
|
148 |
+
|
149 |
+
all_prec = list()
|
150 |
+
all_recall = list()
|
151 |
+
|
152 |
+
corner_tp = 0.0
|
153 |
+
corner_fp = 0.0
|
154 |
+
corner_length = 0.0
|
155 |
+
edge_tp = 0.0
|
156 |
+
edge_fp = 0.0
|
157 |
+
edge_length = 0.0
|
158 |
+
region_tp = 0.0
|
159 |
+
region_fp = 0.0
|
160 |
+
region_length = 0.0
|
161 |
+
|
162 |
+
# get the positional encodings for all pixels
|
163 |
+
pixels, pixel_features = get_pixel_features(image_size=image_size)
|
164 |
+
|
165 |
+
for data_i, data in enumerate(test_dataloader):
|
166 |
+
image = data['img'].cuda()
|
167 |
+
img_path = data['img_path'][0]
|
168 |
+
annot_path = data['annot_path'][0]
|
169 |
+
annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
|
170 |
+
|
171 |
+
with torch.no_grad():
|
172 |
+
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = get_results(image, annot, backbone,
|
173 |
+
corner_model,
|
174 |
+
edge_model,
|
175 |
+
pixels, pixel_features,
|
176 |
+
ckpt_args, infer_times,
|
177 |
+
corner_thresh=0.01,
|
178 |
+
image_size=image_size)
|
179 |
+
|
180 |
+
# viz_image = cv2.imread(img_path)
|
181 |
+
positive_pixels = np.array(list(annot.keys())).round()
|
182 |
+
|
183 |
+
viz_image = data['raw_img'][0].cpu().numpy().transpose(1, 2, 0)
|
184 |
+
viz_image = (viz_image * 255).astype(np.uint8)
|
185 |
+
|
186 |
+
# visualize G.T.
|
187 |
+
gt_path = os.path.join(viz_base, '{}_gt.png'.format(data_i))
|
188 |
+
visualize_cond_generation(positive_pixels, None, viz_image, gt_path, gt_corners=None, image_masks=None)
|
189 |
+
|
190 |
+
if len(pred_corners) > 0:
|
191 |
+
prec, recall = corner_eval(positive_pixels, pred_corners)
|
192 |
+
else:
|
193 |
+
prec = recall = 0
|
194 |
+
all_prec.append(prec)
|
195 |
+
all_recall.append(recall)
|
196 |
+
|
197 |
+
if pred_confs.shape[0] == 0:
|
198 |
+
pred_confs = None
|
199 |
+
|
200 |
+
if image_size != 256:
|
201 |
+
pred_corners_viz = pred_corners * (image_size / 256)
|
202 |
+
else:
|
203 |
+
pred_corners_viz = pred_corners
|
204 |
+
recon_path = os.path.join(viz_base, '{}_pred_corner.png'.format(data_i))
|
205 |
+
visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=prec,
|
206 |
+
recall=recall)
|
207 |
+
|
208 |
+
pred_corners, pred_confs, pos_edges = postprocess_preds(pred_corners, pred_confs, pos_edges)
|
209 |
+
|
210 |
+
pred_data = {
|
211 |
+
'corners': pred_corners,
|
212 |
+
'edges': pos_edges,
|
213 |
+
}
|
214 |
+
|
215 |
+
if dataset == 's3d_floorplan':
|
216 |
+
save_filename = os.path.basename(annot_path)
|
217 |
+
save_npy_path = os.path.join(save_base, save_filename)
|
218 |
+
np.save(save_npy_path, pred_data)
|
219 |
+
else:
|
220 |
+
save_results = {
|
221 |
+
'corners': pred_corners,
|
222 |
+
'edges': pos_edges,
|
223 |
+
'image_path': img_path,
|
224 |
+
}
|
225 |
+
save_path = os.path.join(save_base, '{}_results.npy'.format(data_i))
|
226 |
+
np.save(save_path, save_results)
|
227 |
+
|
228 |
+
gt_data = convert_annot(annot)
|
229 |
+
|
230 |
+
score = compute_metrics(gt_data, pred_data)
|
231 |
+
|
232 |
+
edge_recall, edge_prec = get_recall_and_precision(score['edge_tp'], score['edge_fp'], score['edge_length'])
|
233 |
+
region_recall, region_prec = get_recall_and_precision(score['region_tp'], score['region_fp'],
|
234 |
+
score['region_length'])
|
235 |
+
er_recall = (edge_recall, region_recall)
|
236 |
+
er_prec = (edge_prec, region_prec)
|
237 |
+
|
238 |
+
if image_size != 256:
|
239 |
+
pred_corners_viz = pred_corners * (image_size / 256)
|
240 |
+
else:
|
241 |
+
pred_corners_viz = pred_corners
|
242 |
+
recon_path = os.path.join(viz_base, '{}_pred_edge.png'.format(data_i))
|
243 |
+
visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, recon_path, gt_corners=None, prec=er_prec,
|
244 |
+
recall=er_recall, edges=pos_edges, edge_confs=edge_confs)
|
245 |
+
corner_tp += score['corner_tp']
|
246 |
+
corner_fp += score['corner_fp']
|
247 |
+
corner_length += score['corner_length']
|
248 |
+
edge_tp += score['edge_tp']
|
249 |
+
edge_fp += score['edge_fp']
|
250 |
+
edge_length += score['edge_length']
|
251 |
+
region_tp += score['region_tp']
|
252 |
+
region_fp += score['region_fp']
|
253 |
+
region_length += score['region_length']
|
254 |
+
|
255 |
+
print('Finish inference for sample No.{}'.format(data_i))
|
256 |
+
avg_prec = np.array(all_prec).mean()
|
257 |
+
avg_recall = np.array(all_recall).mean()
|
258 |
+
|
259 |
+
recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
|
260 |
+
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
|
261 |
+
print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
|
262 |
+
|
263 |
+
# edge
|
264 |
+
recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
|
265 |
+
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
|
266 |
+
print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
|
267 |
+
|
268 |
+
# region
|
269 |
+
recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
|
270 |
+
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
|
271 |
+
print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
|
272 |
+
|
273 |
+
print('Avg prec: {}, Avg recall: {}'.format(avg_prec, avg_recall))
|
274 |
+
|
275 |
+
|
276 |
+
def get_results(image, annot, backbone, corner_model, edge_model, pixels, pixel_features,
|
277 |
+
args, infer_times, corner_thresh=0.5, image_size=256):
|
278 |
+
image_feats, feat_mask, all_image_feats = backbone(image)
|
279 |
+
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
|
280 |
+
preds_s1 = corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
|
281 |
+
|
282 |
+
c_outputs = preds_s1
|
283 |
+
# get predicted corners
|
284 |
+
c_outputs_np = c_outputs[0].detach().cpu().numpy()
|
285 |
+
pos_indices = np.where(c_outputs_np >= corner_thresh)
|
286 |
+
pred_corners = pixels[pos_indices]
|
287 |
+
pred_confs = c_outputs_np[pos_indices]
|
288 |
+
pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
|
289 |
+
|
290 |
+
pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
|
291 |
+
|
292 |
+
corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
|
293 |
+
max_candidates = torch.stack([corner_nums.max() * args.corner_to_edge_multiplier] * len(corner_nums), dim=0)
|
294 |
+
|
295 |
+
all_pos_ids = set()
|
296 |
+
all_edge_confs = dict()
|
297 |
+
|
298 |
+
for tt in range(infer_times):
|
299 |
+
if tt == 0:
|
300 |
+
gt_values = torch.zeros_like(edge_mask).long()
|
301 |
+
gt_values[:, :] = 2
|
302 |
+
|
303 |
+
# run the edge model
|
304 |
+
s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = edge_model(image_feats, feat_mask,
|
305 |
+
pixel_features,
|
306 |
+
edge_coords, edge_mask,
|
307 |
+
gt_values, corner_nums,
|
308 |
+
max_candidates,
|
309 |
+
True)
|
310 |
+
# do_inference=True)
|
311 |
+
|
312 |
+
num_total = s1_logits.shape[2]
|
313 |
+
num_selected = selected_ids.shape[1]
|
314 |
+
num_filtered = num_total - num_selected
|
315 |
+
|
316 |
+
s1_preds = s1_logits.squeeze().softmax(0)
|
317 |
+
s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
|
318 |
+
s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
|
319 |
+
s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
|
320 |
+
s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
|
321 |
+
s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
|
322 |
+
|
323 |
+
selected_ids = selected_ids.squeeze().detach().cpu().numpy()
|
324 |
+
if tt != infer_times - 1:
|
325 |
+
s2_preds_np = s2_preds_hb_np
|
326 |
+
|
327 |
+
pos_edge_ids = np.where(s2_preds_np >= 0.9)
|
328 |
+
neg_edge_ids = np.where(s2_preds_np <= 0.01)
|
329 |
+
for pos_id in pos_edge_ids[0]:
|
330 |
+
actual_id = selected_ids[pos_id]
|
331 |
+
if gt_values[0, actual_id] != 2:
|
332 |
+
continue
|
333 |
+
all_pos_ids.add(actual_id)
|
334 |
+
all_edge_confs[actual_id] = s2_preds_np[pos_id]
|
335 |
+
gt_values[0, actual_id] = 1
|
336 |
+
for neg_id in neg_edge_ids[0]:
|
337 |
+
actual_id = selected_ids[neg_id]
|
338 |
+
if gt_values[0, actual_id] != 2:
|
339 |
+
continue
|
340 |
+
gt_values[0, actual_id] = 0
|
341 |
+
num_to_pred = (gt_values == 2).sum()
|
342 |
+
if num_to_pred <= num_filtered:
|
343 |
+
break
|
344 |
+
else:
|
345 |
+
s2_preds_np = s2_preds_hb_np
|
346 |
+
|
347 |
+
pos_edge_ids = np.where(s2_preds_np >= 0.5)
|
348 |
+
for pos_id in pos_edge_ids[0]:
|
349 |
+
actual_id = selected_ids[pos_id]
|
350 |
+
if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
|
351 |
+
continue
|
352 |
+
all_pos_ids.add(actual_id)
|
353 |
+
all_edge_confs[actual_id] = s2_preds_np[pos_id]
|
354 |
+
|
355 |
+
# print('Inference time {}'.format(tt+1))
|
356 |
+
pos_edge_ids = list(all_pos_ids)
|
357 |
+
edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
|
358 |
+
pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
|
359 |
+
edge_confs = np.array(edge_confs)
|
360 |
+
|
361 |
+
if image_size != 256:
|
362 |
+
pred_corners = pred_corners / (image_size / 256)
|
363 |
+
|
364 |
+
return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np
|
365 |
+
|
366 |
+
|
367 |
+
def postprocess_preds(corners, confs, edges):
|
368 |
+
corner_degrees = dict()
|
369 |
+
for edge_i, edge_pair in enumerate(edges):
|
370 |
+
corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
|
371 |
+
corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
|
372 |
+
good_ids = [i for i in range(len(corners)) if i in corner_degrees]
|
373 |
+
if len(good_ids) == len(corners):
|
374 |
+
return corners, confs, edges
|
375 |
+
else:
|
376 |
+
good_corners = corners[good_ids]
|
377 |
+
good_confs = confs[good_ids]
|
378 |
+
id_mapping = {value: idx for idx, value in enumerate(good_ids)}
|
379 |
+
new_edges = list()
|
380 |
+
for edge_pair in edges:
|
381 |
+
new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
|
382 |
+
new_edges.append(new_pair)
|
383 |
+
new_edges = np.array(new_edges)
|
384 |
+
return good_corners, good_confs, new_edges
|
385 |
+
|
386 |
+
|
387 |
+
def process_image(img):
|
388 |
+
mean = [0.485, 0.456, 0.406]
|
389 |
+
std = [0.229, 0.224, 0.225]
|
390 |
+
img = skimage.img_as_float(img)
|
391 |
+
img = img.transpose((2, 0, 1))
|
392 |
+
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
|
393 |
+
img = torch.Tensor(img).cuda()
|
394 |
+
img = img.unsqueeze(0)
|
395 |
+
return img
|
396 |
+
|
397 |
+
|
398 |
+
def plot_heatmap(results, filename):
|
399 |
+
# generate 2 2d grids for the x & y bounds
|
400 |
+
# import pdb; pdb.set_trace()
|
401 |
+
y, x = np.meshgrid(np.linspace(0, 255, 256), np.linspace(0, 255, 256))
|
402 |
+
|
403 |
+
z = results[::-1, :]
|
404 |
+
# x and y are bounds, so z should be the value *inside* those bounds.
|
405 |
+
# Therefore, remove the last value from the z array.
|
406 |
+
z = z[:-1, :-1]
|
407 |
+
|
408 |
+
fig, ax = plt.subplots()
|
409 |
+
|
410 |
+
c = ax.pcolormesh(y, x, z, cmap='RdBu', vmin=0, vmax=1)
|
411 |
+
# set the limits of the plot to the limits of the data
|
412 |
+
ax.axis([x.min(), x.max(), y.min(), y.max()])
|
413 |
+
fig.colorbar(c, ax=ax)
|
414 |
+
fig.savefig(filename)
|
415 |
+
plt.close()
|
416 |
+
|
417 |
+
|
418 |
+
def convert_annot(annot):
|
419 |
+
corners = np.array(list(annot.keys()))
|
420 |
+
corners_mapping = {tuple(c): idx for idx, c in enumerate(corners)}
|
421 |
+
edges = set()
|
422 |
+
for corner, connections in annot.items():
|
423 |
+
idx_c = corners_mapping[tuple(corner)]
|
424 |
+
for other_c in connections:
|
425 |
+
idx_other_c = corners_mapping[tuple(other_c)]
|
426 |
+
if (idx_c, idx_other_c) not in edges and (idx_other_c, idx_c) not in edges:
|
427 |
+
edges.add((idx_c, idx_other_c))
|
428 |
+
edges = np.array(list(edges))
|
429 |
+
gt_data = {
|
430 |
+
'corners': corners,
|
431 |
+
'edges': edges
|
432 |
+
}
|
433 |
+
return gt_data
|
434 |
+
|
435 |
+
|
436 |
+
def get_args_parser():
|
437 |
+
parser = argparse.ArgumentParser('Holistic edge attention transformer', add_help=False)
|
438 |
+
parser.add_argument('--dataset', default='outdoor',
|
439 |
+
help='the dataset for experiments, outdoor/s3d_floorplan')
|
440 |
+
parser.add_argument('--checkpoint_path', default='',
|
441 |
+
help='path to the checkpoints of the model')
|
442 |
+
parser.add_argument('--image_size', default=256, type=int)
|
443 |
+
parser.add_argument('--viz_base', default='./results/viz',
|
444 |
+
help='path to save the intermediate visualizations')
|
445 |
+
parser.add_argument('--save_base', default='./results/npy',
|
446 |
+
help='path to save the prediction results in npy files')
|
447 |
+
parser.add_argument('--infer_times', default=3, type=int)
|
448 |
+
return parser
|
449 |
+
|
450 |
+
|
451 |
+
if __name__ == '__main__':
|
452 |
+
parser = argparse.ArgumentParser('HEAT inference', parents=[get_args_parser()])
|
453 |
+
args = parser.parse_args()
|
454 |
+
main(args.dataset, args.checkpoint_path, args.image_size, args.viz_base, args.save_base,
|
455 |
+
infer_times=args.infer_times)
|
metrics/get_metric.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
import cv2
|
5 |
+
from metrics.new_utils import *
|
6 |
+
|
7 |
+
|
8 |
+
class Metric():
|
9 |
+
def calc(self, gt_data, conv_data, thresh=8.0, iou_thresh=0.7):
|
10 |
+
### compute corners precision/recall
|
11 |
+
gts = gt_data['corners']
|
12 |
+
dets = conv_data['corners']
|
13 |
+
|
14 |
+
per_sample_corner_tp = 0.0
|
15 |
+
per_sample_corner_fp = 0.0
|
16 |
+
per_sample_corner_length = gts.shape[0]
|
17 |
+
found = [False] * gts.shape[0]
|
18 |
+
c_det_annot = {}
|
19 |
+
|
20 |
+
|
21 |
+
# for each corner detection
|
22 |
+
for i, det in enumerate(dets):
|
23 |
+
# get closest gt
|
24 |
+
near_gt = [0, 999999.0, (0.0, 0.0)]
|
25 |
+
for k, gt in enumerate(gts):
|
26 |
+
dist = np.linalg.norm(gt - det)
|
27 |
+
if dist < near_gt[1]:
|
28 |
+
near_gt = [k, dist, gt]
|
29 |
+
if near_gt[1] <= thresh and not found[near_gt[0]]:
|
30 |
+
per_sample_corner_tp += 1.0
|
31 |
+
found[near_gt[0]] = True
|
32 |
+
c_det_annot[i] = near_gt[0]
|
33 |
+
else:
|
34 |
+
per_sample_corner_fp += 1.0
|
35 |
+
|
36 |
+
per_corner_score = {
|
37 |
+
'recall': per_sample_corner_tp / gts.shape[0],
|
38 |
+
'precision': per_sample_corner_tp / (per_sample_corner_tp + per_sample_corner_fp + 1e-8)
|
39 |
+
}
|
40 |
+
|
41 |
+
### compute edges precision/recall
|
42 |
+
per_sample_edge_tp = 0.0
|
43 |
+
per_sample_edge_fp = 0.0
|
44 |
+
edge_corner_annots = gt_data['edges']
|
45 |
+
per_sample_edge_length = edge_corner_annots.shape[0]
|
46 |
+
|
47 |
+
false_edge_ids = []
|
48 |
+
match_gt_ids = set()
|
49 |
+
|
50 |
+
for l, e_det in enumerate(conv_data['edges']):
|
51 |
+
c1, c2 = e_det
|
52 |
+
|
53 |
+
# check if corners are mapped
|
54 |
+
if (c1 not in c_det_annot.keys()) or (c2 not in c_det_annot.keys()):
|
55 |
+
per_sample_edge_fp += 1.0
|
56 |
+
false_edge_ids.append(l)
|
57 |
+
continue
|
58 |
+
# check hit
|
59 |
+
c1_prime = c_det_annot[c1]
|
60 |
+
c2_prime = c_det_annot[c2]
|
61 |
+
is_hit = False
|
62 |
+
|
63 |
+
for k, e_annot in enumerate(edge_corner_annots):
|
64 |
+
c3, c4 = e_annot
|
65 |
+
if ((c1_prime == c3) and (c2_prime == c4)) or ((c1_prime == c4) and (c2_prime == c3)):
|
66 |
+
is_hit = True
|
67 |
+
match_gt_ids.add(k)
|
68 |
+
break
|
69 |
+
|
70 |
+
# hit
|
71 |
+
if is_hit:
|
72 |
+
per_sample_edge_tp += 1.0
|
73 |
+
else:
|
74 |
+
per_sample_edge_fp += 1.0
|
75 |
+
false_edge_ids.append(l)
|
76 |
+
|
77 |
+
per_edge_score = {
|
78 |
+
'recall': per_sample_edge_tp / edge_corner_annots.shape[0],
|
79 |
+
'precision': per_sample_edge_tp / (per_sample_edge_tp + per_sample_edge_fp + 1e-8)
|
80 |
+
}
|
81 |
+
|
82 |
+
# computer regions precision/recall
|
83 |
+
conv_mask = render(corners=conv_data['corners'], edges=conv_data['edges'], render_pad=0, edge_linewidth=1)[0]
|
84 |
+
conv_mask = 1 - conv_mask
|
85 |
+
conv_mask = conv_mask.astype(np.uint8)
|
86 |
+
labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
|
87 |
+
|
88 |
+
#cv2.imwrite('mask-pred.png', region_mask.astype(np.uint8) * 20)
|
89 |
+
|
90 |
+
background_label = region_mask[0, 0]
|
91 |
+
all_conv_masks = []
|
92 |
+
for region_i in range(1, labels):
|
93 |
+
if region_i == background_label:
|
94 |
+
continue
|
95 |
+
the_region = region_mask == region_i
|
96 |
+
if the_region.sum() < 20:
|
97 |
+
continue
|
98 |
+
all_conv_masks.append(the_region)
|
99 |
+
|
100 |
+
gt_mask = render(corners=gt_data['corners'], edges=gt_data['edges'], render_pad=0, edge_linewidth=1)[0]
|
101 |
+
gt_mask = 1 - gt_mask
|
102 |
+
gt_mask = gt_mask.astype(np.uint8)
|
103 |
+
labels, region_mask = cv2.connectedComponents(gt_mask, connectivity=4)
|
104 |
+
|
105 |
+
#cv2.imwrite('mask-gt.png', region_mask.astype(np.uint8) * 20)
|
106 |
+
|
107 |
+
background_label = region_mask[0, 0]
|
108 |
+
all_gt_masks = []
|
109 |
+
for region_i in range(1, labels):
|
110 |
+
if region_i == background_label:
|
111 |
+
continue
|
112 |
+
the_region = region_mask == region_i
|
113 |
+
if the_region.sum() < 20:
|
114 |
+
continue
|
115 |
+
all_gt_masks.append(the_region)
|
116 |
+
|
117 |
+
per_sample_region_tp = 0.0
|
118 |
+
per_sample_region_fp = 0.0
|
119 |
+
per_sample_region_length = len(all_gt_masks)
|
120 |
+
found = [False] * len(all_gt_masks)
|
121 |
+
for i, r_det in enumerate(all_conv_masks):
|
122 |
+
# gt closest gt
|
123 |
+
near_gt = [0, 0, None]
|
124 |
+
for k, r_gt in enumerate(all_gt_masks):
|
125 |
+
iou = np.logical_and(r_gt, r_det).sum() / float(np.logical_or(r_gt, r_det).sum())
|
126 |
+
if iou > near_gt[1]:
|
127 |
+
near_gt = [k, iou, r_gt]
|
128 |
+
if near_gt[1] >= iou_thresh and not found[near_gt[0]]:
|
129 |
+
per_sample_region_tp += 1.0
|
130 |
+
found[near_gt[0]] = True
|
131 |
+
else:
|
132 |
+
per_sample_region_fp += 1.0
|
133 |
+
|
134 |
+
per_region_score = {
|
135 |
+
'recall': per_sample_region_tp / len(all_gt_masks),
|
136 |
+
'precision': per_sample_region_tp / (per_sample_region_tp + per_sample_region_fp + 1e-8)
|
137 |
+
}
|
138 |
+
|
139 |
+
return {
|
140 |
+
'corner_tp': per_sample_corner_tp,
|
141 |
+
'corner_fp': per_sample_corner_fp,
|
142 |
+
'corner_length': per_sample_corner_length,
|
143 |
+
'edge_tp': per_sample_edge_tp,
|
144 |
+
'edge_fp': per_sample_edge_fp,
|
145 |
+
'edge_length': per_sample_edge_length,
|
146 |
+
'region_tp': per_sample_region_tp,
|
147 |
+
'region_fp': per_sample_region_fp,
|
148 |
+
'region_length': per_sample_region_length,
|
149 |
+
'corner': per_corner_score,
|
150 |
+
'edge': per_edge_score,
|
151 |
+
'region': per_region_score
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
def compute_metrics(gt_data, pred_data):
|
156 |
+
metric = Metric()
|
157 |
+
score = metric.calc(gt_data, pred_data)
|
158 |
+
return score
|
159 |
+
|
160 |
+
|
161 |
+
def get_recall_and_precision(tp, fp, length):
|
162 |
+
recall = tp / (length + 1e-8)
|
163 |
+
precision = tp / (tp + fp + 1e-8)
|
164 |
+
return recall, precision
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == '__main__':
|
168 |
+
base_path = './'
|
169 |
+
gt_datapath = '../data/cities_dataset/annot'
|
170 |
+
metric = Metric()
|
171 |
+
corner_tp = 0.0
|
172 |
+
corner_fp = 0.0
|
173 |
+
corner_length = 0.0
|
174 |
+
edge_tp = 0.0
|
175 |
+
edge_fp = 0.0
|
176 |
+
edge_length = 0.0
|
177 |
+
region_tp = 0.0
|
178 |
+
region_fp = 0.0
|
179 |
+
region_length = 0.0
|
180 |
+
for file_name in os.listdir(base_path):
|
181 |
+
if len(file_name) < 10:
|
182 |
+
continue
|
183 |
+
f = open(os.path.join(base_path, file_name), 'rb')
|
184 |
+
gt_data = np.load(os.path.join(gt_datapath, file_name + '.npy'), allow_pickle=True).tolist()
|
185 |
+
candidate = pickle.load(f)
|
186 |
+
conv_corners = candidate.graph.getCornersArray()
|
187 |
+
conv_edges = candidate.graph.getEdgesArray()
|
188 |
+
conv_data = {'corners': conv_corners, 'edges': conv_edges}
|
189 |
+
score = metric.calc(gt_data, conv_data)
|
190 |
+
corner_tp += score['corner_tp']
|
191 |
+
corner_fp += score['corner_fp']
|
192 |
+
corner_length += score['corner_length']
|
193 |
+
edge_tp += score['edge_tp']
|
194 |
+
edge_fp += score['edge_fp']
|
195 |
+
edge_length += score['edge_length']
|
196 |
+
region_tp += score['region_tp']
|
197 |
+
region_fp += score['region_fp']
|
198 |
+
region_length += score['region_length']
|
199 |
+
|
200 |
+
f = open(os.path.join(base_path, 'score.txt'), 'w')
|
201 |
+
# corner
|
202 |
+
recall, precision = get_recall_and_precision(corner_tp, corner_fp, corner_length)
|
203 |
+
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
|
204 |
+
print('corners - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
|
205 |
+
f.write('corners - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
|
206 |
+
|
207 |
+
# edge
|
208 |
+
recall, precision = get_recall_and_precision(edge_tp, edge_fp, edge_length)
|
209 |
+
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
|
210 |
+
print('edges - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
|
211 |
+
f.write('edges - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
|
212 |
+
|
213 |
+
# region
|
214 |
+
recall, precision = get_recall_and_precision(region_tp, region_fp, region_length)
|
215 |
+
f_score = 2.0 * precision * recall / (recall + precision + 1e-8)
|
216 |
+
print('regions - precision: %.3f recall: %.3f f_score: %.3f' % (precision, recall, f_score))
|
217 |
+
f.write('regions - precision: %.3f recall: %.3f f_score: %.3f\n' % (precision, recall, f_score))
|
218 |
+
|
219 |
+
f.close()
|
metrics/new_utils.py
ADDED
@@ -0,0 +1,2100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import cv2
|
4 |
+
import threading
|
5 |
+
import os
|
6 |
+
import skimage
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
|
10 |
+
TWO_CORNER_MINIMUM_DISTANCE = 5
|
11 |
+
SAFE_NUM = 3
|
12 |
+
score_weights = (1., 2., 100.)
|
13 |
+
|
14 |
+
|
15 |
+
#########################################################################################
|
16 |
+
################################# General Functions #####################################
|
17 |
+
#########################################################################################
|
18 |
+
def swap_two_corner_place(corners, edges, id1, id2):
|
19 |
+
for edge_i in range(edges.shape[0]):
|
20 |
+
if edges[edge_i, 0] == id1:
|
21 |
+
edges[edge_i, 0] = id2
|
22 |
+
elif edges[edge_i, 0] == id2:
|
23 |
+
edges[edge_i, 0] = id1
|
24 |
+
if edges[edge_i, 1] == id1:
|
25 |
+
edges[edge_i, 1] = id2
|
26 |
+
elif edges[edge_i, 1] == id2:
|
27 |
+
edges[edge_i, 1] = id1
|
28 |
+
temp = corners[id1].copy()
|
29 |
+
corners[id1] = corners[id2]
|
30 |
+
corners[id2] = temp
|
31 |
+
return corners, edges
|
32 |
+
|
33 |
+
|
34 |
+
def get_neighbor_corner_id(corner_id, edges):
|
35 |
+
where = np.where(edges == corner_id)
|
36 |
+
return edges[where[0], 1 - where[1]]
|
37 |
+
|
38 |
+
|
39 |
+
def swap_two_edge_place(edges, id1, id2):
|
40 |
+
temp = edges[id1].copy()
|
41 |
+
edges[id1] = edges[id2]
|
42 |
+
edges[id2] = temp
|
43 |
+
return edges
|
44 |
+
|
45 |
+
|
46 |
+
def degree_of_three_corners(cornerA, cornerB, cornerM):
|
47 |
+
# cornerM is middle corner
|
48 |
+
AM_length = l2_distance(cornerA, cornerM)
|
49 |
+
BM_length = l2_distance(cornerB, cornerM)
|
50 |
+
dot = np.dot((cornerA[0] - cornerM[0], cornerA[1] - cornerM[1]),
|
51 |
+
(cornerB[0] - cornerM[0], cornerB[1] - cornerM[1]))
|
52 |
+
cos = dot / (AM_length + 1e-8) / (BM_length + 1e-8)
|
53 |
+
cos = min(1, max(-1, cos))
|
54 |
+
degree = np.arccos(cos)
|
55 |
+
return degree / np.pi * 180
|
56 |
+
|
57 |
+
|
58 |
+
def sort_graph(corners, edges):
|
59 |
+
corners = corners.copy()
|
60 |
+
edges = edges.copy()
|
61 |
+
for corner_i in range(corners.shape[0]):
|
62 |
+
min_id = -1
|
63 |
+
min_pos = corners[corner_i]
|
64 |
+
for corner_j in range(corner_i + 1, corners.shape[0]):
|
65 |
+
if (corners[corner_j, 0] < min_pos[0]) or \
|
66 |
+
(corners[corner_j, 0] == min_pos[0] and corners[corner_j, 1] < min_pos[1]):
|
67 |
+
min_pos = corners[corner_j]
|
68 |
+
min_id = corner_j
|
69 |
+
if min_id != -1:
|
70 |
+
corners, edges = swap_two_corner_place(corners, edges, corner_i, min_id)
|
71 |
+
|
72 |
+
for edge_i in range(edges.shape[0]):
|
73 |
+
if edges[edge_i, 0] > edges[edge_i, 1]:
|
74 |
+
temp = edges[edge_i, 0]
|
75 |
+
edges[edge_i, 0] = edges[edge_i, 1]
|
76 |
+
edges[edge_i, 1] = temp
|
77 |
+
|
78 |
+
for edge_i in range(edges.shape[0]):
|
79 |
+
min_id = -1
|
80 |
+
min_pos = edges[edge_i]
|
81 |
+
for edge_j in range(edge_i + 1, edges.shape[0]):
|
82 |
+
if (edges[edge_j, 0] < min_pos[0]) or \
|
83 |
+
(edges[edge_j, 0] == min_pos[0] and edges[edge_j, 1] < min_pos[1]):
|
84 |
+
min_pos = edges[edge_j]
|
85 |
+
min_id = edge_j
|
86 |
+
if min_id != -1:
|
87 |
+
edges = swap_two_edge_place(edges, edge_i, min_id)
|
88 |
+
|
89 |
+
return corners, edges
|
90 |
+
|
91 |
+
|
92 |
+
def IOU(maskA, maskB):
|
93 |
+
return np.logical_and(maskA, maskB).sum() / np.logical_or(maskA, maskB).sum()
|
94 |
+
|
95 |
+
|
96 |
+
def render(corners, edges, render_pad=0, edge_linewidth=2, corner_size=3, scale=1.):
|
97 |
+
size = int(256 * scale)
|
98 |
+
mask = np.ones((2, size, size)) * render_pad
|
99 |
+
|
100 |
+
corners = np.round(corners.copy() * scale).astype(np.int)
|
101 |
+
for edge_i in range(edges.shape[0]):
|
102 |
+
a = edges[edge_i, 0]
|
103 |
+
b = edges[edge_i, 1]
|
104 |
+
mask[0] = cv2.line(mask[0], (int(corners[a, 1]), int(corners[a, 0])),
|
105 |
+
(int(corners[b, 1]), int(corners[b, 0])), 1.0, thickness=edge_linewidth)
|
106 |
+
for corner_i in range(corners.shape[0]):
|
107 |
+
mask[1] = cv2.circle(mask[1], (int(corners[corner_i, 1]), int(corners[corner_i, 0])), corner_size, 1.0, -1)
|
108 |
+
|
109 |
+
return mask
|
110 |
+
|
111 |
+
|
112 |
+
def patch_samples(edge_num, batch_size):
|
113 |
+
num = edge_num // batch_size
|
114 |
+
patchs = []
|
115 |
+
for i in range(num):
|
116 |
+
patchs.append([i * batch_size + j for j in range(batch_size)])
|
117 |
+
|
118 |
+
if edge_num % batch_size != 0:
|
119 |
+
patchs.append([j for j in range(batch_size * num, edge_num)])
|
120 |
+
|
121 |
+
return patchs
|
122 |
+
|
123 |
+
|
124 |
+
def l2_distance(x1, x2):
|
125 |
+
return np.sqrt((x1[0] - x2[0]) ** 2 + (x1[1] - x2[1]) ** 2)
|
126 |
+
|
127 |
+
|
128 |
+
def triangle_region(A, B, C):
|
129 |
+
l1 = np.linalg.norm(np.array(A) - np.array(B))
|
130 |
+
l2 = np.linalg.norm(np.array(A) - np.array(C))
|
131 |
+
l3 = np.linalg.norm(np.array(B) - np.array(C))
|
132 |
+
p = (l1 + l2 + l3) / 2
|
133 |
+
area = np.sqrt(np.abs(p * (p - l1) * (p - l2) * (p - l3)))
|
134 |
+
return area
|
135 |
+
|
136 |
+
|
137 |
+
def remove_intersection_and_duplicate(corners, edges, name):
|
138 |
+
over_all_flag = False
|
139 |
+
ori_corners = corners.copy()
|
140 |
+
ori_edges = edges.copy()
|
141 |
+
while True:
|
142 |
+
flag = False
|
143 |
+
for edge_i in range(edges.shape[0]):
|
144 |
+
for edge_j in range(edge_i + 1, edges.shape[0]):
|
145 |
+
corner11 = corners[edges[edge_i, 0]]
|
146 |
+
corner12 = corners[edges[edge_i, 1]]
|
147 |
+
corner21 = corners[edges[edge_j, 0]]
|
148 |
+
corner22 = corners[edges[edge_j, 1]]
|
149 |
+
|
150 |
+
y1 = corner11[0]
|
151 |
+
x1 = corner11[1]
|
152 |
+
y2 = corner12[0]
|
153 |
+
x2 = corner12[1]
|
154 |
+
a1 = y1 - y2
|
155 |
+
b1 = x2 - x1
|
156 |
+
c1 = x1 * y2 - x2 * y1
|
157 |
+
flag1 = (a1 * corner21[1] + b1 * corner21[0] + c1) * (a1 * corner22[1] + b1 * corner22[0] + c1)
|
158 |
+
|
159 |
+
y1 = corner21[0]
|
160 |
+
x1 = corner21[1]
|
161 |
+
y2 = corner22[0]
|
162 |
+
x2 = corner22[1]
|
163 |
+
a2 = y1 - y2
|
164 |
+
b2 = x2 - x1
|
165 |
+
c2 = x1 * y2 - x2 * y1
|
166 |
+
flag2 = (a2 * corner11[1] + b2 * corner11[0] + c2) * (a2 * corner12[1] + b2 * corner12[0] + c2)
|
167 |
+
|
168 |
+
if flag1 < -1e-5 and flag2 < -1e-5:
|
169 |
+
# intersection!
|
170 |
+
over_all_flag = True
|
171 |
+
flag = True
|
172 |
+
|
173 |
+
new_x = (c2 * b1 - c1 * b2) / (a1 * b2 - a2 * b1)
|
174 |
+
new_y = (a2 * c1 - a1 * c2) / (a1 * b2 - a2 * b1)
|
175 |
+
|
176 |
+
temp_d = 3
|
177 |
+
temp_id = -1
|
178 |
+
if l2_distance((new_y, new_x), corner11) < temp_d:
|
179 |
+
temp_id = edges[edge_i, 0]
|
180 |
+
temp_d = l2_distance((new_y, new_x), corner11)
|
181 |
+
if l2_distance((new_y, new_x), corner12) < temp_d:
|
182 |
+
temp_id = edges[edge_i, 1]
|
183 |
+
temp_d = l2_distance((new_y, new_x), corner12)
|
184 |
+
if l2_distance((new_y, new_x), corner21) < temp_d:
|
185 |
+
temp_id = edges[edge_j, 0]
|
186 |
+
temp_d = l2_distance((new_y, new_x), corner21)
|
187 |
+
if l2_distance((new_y, new_x), corner22) < temp_d:
|
188 |
+
temp_id = edges[edge_j, 1]
|
189 |
+
temp_d = l2_distance((new_y, new_x), corner22)
|
190 |
+
if temp_id != -1:
|
191 |
+
if edges[edge_i, 0] != temp_id and edges[edge_i, 1] != temp_id:
|
192 |
+
tt = edges[edge_i, 0]
|
193 |
+
edges[edge_i, 0] = temp_id
|
194 |
+
edges = np.append(edges, np.array([(temp_id, tt)]), 0)
|
195 |
+
if edges[edge_j, 0] != temp_id and edges[edge_j, 1] != temp_id:
|
196 |
+
tt = edges[edge_j, 0]
|
197 |
+
edges[edge_j, 0] = temp_id
|
198 |
+
edges = np.append(edges, np.array([(temp_id, tt)]), 0)
|
199 |
+
else:
|
200 |
+
corners = np.append(corners, np.array([(new_y, new_x)]), 0)
|
201 |
+
edge_id1 = edges[edge_i, 1]
|
202 |
+
edge_id2 = edges[edge_j, 1]
|
203 |
+
edges[edge_i, 1] = corners.shape[0] - 1
|
204 |
+
edges[edge_j, 1] = corners.shape[0] - 1
|
205 |
+
edges = np.append(edges, np.array([(edge_id1, corners.shape[0] - 1)]), 0)
|
206 |
+
edges = np.append(edges, np.array([(edge_id2, corners.shape[0] - 1)]), 0)
|
207 |
+
break
|
208 |
+
if flag:
|
209 |
+
break
|
210 |
+
if flag:
|
211 |
+
continue
|
212 |
+
break
|
213 |
+
|
214 |
+
# remove duplicate and zero degree
|
215 |
+
graph = Graph(np.round(corners), edges)
|
216 |
+
for corner_i in reversed(range(len(graph.getCorners()))):
|
217 |
+
corner_ele1 = graph.getCorners()[corner_i]
|
218 |
+
for corner_j in reversed(range(corner_i)):
|
219 |
+
corner_ele2 = graph.getCorners()[corner_j]
|
220 |
+
if l2_distance(corner_ele1.x, corner_ele2.x) < 3:
|
221 |
+
connected_edge = graph.getEdgeConnected(corner_ele1)
|
222 |
+
for edge_ele in connected_edge:
|
223 |
+
if edge_ele.x[0] == corner_ele1:
|
224 |
+
another = edge_ele.x[1]
|
225 |
+
else:
|
226 |
+
another = edge_ele.x[0]
|
227 |
+
if another == corner_ele2:
|
228 |
+
graph.remove(edge_ele)
|
229 |
+
edge_ele.x = (another, corner_ele2)
|
230 |
+
graph.remove(corner_ele1)
|
231 |
+
for corner_ele in graph.getCorners():
|
232 |
+
if graph.getCornerDegree(corner_ele) == 0:
|
233 |
+
graph.remove(corner_ele)
|
234 |
+
|
235 |
+
corners = graph.getCornersArray()
|
236 |
+
edges = graph.getEdgesArray()
|
237 |
+
# if over_all_flag:
|
238 |
+
# plt.subplot(121)
|
239 |
+
# ori = render(ori_corners, ori_edges, edge_linewidth=1, corner_size=1)
|
240 |
+
# temp = np.concatenate((ori.transpose((1,2,0)), np.zeros((ori.shape[1],ori.shape[2],1))),2)
|
241 |
+
# plt.imshow(temp)
|
242 |
+
# plt.subplot(122)
|
243 |
+
# new_ = render(corners, edges, edge_linewidth=1, corner_size=1)
|
244 |
+
# temp = np.concatenate((new_.transpose((1,2,0)), np.zeros((new_.shape[1],new_.shape[2],1))),2)
|
245 |
+
# plt.imshow(temp)
|
246 |
+
# plt.show()
|
247 |
+
|
248 |
+
return corners, edges
|
249 |
+
|
250 |
+
|
251 |
+
def get_two_edge_intersection_location(corner11, corner12, corner21, corner22):
|
252 |
+
y1 = corner11[0]
|
253 |
+
x1 = corner11[1]
|
254 |
+
y2 = corner12[0]
|
255 |
+
x2 = corner12[1]
|
256 |
+
a1 = y1 - y2
|
257 |
+
b1 = x2 - x1
|
258 |
+
c1 = x1 * y2 - x2 * y1
|
259 |
+
|
260 |
+
y1 = corner21[0]
|
261 |
+
x1 = corner21[1]
|
262 |
+
y2 = corner22[0]
|
263 |
+
x2 = corner22[1]
|
264 |
+
a2 = y1 - y2
|
265 |
+
b2 = x2 - x1
|
266 |
+
c2 = x1 * y2 - x2 * y1
|
267 |
+
|
268 |
+
l = a1 * b2 - a2 * b1
|
269 |
+
if l == 0:
|
270 |
+
l = 1e-5
|
271 |
+
|
272 |
+
new_x = (c2 * b1 - c1 * b2) / l
|
273 |
+
new_y = (a2 * c1 - a1 * c2) / l
|
274 |
+
|
275 |
+
return round(new_y), round(new_x)
|
276 |
+
|
277 |
+
|
278 |
+
def get_distance_of_corner_and_edge(corner1, corner2, corner):
|
279 |
+
x = corner[0]
|
280 |
+
y = corner[1]
|
281 |
+
x1 = corner1[0]
|
282 |
+
y1 = corner1[1]
|
283 |
+
x2 = corner2[0]
|
284 |
+
y2 = corner2[1]
|
285 |
+
|
286 |
+
cross = (x2 - x1) * (x - x1) + (y2 - y1) * (y - y1)
|
287 |
+
if cross <= 0:
|
288 |
+
# dist to corner1
|
289 |
+
return np.linalg.norm((x - x1, y - y1))
|
290 |
+
|
291 |
+
d2 = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
|
292 |
+
if cross >= d2:
|
293 |
+
# dist to corner2
|
294 |
+
return np.linalg.norm((x - x2, y - y2))
|
295 |
+
|
296 |
+
r = cross / d2
|
297 |
+
px = x1 + (x2 - x1) * r
|
298 |
+
py = y1 + (y2 - y1) * r
|
299 |
+
return np.linalg.norm((x - px, y - py))
|
300 |
+
|
301 |
+
|
302 |
+
#########################################################################################
|
303 |
+
################################# Dataset Functions #####################################
|
304 |
+
#########################################################################################
|
305 |
+
def EuclideanDistance(A, B):
|
306 |
+
BT = B.transpose()
|
307 |
+
vecProd = np.dot(A, BT)
|
308 |
+
|
309 |
+
SqA = A ** 2
|
310 |
+
sumSqA = np.matrix(np.sum(SqA, axis=1))
|
311 |
+
sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
|
312 |
+
|
313 |
+
SqB = B ** 2
|
314 |
+
sumSqB = np.sum(SqB, axis=1)
|
315 |
+
sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
|
316 |
+
SqED = sumSqBEx + sumSqAEx - 2 * vecProd
|
317 |
+
SqED[SqED < 0] = 0.0
|
318 |
+
ED = np.sqrt(SqED)
|
319 |
+
return ED
|
320 |
+
|
321 |
+
|
322 |
+
def samedirection(conv_corner_id, gt_corner_id, conv_corners, gt_corners, conv_edges, gt_edges):
|
323 |
+
# degree
|
324 |
+
if np.where(conv_edges == conv_corner_id)[0].shape[0] != np.where(gt_edges == gt_corner_id)[0].shape[0]:
|
325 |
+
return False
|
326 |
+
|
327 |
+
# direction
|
328 |
+
place = np.where(conv_edges == conv_corner_id)
|
329 |
+
neighbor_id = conv_edges[place[0], 1 - place[1]]
|
330 |
+
|
331 |
+
distance = conv_corners[conv_corner_id] - conv_corners[neighbor_id]
|
332 |
+
direction = np.arctan2(distance[:, 0], distance[:, 1]) * 180 / np.pi / 15
|
333 |
+
direction = (direction + 24) % 24
|
334 |
+
|
335 |
+
conv_dir = np.sort(direction)
|
336 |
+
|
337 |
+
place = np.where(gt_edges == gt_corner_id)
|
338 |
+
neighbor_id = gt_edges[place[0], 1 - place[1]]
|
339 |
+
|
340 |
+
distance = gt_corners[gt_corner_id] - gt_corners[neighbor_id]
|
341 |
+
direction = np.arctan2(distance[:, 0], distance[:, 1]) * 180 / np.pi / 15
|
342 |
+
direction = (direction + 24) % 24
|
343 |
+
|
344 |
+
gt_dir = np.sort(direction)
|
345 |
+
|
346 |
+
conv_dir = list(conv_dir)
|
347 |
+
gt_dir = list(gt_dir)
|
348 |
+
for angle in gt_dir:
|
349 |
+
temp = sorted(conv_dir, key=lambda x: min(np.abs(x - angle), 24 - np.abs(x - angle)))
|
350 |
+
if min(np.abs(temp[0] - angle), 24 - np.abs(temp[0] - angle)) <= 1.3:
|
351 |
+
conv_dir.remove(temp[0])
|
352 |
+
else:
|
353 |
+
return False
|
354 |
+
return True
|
355 |
+
|
356 |
+
|
357 |
+
def simplify_gt(gt_match_location, gt_corner, gt_edge):
|
358 |
+
graph = Graph(np.round(gt_corner), gt_edge)
|
359 |
+
for idx, corner in enumerate(graph.getCorners()):
|
360 |
+
# use score to store the matching info
|
361 |
+
corner.store_score(gt_match_location[idx])
|
362 |
+
|
363 |
+
for idx, corner in enumerate(graph.getCorners()):
|
364 |
+
if corner.get_score() is None:
|
365 |
+
connected_edges = graph.getEdgeConnected(corner)
|
366 |
+
neighbor_corners = []
|
367 |
+
for edge in connected_edges:
|
368 |
+
if edge.x[0] != corner:
|
369 |
+
neighbor_corners.append(edge.x[0])
|
370 |
+
continue
|
371 |
+
if edge.x[1] != corner:
|
372 |
+
neighbor_corners.append(edge.x[1])
|
373 |
+
continue
|
374 |
+
raise BaseException()
|
375 |
+
neighbor_corners = sorted(neighbor_corners, key=lambda ele: l2_distance(ele.x, corner.x))
|
376 |
+
for neighbor_ele in neighbor_corners:
|
377 |
+
if l2_distance(neighbor_ele.x, corner.x) > 8:
|
378 |
+
break
|
379 |
+
if neighbor_ele.get_score() is None:
|
380 |
+
continue
|
381 |
+
# find the suitable neighbor that replace corner
|
382 |
+
for ele in neighbor_corners:
|
383 |
+
if ele == neighbor_ele:
|
384 |
+
continue
|
385 |
+
graph.add_edge(ele, neighbor_ele)
|
386 |
+
neighbor_ele.x = (0.7 * neighbor_ele.x[0] + 0.3 * corner.x[0],
|
387 |
+
0.7 * neighbor_ele.x[1] + 0.3 * corner.x[1])
|
388 |
+
graph.remove(corner)
|
389 |
+
break
|
390 |
+
return graph.getCornersArray(), graph.getEdgesArray()
|
391 |
+
|
392 |
+
|
393 |
+
def get_wrong_corners(corners, gt_corners, edges, gt_edges):
|
394 |
+
corners = corners.copy()
|
395 |
+
gt_corners = gt_corners.copy()
|
396 |
+
edges = edges.copy()
|
397 |
+
gt_edges = gt_edges.copy()
|
398 |
+
dist_matrix = EuclideanDistance(gt_corners, corners)
|
399 |
+
assigned_id = set()
|
400 |
+
gt_match_same_degree = []
|
401 |
+
gt_match_location = []
|
402 |
+
for gt_i in range(gt_corners.shape[0]):
|
403 |
+
sort_id = np.argsort(dist_matrix[gt_i]).__array__()[0]
|
404 |
+
flag = True
|
405 |
+
for id_ in sort_id:
|
406 |
+
if dist_matrix[gt_i, id_] > 7:
|
407 |
+
break
|
408 |
+
temete = samedirection(id_, gt_i, corners, gt_corners, edges, gt_edges)
|
409 |
+
if temete == False:
|
410 |
+
break
|
411 |
+
elif id_ not in assigned_id:
|
412 |
+
assigned_id.add(id_)
|
413 |
+
gt_match_same_degree.append(id_)
|
414 |
+
flag = False
|
415 |
+
break
|
416 |
+
if flag:
|
417 |
+
gt_match_same_degree.append(None)
|
418 |
+
|
419 |
+
matched = []
|
420 |
+
gt_match_location = [None for _ in range(gt_corners.shape[0])]
|
421 |
+
for gt_i in sorted(list(range(gt_corners.shape[0])), key=lambda i: np.min(dist_matrix[i])):
|
422 |
+
sort_id = np.argsort(dist_matrix[gt_i]).__array__()[0]
|
423 |
+
if dist_matrix[gt_i, sort_id[0]] > 7:
|
424 |
+
gt_match_location[gt_i] = None
|
425 |
+
else:
|
426 |
+
for c_i in sort_id:
|
427 |
+
if c_i in matched:
|
428 |
+
continue
|
429 |
+
if dist_matrix[gt_i, c_i] > 7:
|
430 |
+
gt_match_location[gt_i] = None
|
431 |
+
break
|
432 |
+
else:
|
433 |
+
gt_match_location[gt_i] = c_i
|
434 |
+
matched.append(c_i)
|
435 |
+
break
|
436 |
+
|
437 |
+
return set(range(corners.shape[0])) - assigned_id, gt_match_same_degree, gt_match_location
|
438 |
+
|
439 |
+
|
440 |
+
def get_wrong_edges(corners, gt_corners, edges, gt_edges, gt_match):
|
441 |
+
edges = edges.copy()
|
442 |
+
gt_edges = gt_edges.copy()
|
443 |
+
|
444 |
+
all_possible_good_edges = []
|
445 |
+
for edge_i in range(gt_edges.shape[0]):
|
446 |
+
if gt_match[gt_edges[edge_i, 0]] is None or gt_match[gt_edges[edge_i, 1]] is None:
|
447 |
+
continue
|
448 |
+
all_possible_good_edges.append((gt_match[gt_edges[edge_i, 0]], gt_match[gt_edges[edge_i, 1]]))
|
449 |
+
false_edge_id = []
|
450 |
+
for edge_i in range(edges.shape[0]):
|
451 |
+
id1 = edges[edge_i][0]
|
452 |
+
id2 = edges[edge_i][1]
|
453 |
+
if (id1, id2) not in all_possible_good_edges and (id2, id1) not in all_possible_good_edges:
|
454 |
+
false_edge_id.append(edge_i)
|
455 |
+
continue
|
456 |
+
|
457 |
+
return false_edge_id
|
458 |
+
|
459 |
+
|
460 |
+
def get_corner_bin_map(corners, corner_list_for_each_bin, bin_size=10):
|
461 |
+
bin_map = np.zeros((bin_size, 256, 256))
|
462 |
+
for bin_i in range(bin_size):
|
463 |
+
bin_map[bin_i] = render(corners[corner_list_for_each_bin[bin_i]], np.array([]), render_pad=0)[1]
|
464 |
+
return bin_map
|
465 |
+
|
466 |
+
|
467 |
+
#########################################################################################
|
468 |
+
################################ Searching Functions ####################################
|
469 |
+
#########################################################################################
|
470 |
+
def visualization(candidate, show=True):
|
471 |
+
corners = candidate.graph.getCornersArray()
|
472 |
+
edges = candidate.graph.getEdgesArray()
|
473 |
+
mask = render(corners, edges)
|
474 |
+
mask = np.transpose(np.concatenate((mask, np.zeros((1, 256, 256))), 0), (1, 2, 0))
|
475 |
+
plt.imshow(mask)
|
476 |
+
if show:
|
477 |
+
plt.show()
|
478 |
+
|
479 |
+
|
480 |
+
def check_intersection(edge1, edge2):
|
481 |
+
corner11 = edge1.x[0].x
|
482 |
+
corner12 = edge1.x[1].x
|
483 |
+
corner21 = edge2.x[0].x
|
484 |
+
corner22 = edge2.x[1].x
|
485 |
+
|
486 |
+
y1 = corner11[0]
|
487 |
+
x1 = corner11[1]
|
488 |
+
y2 = corner12[0]
|
489 |
+
x2 = corner12[1]
|
490 |
+
a = y1 - y2
|
491 |
+
b = x2 - x1
|
492 |
+
c = x1 * y2 - x2 * y1
|
493 |
+
flag1 = (a * corner21[1] + b * corner21[0] + c) * (a * corner22[1] + b * corner22[0] + c)
|
494 |
+
|
495 |
+
y1 = corner21[0]
|
496 |
+
x1 = corner21[1]
|
497 |
+
y2 = corner22[0]
|
498 |
+
x2 = corner22[1]
|
499 |
+
a = y1 - y2
|
500 |
+
b = x2 - x1
|
501 |
+
c = x1 * y2 - x2 * y1
|
502 |
+
flag2 = (a * corner11[1] + b * corner11[0] + c) * (a * corner12[1] + b * corner12[0] + c)
|
503 |
+
|
504 |
+
if flag1 < -1e-6 and flag2 < -1e-6:
|
505 |
+
return True
|
506 |
+
|
507 |
+
return False
|
508 |
+
|
509 |
+
|
510 |
+
def adding_a_corner_by_triangle_operation(candidate):
|
511 |
+
new_candidates = []
|
512 |
+
name = candidate.name
|
513 |
+
gt_mask = region_cache.get_region(name)
|
514 |
+
gt_mask = gt_mask > 0.4
|
515 |
+
gt_mask_grow = cv2.dilate(gt_mask.astype(np.float64), np.ones((3, 3), np.uint8), iterations=6) > 0
|
516 |
+
|
517 |
+
# get the current candidate region mask
|
518 |
+
conv_mask = render(corners=candidate.graph.getCornersArray(), edges=candidate.graph.getEdgesArray(),
|
519 |
+
render_pad=0, edge_linewidth=1)[0]
|
520 |
+
conv_mask = 1 - conv_mask
|
521 |
+
conv_mask = conv_mask.astype(np.uint8)
|
522 |
+
labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
|
523 |
+
|
524 |
+
background_label = region_mask[0, 0]
|
525 |
+
all_masks = []
|
526 |
+
for region_i in range(1, labels):
|
527 |
+
if region_i == background_label:
|
528 |
+
continue
|
529 |
+
the_region = region_mask == region_i
|
530 |
+
if the_region.sum() < 20:
|
531 |
+
continue
|
532 |
+
all_masks.append(the_region)
|
533 |
+
|
534 |
+
candidate_mask = (np.sum(all_masks, 0) + (1 - conv_mask)) > 0
|
535 |
+
|
536 |
+
final_mask = np.logical_xor(gt_mask_grow, np.logical_and(candidate_mask, gt_mask_grow))
|
537 |
+
|
538 |
+
for corner_i in range(random.randint(0, 16), 256, 16):
|
539 |
+
for corner_j in range(random.randint(0, 16), 256, 16):
|
540 |
+
if candidate.addable((corner_i, corner_j)):
|
541 |
+
if final_mask[corner_i, corner_j] == True: # inside the region
|
542 |
+
new_corner = Element((corner_i, corner_j))
|
543 |
+
new_candidate = candidate.generate_new_candidate_add_a_corner(new_corner)
|
544 |
+
new_graph = new_candidate.graph
|
545 |
+
corners = new_graph.getCorners()
|
546 |
+
|
547 |
+
# find two suitable existed corners to make into a triangle (no intersection and no colinear)
|
548 |
+
for id_A in range(len(corners)):
|
549 |
+
ele_A = corners[id_A]
|
550 |
+
if ele_A == new_corner:
|
551 |
+
continue
|
552 |
+
for id_B in range(id_A + 1, len(corners)):
|
553 |
+
ele_B = corners[id_B]
|
554 |
+
if ele_B == new_corner:
|
555 |
+
continue
|
556 |
+
if new_graph.has_edge(new_corner, ele_A) is not None:
|
557 |
+
raise BaseException('should not have edge in this case')
|
558 |
+
if new_graph.has_edge(new_corner, ele_B) is not None:
|
559 |
+
raise BaseException('should not have edge in this case')
|
560 |
+
temp_edge1 = Element((new_corner, ele_A))
|
561 |
+
temp_edge2 = Element((new_corner, ele_B))
|
562 |
+
|
563 |
+
# check if addable
|
564 |
+
if new_candidate.addable(temp_edge1) is False:
|
565 |
+
continue
|
566 |
+
if new_candidate.addable(temp_edge2) is False:
|
567 |
+
continue
|
568 |
+
|
569 |
+
# avoid intersection
|
570 |
+
if new_graph.checkIntersectionEdge(temp_edge1):
|
571 |
+
continue
|
572 |
+
if new_graph.checkIntersectionEdge(temp_edge2):
|
573 |
+
continue
|
574 |
+
|
575 |
+
# avoid too small triangle
|
576 |
+
if triangle_region(new_corner.x, ele_A.x, ele_B.x) < 20:
|
577 |
+
continue
|
578 |
+
|
579 |
+
### avoid colinear edge (only when fold case)
|
580 |
+
# for edge1
|
581 |
+
neighbor_edges = new_graph.getEdgeConnected(temp_edge1)
|
582 |
+
flag_ = True
|
583 |
+
for neighbor in neighbor_edges:
|
584 |
+
if new_corner in neighbor.x:
|
585 |
+
raise BaseException('new corner should not in any edge')
|
586 |
+
elif ele_A in neighbor.x:
|
587 |
+
shared_corner = ele_A
|
588 |
+
else:
|
589 |
+
raise BaseException('error.')
|
590 |
+
two_neighbor = {neighbor.x[0], neighbor.x[1], ele_A, new_corner}
|
591 |
+
two_neighbor.remove(shared_corner)
|
592 |
+
assert len(two_neighbor) == 2
|
593 |
+
two_neighbor = tuple(two_neighbor)
|
594 |
+
|
595 |
+
line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
|
596 |
+
line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
|
597 |
+
cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
|
598 |
+
cos = min(1, max(-1, cos))
|
599 |
+
if np.arccos(cos) < np.pi / 9: # 20 degree
|
600 |
+
flag_ = False
|
601 |
+
break
|
602 |
+
if flag_ is False:
|
603 |
+
continue
|
604 |
+
# for edge2
|
605 |
+
neighbor_edges = new_graph.getEdgeConnected(temp_edge2)
|
606 |
+
flag_ = True
|
607 |
+
for neighbor in neighbor_edges:
|
608 |
+
if new_corner in neighbor.x:
|
609 |
+
raise BaseException('new corner should not in any edge')
|
610 |
+
elif ele_B in neighbor.x:
|
611 |
+
shared_corner = ele_B
|
612 |
+
else:
|
613 |
+
raise BaseException('error.')
|
614 |
+
two_neighbor = {neighbor.x[0], neighbor.x[1], ele_B, new_corner}
|
615 |
+
two_neighbor.remove(shared_corner)
|
616 |
+
assert len(two_neighbor) == 2
|
617 |
+
two_neighbor = tuple(two_neighbor)
|
618 |
+
|
619 |
+
line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
|
620 |
+
line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
|
621 |
+
cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
|
622 |
+
cos = min(1, max(-1, cos))
|
623 |
+
if np.arccos(cos) < np.pi / 9: # 20 degree
|
624 |
+
flag_ = False
|
625 |
+
break
|
626 |
+
if flag_ is False:
|
627 |
+
continue
|
628 |
+
|
629 |
+
# make new candidate
|
630 |
+
try:
|
631 |
+
new_ = new_candidate.generate_new_candidate_add_an_edge(new_corner, ele_A)
|
632 |
+
new_ = new_.generate_new_candidate_add_an_edge(new_corner, ele_B)
|
633 |
+
new_candidates.append(new_)
|
634 |
+
except:
|
635 |
+
continue
|
636 |
+
# plt.subplot(151)
|
637 |
+
# visualization(candidate, show=False)
|
638 |
+
# plt.subplot(152)
|
639 |
+
# plt.imshow(final_mask)
|
640 |
+
# plt.subplot(153)
|
641 |
+
# plt.imshow(candidate_mask)
|
642 |
+
# plt.subplot(154)
|
643 |
+
# plt.imshow(gt_mask_grow)
|
644 |
+
# plt.subplot(155)
|
645 |
+
# visualization(new_, show=False)
|
646 |
+
# plt.show()
|
647 |
+
|
648 |
+
return new_candidates
|
649 |
+
|
650 |
+
|
651 |
+
def adding_an_edge_from_new_corner_operation(candidate):
|
652 |
+
new_candidates = []
|
653 |
+
name = candidate.name
|
654 |
+
gt_mask = region_cache.get_region(name)
|
655 |
+
gt_mask = gt_mask > 0.4
|
656 |
+
gt_mask_grow = cv2.dilate(gt_mask.astype(np.float64), np.ones((3, 3), np.uint8), iterations=6) > 0
|
657 |
+
|
658 |
+
# get the current candidate region mask
|
659 |
+
conv_mask = render(corners=candidate.graph.getCornersArray(), edges=candidate.graph.getEdgesArray(),
|
660 |
+
render_pad=0, edge_linewidth=1)[0]
|
661 |
+
conv_mask = 1 - conv_mask
|
662 |
+
conv_mask = conv_mask.astype(np.uint8)
|
663 |
+
labels, region_mask = cv2.connectedComponents(conv_mask, connectivity=4)
|
664 |
+
background_label = region_mask[0, 0]
|
665 |
+
all_masks = []
|
666 |
+
for region_i in range(1, labels):
|
667 |
+
if region_i == background_label:
|
668 |
+
continue
|
669 |
+
the_region = region_mask == region_i
|
670 |
+
if the_region.sum() < 20:
|
671 |
+
continue
|
672 |
+
all_masks.append(the_region)
|
673 |
+
candidate_mask = (np.sum(all_masks, 0) + (1 - conv_mask)) > 0
|
674 |
+
|
675 |
+
final_mask = np.logical_xor(gt_mask_grow, np.logical_and(candidate_mask, gt_mask_grow))
|
676 |
+
for corner_i in range(random.randint(0, 16), 256, 16):
|
677 |
+
for corner_j in range(random.randint(0, 16), 256, 16):
|
678 |
+
if candidate.addable((corner_i, corner_j)):
|
679 |
+
if final_mask[corner_i, corner_j] == True:
|
680 |
+
# inside the region
|
681 |
+
new_corner = Element((corner_i, corner_j))
|
682 |
+
new_candidate = candidate.generate_new_candidate_add_a_corner(new_corner)
|
683 |
+
new_graph = new_candidate.graph
|
684 |
+
corners = new_graph.getCorners()
|
685 |
+
|
686 |
+
# find a suitable existed corner that can make
|
687 |
+
# a new edge with new_corner (no intersection and colinear)
|
688 |
+
for corner_ele in corners:
|
689 |
+
if corner_ele == new_corner:
|
690 |
+
continue
|
691 |
+
if new_graph.has_edge(new_corner, corner_ele) is not None:
|
692 |
+
raise BaseException('should not have edge in this case')
|
693 |
+
temp_edge = Element((new_corner, corner_ele))
|
694 |
+
|
695 |
+
# check if addable
|
696 |
+
if new_candidate.addable(temp_edge) is False:
|
697 |
+
continue
|
698 |
+
|
699 |
+
# avoid intersection
|
700 |
+
if new_graph.checkIntersectionEdge(temp_edge):
|
701 |
+
continue
|
702 |
+
|
703 |
+
# avoid colinear edge
|
704 |
+
neighbor_edges = new_graph.getEdgeConnected(temp_edge)
|
705 |
+
flag_ = True
|
706 |
+
for neighbor in neighbor_edges:
|
707 |
+
if new_corner in neighbor.x:
|
708 |
+
raise BaseException('new corner should not in any edge')
|
709 |
+
elif corner_ele in neighbor.x:
|
710 |
+
shared_corner = corner_ele
|
711 |
+
else:
|
712 |
+
raise BaseException('error.')
|
713 |
+
two_neighbor = {neighbor.x[0], neighbor.x[1], corner_ele, new_corner}
|
714 |
+
two_neighbor.remove(shared_corner)
|
715 |
+
assert len(two_neighbor) == 2
|
716 |
+
two_neighbor = tuple(two_neighbor)
|
717 |
+
|
718 |
+
line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
|
719 |
+
line2 = np.array(shared_corner.x) - np.array(two_neighbor[1].x)
|
720 |
+
cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
|
721 |
+
cos = min(1, max(-1, cos))
|
722 |
+
if np.arccos(cos) < np.pi / 9: # 20 degree
|
723 |
+
flag_ = False
|
724 |
+
break
|
725 |
+
if flag_ is False:
|
726 |
+
continue
|
727 |
+
|
728 |
+
# make new candidate
|
729 |
+
try:
|
730 |
+
new_ = new_candidate.generate_new_candidate_add_an_edge(new_corner, corner_ele)
|
731 |
+
new_candidates.append(new_)
|
732 |
+
except:
|
733 |
+
continue
|
734 |
+
|
735 |
+
return new_candidates
|
736 |
+
|
737 |
+
|
738 |
+
def removing_a_corner_operation(candidate):
|
739 |
+
new_candidates = []
|
740 |
+
graph = candidate.graph
|
741 |
+
corners = graph.getCorners()
|
742 |
+
for the_corner in corners:
|
743 |
+
if candidate.removable(the_corner):
|
744 |
+
try:
|
745 |
+
new_ = candidate.generate_new_candidate_remove_a_corner(the_corner)
|
746 |
+
new_candidates.append(new_)
|
747 |
+
except:
|
748 |
+
continue
|
749 |
+
|
750 |
+
return new_candidates
|
751 |
+
|
752 |
+
|
753 |
+
def removing_a_colinear_corner_operation(candidate):
|
754 |
+
new_candidates = []
|
755 |
+
graph = candidate.graph
|
756 |
+
corners = graph.getCorners()
|
757 |
+
for the_corner in corners:
|
758 |
+
if candidate.removable(the_corner): # NO NEED TO CHECK IF COLINEAR and graph.checkColinearCorner(the_corner):
|
759 |
+
try:
|
760 |
+
new_ = candidate.generate_new_candidate_remove_a_colinear_corner(the_corner)
|
761 |
+
|
762 |
+
if new_.graph.checkIntersectionEdge():
|
763 |
+
continue
|
764 |
+
new_candidates.append(new_)
|
765 |
+
except:
|
766 |
+
continue
|
767 |
+
|
768 |
+
return new_candidates
|
769 |
+
|
770 |
+
|
771 |
+
def adding_an_edge_operation(candidate):
|
772 |
+
new_candidates = []
|
773 |
+
graph = candidate.graph
|
774 |
+
corners = graph.getCorners()
|
775 |
+
for corner_i in range(len(corners)):
|
776 |
+
cornerA = corners[corner_i]
|
777 |
+
for corner_j in range(corner_i + 1, len(corners)):
|
778 |
+
cornerB = corners[corner_j]
|
779 |
+
if graph.has_edge(cornerA, cornerB) is not None:
|
780 |
+
continue
|
781 |
+
|
782 |
+
temp_edge = Element((cornerA, cornerB))
|
783 |
+
# check if addable (not in existed_before dict)
|
784 |
+
if candidate.addable(temp_edge) is False:
|
785 |
+
continue
|
786 |
+
|
787 |
+
if graph.checkIntersectionEdge(temp_edge):
|
788 |
+
continue
|
789 |
+
|
790 |
+
# avoid adding a colinear edge
|
791 |
+
neighbor_edges = graph.getEdgeConnected(temp_edge)
|
792 |
+
flag_ = True
|
793 |
+
for neighbor in neighbor_edges:
|
794 |
+
if cornerA in neighbor.x:
|
795 |
+
shared_corner = cornerA
|
796 |
+
elif cornerB in neighbor.x:
|
797 |
+
shared_corner = cornerB
|
798 |
+
else:
|
799 |
+
raise BaseException('error.')
|
800 |
+
two_neighbor = {neighbor.x[0], neighbor.x[1], cornerA, cornerB}
|
801 |
+
two_neighbor.remove(shared_corner)
|
802 |
+
assert len(two_neighbor) == 2
|
803 |
+
two_neighbor = tuple(two_neighbor)
|
804 |
+
|
805 |
+
line1 = np.array(shared_corner.x) - np.array(two_neighbor[0].x)
|
806 |
+
line2 = np.array(two_neighbor[1].x) - np.array(shared_corner.x)
|
807 |
+
cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
|
808 |
+
cos = min(1, max(-1, cos))
|
809 |
+
if np.arccos(cos) < np.pi / 18 or np.arccos(cos) > np.pi - np.pi / 18: # 10 degree
|
810 |
+
flag_ = False
|
811 |
+
break
|
812 |
+
if flag_ is False:
|
813 |
+
continue
|
814 |
+
|
815 |
+
# make new candidate
|
816 |
+
try:
|
817 |
+
new_ = candidate.generate_new_candidate_add_an_edge(cornerA, cornerB)
|
818 |
+
new_candidates.append(new_)
|
819 |
+
except:
|
820 |
+
continue
|
821 |
+
|
822 |
+
return new_candidates
|
823 |
+
|
824 |
+
|
825 |
+
def removing_an_edge_operation(candidate):
|
826 |
+
new_candidates = []
|
827 |
+
graph = candidate.graph
|
828 |
+
edges = graph.getEdges()
|
829 |
+
for edge_ele in edges:
|
830 |
+
if candidate.removable(edge_ele):
|
831 |
+
try:
|
832 |
+
new_ = candidate.generate_new_candidate_remove_an_edge(edge_ele)
|
833 |
+
new_candidates.append(new_)
|
834 |
+
except:
|
835 |
+
continue
|
836 |
+
|
837 |
+
return new_candidates
|
838 |
+
|
839 |
+
|
840 |
+
def adding_an_edge_from_gt(candidate, gt_data):
|
841 |
+
new_candidates = []
|
842 |
+
corners_array = candidate.graph.getCornersArray()
|
843 |
+
edges_array = candidate.graph.getEdgesArray()
|
844 |
+
|
845 |
+
gt_corners = gt_data['corners'].copy()
|
846 |
+
gt_edges = gt_data['edges'].copy()
|
847 |
+
|
848 |
+
_, _, map_same_location = get_wrong_corners(
|
849 |
+
corners_array, gt_corners, edges_array, gt_edges)
|
850 |
+
|
851 |
+
gt_corners, gt_edges = simplify_gt(map_same_location, gt_corners, gt_edges)
|
852 |
+
|
853 |
+
_, _, map_same_location = get_wrong_corners(
|
854 |
+
corners_array, gt_corners, edges_array, gt_edges)
|
855 |
+
|
856 |
+
for corner_i in range(gt_corners.shape[0]):
|
857 |
+
if map_same_location[corner_i] is None:
|
858 |
+
# doesn't exist in candidate
|
859 |
+
neighbor_id = get_neighbor_corner_id(corner_i, gt_edges)
|
860 |
+
for corner_j in neighbor_id:
|
861 |
+
if map_same_location[corner_j] is not None:
|
862 |
+
# exist corner in candidate that maps neighbor corner
|
863 |
+
new_candidate = candidate.copy()
|
864 |
+
new_corner = Element(
|
865 |
+
(
|
866 |
+
int(np.round(gt_corners[corner_i, 0])), int(np.round(gt_corners[corner_i, 1]))
|
867 |
+
)
|
868 |
+
)
|
869 |
+
if new_candidate.addable(new_corner) is False:
|
870 |
+
continue
|
871 |
+
# new corner can be too close to an edge
|
872 |
+
flag = False
|
873 |
+
for edge_ele in new_candidate.graph.getEdges():
|
874 |
+
if get_distance_of_corner_and_edge(edge_ele.x[0].x, edge_ele.x[1].x, new_corner.x) < 7:
|
875 |
+
flag = True
|
876 |
+
break
|
877 |
+
if flag:
|
878 |
+
continue
|
879 |
+
|
880 |
+
new_corner = new_candidate.addCorner(new_corner)
|
881 |
+
neighbor_index = map_same_location[corner_j]
|
882 |
+
neighbor_corner = new_candidate.graph.getCorners()[neighbor_index]
|
883 |
+
new_edge = new_candidate.addEdge(new_corner, neighbor_corner)
|
884 |
+
if new_candidate.graph.checkIntersectionEdge(new_edge):
|
885 |
+
continue
|
886 |
+
new_candidates.append(new_candidate)
|
887 |
+
|
888 |
+
return new_candidates
|
889 |
+
|
890 |
+
|
891 |
+
def adding_a_corner_from_two_edges_extension(candidate):
|
892 |
+
new_candidates = []
|
893 |
+
graph = candidate.graph
|
894 |
+
edges = candidate.graph.getEdges()
|
895 |
+
for edge_i in range(len(edges)):
|
896 |
+
for edge_j in range(edge_i + 1, len(edges)):
|
897 |
+
edgeA = edges[edge_i]
|
898 |
+
edgeB = edges[edge_j]
|
899 |
+
if graph.isNeighbor(edgeA, edgeB):
|
900 |
+
continue
|
901 |
+
intersection_loc = get_two_edge_intersection_location(edgeA.x[0].x, edgeA.x[1].x, edgeB.x[0].x,
|
902 |
+
edgeB.x[1].x)
|
903 |
+
if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
|
904 |
+
intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
|
905 |
+
continue
|
906 |
+
# intersection point can not be too close to an edge
|
907 |
+
flag = False
|
908 |
+
for edge_ele in graph.getEdges():
|
909 |
+
if get_distance_of_corner_and_edge(edge_ele.x[0].x, edge_ele.x[1].x, intersection_loc) < 7:
|
910 |
+
flag = True
|
911 |
+
break
|
912 |
+
if flag:
|
913 |
+
continue
|
914 |
+
new_candidate = candidate.copy()
|
915 |
+
new_graph = new_candidate.graph
|
916 |
+
new_edgeA = new_graph.getRealElement(edgeA)
|
917 |
+
new_edgeB = new_graph.getRealElement(edgeB)
|
918 |
+
new_corner = Element(intersection_loc)
|
919 |
+
if new_candidate.addable(new_corner) is False:
|
920 |
+
continue
|
921 |
+
new_corner = new_candidate.addCorner_v2(new_corner)
|
922 |
+
# get cornerA and cornerB from edgeA, edgeB
|
923 |
+
if l2_distance(new_corner.x, new_edgeA.x[0].x) < l2_distance(new_corner.x, new_edgeA.x[1].x):
|
924 |
+
cornerA = new_edgeA.x[0]
|
925 |
+
else:
|
926 |
+
cornerA = new_edgeA.x[1]
|
927 |
+
if l2_distance(new_corner.x, new_edgeB.x[0].x) < l2_distance(new_corner.x, new_edgeB.x[1].x):
|
928 |
+
cornerB = new_edgeB.x[0]
|
929 |
+
else:
|
930 |
+
cornerB = new_edgeB.x[1]
|
931 |
+
|
932 |
+
# new edge can not be too short
|
933 |
+
if l2_distance(cornerA.x, new_corner.x) < 7:
|
934 |
+
continue
|
935 |
+
if l2_distance(cornerB.x, new_corner.x) < 7:
|
936 |
+
continue
|
937 |
+
|
938 |
+
# new intersection cannot be too flat
|
939 |
+
if degree_of_three_corners(cornerA.x, cornerB.x, new_corner.x) > 165:
|
940 |
+
continue
|
941 |
+
|
942 |
+
flag = False
|
943 |
+
for edge_ele in new_graph.getEdges():
|
944 |
+
if new_corner in edge_ele.x and cornerA in edge_ele.x:
|
945 |
+
flag = True
|
946 |
+
break
|
947 |
+
if edge_ele.x[0] not in (new_corner, cornerA):
|
948 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[0].x)
|
949 |
+
if l <= 7:
|
950 |
+
flag = True
|
951 |
+
break
|
952 |
+
if edge_ele.x[1] not in (new_corner, cornerA):
|
953 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[1].x)
|
954 |
+
if l <= 7:
|
955 |
+
flag = True
|
956 |
+
break
|
957 |
+
if flag:
|
958 |
+
continue
|
959 |
+
add_edgeA = new_candidate.addEdge(new_corner, cornerA)
|
960 |
+
if new_graph.checkIntersectionEdge(add_edgeA):
|
961 |
+
continue
|
962 |
+
|
963 |
+
flag = False
|
964 |
+
for edge_ele in new_graph.getEdges():
|
965 |
+
if new_corner in edge_ele.x and cornerB in edge_ele.x:
|
966 |
+
flag = True
|
967 |
+
break
|
968 |
+
if edge_ele.x[0] not in (new_corner, cornerB):
|
969 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[0].x)
|
970 |
+
if l <= 7:
|
971 |
+
flag = True
|
972 |
+
break
|
973 |
+
if edge_ele.x[1] not in (new_corner, cornerB):
|
974 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[1].x)
|
975 |
+
if l <= 7:
|
976 |
+
flag = True
|
977 |
+
break
|
978 |
+
if flag:
|
979 |
+
continue
|
980 |
+
add_edgeB = new_candidate.addEdge(new_corner, cornerB)
|
981 |
+
if new_graph.checkIntersectionEdge(add_edgeB):
|
982 |
+
continue
|
983 |
+
|
984 |
+
# make real new candidate
|
985 |
+
# new_candidate = candidate.copy()
|
986 |
+
# new_graph = new_candidate.graph
|
987 |
+
# new_corner = Element(intersection_loc)
|
988 |
+
# new_corner = new_graph.add_corner_v2(new_corner)
|
989 |
+
# new_candidate = new_candidate.generate_new_candidate_add_an_edge(new_corner, cornerA)
|
990 |
+
# new_candidate = new_candidate.generate_new_candidate_add_an_edge(new_corner, cornerB)
|
991 |
+
|
992 |
+
new_candidates.append(new_candidate)
|
993 |
+
return new_candidates
|
994 |
+
|
995 |
+
|
996 |
+
def adding_a_corner_from_parallel(candidate):
|
997 |
+
new_candidates = []
|
998 |
+
graph = candidate.graph
|
999 |
+
edges = candidate.graph.getEdges()
|
1000 |
+
for edge_i in range(len(edges)):
|
1001 |
+
for edge_j in range(edge_i + 1, len(edges)):
|
1002 |
+
edgeA = edges[edge_i]
|
1003 |
+
edgeB = edges[edge_j]
|
1004 |
+
# get intersection loc
|
1005 |
+
if graph.isNeighbor(edgeA, edgeB):
|
1006 |
+
shared_corner = edgeA.x[0] if edgeA.x[0] in edgeB.x else edgeA.x[1]
|
1007 |
+
intersection_loc = shared_corner.x
|
1008 |
+
else:
|
1009 |
+
intersection_loc = get_two_edge_intersection_location(
|
1010 |
+
edgeA.x[0].x, edgeA.x[1].x, edgeB.x[0].x, edgeB.x[1].x)
|
1011 |
+
if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
|
1012 |
+
intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
|
1013 |
+
continue
|
1014 |
+
|
1015 |
+
# get another two loc
|
1016 |
+
locA = edgeA.x[1].x if \
|
1017 |
+
l2_distance(edgeA.x[0].x, intersection_loc) < l2_distance(edgeA.x[1].x, intersection_loc) else \
|
1018 |
+
edgeA.x[0].x
|
1019 |
+
locB = edgeB.x[1].x if \
|
1020 |
+
l2_distance(edgeB.x[0].x, intersection_loc) < l2_distance(edgeB.x[1].x, intersection_loc) else \
|
1021 |
+
edgeB.x[0].x
|
1022 |
+
|
1023 |
+
# get new loc
|
1024 |
+
new_loc = (locA[0] + locB[0] - intersection_loc[0], locA[1] + locB[1] - intersection_loc[1])
|
1025 |
+
if new_loc[0] >= 255 or new_loc[1] >= 255 or \
|
1026 |
+
new_loc[0] <= 0 or new_loc[1] <= 0:
|
1027 |
+
continue
|
1028 |
+
|
1029 |
+
new_corner = Element(new_loc)
|
1030 |
+
new_candidate = candidate.copy()
|
1031 |
+
new_graph = new_candidate.graph
|
1032 |
+
edgeA = new_graph.getRealElement(edgeA)
|
1033 |
+
edgeB = new_graph.getRealElement(edgeB)
|
1034 |
+
if new_candidate.addable(new_corner) is False:
|
1035 |
+
continue
|
1036 |
+
new_corner = new_candidate.addCorner_v2(new_corner)
|
1037 |
+
# get cornerA and cornerB from edgeA, edgeB
|
1038 |
+
cornerA = edgeA.x[1] if l2_distance(edgeA.x[0].x, intersection_loc) < l2_distance(edgeA.x[1].x,
|
1039 |
+
intersection_loc) \
|
1040 |
+
else edgeA.x[0]
|
1041 |
+
cornerB = edgeB.x[1] if l2_distance(edgeB.x[0].x, intersection_loc) < l2_distance(edgeB.x[1].x,
|
1042 |
+
intersection_loc) \
|
1043 |
+
else edgeB.x[0]
|
1044 |
+
|
1045 |
+
# new edge can not be too short
|
1046 |
+
if l2_distance(cornerA.x, new_corner.x) < 12:
|
1047 |
+
continue
|
1048 |
+
if l2_distance(cornerB.x, new_corner.x) < 12:
|
1049 |
+
continue
|
1050 |
+
|
1051 |
+
flag = False
|
1052 |
+
for edge_ele in new_graph.getEdges():
|
1053 |
+
if new_corner in edge_ele.x and cornerA in edge_ele.x:
|
1054 |
+
flag = True
|
1055 |
+
break
|
1056 |
+
if edge_ele.x[0] not in (new_corner, cornerA):
|
1057 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[0].x)
|
1058 |
+
if l <= 7:
|
1059 |
+
flag = True
|
1060 |
+
break
|
1061 |
+
if edge_ele.x[1] not in (new_corner, cornerA):
|
1062 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerA.x, edge_ele.x[1].x)
|
1063 |
+
if l <= 7:
|
1064 |
+
flag = True
|
1065 |
+
break
|
1066 |
+
if flag:
|
1067 |
+
continue
|
1068 |
+
add_edgeA = new_candidate.addEdge(new_corner, cornerA)
|
1069 |
+
if new_graph.checkIntersectionEdge(add_edgeA):
|
1070 |
+
continue
|
1071 |
+
|
1072 |
+
flag = False
|
1073 |
+
for edge_ele in new_graph.getEdges():
|
1074 |
+
if new_corner in edge_ele.x and cornerB in edge_ele.x:
|
1075 |
+
flag = True
|
1076 |
+
break
|
1077 |
+
if edge_ele.x[0] not in (new_corner, cornerB):
|
1078 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[0].x)
|
1079 |
+
if l <= 7:
|
1080 |
+
flag = True
|
1081 |
+
break
|
1082 |
+
if edge_ele.x[1] not in (new_corner, cornerB):
|
1083 |
+
l = get_distance_of_corner_and_edge(new_corner.x, cornerB.x, edge_ele.x[1].x)
|
1084 |
+
if l <= 7:
|
1085 |
+
flag = True
|
1086 |
+
break
|
1087 |
+
if flag:
|
1088 |
+
continue
|
1089 |
+
add_edgeB = new_candidate.addEdge(new_corner, cornerB)
|
1090 |
+
if new_graph.checkIntersectionEdge(add_edgeB):
|
1091 |
+
continue
|
1092 |
+
|
1093 |
+
new_candidates.append(new_candidate)
|
1094 |
+
return new_candidates
|
1095 |
+
|
1096 |
+
|
1097 |
+
def adding_a_orthogonal_edge(candidate):
|
1098 |
+
new_candidates = []
|
1099 |
+
graph = candidate.graph
|
1100 |
+
edges = candidate.graph.getEdges()
|
1101 |
+
for edge in edges:
|
1102 |
+
cornerA = edge.x[0]
|
1103 |
+
cornerB = edge.x[1]
|
1104 |
+
|
1105 |
+
# get orthogonal direction
|
1106 |
+
dir_ = (cornerA.x[1] - cornerB.x[1], cornerB.x[0] - cornerA.x[0])
|
1107 |
+
|
1108 |
+
for the_corner in edge.x:
|
1109 |
+
temp_orth_loc = (the_corner.x[0] - dir_[0], the_corner.x[1] - dir_[1])
|
1110 |
+
for inter_edge in edges:
|
1111 |
+
if inter_edge == edge:
|
1112 |
+
continue
|
1113 |
+
if the_corner in inter_edge.x:
|
1114 |
+
continue
|
1115 |
+
intersection_loc = get_two_edge_intersection_location(
|
1116 |
+
the_corner.x, temp_orth_loc, inter_edge.x[0].x, inter_edge.x[1].x
|
1117 |
+
)
|
1118 |
+
if intersection_loc[0] >= 255 or intersection_loc[1] >= 255 or \
|
1119 |
+
intersection_loc[0] <= 0 or intersection_loc[1] <= 0:
|
1120 |
+
continue
|
1121 |
+
if np.dot((inter_edge.x[0].x[0] - intersection_loc[0], inter_edge.x[0].x[1] - intersection_loc[1]),
|
1122 |
+
(inter_edge.x[1].x[0] - intersection_loc[0], inter_edge.x[1].x[1] - intersection_loc[1])) > 0:
|
1123 |
+
# which means the intersection is not inside inter_edge but at the edge extension
|
1124 |
+
continue
|
1125 |
+
if l2_distance(intersection_loc, inter_edge.x[0].x) < 5 or \
|
1126 |
+
l2_distance(intersection_loc, inter_edge.x[1].x) < 5:
|
1127 |
+
continue
|
1128 |
+
|
1129 |
+
# no thin degree with neighbor edge
|
1130 |
+
flag = False
|
1131 |
+
neighbor_corners = graph.getNeighborCorner(the_corner)
|
1132 |
+
for corner_ele in neighbor_corners:
|
1133 |
+
if corner_ele in edge.x:
|
1134 |
+
continue
|
1135 |
+
if degree_of_three_corners(corner_ele.x, intersection_loc, the_corner.x) < 15:
|
1136 |
+
flag = True
|
1137 |
+
break
|
1138 |
+
if degree_of_three_corners(corner_ele.x, intersection_loc, the_corner.x) > 165:
|
1139 |
+
flag = True
|
1140 |
+
break
|
1141 |
+
if flag:
|
1142 |
+
continue
|
1143 |
+
|
1144 |
+
new_candidate = candidate.copy()
|
1145 |
+
new_graph = new_candidate.graph
|
1146 |
+
new_corner = Element(intersection_loc)
|
1147 |
+
if new_candidate.addable(new_corner) is False:
|
1148 |
+
continue
|
1149 |
+
new_corner = new_candidate.addCorner_v2(new_corner)
|
1150 |
+
|
1151 |
+
# new edge can not be too short
|
1152 |
+
if l2_distance(new_corner.x, the_corner.x) < 7:
|
1153 |
+
continue
|
1154 |
+
|
1155 |
+
add_edge = new_candidate.addEdge(new_corner, new_graph.getRealElement(the_corner))
|
1156 |
+
if new_graph.checkIntersectionEdge(add_edge):
|
1157 |
+
continue
|
1158 |
+
|
1159 |
+
new_candidates.append(new_candidate)
|
1160 |
+
return new_candidates
|
1161 |
+
|
1162 |
+
|
1163 |
+
class _thread(threading.Thread):
|
1164 |
+
def __init__(self, threadID, name, candidate, lock, result_list, func):
|
1165 |
+
threading.Thread.__init__(self)
|
1166 |
+
self.threadID = threadID
|
1167 |
+
self.name = name
|
1168 |
+
self.candidate = candidate
|
1169 |
+
self.lock = lock
|
1170 |
+
self.result_list = result_list
|
1171 |
+
self.func = func
|
1172 |
+
|
1173 |
+
def run(self):
|
1174 |
+
print('running id: ', self.name)
|
1175 |
+
start_time = time.time()
|
1176 |
+
candidates = self.func(self.candidate)
|
1177 |
+
print('test: =================================', self.name, len(candidates))
|
1178 |
+
self.lock.acquire()
|
1179 |
+
self.result_list.extend(candidates)
|
1180 |
+
self.lock.release()
|
1181 |
+
print(self.name, "spend time: {}s".format(time.time() - start_time))
|
1182 |
+
|
1183 |
+
|
1184 |
+
def candidate_enumerate_training(candidate, gt):
|
1185 |
+
new_candidates = []
|
1186 |
+
# remove a corner
|
1187 |
+
try:
|
1188 |
+
new_ = removing_a_corner_operation(candidate)
|
1189 |
+
if len(new_) > 0:
|
1190 |
+
new_candidates.append(random.choice(new_))
|
1191 |
+
except:
|
1192 |
+
print('something wrong with remove a corner !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
1193 |
+
|
1194 |
+
# remove a colinear corner
|
1195 |
+
try:
|
1196 |
+
new_ = removing_a_colinear_corner_operation(candidate)
|
1197 |
+
if len(new_) > 0:
|
1198 |
+
new_candidates.append(random.choice(new_))
|
1199 |
+
except:
|
1200 |
+
print('something wrong with remove a colinear corner !!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
1201 |
+
|
1202 |
+
# remove an edge
|
1203 |
+
try:
|
1204 |
+
new_ = removing_an_edge_operation(candidate)
|
1205 |
+
if len(new_) > 0:
|
1206 |
+
new_candidates.append(random.choice(new_))
|
1207 |
+
except:
|
1208 |
+
print('something wrong with remove an edge !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
1209 |
+
|
1210 |
+
# add an edge from existed corner
|
1211 |
+
try:
|
1212 |
+
new_ = adding_an_edge_operation(candidate)
|
1213 |
+
if len(new_) > 0:
|
1214 |
+
new_candidates.append(random.choice(new_))
|
1215 |
+
except:
|
1216 |
+
print('something wrong with add an edge from existed corner !!!!!!!!!!!!!!!!!!!!')
|
1217 |
+
|
1218 |
+
# add a corner from two edges
|
1219 |
+
try:
|
1220 |
+
new_ = adding_a_corner_from_two_edges_extension(candidate)
|
1221 |
+
if len(new_) > 0:
|
1222 |
+
new_candidates.append(random.choice(new_))
|
1223 |
+
except:
|
1224 |
+
print('something wrong with add a corner from two edges !!!!!!!!!!!!!!!!!!!!!!!!')
|
1225 |
+
|
1226 |
+
try:
|
1227 |
+
new_ = adding_a_corner_from_parallel(candidate)
|
1228 |
+
if len(new_) > 0:
|
1229 |
+
new_candidates.append(random.choice(new_))
|
1230 |
+
except:
|
1231 |
+
print('something wrong with add a corner from parallel !!!!!!!!!!!!!!!!!!!!!!!!')
|
1232 |
+
|
1233 |
+
# add an edge from gt
|
1234 |
+
try:
|
1235 |
+
new_ = adding_an_edge_from_gt(candidate, gt)
|
1236 |
+
if len(new_) > 0:
|
1237 |
+
new_candidates.append(random.choice(new_))
|
1238 |
+
except:
|
1239 |
+
print('something wrong with add an edge from gt !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
1240 |
+
|
1241 |
+
# add a orthogonal edge
|
1242 |
+
try:
|
1243 |
+
new_ = adding_a_orthogonal_edge(candidate)
|
1244 |
+
if len(new_) > 0:
|
1245 |
+
new_candidates.append(random.choice(new_))
|
1246 |
+
except:
|
1247 |
+
print('something wrong with add a orthogonal edge !!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
1248 |
+
return new_candidates
|
1249 |
+
|
1250 |
+
|
1251 |
+
def candidate_enumerate(candidate):
|
1252 |
+
new_candidates = []
|
1253 |
+
new_candidates.extend(removing_a_corner_operation(candidate))
|
1254 |
+
new_candidates.extend(removing_a_colinear_corner_operation(candidate))
|
1255 |
+
new_candidates.extend(removing_an_edge_operation(candidate))
|
1256 |
+
new_candidates.extend(adding_an_edge_operation(candidate))
|
1257 |
+
new_candidates.extend(adding_a_corner_from_two_edges_extension(candidate))
|
1258 |
+
new_candidates.extend(adding_a_corner_from_parallel(candidate))
|
1259 |
+
new_candidates.extend(adding_a_orthogonal_edge(candidate))
|
1260 |
+
|
1261 |
+
return new_candidates
|
1262 |
+
|
1263 |
+
|
1264 |
+
def candidate_enumerate_thread(candidate):
|
1265 |
+
new_candidates = []
|
1266 |
+
lock = threading.Lock()
|
1267 |
+
|
1268 |
+
thread1 = _thread(1, 'remove_a_corner', candidate, lock, new_candidates, removing_a_corner_operation)
|
1269 |
+
thread2 = _thread(2, 'remove_a_colinear_corner', candidate, lock, new_candidates,
|
1270 |
+
removing_a_colinear_corner_operation)
|
1271 |
+
thread3 = _thread(3, 'add_an_edge', candidate, lock, new_candidates, adding_an_edge_operation)
|
1272 |
+
thread4 = _thread(4, 'remove_an_edge', candidate, lock, new_candidates, removing_an_edge_operation)
|
1273 |
+
|
1274 |
+
thread1.start()
|
1275 |
+
thread2.start()
|
1276 |
+
thread3.start()
|
1277 |
+
thread4.start()
|
1278 |
+
|
1279 |
+
threads = []
|
1280 |
+
threads.append(thread1)
|
1281 |
+
threads.append(thread2)
|
1282 |
+
threads.append(thread3)
|
1283 |
+
threads.append(thread4)
|
1284 |
+
|
1285 |
+
for t in threads:
|
1286 |
+
t.join()
|
1287 |
+
|
1288 |
+
return new_candidates
|
1289 |
+
|
1290 |
+
|
1291 |
+
def reduce_duplicate_candidate(candidates):
|
1292 |
+
i = 0
|
1293 |
+
while i < len(candidates):
|
1294 |
+
for j in reversed(range(i + 1, len(candidates))):
|
1295 |
+
if candidates[i].equal(candidates[j]):
|
1296 |
+
del candidates[j]
|
1297 |
+
i = i + 1
|
1298 |
+
return candidates
|
1299 |
+
|
1300 |
+
|
1301 |
+
def save_candidate_image(candidate, base_path, base_name):
|
1302 |
+
corners = candidate.graph.getCornersArray()
|
1303 |
+
edges = candidate.graph.getEdgesArray()
|
1304 |
+
# graph svg
|
1305 |
+
svg = svg_generate(corners, edges, base_name, samecolor=True)
|
1306 |
+
svg.saveas(os.path.join(base_path, base_name + '.svg'))
|
1307 |
+
# corner image
|
1308 |
+
temp_mask = np.zeros((256, 256))
|
1309 |
+
for ele in candidate.graph.getCorners():
|
1310 |
+
if ele.get_score() < 0:
|
1311 |
+
temp_mask = cv2.circle(temp_mask, ele.x[::-1], 3, 1, -1)
|
1312 |
+
fig = plt.figure(frameon=False)
|
1313 |
+
fig.set_size_inches(1, 1)
|
1314 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
1315 |
+
ax.set_axis_off()
|
1316 |
+
fig.add_axes(ax)
|
1317 |
+
ax.imshow(temp_mask, aspect='auto')
|
1318 |
+
fig.savefig(os.path.join(base_path, base_name + '_corner.png'), dpi=256)
|
1319 |
+
# edges image
|
1320 |
+
temp_mask = np.zeros((256, 256))
|
1321 |
+
for ele in candidate.graph.getEdges():
|
1322 |
+
if ele.get_score() < 0:
|
1323 |
+
A = ele.x[0]
|
1324 |
+
B = ele.x[1]
|
1325 |
+
temp_mask = cv2.line(temp_mask, A.x[::-1], B.x[::-1], 1, thickness=1)
|
1326 |
+
ax.imshow(temp_mask, aspect='auto')
|
1327 |
+
fig.savefig(os.path.join(base_path, base_name + '_edge.png'), dpi=256)
|
1328 |
+
# region no need fig
|
1329 |
+
plt.close()
|
1330 |
+
|
1331 |
+
|
1332 |
+
#########################################################################################
|
1333 |
+
###################################### Class ############################################
|
1334 |
+
#########################################################################################
|
1335 |
+
|
1336 |
+
class Element:
|
1337 |
+
def __init__(self, x, safe_count=0):
|
1338 |
+
assert type(x) is tuple
|
1339 |
+
assert type(x[0]) == int or type(x[0]) == Element
|
1340 |
+
assert type(x[1]) == int or type(x[1]) == Element
|
1341 |
+
self.x = x
|
1342 |
+
self.__score = None
|
1343 |
+
self.safe_count = safe_count
|
1344 |
+
|
1345 |
+
def store_score(self, score):
|
1346 |
+
self.__score = score
|
1347 |
+
|
1348 |
+
def get_score(self):
|
1349 |
+
return self.__score
|
1350 |
+
|
1351 |
+
def equal(self, ele):
|
1352 |
+
if type(self.x[0]) != type(ele.x[0]):
|
1353 |
+
return False
|
1354 |
+
if type(self.x[0]) == int:
|
1355 |
+
# corner
|
1356 |
+
return True if self.x[0] == ele.x[0] and self.x[1] == ele.x[1] else False
|
1357 |
+
if type(self.x[0]) == Element:
|
1358 |
+
# edge
|
1359 |
+
if self.x[0].equal(ele.x[0]) and self.x[1].equal(ele.x[1]):
|
1360 |
+
return True
|
1361 |
+
if self.x[1].equal(ele.x[0]) and self.x[0].equal(ele.x[1]):
|
1362 |
+
return True
|
1363 |
+
return False
|
1364 |
+
raise BaseException('no implement type')
|
1365 |
+
|
1366 |
+
|
1367 |
+
class regionCache():
|
1368 |
+
def __init__(self, datapath):
|
1369 |
+
self.cache = {}
|
1370 |
+
self.datapath = datapath
|
1371 |
+
|
1372 |
+
def get_region(self, name):
|
1373 |
+
if name in self.cache.keys():
|
1374 |
+
return self.cache[name]
|
1375 |
+
gt_mask = np.load(os.path.join(self.datapath, name + '.npy'))
|
1376 |
+
if len(self.cache) == 5:
|
1377 |
+
self.cache.pop(list(self.cache.keys())[0])
|
1378 |
+
self.cache[name] = gt_mask
|
1379 |
+
return gt_mask
|
1380 |
+
|
1381 |
+
|
1382 |
+
class imgCache():
|
1383 |
+
def __init__(self, datapath):
|
1384 |
+
self.cache = {}
|
1385 |
+
self.datapath = datapath
|
1386 |
+
|
1387 |
+
def get_image(self, name):
|
1388 |
+
if name in self.cache.keys():
|
1389 |
+
return self.cache[name]
|
1390 |
+
img = skimage.img_as_float(plt.imread(os.path.join(self.datapath, 'rgb', name + '.jpg')))
|
1391 |
+
if len(self.cache) == 5:
|
1392 |
+
self.cache.pop(list(self.cache.keys())[0])
|
1393 |
+
self.cache[name] = img
|
1394 |
+
return img
|
1395 |
+
|
1396 |
+
|
1397 |
+
class Graph:
|
1398 |
+
def __init__(self, corners, edges):
|
1399 |
+
corners, edges = sort_graph(corners, edges)
|
1400 |
+
|
1401 |
+
self.__corners = []
|
1402 |
+
for corner_i in range(corners.shape[0]):
|
1403 |
+
self.__corners.append(
|
1404 |
+
Element(
|
1405 |
+
tuple(
|
1406 |
+
(int(corners[corner_i, 0]), int(corners[corner_i, 1]))
|
1407 |
+
)
|
1408 |
+
)
|
1409 |
+
)
|
1410 |
+
self.__edges = []
|
1411 |
+
for edge_i in range(edges.shape[0]):
|
1412 |
+
self.__edges.append(Element((self.__corners[edges[edge_i, 0]], self.__corners[edges[edge_i, 1]])))
|
1413 |
+
self.__regions = []
|
1414 |
+
self.__regions.append(Element((0, 0))) # we use entire region here
|
1415 |
+
|
1416 |
+
@classmethod
|
1417 |
+
def initialFromTuple(cls, corners, edges):
|
1418 |
+
edge_index = []
|
1419 |
+
for item in edges:
|
1420 |
+
a = corners.index(item[0])
|
1421 |
+
b = corners.index(item[1])
|
1422 |
+
edge_index.append((a, b))
|
1423 |
+
edge_index = np.array(edge_index)
|
1424 |
+
corners = np.array(corners)
|
1425 |
+
return cls(corners, edge_index)
|
1426 |
+
|
1427 |
+
def store_score(self, corner_score=None, edge_score=None, region_score=None):
|
1428 |
+
'''
|
1429 |
+
:param corner_score: np array size: len(corners)
|
1430 |
+
:param edge_score: np array size: len(edges)
|
1431 |
+
:param region_score: np.array size: len(regions)
|
1432 |
+
:return:
|
1433 |
+
'''
|
1434 |
+
if corner_score is not None:
|
1435 |
+
for idx, element in enumerate(self.__corners):
|
1436 |
+
element.store_score(corner_score[idx])
|
1437 |
+
if edge_score is not None:
|
1438 |
+
for idx, element in enumerate(self.__edges):
|
1439 |
+
element.store_score(edge_score[idx])
|
1440 |
+
if region_score is not None:
|
1441 |
+
for idx, element in enumerate(self.__regions):
|
1442 |
+
element.store_score(region_score[idx])
|
1443 |
+
return
|
1444 |
+
|
1445 |
+
def getCornersArray(self):
|
1446 |
+
c = []
|
1447 |
+
for ele in self.__corners:
|
1448 |
+
c.append(ele.x)
|
1449 |
+
return np.array(c)
|
1450 |
+
|
1451 |
+
def getEdgesArray(self):
|
1452 |
+
c = []
|
1453 |
+
for ele in self.__edges:
|
1454 |
+
corner1 = ele.x[0]
|
1455 |
+
corner2 = ele.x[1]
|
1456 |
+
idx1 = self.__corners.index(corner1)
|
1457 |
+
idx2 = self.__corners.index(corner2)
|
1458 |
+
c.append([idx1, idx2])
|
1459 |
+
return np.array(c)
|
1460 |
+
|
1461 |
+
def getCorners(self):
|
1462 |
+
return self.__corners
|
1463 |
+
|
1464 |
+
def getRegions(self):
|
1465 |
+
return self.__regions
|
1466 |
+
|
1467 |
+
def getEdges(self):
|
1468 |
+
return self.__edges
|
1469 |
+
|
1470 |
+
def graph_score(self):
|
1471 |
+
corner_score = 0
|
1472 |
+
for ele in self.__corners:
|
1473 |
+
corner_score += ele.get_score()
|
1474 |
+
edge_score = 0
|
1475 |
+
for ele in self.__edges:
|
1476 |
+
edge_score += ele.get_score()
|
1477 |
+
region_score = 0
|
1478 |
+
for ele in self.__regions:
|
1479 |
+
region_score += ele.get_score()
|
1480 |
+
return score_weights[0] * corner_score + score_weights[1] * edge_score + score_weights[2] * region_score
|
1481 |
+
|
1482 |
+
def corner_score(self):
|
1483 |
+
corner_score = 0
|
1484 |
+
for ele in self.__corners:
|
1485 |
+
corner_score += ele.get_score()
|
1486 |
+
return corner_score
|
1487 |
+
|
1488 |
+
def edge_score(self):
|
1489 |
+
edge_score = 0
|
1490 |
+
for ele in self.__edges:
|
1491 |
+
edge_score += ele.get_score()
|
1492 |
+
return edge_score
|
1493 |
+
|
1494 |
+
def region_score(self):
|
1495 |
+
region_score = 0
|
1496 |
+
for ele in self.__regions:
|
1497 |
+
region_score += ele.get_score()
|
1498 |
+
return region_score
|
1499 |
+
|
1500 |
+
def remove(self, ele):
|
1501 |
+
'''
|
1502 |
+
:param ele: remove eles as well as some other related elements
|
1503 |
+
:return: set() of removed elements
|
1504 |
+
'''
|
1505 |
+
# corner
|
1506 |
+
removed = set()
|
1507 |
+
if ele in self.__corners:
|
1508 |
+
self.__corners.remove(ele)
|
1509 |
+
removed.add(ele)
|
1510 |
+
# remove edge that has the corner
|
1511 |
+
for idx in reversed(range(len(self.__edges))):
|
1512 |
+
edge_ele = self.__edges[idx]
|
1513 |
+
if ele in edge_ele.x:
|
1514 |
+
removed = removed.union(self.remove(edge_ele))
|
1515 |
+
# edge
|
1516 |
+
elif ele in self.__edges:
|
1517 |
+
self.__edges.remove(ele)
|
1518 |
+
removed.add(ele)
|
1519 |
+
corner1 = ele.x[0]
|
1520 |
+
corner2 = ele.x[1]
|
1521 |
+
if corner1.safe_count == 0:
|
1522 |
+
# can be delete
|
1523 |
+
_count = 0
|
1524 |
+
for edge_ele in self.__edges:
|
1525 |
+
if corner1 in edge_ele.x:
|
1526 |
+
_count += 1
|
1527 |
+
if _count == 0:
|
1528 |
+
removed = removed.union(self.remove(corner1))
|
1529 |
+
if corner2.safe_count == 0:
|
1530 |
+
# can be delete
|
1531 |
+
_count = 0
|
1532 |
+
for edge_ele in self.__edges:
|
1533 |
+
if corner2 in edge_ele.x:
|
1534 |
+
_count += 1
|
1535 |
+
if _count == 0:
|
1536 |
+
removed = removed.union(self.remove(corner2))
|
1537 |
+
return removed
|
1538 |
+
|
1539 |
+
def has_edge(self, ele1, ele2):
|
1540 |
+
"""
|
1541 |
+
:param ele1: corner1
|
1542 |
+
:param ele2: corner2
|
1543 |
+
:return: edge or none
|
1544 |
+
"""
|
1545 |
+
for edge_ele in self.__edges:
|
1546 |
+
if ele1 in edge_ele.x and ele2 in edge_ele.x:
|
1547 |
+
return edge_ele
|
1548 |
+
return None
|
1549 |
+
|
1550 |
+
def add_edge(self, ele1, ele2):
|
1551 |
+
temp = self.has_edge(ele1, ele2)
|
1552 |
+
if temp is not None:
|
1553 |
+
temp.safe_count = SAFE_NUM
|
1554 |
+
return temp
|
1555 |
+
new_ele = Element((ele1, ele2), safe_count=SAFE_NUM)
|
1556 |
+
self.__edges.append(new_ele)
|
1557 |
+
return new_ele
|
1558 |
+
|
1559 |
+
def add_corner(self, ele):
|
1560 |
+
for corner in self.__corners:
|
1561 |
+
if corner.x == ele.x:
|
1562 |
+
corner.safe_count = SAFE_NUM
|
1563 |
+
return corner
|
1564 |
+
ele.safe_count = SAFE_NUM
|
1565 |
+
self.__corners.append(ele)
|
1566 |
+
return ele
|
1567 |
+
|
1568 |
+
def add_corner_v2(self, ele):
|
1569 |
+
# if new corner is near a existed corner, return the existed corner
|
1570 |
+
# if new corner is on an edge, split edge
|
1571 |
+
for corner in self.__corners:
|
1572 |
+
if l2_distance(corner.x, ele.x) < 5:
|
1573 |
+
corner.safe_count = SAFE_NUM
|
1574 |
+
return corner
|
1575 |
+
min_d = 256
|
1576 |
+
the_edge = None
|
1577 |
+
for edge in self.__edges:
|
1578 |
+
temp = get_distance_of_corner_and_edge(edge.x[0].x, edge.x[1].x, ele.x)
|
1579 |
+
if temp < min_d:
|
1580 |
+
min_d = temp
|
1581 |
+
the_edge = edge
|
1582 |
+
if min_d < 3:
|
1583 |
+
# split edge
|
1584 |
+
corner1 = the_edge.x[0]
|
1585 |
+
corner2 = the_edge.x[1]
|
1586 |
+
new_ele = Element((corner1, ele), safe_count=the_edge.safe_count)
|
1587 |
+
self.__edges.append(new_ele)
|
1588 |
+
new_ele = Element((corner2, ele), safe_count=the_edge.safe_count)
|
1589 |
+
self.__edges.append(new_ele)
|
1590 |
+
self.__edges.remove(the_edge)
|
1591 |
+
ele.safe_count = SAFE_NUM
|
1592 |
+
self.__corners.append(ele)
|
1593 |
+
return ele
|
1594 |
+
|
1595 |
+
def checkColinearCorner(self, ele):
|
1596 |
+
if self.getCornerDegree(ele) != 2:
|
1597 |
+
return False
|
1598 |
+
edge_in = []
|
1599 |
+
for edge_ele in self.__edges:
|
1600 |
+
if ele in edge_ele.x:
|
1601 |
+
edge_in.append(edge_ele)
|
1602 |
+
if len(edge_in) == 2:
|
1603 |
+
break
|
1604 |
+
two_neighbor = {edge_in[0].x[0], edge_in[0].x[1], edge_in[1].x[0], edge_in[1].x[1]}
|
1605 |
+
two_neighbor.remove(ele)
|
1606 |
+
two_neighbor = tuple(two_neighbor)
|
1607 |
+
if self.has_edge(two_neighbor[0], two_neighbor[1]) is not None:
|
1608 |
+
return False
|
1609 |
+
|
1610 |
+
line1 = np.array(ele.x) - np.array(two_neighbor[0].x)
|
1611 |
+
line2 = np.array(two_neighbor[1].x) - np.array(ele.x)
|
1612 |
+
cos = np.dot(line1, line2) / (np.linalg.norm(line1) * np.linalg.norm(line2))
|
1613 |
+
cos = min(1, max(-1, cos))
|
1614 |
+
if np.arccos(cos) < np.pi / 9: # 20 degree
|
1615 |
+
return True
|
1616 |
+
return False
|
1617 |
+
|
1618 |
+
def checkIntersectionEdge(self, ele=None):
|
1619 |
+
if ele is None:
|
1620 |
+
for edge_i in range(len(self.__edges)):
|
1621 |
+
for edge_j in range(edge_i + 1, len(self.__edges)):
|
1622 |
+
if check_intersection(self.__edges[edge_i], self.__edges[edge_j]):
|
1623 |
+
return True
|
1624 |
+
return False
|
1625 |
+
for edge_ele in self.__edges:
|
1626 |
+
if ele == edge_ele:
|
1627 |
+
continue
|
1628 |
+
if check_intersection(edge_ele, ele):
|
1629 |
+
return True
|
1630 |
+
return False
|
1631 |
+
|
1632 |
+
def getCornerDegree(self, ele):
|
1633 |
+
degree = 0
|
1634 |
+
for edge_ele in self.__edges:
|
1635 |
+
if ele in edge_ele.x:
|
1636 |
+
degree += 1
|
1637 |
+
return degree
|
1638 |
+
|
1639 |
+
def getEdgeConnected(self, ele):
|
1640 |
+
out_ = set()
|
1641 |
+
if type(ele.x[0]) == int:
|
1642 |
+
# corner
|
1643 |
+
for edge_ele in self.__edges:
|
1644 |
+
if ele in edge_ele.x:
|
1645 |
+
out_.add(edge_ele)
|
1646 |
+
return out_
|
1647 |
+
if type(ele.x[0]) == Element:
|
1648 |
+
# Edge
|
1649 |
+
out_ = out_.union(self.getEdgeConnected(ele.x[0]))
|
1650 |
+
out_ = out_.union(self.getEdgeConnected(ele.x[1]))
|
1651 |
+
if ele in out_:
|
1652 |
+
out_.remove(ele)
|
1653 |
+
return out_
|
1654 |
+
|
1655 |
+
def getNeighborCorner(self, ele):
|
1656 |
+
out_ = set()
|
1657 |
+
for edge_ele in self.__edges:
|
1658 |
+
if ele == edge_ele.x[0]:
|
1659 |
+
out_.add(edge_ele.x[1])
|
1660 |
+
if ele == edge_ele.x[1]:
|
1661 |
+
out_.add(edge_ele.x[0])
|
1662 |
+
return out_
|
1663 |
+
|
1664 |
+
def getRealElement(self, ele):
|
1665 |
+
# edge
|
1666 |
+
if type(ele.x[0]) == Element:
|
1667 |
+
for e in self.__edges:
|
1668 |
+
if (e.x[0].x == ele.x[0].x and e.x[1].x == ele.x[1].x) or \
|
1669 |
+
(e.x[1].x == ele.x[0].x and e.x[0].x == ele.x[1].x):
|
1670 |
+
return e
|
1671 |
+
raise BaseException("no same edge exists.")
|
1672 |
+
# corner
|
1673 |
+
elif type(ele.x[0]) == int:
|
1674 |
+
for c in self.__corners:
|
1675 |
+
if c.x == ele.x:
|
1676 |
+
return c
|
1677 |
+
raise BaseException("no same corner exists.")
|
1678 |
+
|
1679 |
+
def copy(self):
|
1680 |
+
corners = self.getCornersArray()
|
1681 |
+
edges = self.getEdgesArray()
|
1682 |
+
new_graph = Graph(corners, edges)
|
1683 |
+
for idx, ele in enumerate(self.__corners):
|
1684 |
+
new_graph.__corners[idx].store_score(self.__corners[idx].get_score())
|
1685 |
+
for idx, ele in enumerate(self.__edges):
|
1686 |
+
new_graph.__edges[idx].store_score(self.__edges[idx].get_score())
|
1687 |
+
for idx, ele in enumerate(self.__regions):
|
1688 |
+
new_graph.__regions[idx].store_score(self.__regions[idx].get_score)
|
1689 |
+
return new_graph
|
1690 |
+
|
1691 |
+
def update_safe_count(self):
|
1692 |
+
for ele in self.__corners:
|
1693 |
+
if ele.safe_count > 0:
|
1694 |
+
ele.safe_count -= 1
|
1695 |
+
for ele in self.__edges:
|
1696 |
+
if ele.safe_count > 0:
|
1697 |
+
ele.safe_count -= 1
|
1698 |
+
|
1699 |
+
def isNeighbor(self, element1, element2):
|
1700 |
+
'''
|
1701 |
+
:param element1:
|
1702 |
+
:param element2:
|
1703 |
+
:return: True / False
|
1704 |
+
'''
|
1705 |
+
if element1 == element2:
|
1706 |
+
return False
|
1707 |
+
if type(element1.x[0]) != type(element2.x[0]):
|
1708 |
+
# corner and edge
|
1709 |
+
return False
|
1710 |
+
if type(element1.x[0]) == int:
|
1711 |
+
# both are corner type
|
1712 |
+
for edge_ele in self.__edges:
|
1713 |
+
if edge_ele.x[0] == element1 and edge_ele.x[1] == element2:
|
1714 |
+
return True
|
1715 |
+
if edge_ele.x[0] == element2 and edge_ele.x[1] == element1:
|
1716 |
+
return True
|
1717 |
+
return False
|
1718 |
+
if type(element1.x[0]) == Element:
|
1719 |
+
# both are edge type
|
1720 |
+
if len({element1.x[0], element1.x[1], element2.x[0], element2.x[1]}) < 4:
|
1721 |
+
return True
|
1722 |
+
return False
|
1723 |
+
|
1724 |
+
def equal(self, graph):
|
1725 |
+
if len(self.__corners) != len(graph.__corners) or \
|
1726 |
+
len(self.__edges) != len(graph.__edges):
|
1727 |
+
return False
|
1728 |
+
for corner_i in range(len(self.__corners)):
|
1729 |
+
if self.__corners[corner_i].equal(graph.__corners[corner_i]) is False:
|
1730 |
+
return False
|
1731 |
+
for edge_i in range(len(self.__edges)):
|
1732 |
+
if self.__edges[edge_i].equal(graph.__edges[edge_i]) is False:
|
1733 |
+
return False
|
1734 |
+
|
1735 |
+
return True
|
1736 |
+
|
1737 |
+
|
1738 |
+
class Candidate:
|
1739 |
+
def __init__(self, graph, name, corner_existed_before, edge_existed_before):
|
1740 |
+
'''
|
1741 |
+
:param graph: Class graph
|
1742 |
+
:param name: string, data name
|
1743 |
+
:param corner_existed_before: dict {(x_i,y_i):c_1 ...} indicates counts for corresponding corners, after one search,
|
1744 |
+
counts -= 1, if count == 0, remove from the set.
|
1745 |
+
:param edge_existed_before: dict {((x_i1,y_i1),(x_i2,y_i2)):ci}
|
1746 |
+
'''
|
1747 |
+
self.graph = graph
|
1748 |
+
self.name = name
|
1749 |
+
self.corner_existed_before = corner_existed_before
|
1750 |
+
self.edge_existed_before = edge_existed_before
|
1751 |
+
|
1752 |
+
@classmethod
|
1753 |
+
def initial(cls, graph, name):
|
1754 |
+
return cls(graph, name, {}, {})
|
1755 |
+
|
1756 |
+
def update(self):
|
1757 |
+
# all the existed before elements count - 1
|
1758 |
+
for key in self.corner_existed_before.keys():
|
1759 |
+
self.corner_existed_before[key] -= 1
|
1760 |
+
for key in self.edge_existed_before.keys():
|
1761 |
+
self.edge_existed_before[key] -= 1
|
1762 |
+
|
1763 |
+
# check if some need to remove from existed before set
|
1764 |
+
for key in list(self.corner_existed_before.keys()):
|
1765 |
+
if self.corner_existed_before[key] == 0:
|
1766 |
+
self.corner_existed_before.pop(key)
|
1767 |
+
|
1768 |
+
for key in list(self.edge_existed_before.keys()):
|
1769 |
+
if self.edge_existed_before[key] == 0:
|
1770 |
+
self.edge_existed_before.pop(key)
|
1771 |
+
|
1772 |
+
# update graph
|
1773 |
+
self.graph.update_safe_count()
|
1774 |
+
|
1775 |
+
def copy(self):
|
1776 |
+
corner_existed_before = self.corner_existed_before.copy()
|
1777 |
+
edge_existed_before = self.edge_existed_before.copy()
|
1778 |
+
new_graph = self.graph.copy()
|
1779 |
+
return Candidate(new_graph, self.name, corner_existed_before, edge_existed_before)
|
1780 |
+
|
1781 |
+
def removable(self, ele):
|
1782 |
+
'''
|
1783 |
+
:param x: input is element
|
1784 |
+
:return:
|
1785 |
+
'''
|
1786 |
+
assert type(ele) == Element
|
1787 |
+
# edge
|
1788 |
+
return True if ele.safe_count == 0 else False
|
1789 |
+
|
1790 |
+
def addable(self, ele):
|
1791 |
+
if type(ele) == Element:
|
1792 |
+
if type(ele.x[0]) == Element:
|
1793 |
+
# edge
|
1794 |
+
for edge in self.graph.getEdges():
|
1795 |
+
c1 = edge.x[0]
|
1796 |
+
c2 = edge.x[1]
|
1797 |
+
if (ele.x[0].x == c1.x and ele.x[1].x == c2.x) or \
|
1798 |
+
(ele.x[1].x == c1.x and ele.x[0].x == c2.x):
|
1799 |
+
# already existed
|
1800 |
+
return False
|
1801 |
+
corner1_loc = ele.x[0].x
|
1802 |
+
corner2_loc = ele.x[1].x
|
1803 |
+
if (corner1_loc, corner2_loc) in self.edge_existed_before.keys() or \
|
1804 |
+
(corner2_loc, corner1_loc) in self.edge_existed_before.keys():
|
1805 |
+
return False
|
1806 |
+
return True
|
1807 |
+
else:
|
1808 |
+
# corner
|
1809 |
+
for corner in self.graph.getCorners():
|
1810 |
+
if l2_distance(ele.x, corner.x) < TWO_CORNER_MINIMUM_DISTANCE:
|
1811 |
+
# already existed
|
1812 |
+
return False
|
1813 |
+
if ele.x in self.corner_existed_before.keys():
|
1814 |
+
return False
|
1815 |
+
return True
|
1816 |
+
else: # (x,y) or ((x1,y1),(x2,y2))
|
1817 |
+
if type(ele[0]) == tuple:
|
1818 |
+
# edge
|
1819 |
+
corner1_loc = ele[0]
|
1820 |
+
corner2_loc = ele[1]
|
1821 |
+
for edge in self.graph.getEdges():
|
1822 |
+
c1 = edge.x[0]
|
1823 |
+
c2 = edge.x[1]
|
1824 |
+
if (corner1_loc == c1.x and corner2_loc == c2.x) or \
|
1825 |
+
(corner2_loc == c1.x and corner1_loc == c2.x):
|
1826 |
+
# already existed
|
1827 |
+
return False
|
1828 |
+
if (corner1_loc, corner2_loc) in self.edge_existed_before.keys() or \
|
1829 |
+
(corner2_loc, corner1_loc) in self.edge_existed_before.keys():
|
1830 |
+
return False
|
1831 |
+
return True
|
1832 |
+
else:
|
1833 |
+
# corner
|
1834 |
+
for corner in self.graph.getCorners():
|
1835 |
+
if l2_distance(ele, corner.x) < TWO_CORNER_MINIMUM_DISTANCE:
|
1836 |
+
# already existed
|
1837 |
+
return False
|
1838 |
+
if ele in self.corner_existed_before.keys():
|
1839 |
+
return False
|
1840 |
+
return True
|
1841 |
+
|
1842 |
+
def addCorner(self, ele):
|
1843 |
+
if ele.x in self.corner_existed_before.keys():
|
1844 |
+
raise BaseException('cannot add the corner')
|
1845 |
+
new_ele = self.graph.add_corner(ele) # possible changed
|
1846 |
+
return new_ele
|
1847 |
+
|
1848 |
+
def addCorner_v2(self, ele):
|
1849 |
+
if ele.x in self.corner_existed_before.keys():
|
1850 |
+
raise BaseException('cannot add the corner')
|
1851 |
+
new_ele = self.graph.add_corner_v2(ele)
|
1852 |
+
return new_ele
|
1853 |
+
|
1854 |
+
def addEdge(self, ele1, ele2):
|
1855 |
+
corner1 = ele1
|
1856 |
+
corner2 = ele2
|
1857 |
+
assert corner1 in self.graph.getCorners()
|
1858 |
+
assert corner2 in self.graph.getCorners()
|
1859 |
+
if (corner1.x, corner2.x) in self.edge_existed_before.keys() or \
|
1860 |
+
(corner2.x, corner1.x) in self.edge_existed_before.keys():
|
1861 |
+
raise BaseException('cannot add the edge')
|
1862 |
+
new_ele = self.graph.add_edge(corner1, corner2)
|
1863 |
+
return new_ele
|
1864 |
+
|
1865 |
+
def removeCorner(self, ele):
|
1866 |
+
if ele.x in self.corner_existed_before.keys():
|
1867 |
+
raise BaseException('already existed.')
|
1868 |
+
self.corner_existed_before[ele.x] = SAFE_NUM
|
1869 |
+
|
1870 |
+
def removeEdge(self, ele):
|
1871 |
+
corner1 = ele.x[0]
|
1872 |
+
corner2 = ele.x[1]
|
1873 |
+
loc1 = corner1.x
|
1874 |
+
loc2 = corner2.x
|
1875 |
+
if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
|
1876 |
+
loc1 = corner2.x
|
1877 |
+
loc2 = corner1.x
|
1878 |
+
if (loc1, loc2) in self.edge_existed_before.keys():
|
1879 |
+
raise BaseException('already existed.')
|
1880 |
+
self.edge_existed_before[(loc1, loc2)] = SAFE_NUM
|
1881 |
+
|
1882 |
+
def generate_new_candidate_remove_a_colinear_corner(self, ele):
|
1883 |
+
# need to check if ele is a colinear corner before
|
1884 |
+
new_candidate = self.copy()
|
1885 |
+
new_graph = new_candidate.graph
|
1886 |
+
ele = new_graph.getRealElement(ele)
|
1887 |
+
|
1888 |
+
# find two neighbor corners
|
1889 |
+
temp = set()
|
1890 |
+
for element in new_graph.getEdgeConnected(ele):
|
1891 |
+
# edge
|
1892 |
+
if type(element.x[0]) == Element:
|
1893 |
+
temp.add(element.x[0])
|
1894 |
+
temp.add(element.x[1])
|
1895 |
+
temp.remove(ele)
|
1896 |
+
temp = tuple(temp)
|
1897 |
+
assert len(temp) == 2
|
1898 |
+
|
1899 |
+
# add edge to two neighbor corners
|
1900 |
+
# (add before remove, in case the neighbor corners will be removed by zero degree)
|
1901 |
+
# special case no need to check existed_before, instead remove if in existed_before dict
|
1902 |
+
added = new_graph.add_edge(temp[0], temp[1])
|
1903 |
+
if (temp[0].x, temp[1].x) in self.edge_existed_before.keys():
|
1904 |
+
self.edge_existed_before.pop((temp[0].x, temp[1].x))
|
1905 |
+
if (temp[1].x, temp[0].x) in self.edge_existed_before.keys():
|
1906 |
+
self.edge_existed_before.pop((temp[1].x, temp[0].x))
|
1907 |
+
|
1908 |
+
# remove
|
1909 |
+
removed = new_graph.remove(ele)
|
1910 |
+
|
1911 |
+
# add removed elements into existed before
|
1912 |
+
for element in removed:
|
1913 |
+
# edge
|
1914 |
+
if type(element.x[0]) == Element:
|
1915 |
+
new_candidate.removeEdge(element)
|
1916 |
+
# corner
|
1917 |
+
elif type(element.x[0]) == int:
|
1918 |
+
new_candidate.removeCorner(element)
|
1919 |
+
else:
|
1920 |
+
raise BaseException('wrong type.')
|
1921 |
+
|
1922 |
+
# modify scores that need to be recounted
|
1923 |
+
# all corners are recounted
|
1924 |
+
for element in new_graph.getCorners():
|
1925 |
+
element.store_score(None)
|
1926 |
+
|
1927 |
+
# edges that are neighbors to the removed edges OR new edges will be recounted
|
1928 |
+
for element in new_graph.getEdges():
|
1929 |
+
for modified_ele in removed.union({added}):
|
1930 |
+
if new_graph.isNeighbor(element, modified_ele):
|
1931 |
+
element.store_score(None)
|
1932 |
+
break
|
1933 |
+
|
1934 |
+
# all regions are recounted
|
1935 |
+
for element in new_graph.getRegions():
|
1936 |
+
element.store_score(None)
|
1937 |
+
|
1938 |
+
return new_candidate
|
1939 |
+
|
1940 |
+
def generate_new_candidate_remove_a_corner(self, ele):
|
1941 |
+
# need to check if ele is removable before call this method
|
1942 |
+
new_candidate = self.copy()
|
1943 |
+
new_graph = new_candidate.graph
|
1944 |
+
ele = new_graph.getRealElement(ele)
|
1945 |
+
removed = new_graph.remove(ele)
|
1946 |
+
|
1947 |
+
# add removed elements into existed before
|
1948 |
+
for element in removed:
|
1949 |
+
# edge
|
1950 |
+
if type(element.x[0]) == Element:
|
1951 |
+
corner1 = element.x[0]
|
1952 |
+
corner2 = element.x[1]
|
1953 |
+
loc1 = corner1.x
|
1954 |
+
loc2 = corner2.x
|
1955 |
+
if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
|
1956 |
+
loc1 = corner2.x
|
1957 |
+
loc2 = corner1.x
|
1958 |
+
if (loc1, loc2) in self.edge_existed_before.keys():
|
1959 |
+
raise BaseException('already existed.')
|
1960 |
+
new_candidate.edge_existed_before[(loc1, loc2)] = SAFE_NUM
|
1961 |
+
# corner
|
1962 |
+
elif type(element.x[0]) == int:
|
1963 |
+
if element.x in self.corner_existed_before.keys():
|
1964 |
+
raise BaseException('already existed.')
|
1965 |
+
new_candidate.corner_existed_before[element.x] = SAFE_NUM
|
1966 |
+
else:
|
1967 |
+
raise BaseException('wrong type.')
|
1968 |
+
|
1969 |
+
# modify scores that need to be recounted
|
1970 |
+
# all corners are recounted
|
1971 |
+
for element in new_graph.getCorners():
|
1972 |
+
element.store_score(None)
|
1973 |
+
|
1974 |
+
# edges that are neighbors to the removed edges will be recounted
|
1975 |
+
for element in new_graph.getEdges():
|
1976 |
+
for removed_ele in removed:
|
1977 |
+
if new_graph.isNeighbor(element, removed_ele):
|
1978 |
+
element.store_score(None)
|
1979 |
+
break
|
1980 |
+
|
1981 |
+
# all regions are recounted
|
1982 |
+
for element in new_graph.getRegions():
|
1983 |
+
element.store_score(None)
|
1984 |
+
|
1985 |
+
return new_candidate
|
1986 |
+
|
1987 |
+
def generate_new_candidate_add_an_edge(self, ele1, ele2):
|
1988 |
+
# need to check addable before call this method
|
1989 |
+
new_candidate = self.copy()
|
1990 |
+
new_graph = new_candidate.graph
|
1991 |
+
ele1 = new_graph.getRealElement(ele1)
|
1992 |
+
ele2 = new_graph.getRealElement(ele2)
|
1993 |
+
|
1994 |
+
# add edge
|
1995 |
+
new_ele = new_candidate.addEdge(ele1, ele2)
|
1996 |
+
|
1997 |
+
# modify scores that need to be recounted
|
1998 |
+
# all corners are recounted
|
1999 |
+
for element in new_graph.getCorners():
|
2000 |
+
element.store_score(None)
|
2001 |
+
|
2002 |
+
# edges that are neighbors to the added edges will be recounted
|
2003 |
+
for element in new_graph.getEdges():
|
2004 |
+
if new_graph.isNeighbor(element, new_ele):
|
2005 |
+
element.store_score(None)
|
2006 |
+
|
2007 |
+
# all regions are recounted
|
2008 |
+
for element in new_graph.getRegions():
|
2009 |
+
element.store_score(None)
|
2010 |
+
|
2011 |
+
return new_candidate
|
2012 |
+
|
2013 |
+
def generate_new_candidate_remove_an_edge(self, ele):
|
2014 |
+
# need to check if ele is removable before call this method
|
2015 |
+
new_candidate = self.copy()
|
2016 |
+
new_graph = new_candidate.graph
|
2017 |
+
ele = new_graph.getRealElement(ele)
|
2018 |
+
removed = new_graph.remove(ele)
|
2019 |
+
|
2020 |
+
# add removed elements into existed before
|
2021 |
+
for element in removed:
|
2022 |
+
# edge
|
2023 |
+
if type(element.x[0]) == Element:
|
2024 |
+
corner1 = element.x[0]
|
2025 |
+
corner2 = element.x[1]
|
2026 |
+
loc1 = corner1.x
|
2027 |
+
loc2 = corner2.x
|
2028 |
+
if (loc1[0] > loc2[0]) or (loc1[0] == loc2[0] and loc1[1] > loc2[1]):
|
2029 |
+
loc1 = corner2.x
|
2030 |
+
loc2 = corner1.x
|
2031 |
+
if (loc1, loc2) in self.edge_existed_before.keys():
|
2032 |
+
raise BaseException('already existed.')
|
2033 |
+
new_candidate.edge_existed_before[(loc1, loc2)] = SAFE_NUM
|
2034 |
+
# corner
|
2035 |
+
elif type(element.x[0]) == int:
|
2036 |
+
if element.x in self.corner_existed_before.keys():
|
2037 |
+
raise BaseException('already existed.')
|
2038 |
+
new_candidate.corner_existed_before[element.x] = SAFE_NUM
|
2039 |
+
else:
|
2040 |
+
raise BaseException('wrong type.')
|
2041 |
+
|
2042 |
+
# modify scores that need to be recounted
|
2043 |
+
# all corners are recounted
|
2044 |
+
for element in new_graph.getCorners():
|
2045 |
+
element.store_score(None)
|
2046 |
+
|
2047 |
+
# edges that are neighbors to the removed edges will be recounted
|
2048 |
+
for element in new_graph.getEdges():
|
2049 |
+
for removed_ele in removed:
|
2050 |
+
if new_graph.isNeighbor(element, removed_ele):
|
2051 |
+
element.store_score(None)
|
2052 |
+
break
|
2053 |
+
|
2054 |
+
# all regions are recounted
|
2055 |
+
for element in new_graph.getRegions():
|
2056 |
+
element.store_score(None)
|
2057 |
+
|
2058 |
+
return new_candidate
|
2059 |
+
|
2060 |
+
def generate_new_candidate_add_a_new_triangle(self, ele_new, ele1, ele2):
|
2061 |
+
# this method is to add a new corner as well as two new edges into the graph
|
2062 |
+
# need to check addable of ele_new before call this method
|
2063 |
+
new_candidate = self.copy()
|
2064 |
+
new_graph = new_candidate.graph
|
2065 |
+
ele1 = new_graph.getRealElement(ele1)
|
2066 |
+
ele2 = new_graph.getRealElement(ele2)
|
2067 |
+
|
2068 |
+
# add corner
|
2069 |
+
ele_new = new_candidate.addCorner(ele_new) # ele_new possible change
|
2070 |
+
|
2071 |
+
# no score need to be recounted in current situation
|
2072 |
+
|
2073 |
+
# add two_new edge (ele1, ele_new) and (ele2, ele_new)
|
2074 |
+
new_candidate = new_candidate.generate_new_candidate_add_an_edge(ele_new, ele1)
|
2075 |
+
new_candidate = new_candidate.generate_new_candidate_add_an_edge(ele_new, ele2)
|
2076 |
+
|
2077 |
+
return new_candidate
|
2078 |
+
|
2079 |
+
def generate_new_candidate_add_a_corner(self, ele):
|
2080 |
+
# need to check addable of ele before call this method
|
2081 |
+
new_candidate = self.copy()
|
2082 |
+
new_graph = new_candidate.graph
|
2083 |
+
|
2084 |
+
# add corner
|
2085 |
+
ele = new_candidate.addCorner(ele)
|
2086 |
+
|
2087 |
+
# modify scores that need to be recounted
|
2088 |
+
# all corners are recounted
|
2089 |
+
for element in new_graph.getCorners():
|
2090 |
+
element.store_score(None)
|
2091 |
+
|
2092 |
+
# no edge need to be recounted
|
2093 |
+
# all regions are recounted
|
2094 |
+
for element in new_graph.getRegions():
|
2095 |
+
element.store_score(None)
|
2096 |
+
|
2097 |
+
return new_candidate
|
2098 |
+
|
2099 |
+
def equal(self, candidate):
|
2100 |
+
return self.graph.equal(candidate.graph)
|
models/__init__.py
ADDED
File without changes
|
models/corner_models.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
|
7 |
+
DeformableTransformerDecoder, DeformableAttnDecoderLayer
|
8 |
+
from models.ops.modules import MSDeformAttn
|
9 |
+
from models.resnet import convrelu
|
10 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from utils.misc import NestedTensor
|
13 |
+
|
14 |
+
|
15 |
+
class HeatCorner(nn.Module):
|
16 |
+
"""
|
17 |
+
The corner model of HEAT is the edge model till the edge-filtering part. So only per-candidate prediction w/o
|
18 |
+
relational modeling.
|
19 |
+
"""
|
20 |
+
def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
|
21 |
+
super(HeatCorner, self).__init__()
|
22 |
+
self.input_dim = input_dim
|
23 |
+
self.hidden_dim = hidden_dim
|
24 |
+
self.num_feature_levels = num_feature_levels
|
25 |
+
|
26 |
+
if num_feature_levels > 1:
|
27 |
+
num_backbone_outs = len(backbone_strides)
|
28 |
+
input_proj_list = []
|
29 |
+
for _ in range(num_backbone_outs):
|
30 |
+
in_channels = backbone_num_channels[_]
|
31 |
+
input_proj_list.append(nn.Sequential(
|
32 |
+
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
|
33 |
+
nn.GroupNorm(32, hidden_dim),
|
34 |
+
))
|
35 |
+
for _ in range(num_feature_levels - num_backbone_outs):
|
36 |
+
input_proj_list.append(nn.Sequential(
|
37 |
+
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
|
38 |
+
nn.GroupNorm(32, hidden_dim),
|
39 |
+
))
|
40 |
+
in_channels = hidden_dim
|
41 |
+
self.input_proj = nn.ModuleList(input_proj_list)
|
42 |
+
else:
|
43 |
+
self.input_proj = nn.ModuleList([
|
44 |
+
nn.Sequential(
|
45 |
+
nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
|
46 |
+
nn.GroupNorm(32, hidden_dim),
|
47 |
+
)])
|
48 |
+
|
49 |
+
self.patch_size = 4
|
50 |
+
patch_dim = (self.patch_size ** 2) * input_dim
|
51 |
+
self.to_patch_embedding = nn.Sequential(
|
52 |
+
Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size),
|
53 |
+
nn.Linear(patch_dim, input_dim),
|
54 |
+
nn.Linear(input_dim, hidden_dim),
|
55 |
+
)
|
56 |
+
|
57 |
+
self.pixel_pe_fc = nn.Linear(input_dim, hidden_dim)
|
58 |
+
self.transformer = CornerTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
|
59 |
+
dim_feedforward=1024, dropout=0.1)
|
60 |
+
|
61 |
+
self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def get_ms_feat(xs, img_mask):
|
65 |
+
out: Dict[str, NestedTensor] = {}
|
66 |
+
for name, x in sorted(xs.items()):
|
67 |
+
m = img_mask
|
68 |
+
assert m is not None
|
69 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
70 |
+
out[name] = NestedTensor(x, mask)
|
71 |
+
return out
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def get_decoder_reference_points(height, width, device):
|
75 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
76 |
+
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device))
|
77 |
+
ref_y = ref_y.reshape(-1)[None] / height
|
78 |
+
ref_x = ref_x.reshape(-1)[None] / width
|
79 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
80 |
+
return ref
|
81 |
+
|
82 |
+
def forward(self, image_feats, feat_mask, pixels_feat, pixels, all_image_feats):
|
83 |
+
# process image features
|
84 |
+
features = self.get_ms_feat(image_feats, feat_mask)
|
85 |
+
|
86 |
+
srcs = []
|
87 |
+
masks = []
|
88 |
+
all_pos = []
|
89 |
+
|
90 |
+
new_features = list()
|
91 |
+
for name, x in sorted(features.items()):
|
92 |
+
new_features.append(x)
|
93 |
+
features = new_features
|
94 |
+
|
95 |
+
for l, feat in enumerate(features):
|
96 |
+
src, mask = feat.decompose()
|
97 |
+
mask = mask.to(src.device)
|
98 |
+
srcs.append(self.input_proj[l](src))
|
99 |
+
pos = self.img_pos(src).to(src.dtype)
|
100 |
+
all_pos.append(pos)
|
101 |
+
masks.append(mask)
|
102 |
+
assert mask is not None
|
103 |
+
|
104 |
+
if self.num_feature_levels > len(srcs):
|
105 |
+
_len_srcs = len(srcs)
|
106 |
+
for l in range(_len_srcs, self.num_feature_levels):
|
107 |
+
if l == _len_srcs:
|
108 |
+
src = self.input_proj[l](features[-1].tensors)
|
109 |
+
else:
|
110 |
+
src = self.input_proj[l](srcs[-1])
|
111 |
+
m = feat_mask
|
112 |
+
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
|
113 |
+
pos_l = self.img_pos(src).to(src.dtype)
|
114 |
+
srcs.append(src)
|
115 |
+
masks.append(mask)
|
116 |
+
all_pos.append(pos_l)
|
117 |
+
|
118 |
+
sp_inputs = self.to_patch_embedding(pixels_feat)
|
119 |
+
|
120 |
+
# compute the reference points
|
121 |
+
H_tgt = W_tgt = int(np.sqrt(sp_inputs.shape[1]))
|
122 |
+
reference_points_s1 = self.get_decoder_reference_points(H_tgt, W_tgt, sp_inputs.device)
|
123 |
+
|
124 |
+
corner_logits = self.transformer(srcs, masks, all_pos, sp_inputs, reference_points_s1, all_image_feats)
|
125 |
+
return corner_logits
|
126 |
+
|
127 |
+
|
128 |
+
class PositionEmbeddingSine(nn.Module):
|
129 |
+
"""
|
130 |
+
This is a more standard version of the position embedding, very similar to the one
|
131 |
+
used by the Attention is all you need paper, generalized to work on images.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
135 |
+
super().__init__()
|
136 |
+
self.num_pos_feats = num_pos_feats
|
137 |
+
self.temperature = temperature
|
138 |
+
self.normalize = normalize
|
139 |
+
if scale is not None and normalize is False:
|
140 |
+
raise ValueError("normalize should be True if scale is passed")
|
141 |
+
if scale is None:
|
142 |
+
scale = 2 * math.pi
|
143 |
+
self.scale = scale
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]]).bool().to(x.device)
|
147 |
+
not_mask = ~mask
|
148 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
149 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
150 |
+
if self.normalize:
|
151 |
+
eps = 1e-6
|
152 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
153 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
154 |
+
|
155 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
156 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
157 |
+
|
158 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
159 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
160 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
161 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
162 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
163 |
+
return pos
|
164 |
+
|
165 |
+
|
166 |
+
class CornerTransformer(nn.Module):
|
167 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
168 |
+
dim_feedforward=1024, dropout=0.1,
|
169 |
+
activation="relu", return_intermediate_dec=False,
|
170 |
+
num_feature_levels=4, dec_n_points=4, enc_n_points=4,
|
171 |
+
):
|
172 |
+
super(CornerTransformer, self).__init__()
|
173 |
+
|
174 |
+
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
|
175 |
+
dropout, activation,
|
176 |
+
num_feature_levels, nhead, enc_n_points)
|
177 |
+
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
|
178 |
+
|
179 |
+
decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
|
180 |
+
dropout, activation,
|
181 |
+
num_feature_levels, nhead, dec_n_points)
|
182 |
+
self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
|
183 |
+
|
184 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
185 |
+
|
186 |
+
# upconv layers
|
187 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
188 |
+
self.conv_up1 = convrelu(256 + 256, 256, 3, 1)
|
189 |
+
self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
|
190 |
+
self.conv_original_size2 = convrelu(64 + 128, d_model, 3, 1)
|
191 |
+
self.output_fc_1 = nn.Linear(d_model, 1)
|
192 |
+
self.output_fc_2 = nn.Linear(d_model, 1)
|
193 |
+
|
194 |
+
self._reset_parameters()
|
195 |
+
|
196 |
+
def _reset_parameters(self):
|
197 |
+
for p in self.parameters():
|
198 |
+
if p.dim() > 1:
|
199 |
+
nn.init.xavier_uniform_(p)
|
200 |
+
for m in self.modules():
|
201 |
+
if isinstance(m, MSDeformAttn):
|
202 |
+
m._reset_parameters()
|
203 |
+
normal_(self.level_embed)
|
204 |
+
|
205 |
+
def get_valid_ratio(self, mask):
|
206 |
+
_, H, W = mask.shape
|
207 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
208 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
209 |
+
valid_ratio_h = valid_H.float() / H
|
210 |
+
valid_ratio_w = valid_W.float() / W
|
211 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
212 |
+
return valid_ratio
|
213 |
+
|
214 |
+
def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, all_image_feats):
|
215 |
+
# prepare input for encoder
|
216 |
+
src_flatten = []
|
217 |
+
mask_flatten = []
|
218 |
+
lvl_pos_embed_flatten = []
|
219 |
+
spatial_shapes = []
|
220 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
221 |
+
bs, c, h, w = src.shape
|
222 |
+
spatial_shape = (h, w)
|
223 |
+
spatial_shapes.append(spatial_shape)
|
224 |
+
src = src.flatten(2).transpose(1, 2)
|
225 |
+
mask = mask.flatten(1)
|
226 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
227 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
228 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
229 |
+
src_flatten.append(src)
|
230 |
+
mask_flatten.append(mask)
|
231 |
+
src_flatten = torch.cat(src_flatten, 1)
|
232 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
233 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
234 |
+
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
235 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
236 |
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
237 |
+
|
238 |
+
# encoder
|
239 |
+
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
|
240 |
+
mask_flatten)
|
241 |
+
|
242 |
+
# prepare input for decoder
|
243 |
+
bs, _, c = memory.shape
|
244 |
+
|
245 |
+
tgt = query_embed
|
246 |
+
|
247 |
+
# relational decoder
|
248 |
+
hs_pixels_s1, _ = self.per_edge_decoder(tgt, reference_points, memory,
|
249 |
+
spatial_shapes, level_start_index, valid_ratios, query_embed,
|
250 |
+
mask_flatten)
|
251 |
+
|
252 |
+
feats_s1, preds_s1 = self.generate_corner_preds(hs_pixels_s1, all_image_feats)
|
253 |
+
|
254 |
+
return preds_s1
|
255 |
+
|
256 |
+
def generate_corner_preds(self, outputs, conv_outputs):
|
257 |
+
B, L, C = outputs.shape
|
258 |
+
side = int(np.sqrt(L))
|
259 |
+
outputs = outputs.view(B, side, side, C)
|
260 |
+
outputs = outputs.permute(0, 3, 1, 2)
|
261 |
+
outputs = torch.cat([outputs, conv_outputs['layer1']], dim=1)
|
262 |
+
x = self.conv_up1(outputs)
|
263 |
+
|
264 |
+
x = self.upsample(x)
|
265 |
+
x = torch.cat([x, conv_outputs['layer0']], dim=1)
|
266 |
+
x = self.conv_up0(x)
|
267 |
+
|
268 |
+
x = self.upsample(x)
|
269 |
+
x = torch.cat([x, conv_outputs['x_original']], dim=1)
|
270 |
+
x = self.conv_original_size2(x)
|
271 |
+
|
272 |
+
logits = x.permute(0, 2, 3, 1)
|
273 |
+
preds = self.output_fc_1(logits)
|
274 |
+
preds = preds.squeeze(-1).sigmoid()
|
275 |
+
return logits, preds
|
models/corner_to_edge.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import scipy.ndimage.filters as filters
|
4 |
+
import cv2
|
5 |
+
import itertools
|
6 |
+
|
7 |
+
NEIGHBOUR_SIZE = 5
|
8 |
+
MATCH_THRESH = 5
|
9 |
+
LOCAL_MAX_THRESH = 0.01
|
10 |
+
viz_count = 0
|
11 |
+
|
12 |
+
# pre-compute all combinations to generate edge candidates faster
|
13 |
+
all_combibations = dict()
|
14 |
+
for length in range(2, 351):
|
15 |
+
ids = np.arange(length)
|
16 |
+
combs = np.array(list(itertools.combinations(ids, 2)))
|
17 |
+
all_combibations[length] = combs
|
18 |
+
|
19 |
+
|
20 |
+
def prepare_edge_data(c_outputs, annots, images, max_corner_num):
|
21 |
+
bs = c_outputs.shape[0]
|
22 |
+
# prepares parameters for each sample of the batch
|
23 |
+
all_results = list()
|
24 |
+
|
25 |
+
for b_i in range(bs):
|
26 |
+
annot = annots[b_i]
|
27 |
+
output = c_outputs[b_i]
|
28 |
+
results = process_each_sample({'annot': annot, 'output': output, 'viz_img': images[b_i]}, max_corner_num)
|
29 |
+
all_results.append(results)
|
30 |
+
|
31 |
+
processed_corners = [item['corners'] for item in all_results]
|
32 |
+
edge_coords = [item['edges'] for item in all_results]
|
33 |
+
edge_labels = [item['labels'] for item in all_results]
|
34 |
+
|
35 |
+
edge_info = {
|
36 |
+
'edge_coords': edge_coords,
|
37 |
+
'edge_labels': edge_labels,
|
38 |
+
'processed_corners': processed_corners
|
39 |
+
}
|
40 |
+
|
41 |
+
edge_data = collate_edge_info(edge_info)
|
42 |
+
return edge_data
|
43 |
+
|
44 |
+
|
45 |
+
def process_annot(annot, do_round=True):
|
46 |
+
corners = np.array(list(annot.keys()))
|
47 |
+
ind = np.lexsort(corners.T) # sort the g.t. corners to fix the order for the matching later
|
48 |
+
corners = corners[ind] # sorted by y, then x
|
49 |
+
corner_mapping = {tuple(k): v for v, k in enumerate(corners)}
|
50 |
+
|
51 |
+
edges = list()
|
52 |
+
for c, connections in annot.items():
|
53 |
+
for other_c in connections:
|
54 |
+
edge_pair = (corner_mapping[c], corner_mapping[tuple(other_c)])
|
55 |
+
edges.append(edge_pair)
|
56 |
+
corner_degrees = [len(annot[tuple(c)]) for c in corners]
|
57 |
+
if do_round:
|
58 |
+
corners = corners.round()
|
59 |
+
return corners, edges, corner_degrees
|
60 |
+
|
61 |
+
|
62 |
+
def process_each_sample(data, max_corner_num):
|
63 |
+
annot = data['annot']
|
64 |
+
output = data['output']
|
65 |
+
|
66 |
+
preds = output.detach().cpu().numpy()
|
67 |
+
|
68 |
+
data_max = filters.maximum_filter(preds, NEIGHBOUR_SIZE)
|
69 |
+
maxima = (preds == data_max)
|
70 |
+
data_min = filters.minimum_filter(preds, NEIGHBOUR_SIZE)
|
71 |
+
diff = ((data_max - data_min) > 0)
|
72 |
+
maxima[diff == 0] = 0
|
73 |
+
local_maximas = np.where((maxima > 0) & (preds > LOCAL_MAX_THRESH))
|
74 |
+
pred_corners = np.stack(local_maximas, axis=-1)[:, [1, 0]] # to (x, y format)
|
75 |
+
|
76 |
+
# produce edge labels labels from pred corners here
|
77 |
+
|
78 |
+
processed_corners, edges, labels = get_edge_label_mix_gt(pred_corners, annot, max_corner_num)
|
79 |
+
# global viz_count
|
80 |
+
# viz_img = data['viz_img']
|
81 |
+
#output_path = './viz_training/{}_example_gt.png'.format(viz_count)
|
82 |
+
#_visualize_edge_training_data(processed_corners, edges, labels, viz_img, output_path)
|
83 |
+
#viz_count += 1
|
84 |
+
|
85 |
+
results = {
|
86 |
+
'corners': processed_corners,
|
87 |
+
'edges': edges,
|
88 |
+
'labels': labels,
|
89 |
+
}
|
90 |
+
return results
|
91 |
+
|
92 |
+
|
93 |
+
def get_edge_label_mix_gt(pred_corners, annot, max_corner_num):
|
94 |
+
ind = np.lexsort(pred_corners.T) # sort the pred corners to fix the order for matching
|
95 |
+
pred_corners = pred_corners[ind] # sorted by y, then x
|
96 |
+
gt_corners, edge_pairs, corner_degrees = process_annot(annot)
|
97 |
+
|
98 |
+
output_to_gt = dict()
|
99 |
+
gt_to_output = dict()
|
100 |
+
diff = np.sqrt(((pred_corners[:, None] - gt_corners) ** 2).sum(-1))
|
101 |
+
diff = diff.T
|
102 |
+
|
103 |
+
if len(pred_corners) > 0:
|
104 |
+
for target_i, target in enumerate(gt_corners):
|
105 |
+
dist = diff[target_i]
|
106 |
+
if len(output_to_gt) > 0:
|
107 |
+
dist[list(output_to_gt.keys())] = 1000 # ignore already matched pred corners
|
108 |
+
min_dist = dist.min()
|
109 |
+
min_idx = dist.argmin()
|
110 |
+
if min_dist < MATCH_THRESH and min_idx not in output_to_gt: # a positive match
|
111 |
+
output_to_gt[min_idx] = (target_i, min_dist)
|
112 |
+
gt_to_output[target_i] = min_idx
|
113 |
+
|
114 |
+
all_corners = gt_corners.copy()
|
115 |
+
|
116 |
+
# replace matched g.t. corners with pred corners
|
117 |
+
for gt_i in range(len(gt_corners)):
|
118 |
+
if gt_i in gt_to_output:
|
119 |
+
all_corners[gt_i] = pred_corners[gt_to_output[gt_i]]
|
120 |
+
|
121 |
+
nm_pred_ids = [i for i in range(len(pred_corners)) if i not in output_to_gt]
|
122 |
+
nm_pred_ids = np.random.permutation(nm_pred_ids)
|
123 |
+
if len(nm_pred_ids) > 0:
|
124 |
+
nm_pred_corners = pred_corners[nm_pred_ids]
|
125 |
+
#if len(nm_pred_ids) + len(all_corners) <= 150:
|
126 |
+
if len(nm_pred_ids) + len(all_corners) <= max_corner_num:
|
127 |
+
all_corners = np.concatenate([all_corners, nm_pred_corners], axis=0)
|
128 |
+
else:
|
129 |
+
#all_corners = np.concatenate([all_corners, nm_pred_corners[:(150 - len(gt_corners)), :]], axis=0)
|
130 |
+
all_corners = np.concatenate([all_corners, nm_pred_corners[:(max_corner_num - len(gt_corners)), :]], axis=0)
|
131 |
+
|
132 |
+
processed_corners, edges, edge_ids, labels = _get_edges(all_corners, edge_pairs)
|
133 |
+
|
134 |
+
return processed_corners, edges, labels
|
135 |
+
|
136 |
+
|
137 |
+
def _get_edges(corners, edge_pairs):
|
138 |
+
ind = np.lexsort(corners.T)
|
139 |
+
corners = corners[ind] # sorted by y, then x
|
140 |
+
corners = corners.round()
|
141 |
+
id_mapping = {old: new for new, old in enumerate(ind)}
|
142 |
+
|
143 |
+
all_ids = all_combibations[len(corners)]
|
144 |
+
edges = corners[all_ids]
|
145 |
+
labels = np.zeros(edges.shape[0])
|
146 |
+
|
147 |
+
N = len(corners)
|
148 |
+
edge_pairs = [(id_mapping[p[0]], id_mapping[p[1]]) for p in edge_pairs]
|
149 |
+
edge_pairs = [p for p in edge_pairs if p[0] < p[1]]
|
150 |
+
pos_ids = [int((2 * N - 1 - p[0]) * p[0] / 2 + p[1] - p[0] - 1) for p in edge_pairs]
|
151 |
+
labels[pos_ids] = 1
|
152 |
+
|
153 |
+
edge_ids = np.array(all_ids)
|
154 |
+
return corners, edges, edge_ids, labels
|
155 |
+
|
156 |
+
|
157 |
+
def collate_edge_info(data):
|
158 |
+
batched_data = {}
|
159 |
+
lengths_info = {}
|
160 |
+
for field in data.keys():
|
161 |
+
batch_values = data[field]
|
162 |
+
all_lens = [len(value) for value in batch_values]
|
163 |
+
max_len = max(all_lens)
|
164 |
+
pad_value = 0
|
165 |
+
batch_values = [pad_sequence(value, max_len, pad_value) for value in batch_values]
|
166 |
+
batch_values = np.stack(batch_values, axis=0)
|
167 |
+
|
168 |
+
if field in ['edge_coords', 'edge_labels', 'gt_values']:
|
169 |
+
batch_values = torch.Tensor(batch_values).long()
|
170 |
+
if field in ['processed_corners', 'edge_coords']:
|
171 |
+
lengths_info[field] = all_lens
|
172 |
+
batched_data[field] = batch_values
|
173 |
+
|
174 |
+
# Add length and mask into the data, the mask if for Transformers' input format, True means padding
|
175 |
+
for field, lengths in lengths_info.items():
|
176 |
+
lengths_str = field + '_lengths'
|
177 |
+
batched_data[lengths_str] = torch.Tensor(lengths).long()
|
178 |
+
mask = torch.arange(max(lengths))
|
179 |
+
mask = mask.unsqueeze(0).repeat(batched_data[field].shape[0], 1)
|
180 |
+
mask = mask >= batched_data[lengths_str].unsqueeze(-1)
|
181 |
+
mask_str = field + '_mask'
|
182 |
+
batched_data[mask_str] = mask
|
183 |
+
|
184 |
+
return batched_data
|
185 |
+
|
186 |
+
|
187 |
+
def pad_sequence(seq, length, pad_value=0):
|
188 |
+
if len(seq) == length:
|
189 |
+
return seq
|
190 |
+
else:
|
191 |
+
pad_len = length - len(seq)
|
192 |
+
if len(seq.shape) == 1:
|
193 |
+
if pad_value == 0:
|
194 |
+
paddings = np.zeros([pad_len, ])
|
195 |
+
else:
|
196 |
+
paddings = np.ones([pad_len, ]) * pad_value
|
197 |
+
else:
|
198 |
+
if pad_value == 0:
|
199 |
+
paddings = np.zeros([pad_len, ] + list(seq.shape[1:]))
|
200 |
+
else:
|
201 |
+
paddings = np.ones([pad_len, ] + list(seq.shape[1:])) * pad_value
|
202 |
+
padded_seq = np.concatenate([seq, paddings], axis=0)
|
203 |
+
return padded_seq
|
204 |
+
|
205 |
+
|
206 |
+
def get_infer_edge_pairs(corners, confs):
|
207 |
+
ind = np.lexsort(corners.T)
|
208 |
+
corners = corners[ind] # sorted by y, then x
|
209 |
+
confs = confs[ind]
|
210 |
+
|
211 |
+
edge_ids = all_combibations[len(corners)]
|
212 |
+
edge_coords = corners[edge_ids]
|
213 |
+
|
214 |
+
edge_coords = torch.tensor(np.array(edge_coords)).unsqueeze(0).long()
|
215 |
+
mask = torch.zeros([edge_coords.shape[0], edge_coords.shape[1]]).bool()
|
216 |
+
edge_ids = torch.tensor(np.array(edge_ids))
|
217 |
+
return corners, confs, edge_coords, mask, edge_ids
|
218 |
+
|
219 |
+
|
220 |
+
def _visualize_edge_training_data(corners, edges, edge_labels, image, save_path):
|
221 |
+
image = image.transpose([1, 2, 0])
|
222 |
+
image = (image * 255).astype(np.uint8)
|
223 |
+
image = np.ascontiguousarray(image)
|
224 |
+
|
225 |
+
for edge, label in zip(edges, edge_labels):
|
226 |
+
if label == 1:
|
227 |
+
cv2.line(image, tuple(edge[0].astype(np.int)), tuple(edge[1].astype(np.int)), (255, 255, 0), 2)
|
228 |
+
|
229 |
+
for c in corners:
|
230 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
|
231 |
+
|
232 |
+
cv2.imwrite(save_path, image)
|
models/deformable_transformer.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
from torch import nn, Tensor
|
4 |
+
from models.ops.modules import MSDeformAttn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class DeformableTransformerEncoderLayer(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
d_model=256, d_ffn=1024,
|
11 |
+
dropout=0.1, activation="relu",
|
12 |
+
n_levels=4, n_heads=8, n_points=4):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# self attention
|
16 |
+
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
17 |
+
self.dropout1 = nn.Dropout(dropout)
|
18 |
+
self.norm1 = nn.LayerNorm(d_model)
|
19 |
+
|
20 |
+
# ffn
|
21 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
22 |
+
self.activation = _get_activation_fn(activation)
|
23 |
+
self.dropout2 = nn.Dropout(dropout)
|
24 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
25 |
+
self.dropout3 = nn.Dropout(dropout)
|
26 |
+
self.norm2 = nn.LayerNorm(d_model)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def with_pos_embed(tensor, pos):
|
30 |
+
return tensor if pos is None else tensor + pos
|
31 |
+
|
32 |
+
def forward_ffn(self, src):
|
33 |
+
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
34 |
+
src = src + self.dropout3(src2)
|
35 |
+
src = self.norm2(src)
|
36 |
+
return src
|
37 |
+
|
38 |
+
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
|
39 |
+
# self attention
|
40 |
+
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index,
|
41 |
+
padding_mask)
|
42 |
+
src = src + self.dropout1(src2)
|
43 |
+
src = self.norm1(src)
|
44 |
+
|
45 |
+
# ffn
|
46 |
+
src = self.forward_ffn(src)
|
47 |
+
|
48 |
+
return src
|
49 |
+
|
50 |
+
|
51 |
+
class DeformableTransformerEncoder(nn.Module):
|
52 |
+
def __init__(self, encoder_layer, num_layers):
|
53 |
+
super().__init__()
|
54 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
55 |
+
self.num_layers = num_layers
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
59 |
+
reference_points_list = []
|
60 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
61 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
62 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
|
63 |
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
64 |
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
65 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
66 |
+
reference_points_list.append(ref)
|
67 |
+
reference_points = torch.cat(reference_points_list, 1)
|
68 |
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
69 |
+
return reference_points
|
70 |
+
|
71 |
+
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
|
72 |
+
output = src
|
73 |
+
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
|
74 |
+
for _, layer in enumerate(self.layers):
|
75 |
+
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
|
76 |
+
|
77 |
+
return output
|
78 |
+
|
79 |
+
|
80 |
+
class DeformableAttnDecoderLayer(nn.Module):
|
81 |
+
def __init__(self, d_model=256, d_ffn=1024,
|
82 |
+
dropout=0.1, activation="relu",
|
83 |
+
n_levels=4, n_heads=8, n_points=4):
|
84 |
+
super().__init__()
|
85 |
+
# cross attention
|
86 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
87 |
+
self.dropout1 = nn.Dropout(dropout)
|
88 |
+
self.norm1 = nn.LayerNorm(d_model)
|
89 |
+
|
90 |
+
# ffn
|
91 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
92 |
+
self.activation = _get_activation_fn(activation)
|
93 |
+
self.dropout3 = nn.Dropout(dropout)
|
94 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
95 |
+
self.dropout4 = nn.Dropout(dropout)
|
96 |
+
self.norm3 = nn.LayerNorm(d_model)
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def with_pos_embed(tensor, pos):
|
100 |
+
return tensor if pos is None else tensor + pos
|
101 |
+
|
102 |
+
def forward_ffn(self, tgt):
|
103 |
+
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
104 |
+
tgt = tgt + self.dropout4(tgt2)
|
105 |
+
tgt = self.norm3(tgt)
|
106 |
+
return tgt
|
107 |
+
|
108 |
+
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index,
|
109 |
+
src_padding_mask=None,
|
110 |
+
key_padding_mask=None):
|
111 |
+
# cross attention
|
112 |
+
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
|
113 |
+
reference_points,
|
114 |
+
src, src_spatial_shapes, level_start_index, src_padding_mask)
|
115 |
+
tgt = tgt + self.dropout1(tgt2)
|
116 |
+
tgt = self.norm1(tgt)
|
117 |
+
|
118 |
+
# ffn
|
119 |
+
tgt = self.forward_ffn(tgt)
|
120 |
+
|
121 |
+
return tgt
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
126 |
+
def __init__(self, d_model=256, d_ffn=1024,
|
127 |
+
dropout=0.1, activation="relu",
|
128 |
+
n_levels=4, n_heads=8, n_points=4):
|
129 |
+
super().__init__()
|
130 |
+
# cross attention
|
131 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
132 |
+
self.dropout1 = nn.Dropout(dropout)
|
133 |
+
self.norm1 = nn.LayerNorm(d_model)
|
134 |
+
|
135 |
+
# self attention
|
136 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
137 |
+
self.dropout2 = nn.Dropout(dropout)
|
138 |
+
self.norm2 = nn.LayerNorm(d_model)
|
139 |
+
|
140 |
+
# ffn
|
141 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
142 |
+
self.activation = _get_activation_fn(activation)
|
143 |
+
self.dropout3 = nn.Dropout(dropout)
|
144 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
145 |
+
self.dropout4 = nn.Dropout(dropout)
|
146 |
+
self.norm3 = nn.LayerNorm(d_model)
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def with_pos_embed(tensor, pos):
|
150 |
+
return tensor if pos is None else tensor + pos
|
151 |
+
|
152 |
+
def forward_ffn(self, tgt):
|
153 |
+
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
154 |
+
tgt = tgt + self.dropout4(tgt2)
|
155 |
+
tgt = self.norm3(tgt)
|
156 |
+
return tgt
|
157 |
+
|
158 |
+
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index,
|
159 |
+
src_padding_mask=None,
|
160 |
+
key_padding_mask=None,
|
161 |
+
get_image_feat=True):
|
162 |
+
# self attention
|
163 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
164 |
+
tgt2 = \
|
165 |
+
self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), key_padding_mask=key_padding_mask)[
|
166 |
+
0].transpose(0, 1)
|
167 |
+
tgt = tgt + self.dropout2(tgt2)
|
168 |
+
tgt = self.norm2(tgt)
|
169 |
+
|
170 |
+
if get_image_feat:
|
171 |
+
# cross attention
|
172 |
+
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
|
173 |
+
reference_points,
|
174 |
+
src, src_spatial_shapes, level_start_index, src_padding_mask)
|
175 |
+
tgt = tgt + self.dropout1(tgt2)
|
176 |
+
tgt = self.norm1(tgt)
|
177 |
+
|
178 |
+
# ffn
|
179 |
+
tgt = self.forward_ffn(tgt)
|
180 |
+
|
181 |
+
return tgt
|
182 |
+
|
183 |
+
|
184 |
+
class DeformableTransformerDecoder(nn.Module):
|
185 |
+
def __init__(self, decoder_layer, num_layers, return_intermediate=False, with_sa=True):
|
186 |
+
super().__init__()
|
187 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
188 |
+
self.num_layers = num_layers
|
189 |
+
self.return_intermediate = return_intermediate
|
190 |
+
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
191 |
+
self.with_sa = with_sa
|
192 |
+
|
193 |
+
def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
|
194 |
+
query_pos=None, src_padding_mask=None, key_padding_mask=None, get_image_feat=True):
|
195 |
+
output = tgt
|
196 |
+
|
197 |
+
intermediate = []
|
198 |
+
intermediate_reference_points = []
|
199 |
+
for lid, layer in enumerate(self.layers):
|
200 |
+
if reference_points.shape[-1] == 4:
|
201 |
+
reference_points_input = reference_points[:, :, None] \
|
202 |
+
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
|
203 |
+
else:
|
204 |
+
assert reference_points.shape[-1] == 2
|
205 |
+
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
|
206 |
+
if self.with_sa:
|
207 |
+
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index,
|
208 |
+
src_padding_mask, key_padding_mask, get_image_feat)
|
209 |
+
else:
|
210 |
+
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes,
|
211 |
+
src_level_start_index,
|
212 |
+
src_padding_mask, key_padding_mask)
|
213 |
+
|
214 |
+
if self.return_intermediate:
|
215 |
+
intermediate.append(output)
|
216 |
+
intermediate_reference_points.append(reference_points)
|
217 |
+
|
218 |
+
if self.return_intermediate:
|
219 |
+
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
|
220 |
+
|
221 |
+
return output, reference_points
|
222 |
+
|
223 |
+
|
224 |
+
def _get_clones(module, N):
|
225 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
226 |
+
|
227 |
+
|
228 |
+
def _get_activation_fn(activation):
|
229 |
+
"""Return an activation function given a string"""
|
230 |
+
if activation == "relu":
|
231 |
+
return F.relu
|
232 |
+
if activation == "gelu":
|
233 |
+
return F.gelu
|
234 |
+
if activation == "glu":
|
235 |
+
return F.glu
|
236 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
models/edge_models.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from models.mlp import MLP
|
6 |
+
from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \
|
7 |
+
DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer
|
8 |
+
from models.ops.modules import MSDeformAttn
|
9 |
+
from models.corner_models import PositionEmbeddingSine
|
10 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from utils.misc import NestedTensor
|
13 |
+
|
14 |
+
|
15 |
+
class HeatEdge(nn.Module):
|
16 |
+
def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ):
|
17 |
+
super(HeatEdge, self).__init__()
|
18 |
+
self.input_dim = input_dim
|
19 |
+
self.hidden_dim = hidden_dim
|
20 |
+
self.num_feature_levels = num_feature_levels
|
21 |
+
|
22 |
+
if num_feature_levels > 1:
|
23 |
+
num_backbone_outs = len(backbone_strides)
|
24 |
+
input_proj_list = []
|
25 |
+
for _ in range(num_backbone_outs):
|
26 |
+
in_channels = backbone_num_channels[_]
|
27 |
+
input_proj_list.append(nn.Sequential(
|
28 |
+
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
|
29 |
+
nn.GroupNorm(32, hidden_dim),
|
30 |
+
))
|
31 |
+
for _ in range(num_feature_levels - num_backbone_outs):
|
32 |
+
input_proj_list.append(nn.Sequential(
|
33 |
+
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
|
34 |
+
nn.GroupNorm(32, hidden_dim),
|
35 |
+
))
|
36 |
+
in_channels = hidden_dim
|
37 |
+
self.input_proj = nn.ModuleList(input_proj_list)
|
38 |
+
else:
|
39 |
+
self.input_proj = nn.ModuleList([
|
40 |
+
nn.Sequential(
|
41 |
+
nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1),
|
42 |
+
nn.GroupNorm(32, hidden_dim),
|
43 |
+
)])
|
44 |
+
|
45 |
+
self.img_pos = PositionEmbeddingSine(hidden_dim // 2)
|
46 |
+
|
47 |
+
self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim)
|
48 |
+
self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2)
|
49 |
+
|
50 |
+
self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1,
|
51 |
+
num_decoder_layers=6, dim_feedforward=1024, dropout=0.1)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def get_ms_feat(xs, img_mask):
|
55 |
+
out: Dict[str, NestedTensor] = {}
|
56 |
+
for name, x in sorted(xs.items()):
|
57 |
+
m = img_mask
|
58 |
+
assert m is not None
|
59 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
60 |
+
out[name] = NestedTensor(x, mask)
|
61 |
+
return out
|
62 |
+
|
63 |
+
def forward(self, image_feats, feat_mask, corner_outputs, edge_coords, edge_masks, gt_values, corner_nums,
|
64 |
+
max_candidates, do_inference=False):
|
65 |
+
# Prepare ConvNet features
|
66 |
+
features = self.get_ms_feat(image_feats, feat_mask)
|
67 |
+
|
68 |
+
srcs = []
|
69 |
+
masks = []
|
70 |
+
all_pos = []
|
71 |
+
|
72 |
+
new_features = list()
|
73 |
+
for name, x in sorted(features.items()):
|
74 |
+
new_features.append(x)
|
75 |
+
features = new_features
|
76 |
+
|
77 |
+
for l, feat in enumerate(features):
|
78 |
+
src, mask = feat.decompose()
|
79 |
+
mask = mask.to(src.device)
|
80 |
+
srcs.append(self.input_proj[l](src))
|
81 |
+
pos = self.img_pos(src).to(src.dtype)
|
82 |
+
all_pos.append(pos)
|
83 |
+
masks.append(mask)
|
84 |
+
assert mask is not None
|
85 |
+
|
86 |
+
if self.num_feature_levels > len(srcs):
|
87 |
+
_len_srcs = len(srcs)
|
88 |
+
for l in range(_len_srcs, self.num_feature_levels):
|
89 |
+
if l == _len_srcs:
|
90 |
+
src = self.input_proj[l](features[-1].tensors)
|
91 |
+
else:
|
92 |
+
src = self.input_proj[l](srcs[-1])
|
93 |
+
m = feat_mask
|
94 |
+
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device)
|
95 |
+
pos_l = self.img_pos(src).to(src.dtype)
|
96 |
+
srcs.append(src)
|
97 |
+
masks.append(mask)
|
98 |
+
all_pos.append(pos_l)
|
99 |
+
|
100 |
+
bs = edge_masks.size(0)
|
101 |
+
num_edges = edge_masks.size(1)
|
102 |
+
|
103 |
+
corner_feats = corner_outputs
|
104 |
+
edge_feats = list()
|
105 |
+
for b_i in range(bs):
|
106 |
+
feats = corner_feats[b_i, edge_coords[b_i, :, :, 1], edge_coords[b_i, :, :, 0], :]
|
107 |
+
edge_feats.append(feats)
|
108 |
+
edge_feats = torch.stack(edge_feats, dim=0)
|
109 |
+
edge_feats = edge_feats.view(bs, num_edges, -1)
|
110 |
+
|
111 |
+
edge_inputs = self.edge_input_fc(edge_feats.view(bs * num_edges, -1))
|
112 |
+
edge_inputs = edge_inputs.view(bs, num_edges, -1)
|
113 |
+
|
114 |
+
edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2
|
115 |
+
edge_center = edge_center / feat_mask.shape[1]
|
116 |
+
|
117 |
+
logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values = self.transformer(srcs,
|
118 |
+
masks,
|
119 |
+
all_pos,
|
120 |
+
edge_inputs,
|
121 |
+
edge_center,
|
122 |
+
gt_values,
|
123 |
+
edge_masks,
|
124 |
+
corner_nums,
|
125 |
+
max_candidates,
|
126 |
+
do_inference)
|
127 |
+
|
128 |
+
return logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values
|
129 |
+
|
130 |
+
|
131 |
+
class EdgeTransformer(nn.Module):
|
132 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
133 |
+
num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
|
134 |
+
activation="relu", return_intermediate_dec=False,
|
135 |
+
num_feature_levels=4, dec_n_points=4, enc_n_points=4,
|
136 |
+
):
|
137 |
+
super(EdgeTransformer, self).__init__()
|
138 |
+
|
139 |
+
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
|
140 |
+
dropout, activation,
|
141 |
+
num_feature_levels, nhead, enc_n_points)
|
142 |
+
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
|
143 |
+
|
144 |
+
decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward,
|
145 |
+
dropout, activation,
|
146 |
+
num_feature_levels, nhead, dec_n_points)
|
147 |
+
# one-layer decoder, without self-attention layers
|
148 |
+
self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False)
|
149 |
+
|
150 |
+
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
|
151 |
+
dropout, activation,
|
152 |
+
num_feature_levels, nhead, dec_n_points)
|
153 |
+
|
154 |
+
# edge decoder w/ self-attention layers (image-aware decoder and geom-only decoder)
|
155 |
+
self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers,
|
156 |
+
return_intermediate_dec, with_sa=True)
|
157 |
+
|
158 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
159 |
+
|
160 |
+
self.gt_label_embed = nn.Embedding(3, d_model)
|
161 |
+
|
162 |
+
self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
|
163 |
+
self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2)
|
164 |
+
|
165 |
+
self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
|
166 |
+
self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
|
167 |
+
self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2)
|
168 |
+
self._reset_parameters()
|
169 |
+
|
170 |
+
def _reset_parameters(self):
|
171 |
+
for p in self.parameters():
|
172 |
+
if p.dim() > 1:
|
173 |
+
nn.init.xavier_uniform_(p)
|
174 |
+
for m in self.modules():
|
175 |
+
if isinstance(m, MSDeformAttn):
|
176 |
+
m._reset_parameters()
|
177 |
+
normal_(self.level_embed)
|
178 |
+
|
179 |
+
def get_valid_ratio(self, mask):
|
180 |
+
_, H, W = mask.shape
|
181 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
182 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
183 |
+
valid_ratio_h = valid_H.float() / H
|
184 |
+
valid_ratio_w = valid_W.float() / W
|
185 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
186 |
+
return valid_ratio
|
187 |
+
|
188 |
+
def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, labels, key_padding_mask, corner_nums,
|
189 |
+
max_candidates, do_inference=False):
|
190 |
+
# prepare input for encoder
|
191 |
+
src_flatten = []
|
192 |
+
mask_flatten = []
|
193 |
+
lvl_pos_embed_flatten = []
|
194 |
+
spatial_shapes = []
|
195 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
196 |
+
bs, c, h, w = src.shape
|
197 |
+
spatial_shape = (h, w)
|
198 |
+
spatial_shapes.append(spatial_shape)
|
199 |
+
src = src.flatten(2).transpose(1, 2)
|
200 |
+
mask = mask.flatten(1)
|
201 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
202 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
203 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
204 |
+
src_flatten.append(src)
|
205 |
+
mask_flatten.append(mask)
|
206 |
+
src_flatten = torch.cat(src_flatten, 1)
|
207 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
208 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
209 |
+
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
210 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
211 |
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
212 |
+
|
213 |
+
# encoder
|
214 |
+
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten,
|
215 |
+
mask_flatten)
|
216 |
+
|
217 |
+
# prepare input for decoder
|
218 |
+
bs, _, c = memory.shape
|
219 |
+
|
220 |
+
tgt = query_embed
|
221 |
+
|
222 |
+
# per-edge filtering with single-layer decoder (no self-attn)
|
223 |
+
hs_per_edge, _ = self.per_edge_decoder(tgt, reference_points, memory,
|
224 |
+
spatial_shapes, level_start_index, valid_ratios, query_embed,
|
225 |
+
mask_flatten)
|
226 |
+
logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1)
|
227 |
+
filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids = self.candidate_filtering(
|
228 |
+
logits_per_edge,
|
229 |
+
hs_per_edge, query_embed, reference_points,
|
230 |
+
labels,
|
231 |
+
key_padding_mask, corner_nums, max_candidates)
|
232 |
+
|
233 |
+
# generate the info for masked training
|
234 |
+
if not do_inference:
|
235 |
+
filtered_gt_values = self.generate_gt_masking(filtered_labels, filtered_mask)
|
236 |
+
else:
|
237 |
+
filtered_gt_values = filtered_labels
|
238 |
+
gt_info = self.gt_label_embed(filtered_gt_values)
|
239 |
+
|
240 |
+
# relational decoder with image feature (image-aware decoder)
|
241 |
+
hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1))
|
242 |
+
|
243 |
+
hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory,
|
244 |
+
spatial_shapes, level_start_index, valid_ratios, filtered_query,
|
245 |
+
mask_flatten,
|
246 |
+
key_padding_mask=filtered_mask, get_image_feat=True)
|
247 |
+
|
248 |
+
logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1)
|
249 |
+
|
250 |
+
# relational decoder without image feature (geom-only decoder)
|
251 |
+
rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1))
|
252 |
+
|
253 |
+
hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory,
|
254 |
+
spatial_shapes, level_start_index, valid_ratios, filtered_query,
|
255 |
+
mask_flatten,
|
256 |
+
key_padding_mask=filtered_mask, get_image_feat=False)
|
257 |
+
|
258 |
+
logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1)
|
259 |
+
|
260 |
+
return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values
|
261 |
+
|
262 |
+
@staticmethod
|
263 |
+
def candidate_filtering(logits, hs, query, rp, labels, key_padding_mask, corner_nums, max_candidates):
|
264 |
+
"""
|
265 |
+
Filter out the easy-negatives from the edge candidates, and update the edge information correspondingly
|
266 |
+
"""
|
267 |
+
B, L, _ = hs.shape
|
268 |
+
preds = logits.detach().softmax(1)[:, 1, :] # BxL
|
269 |
+
preds[key_padding_mask == True] = -1 # ignore the masking parts
|
270 |
+
sorted_ids = torch.argsort(preds, dim=-1, descending=True)
|
271 |
+
filtered_hs = list()
|
272 |
+
filtered_mask = list()
|
273 |
+
filtered_query = list()
|
274 |
+
filtered_rp = list()
|
275 |
+
filtered_labels = list()
|
276 |
+
selected_ids = list()
|
277 |
+
for b_i in range(B):
|
278 |
+
num_candidates = corner_nums[b_i] * 3
|
279 |
+
ids = sorted_ids[b_i, :max_candidates[b_i]]
|
280 |
+
filtered_hs.append(hs[b_i][ids])
|
281 |
+
new_mask = key_padding_mask[b_i][ids]
|
282 |
+
new_mask[num_candidates:] = True
|
283 |
+
filtered_mask.append(new_mask)
|
284 |
+
filtered_query.append(query[b_i][ids])
|
285 |
+
filtered_rp.append(rp[b_i][ids])
|
286 |
+
filtered_labels.append(labels[b_i][ids])
|
287 |
+
selected_ids.append(ids)
|
288 |
+
filtered_hs = torch.stack(filtered_hs, dim=0)
|
289 |
+
filtered_mask = torch.stack(filtered_mask, dim=0)
|
290 |
+
filtered_query = torch.stack(filtered_query, dim=0)
|
291 |
+
filtered_rp = torch.stack(filtered_rp, dim=0)
|
292 |
+
filtered_labels = torch.stack(filtered_labels, dim=0)
|
293 |
+
selected_ids = torch.stack(selected_ids, dim=0)
|
294 |
+
|
295 |
+
return filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids
|
296 |
+
|
297 |
+
@staticmethod
|
298 |
+
def generate_gt_masking(labels, mask):
|
299 |
+
"""
|
300 |
+
Generate the info for masked training on-the-fly with ratio=0.5
|
301 |
+
"""
|
302 |
+
bs = labels.shape[0]
|
303 |
+
gt_values = torch.zeros_like(mask).long()
|
304 |
+
for b_i in range(bs):
|
305 |
+
edge_length = (mask[b_i] == 0).sum()
|
306 |
+
rand_ratio = np.random.rand() * 0.5 + 0.5
|
307 |
+
gt_rand = torch.rand(edge_length)
|
308 |
+
gt_flag = torch.zeros(edge_length)
|
309 |
+
gt_flag[torch.where(gt_rand >= rand_ratio)] = 1
|
310 |
+
gt_idx = torch.where(gt_flag == 1)
|
311 |
+
pred_idx = torch.where(gt_flag == 0)
|
312 |
+
gt_values[b_i, gt_idx[0]] = labels[b_i, gt_idx[0]]
|
313 |
+
gt_values[b_i, pred_idx[0]] = 2 # use 2 to represent unknown value, need to predict
|
314 |
+
return gt_values
|
models/loss.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from utils.geometry_utils import edge_acc
|
5 |
+
|
6 |
+
|
7 |
+
class CornerCriterion(nn.Module):
|
8 |
+
def __init__(self, image_size):
|
9 |
+
super().__init__()
|
10 |
+
self.loss_rate = 9
|
11 |
+
|
12 |
+
def forward(self, outputs_s1, targets, gauss_targets, epoch=0):
|
13 |
+
# Compute the acc first, use the acc to guide the setup of loss weight
|
14 |
+
preds_s1 = (outputs_s1 >= 0.5).float()
|
15 |
+
pos_target_ids = torch.where(targets == 1)
|
16 |
+
correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum()
|
17 |
+
recall_s1 = correct / len(pos_target_ids[0])
|
18 |
+
|
19 |
+
rate = self.loss_rate
|
20 |
+
|
21 |
+
loss_weight = (gauss_targets > 0.5).float() * rate + 1
|
22 |
+
loss_s1 = F.binary_cross_entropy(outputs_s1, gauss_targets, weight=loss_weight, reduction='none')
|
23 |
+
loss_s1 = loss_s1.sum(-1).sum(-1).mean()
|
24 |
+
|
25 |
+
return loss_s1, recall_s1
|
26 |
+
|
27 |
+
|
28 |
+
class EdgeCriterion(nn.Module):
|
29 |
+
def __init__(self):
|
30 |
+
super().__init__()
|
31 |
+
self.edge_loss = nn.CrossEntropyLoss(weight=torch.tensor([0.33, 1.0]).cuda(), reduction='none')
|
32 |
+
|
33 |
+
def forward(self, logits_s1, logits_s2_hybrid, logits_s2_rel, s2_ids, s2_edge_mask, edge_labels, edge_lengths,
|
34 |
+
edge_mask, s2_gt_values):
|
35 |
+
# loss for edge filtering
|
36 |
+
s1_losses = self.edge_loss(logits_s1, edge_labels)
|
37 |
+
s1_losses[torch.where(edge_mask == True)] = 0
|
38 |
+
s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0]
|
39 |
+
gt_values = torch.ones_like(edge_mask).long() * 2
|
40 |
+
s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values)
|
41 |
+
|
42 |
+
# loss for stage-2
|
43 |
+
s2_labels = torch.gather(edge_labels, 1, s2_ids)
|
44 |
+
|
45 |
+
# the image-aware decoder
|
46 |
+
s2_losses_hybrid = self.edge_loss(logits_s2_hybrid, s2_labels)
|
47 |
+
s2_losses_hybrid[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
|
48 |
+
# aggregate the loss into the final scalar
|
49 |
+
s2_losses_hybrid = s2_losses_hybrid[torch.where(s2_losses_hybrid > 0)].sum() / s2_edge_mask.shape[0]
|
50 |
+
s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
|
51 |
+
# compute edge-level acc
|
52 |
+
s2_acc_hybrid = edge_acc(logits_s2_hybrid, s2_labels, s2_edge_lengths, s2_gt_values)
|
53 |
+
|
54 |
+
# the geom-only decoder
|
55 |
+
s2_losses_rel = self.edge_loss(logits_s2_rel, s2_labels)
|
56 |
+
s2_losses_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0
|
57 |
+
# aggregate the loss into the final scalar
|
58 |
+
s2_losses_rel = s2_losses_rel[torch.where(s2_losses_rel > 0)].sum() / s2_edge_mask.shape[0]
|
59 |
+
s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1)
|
60 |
+
# compute edge-level f1-score
|
61 |
+
s2_acc_rel = edge_acc(logits_s2_rel, s2_labels, s2_edge_lengths, s2_gt_values)
|
62 |
+
|
63 |
+
return s1_losses, s1_acc, s2_losses_hybrid, s2_acc_hybrid, s2_losses_rel, s2_acc_rel
|
models/mlp.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class MLP(nn.Module):
|
6 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
7 |
+
|
8 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
9 |
+
super(MLP, self).__init__()
|
10 |
+
self.output_dim = output_dim
|
11 |
+
self.num_layers = num_layers
|
12 |
+
h = [hidden_dim] * (num_layers - 1)
|
13 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
B, N, D = x.size()
|
17 |
+
x = x.reshape(B*N, D)
|
18 |
+
for i, layer in enumerate(self.layers):
|
19 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
20 |
+
x = x.view(B, N, self.output_dim)
|
21 |
+
return x
|
models/ops/functions/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from .ms_deform_attn_func import MSDeformAttnFunction
|
10 |
+
|
models/ops/functions/ms_deform_attn_func.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from __future__ import absolute_import
|
10 |
+
from __future__ import print_function
|
11 |
+
from __future__ import division
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.autograd import Function
|
16 |
+
from torch.autograd.function import once_differentiable
|
17 |
+
|
18 |
+
import MultiScaleDeformableAttention as MSDA
|
19 |
+
|
20 |
+
|
21 |
+
class MSDeformAttnFunction(Function):
|
22 |
+
@staticmethod
|
23 |
+
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
|
24 |
+
ctx.im2col_step = im2col_step
|
25 |
+
output = MSDA.ms_deform_attn_forward(
|
26 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
|
27 |
+
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
|
28 |
+
return output
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
@once_differentiable
|
32 |
+
def backward(ctx, grad_output):
|
33 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
34 |
+
grad_value, grad_sampling_loc, grad_attn_weight = \
|
35 |
+
MSDA.ms_deform_attn_backward(
|
36 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
|
37 |
+
|
38 |
+
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
39 |
+
|
40 |
+
|
41 |
+
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
42 |
+
# for debug and test only,
|
43 |
+
# need to use cuda version instead
|
44 |
+
N_, S_, M_, D_ = value.shape
|
45 |
+
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
46 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
47 |
+
sampling_grids = 2 * sampling_locations - 1
|
48 |
+
sampling_value_list = []
|
49 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
50 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
51 |
+
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
52 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
53 |
+
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
54 |
+
# N_*M_, D_, Lq_, P_
|
55 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
56 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
57 |
+
sampling_value_list.append(sampling_value_l_)
|
58 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
59 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
|
60 |
+
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
61 |
+
return output.transpose(1, 2).contiguous()
|
models/ops/make.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# ------------------------------------------------------------------------------------------------
|
3 |
+
# Deformable DETR
|
4 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------------------------------
|
7 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
# ------------------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
python3 setup.py build install --user
|
models/ops/modules/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from .ms_deform_attn import MSDeformAttn
|
models/ops/modules/ms_deform_attn.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from __future__ import absolute_import
|
10 |
+
from __future__ import print_function
|
11 |
+
from __future__ import division
|
12 |
+
|
13 |
+
import warnings
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.nn.init import xavier_uniform_, constant_
|
20 |
+
|
21 |
+
from ..functions import MSDeformAttnFunction
|
22 |
+
|
23 |
+
|
24 |
+
def _is_power_of_2(n):
|
25 |
+
if (not isinstance(n, int)) or (n < 0):
|
26 |
+
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
27 |
+
return (n & (n-1) == 0) and n != 0
|
28 |
+
|
29 |
+
|
30 |
+
class MSDeformAttn(nn.Module):
|
31 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
32 |
+
"""
|
33 |
+
Multi-Scale Deformable Attention Module
|
34 |
+
:param d_model hidden dimension
|
35 |
+
:param n_levels number of feature levels
|
36 |
+
:param n_heads number of attention heads
|
37 |
+
:param n_points number of sampling points per attention head per feature level
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
if d_model % n_heads != 0:
|
41 |
+
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
42 |
+
_d_per_head = d_model // n_heads
|
43 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
44 |
+
if not _is_power_of_2(_d_per_head):
|
45 |
+
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
46 |
+
"which is more efficient in our CUDA implementation.")
|
47 |
+
|
48 |
+
self.im2col_step = 64
|
49 |
+
|
50 |
+
self.d_model = d_model
|
51 |
+
self.n_levels = n_levels
|
52 |
+
self.n_heads = n_heads
|
53 |
+
self.n_points = n_points
|
54 |
+
|
55 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
56 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
57 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
58 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
59 |
+
|
60 |
+
self._reset_parameters()
|
61 |
+
|
62 |
+
def _reset_parameters(self):
|
63 |
+
constant_(self.sampling_offsets.weight.data, 0.)
|
64 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
65 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
66 |
+
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
|
67 |
+
for i in range(self.n_points):
|
68 |
+
grid_init[:, :, i, :] *= i + 1
|
69 |
+
with torch.no_grad():
|
70 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
71 |
+
constant_(self.attention_weights.weight.data, 0.)
|
72 |
+
constant_(self.attention_weights.bias.data, 0.)
|
73 |
+
xavier_uniform_(self.value_proj.weight.data)
|
74 |
+
constant_(self.value_proj.bias.data, 0.)
|
75 |
+
xavier_uniform_(self.output_proj.weight.data)
|
76 |
+
constant_(self.output_proj.bias.data, 0.)
|
77 |
+
|
78 |
+
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
|
79 |
+
"""
|
80 |
+
:param query (N, Length_{query}, C)
|
81 |
+
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
82 |
+
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
83 |
+
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
84 |
+
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
85 |
+
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
86 |
+
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
|
87 |
+
|
88 |
+
:return output (N, Length_{query}, C)
|
89 |
+
"""
|
90 |
+
N, Len_q, _ = query.shape
|
91 |
+
N, Len_in, _ = input_flatten.shape
|
92 |
+
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
93 |
+
|
94 |
+
value = self.value_proj(input_flatten)
|
95 |
+
if input_padding_mask is not None:
|
96 |
+
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
97 |
+
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
98 |
+
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
99 |
+
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
|
100 |
+
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
|
101 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
102 |
+
if reference_points.shape[-1] == 2:
|
103 |
+
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
104 |
+
sampling_locations = reference_points[:, :, None, :, None, :] \
|
105 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
106 |
+
elif reference_points.shape[-1] == 4:
|
107 |
+
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
108 |
+
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
109 |
+
else:
|
110 |
+
raise ValueError(
|
111 |
+
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
|
112 |
+
output = MSDeformAttnFunction.apply(
|
113 |
+
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
|
114 |
+
output = self.output_proj(output)
|
115 |
+
return output
|
models/ops/setup.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
15 |
+
from torch.utils.cpp_extension import CppExtension
|
16 |
+
from torch.utils.cpp_extension import CUDAExtension
|
17 |
+
|
18 |
+
from setuptools import find_packages
|
19 |
+
from setuptools import setup
|
20 |
+
|
21 |
+
requirements = ["torch", "torchvision"]
|
22 |
+
|
23 |
+
def get_extensions():
|
24 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
extensions_dir = os.path.join(this_dir, "src")
|
26 |
+
|
27 |
+
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
28 |
+
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
|
29 |
+
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
|
30 |
+
|
31 |
+
sources = main_file + source_cpu
|
32 |
+
extension = CppExtension
|
33 |
+
extra_compile_args = {"cxx": []}
|
34 |
+
define_macros = []
|
35 |
+
|
36 |
+
if torch.cuda.is_available() and CUDA_HOME is not None:
|
37 |
+
extension = CUDAExtension
|
38 |
+
sources += source_cuda
|
39 |
+
define_macros += [("WITH_CUDA", None)]
|
40 |
+
extra_compile_args["nvcc"] = [
|
41 |
+
"-DCUDA_HAS_FP16=1",
|
42 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
43 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
44 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
45 |
+
]
|
46 |
+
else:
|
47 |
+
raise NotImplementedError('Cuda is not availabel')
|
48 |
+
|
49 |
+
sources = [os.path.join(extensions_dir, s) for s in sources]
|
50 |
+
include_dirs = [extensions_dir]
|
51 |
+
ext_modules = [
|
52 |
+
extension(
|
53 |
+
"MultiScaleDeformableAttention",
|
54 |
+
sources,
|
55 |
+
include_dirs=include_dirs,
|
56 |
+
define_macros=define_macros,
|
57 |
+
extra_compile_args=extra_compile_args,
|
58 |
+
)
|
59 |
+
]
|
60 |
+
return ext_modules
|
61 |
+
|
62 |
+
setup(
|
63 |
+
name="MultiScaleDeformableAttention",
|
64 |
+
version="1.0",
|
65 |
+
author="Weijie Su",
|
66 |
+
url="https://github.com/fundamentalvision/Deformable-DETR",
|
67 |
+
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
|
68 |
+
packages=find_packages(exclude=("configs", "tests",)),
|
69 |
+
ext_modules=get_extensions(),
|
70 |
+
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
71 |
+
)
|
models/ops/src/cpu/ms_deform_attn_cpu.cpp
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#include <vector>
|
12 |
+
|
13 |
+
#include <ATen/ATen.h>
|
14 |
+
#include <ATen/cuda/CUDAContext.h>
|
15 |
+
|
16 |
+
|
17 |
+
at::Tensor
|
18 |
+
ms_deform_attn_cpu_forward(
|
19 |
+
const at::Tensor &value,
|
20 |
+
const at::Tensor &spatial_shapes,
|
21 |
+
const at::Tensor &level_start_index,
|
22 |
+
const at::Tensor &sampling_loc,
|
23 |
+
const at::Tensor &attn_weight,
|
24 |
+
const int im2col_step)
|
25 |
+
{
|
26 |
+
AT_ERROR("Not implement on cpu");
|
27 |
+
}
|
28 |
+
|
29 |
+
std::vector<at::Tensor>
|
30 |
+
ms_deform_attn_cpu_backward(
|
31 |
+
const at::Tensor &value,
|
32 |
+
const at::Tensor &spatial_shapes,
|
33 |
+
const at::Tensor &level_start_index,
|
34 |
+
const at::Tensor &sampling_loc,
|
35 |
+
const at::Tensor &attn_weight,
|
36 |
+
const at::Tensor &grad_output,
|
37 |
+
const int im2col_step)
|
38 |
+
{
|
39 |
+
AT_ERROR("Not implement on cpu");
|
40 |
+
}
|
41 |
+
|
models/ops/src/cpu/ms_deform_attn_cpu.h
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#pragma once
|
12 |
+
#include <torch/extension.h>
|
13 |
+
|
14 |
+
at::Tensor
|
15 |
+
ms_deform_attn_cpu_forward(
|
16 |
+
const at::Tensor &value,
|
17 |
+
const at::Tensor &spatial_shapes,
|
18 |
+
const at::Tensor &level_start_index,
|
19 |
+
const at::Tensor &sampling_loc,
|
20 |
+
const at::Tensor &attn_weight,
|
21 |
+
const int im2col_step);
|
22 |
+
|
23 |
+
std::vector<at::Tensor>
|
24 |
+
ms_deform_attn_cpu_backward(
|
25 |
+
const at::Tensor &value,
|
26 |
+
const at::Tensor &spatial_shapes,
|
27 |
+
const at::Tensor &level_start_index,
|
28 |
+
const at::Tensor &sampling_loc,
|
29 |
+
const at::Tensor &attn_weight,
|
30 |
+
const at::Tensor &grad_output,
|
31 |
+
const int im2col_step);
|
32 |
+
|
33 |
+
|
models/ops/src/cuda/ms_deform_attn_cuda.cu
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#include <vector>
|
12 |
+
#include "cuda/ms_deform_im2col_cuda.cuh"
|
13 |
+
|
14 |
+
#include <ATen/ATen.h>
|
15 |
+
#include <ATen/cuda/CUDAContext.h>
|
16 |
+
#include <cuda.h>
|
17 |
+
#include <cuda_runtime.h>
|
18 |
+
|
19 |
+
|
20 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
21 |
+
const at::Tensor &value,
|
22 |
+
const at::Tensor &spatial_shapes,
|
23 |
+
const at::Tensor &level_start_index,
|
24 |
+
const at::Tensor &sampling_loc,
|
25 |
+
const at::Tensor &attn_weight,
|
26 |
+
const int im2col_step)
|
27 |
+
{
|
28 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
29 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
30 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
31 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
32 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
33 |
+
|
34 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
35 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
36 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
37 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
38 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
39 |
+
|
40 |
+
const int batch = value.size(0);
|
41 |
+
const int spatial_size = value.size(1);
|
42 |
+
const int num_heads = value.size(2);
|
43 |
+
const int channels = value.size(3);
|
44 |
+
|
45 |
+
const int num_levels = spatial_shapes.size(0);
|
46 |
+
|
47 |
+
const int num_query = sampling_loc.size(1);
|
48 |
+
const int num_point = sampling_loc.size(4);
|
49 |
+
|
50 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
51 |
+
|
52 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
53 |
+
|
54 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
55 |
+
|
56 |
+
const int batch_n = im2col_step_;
|
57 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
58 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
59 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
60 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
61 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
62 |
+
{
|
63 |
+
auto columns = output_n.select(0, n);
|
64 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
65 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
66 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
67 |
+
spatial_shapes.data<int64_t>(),
|
68 |
+
level_start_index.data<int64_t>(),
|
69 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
70 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
71 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
72 |
+
columns.data<scalar_t>());
|
73 |
+
|
74 |
+
}));
|
75 |
+
}
|
76 |
+
|
77 |
+
output = output.view({batch, num_query, num_heads*channels});
|
78 |
+
|
79 |
+
return output;
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
84 |
+
const at::Tensor &value,
|
85 |
+
const at::Tensor &spatial_shapes,
|
86 |
+
const at::Tensor &level_start_index,
|
87 |
+
const at::Tensor &sampling_loc,
|
88 |
+
const at::Tensor &attn_weight,
|
89 |
+
const at::Tensor &grad_output,
|
90 |
+
const int im2col_step)
|
91 |
+
{
|
92 |
+
|
93 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
94 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
95 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
96 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
97 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
98 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
99 |
+
|
100 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
101 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
102 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
103 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
104 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
105 |
+
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
106 |
+
|
107 |
+
const int batch = value.size(0);
|
108 |
+
const int spatial_size = value.size(1);
|
109 |
+
const int num_heads = value.size(2);
|
110 |
+
const int channels = value.size(3);
|
111 |
+
|
112 |
+
const int num_levels = spatial_shapes.size(0);
|
113 |
+
|
114 |
+
const int num_query = sampling_loc.size(1);
|
115 |
+
const int num_point = sampling_loc.size(4);
|
116 |
+
|
117 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
118 |
+
|
119 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
120 |
+
|
121 |
+
auto grad_value = at::zeros_like(value);
|
122 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
123 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
124 |
+
|
125 |
+
const int batch_n = im2col_step_;
|
126 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
127 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
128 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
129 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
130 |
+
|
131 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
132 |
+
{
|
133 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
134 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
135 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
136 |
+
grad_output_g.data<scalar_t>(),
|
137 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
138 |
+
spatial_shapes.data<int64_t>(),
|
139 |
+
level_start_index.data<int64_t>(),
|
140 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
141 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
142 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
143 |
+
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
144 |
+
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
145 |
+
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
146 |
+
|
147 |
+
}));
|
148 |
+
}
|
149 |
+
|
150 |
+
return {
|
151 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
152 |
+
};
|
153 |
+
}
|
models/ops/src/cuda/ms_deform_attn_cuda.h
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#pragma once
|
12 |
+
#include <torch/extension.h>
|
13 |
+
|
14 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
15 |
+
const at::Tensor &value,
|
16 |
+
const at::Tensor &spatial_shapes,
|
17 |
+
const at::Tensor &level_start_index,
|
18 |
+
const at::Tensor &sampling_loc,
|
19 |
+
const at::Tensor &attn_weight,
|
20 |
+
const int im2col_step);
|
21 |
+
|
22 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
23 |
+
const at::Tensor &value,
|
24 |
+
const at::Tensor &spatial_shapes,
|
25 |
+
const at::Tensor &level_start_index,
|
26 |
+
const at::Tensor &sampling_loc,
|
27 |
+
const at::Tensor &attn_weight,
|
28 |
+
const at::Tensor &grad_output,
|
29 |
+
const int im2col_step);
|
30 |
+
|
models/ops/src/cuda/ms_deform_im2col_cuda.cuh
ADDED
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************
|
7 |
+
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
|
8 |
+
* Copyright (c) 2018 Microsoft
|
9 |
+
**************************************************************************
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <cstdio>
|
13 |
+
#include <algorithm>
|
14 |
+
#include <cstring>
|
15 |
+
|
16 |
+
#include <ATen/ATen.h>
|
17 |
+
#include <ATen/cuda/CUDAContext.h>
|
18 |
+
|
19 |
+
#include <THC/THCAtomics.cuh>
|
20 |
+
|
21 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
22 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
23 |
+
i < (n); \
|
24 |
+
i += blockDim.x * gridDim.x)
|
25 |
+
|
26 |
+
const int CUDA_NUM_THREADS = 1024;
|
27 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
28 |
+
{
|
29 |
+
return (N + num_threads - 1) / num_threads;
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
template <typename scalar_t>
|
34 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
35 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
36 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
37 |
+
{
|
38 |
+
const int h_low = floor(h);
|
39 |
+
const int w_low = floor(w);
|
40 |
+
const int h_high = h_low + 1;
|
41 |
+
const int w_high = w_low + 1;
|
42 |
+
|
43 |
+
const scalar_t lh = h - h_low;
|
44 |
+
const scalar_t lw = w - w_low;
|
45 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
46 |
+
|
47 |
+
const int w_stride = nheads * channels;
|
48 |
+
const int h_stride = width * w_stride;
|
49 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
50 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
51 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
52 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
53 |
+
const int base_ptr = m * channels + c;
|
54 |
+
|
55 |
+
scalar_t v1 = 0;
|
56 |
+
if (h_low >= 0 && w_low >= 0)
|
57 |
+
{
|
58 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
59 |
+
v1 = bottom_data[ptr1];
|
60 |
+
}
|
61 |
+
scalar_t v2 = 0;
|
62 |
+
if (h_low >= 0 && w_high <= width - 1)
|
63 |
+
{
|
64 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
65 |
+
v2 = bottom_data[ptr2];
|
66 |
+
}
|
67 |
+
scalar_t v3 = 0;
|
68 |
+
if (h_high <= height - 1 && w_low >= 0)
|
69 |
+
{
|
70 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
71 |
+
v3 = bottom_data[ptr3];
|
72 |
+
}
|
73 |
+
scalar_t v4 = 0;
|
74 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
75 |
+
{
|
76 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
77 |
+
v4 = bottom_data[ptr4];
|
78 |
+
}
|
79 |
+
|
80 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
81 |
+
|
82 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
83 |
+
return val;
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
template <typename scalar_t>
|
88 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
89 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
90 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
91 |
+
const scalar_t &top_grad,
|
92 |
+
const scalar_t &attn_weight,
|
93 |
+
scalar_t* &grad_value,
|
94 |
+
scalar_t* grad_sampling_loc,
|
95 |
+
scalar_t* grad_attn_weight)
|
96 |
+
{
|
97 |
+
const int h_low = floor(h);
|
98 |
+
const int w_low = floor(w);
|
99 |
+
const int h_high = h_low + 1;
|
100 |
+
const int w_high = w_low + 1;
|
101 |
+
|
102 |
+
const scalar_t lh = h - h_low;
|
103 |
+
const scalar_t lw = w - w_low;
|
104 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
105 |
+
|
106 |
+
const int w_stride = nheads * channels;
|
107 |
+
const int h_stride = width * w_stride;
|
108 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
109 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
110 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
111 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
112 |
+
const int base_ptr = m * channels + c;
|
113 |
+
|
114 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
115 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
116 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
117 |
+
|
118 |
+
scalar_t v1 = 0;
|
119 |
+
if (h_low >= 0 && w_low >= 0)
|
120 |
+
{
|
121 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
122 |
+
v1 = bottom_data[ptr1];
|
123 |
+
grad_h_weight -= hw * v1;
|
124 |
+
grad_w_weight -= hh * v1;
|
125 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
126 |
+
}
|
127 |
+
scalar_t v2 = 0;
|
128 |
+
if (h_low >= 0 && w_high <= width - 1)
|
129 |
+
{
|
130 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
131 |
+
v2 = bottom_data[ptr2];
|
132 |
+
grad_h_weight -= lw * v2;
|
133 |
+
grad_w_weight += hh * v2;
|
134 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
135 |
+
}
|
136 |
+
scalar_t v3 = 0;
|
137 |
+
if (h_high <= height - 1 && w_low >= 0)
|
138 |
+
{
|
139 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
140 |
+
v3 = bottom_data[ptr3];
|
141 |
+
grad_h_weight += hw * v3;
|
142 |
+
grad_w_weight -= lh * v3;
|
143 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
144 |
+
}
|
145 |
+
scalar_t v4 = 0;
|
146 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
147 |
+
{
|
148 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
149 |
+
v4 = bottom_data[ptr4];
|
150 |
+
grad_h_weight += lw * v4;
|
151 |
+
grad_w_weight += lh * v4;
|
152 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
153 |
+
}
|
154 |
+
|
155 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
156 |
+
*grad_attn_weight = top_grad * val;
|
157 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
158 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
template <typename scalar_t>
|
163 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
164 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
165 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
166 |
+
const scalar_t &top_grad,
|
167 |
+
const scalar_t &attn_weight,
|
168 |
+
scalar_t* &grad_value,
|
169 |
+
scalar_t* grad_sampling_loc,
|
170 |
+
scalar_t* grad_attn_weight)
|
171 |
+
{
|
172 |
+
const int h_low = floor(h);
|
173 |
+
const int w_low = floor(w);
|
174 |
+
const int h_high = h_low + 1;
|
175 |
+
const int w_high = w_low + 1;
|
176 |
+
|
177 |
+
const scalar_t lh = h - h_low;
|
178 |
+
const scalar_t lw = w - w_low;
|
179 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
180 |
+
|
181 |
+
const int w_stride = nheads * channels;
|
182 |
+
const int h_stride = width * w_stride;
|
183 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
184 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
185 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
186 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
187 |
+
const int base_ptr = m * channels + c;
|
188 |
+
|
189 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
190 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
191 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
192 |
+
|
193 |
+
scalar_t v1 = 0;
|
194 |
+
if (h_low >= 0 && w_low >= 0)
|
195 |
+
{
|
196 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
197 |
+
v1 = bottom_data[ptr1];
|
198 |
+
grad_h_weight -= hw * v1;
|
199 |
+
grad_w_weight -= hh * v1;
|
200 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
201 |
+
}
|
202 |
+
scalar_t v2 = 0;
|
203 |
+
if (h_low >= 0 && w_high <= width - 1)
|
204 |
+
{
|
205 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
206 |
+
v2 = bottom_data[ptr2];
|
207 |
+
grad_h_weight -= lw * v2;
|
208 |
+
grad_w_weight += hh * v2;
|
209 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
210 |
+
}
|
211 |
+
scalar_t v3 = 0;
|
212 |
+
if (h_high <= height - 1 && w_low >= 0)
|
213 |
+
{
|
214 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
215 |
+
v3 = bottom_data[ptr3];
|
216 |
+
grad_h_weight += hw * v3;
|
217 |
+
grad_w_weight -= lh * v3;
|
218 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
219 |
+
}
|
220 |
+
scalar_t v4 = 0;
|
221 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
222 |
+
{
|
223 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
224 |
+
v4 = bottom_data[ptr4];
|
225 |
+
grad_h_weight += lw * v4;
|
226 |
+
grad_w_weight += lh * v4;
|
227 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
228 |
+
}
|
229 |
+
|
230 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
231 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
232 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
233 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
template <typename scalar_t>
|
238 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
239 |
+
const scalar_t *data_value,
|
240 |
+
const int64_t *data_spatial_shapes,
|
241 |
+
const int64_t *data_level_start_index,
|
242 |
+
const scalar_t *data_sampling_loc,
|
243 |
+
const scalar_t *data_attn_weight,
|
244 |
+
const int batch_size,
|
245 |
+
const int spatial_size,
|
246 |
+
const int num_heads,
|
247 |
+
const int channels,
|
248 |
+
const int num_levels,
|
249 |
+
const int num_query,
|
250 |
+
const int num_point,
|
251 |
+
scalar_t *data_col)
|
252 |
+
{
|
253 |
+
CUDA_KERNEL_LOOP(index, n)
|
254 |
+
{
|
255 |
+
int _temp = index;
|
256 |
+
const int c_col = _temp % channels;
|
257 |
+
_temp /= channels;
|
258 |
+
const int sampling_index = _temp;
|
259 |
+
const int m_col = _temp % num_heads;
|
260 |
+
_temp /= num_heads;
|
261 |
+
const int q_col = _temp % num_query;
|
262 |
+
_temp /= num_query;
|
263 |
+
const int b_col = _temp;
|
264 |
+
|
265 |
+
scalar_t *data_col_ptr = data_col + index;
|
266 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
267 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
268 |
+
const int qid_stride = num_heads * channels;
|
269 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
270 |
+
scalar_t col = 0;
|
271 |
+
|
272 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
273 |
+
{
|
274 |
+
const int level_start_id = data_level_start_index[l_col];
|
275 |
+
const int spatial_h_ptr = l_col << 1;
|
276 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
277 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
278 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
279 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
280 |
+
{
|
281 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
282 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
283 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
284 |
+
|
285 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
286 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
287 |
+
|
288 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
289 |
+
{
|
290 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
291 |
+
}
|
292 |
+
|
293 |
+
data_weight_ptr += 1;
|
294 |
+
data_loc_w_ptr += 2;
|
295 |
+
}
|
296 |
+
}
|
297 |
+
*data_col_ptr = col;
|
298 |
+
}
|
299 |
+
}
|
300 |
+
|
301 |
+
template <typename scalar_t, unsigned int blockSize>
|
302 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
303 |
+
const scalar_t *grad_col,
|
304 |
+
const scalar_t *data_value,
|
305 |
+
const int64_t *data_spatial_shapes,
|
306 |
+
const int64_t *data_level_start_index,
|
307 |
+
const scalar_t *data_sampling_loc,
|
308 |
+
const scalar_t *data_attn_weight,
|
309 |
+
const int batch_size,
|
310 |
+
const int spatial_size,
|
311 |
+
const int num_heads,
|
312 |
+
const int channels,
|
313 |
+
const int num_levels,
|
314 |
+
const int num_query,
|
315 |
+
const int num_point,
|
316 |
+
scalar_t *grad_value,
|
317 |
+
scalar_t *grad_sampling_loc,
|
318 |
+
scalar_t *grad_attn_weight)
|
319 |
+
{
|
320 |
+
CUDA_KERNEL_LOOP(index, n)
|
321 |
+
{
|
322 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
323 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
324 |
+
unsigned int tid = threadIdx.x;
|
325 |
+
int _temp = index;
|
326 |
+
const int c_col = _temp % channels;
|
327 |
+
_temp /= channels;
|
328 |
+
const int sampling_index = _temp;
|
329 |
+
const int m_col = _temp % num_heads;
|
330 |
+
_temp /= num_heads;
|
331 |
+
const int q_col = _temp % num_query;
|
332 |
+
_temp /= num_query;
|
333 |
+
const int b_col = _temp;
|
334 |
+
|
335 |
+
const scalar_t top_grad = grad_col[index];
|
336 |
+
|
337 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
338 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
339 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
340 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
341 |
+
grad_attn_weight += grad_sampling_ptr;
|
342 |
+
const int grad_weight_stride = 1;
|
343 |
+
const int grad_loc_stride = 2;
|
344 |
+
const int qid_stride = num_heads * channels;
|
345 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
346 |
+
|
347 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
348 |
+
{
|
349 |
+
const int level_start_id = data_level_start_index[l_col];
|
350 |
+
const int spatial_h_ptr = l_col << 1;
|
351 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
352 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
353 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
354 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
355 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
356 |
+
|
357 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
358 |
+
{
|
359 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
360 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
361 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
362 |
+
|
363 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
364 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
365 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
366 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
367 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
368 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
369 |
+
{
|
370 |
+
ms_deform_attn_col2im_bilinear(
|
371 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
372 |
+
top_grad, weight, grad_value_ptr,
|
373 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
374 |
+
}
|
375 |
+
|
376 |
+
__syncthreads();
|
377 |
+
if (tid == 0)
|
378 |
+
{
|
379 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
380 |
+
int sid=2;
|
381 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
382 |
+
{
|
383 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
384 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
385 |
+
_grad_a += cache_grad_attn_weight[tid];
|
386 |
+
sid += 2;
|
387 |
+
}
|
388 |
+
|
389 |
+
|
390 |
+
*grad_sampling_loc = _grad_w;
|
391 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
392 |
+
*grad_attn_weight = _grad_a;
|
393 |
+
}
|
394 |
+
__syncthreads();
|
395 |
+
|
396 |
+
data_weight_ptr += 1;
|
397 |
+
data_loc_w_ptr += 2;
|
398 |
+
grad_attn_weight += grad_weight_stride;
|
399 |
+
grad_sampling_loc += grad_loc_stride;
|
400 |
+
}
|
401 |
+
}
|
402 |
+
}
|
403 |
+
}
|
404 |
+
|
405 |
+
|
406 |
+
template <typename scalar_t, unsigned int blockSize>
|
407 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
408 |
+
const scalar_t *grad_col,
|
409 |
+
const scalar_t *data_value,
|
410 |
+
const int64_t *data_spatial_shapes,
|
411 |
+
const int64_t *data_level_start_index,
|
412 |
+
const scalar_t *data_sampling_loc,
|
413 |
+
const scalar_t *data_attn_weight,
|
414 |
+
const int batch_size,
|
415 |
+
const int spatial_size,
|
416 |
+
const int num_heads,
|
417 |
+
const int channels,
|
418 |
+
const int num_levels,
|
419 |
+
const int num_query,
|
420 |
+
const int num_point,
|
421 |
+
scalar_t *grad_value,
|
422 |
+
scalar_t *grad_sampling_loc,
|
423 |
+
scalar_t *grad_attn_weight)
|
424 |
+
{
|
425 |
+
CUDA_KERNEL_LOOP(index, n)
|
426 |
+
{
|
427 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
428 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
429 |
+
unsigned int tid = threadIdx.x;
|
430 |
+
int _temp = index;
|
431 |
+
const int c_col = _temp % channels;
|
432 |
+
_temp /= channels;
|
433 |
+
const int sampling_index = _temp;
|
434 |
+
const int m_col = _temp % num_heads;
|
435 |
+
_temp /= num_heads;
|
436 |
+
const int q_col = _temp % num_query;
|
437 |
+
_temp /= num_query;
|
438 |
+
const int b_col = _temp;
|
439 |
+
|
440 |
+
const scalar_t top_grad = grad_col[index];
|
441 |
+
|
442 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
443 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
444 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
445 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
446 |
+
grad_attn_weight += grad_sampling_ptr;
|
447 |
+
const int grad_weight_stride = 1;
|
448 |
+
const int grad_loc_stride = 2;
|
449 |
+
const int qid_stride = num_heads * channels;
|
450 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
451 |
+
|
452 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
453 |
+
{
|
454 |
+
const int level_start_id = data_level_start_index[l_col];
|
455 |
+
const int spatial_h_ptr = l_col << 1;
|
456 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
457 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
458 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
459 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
460 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
461 |
+
|
462 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
463 |
+
{
|
464 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
465 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
466 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
467 |
+
|
468 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
469 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
470 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
471 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
472 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
473 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
474 |
+
{
|
475 |
+
ms_deform_attn_col2im_bilinear(
|
476 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
477 |
+
top_grad, weight, grad_value_ptr,
|
478 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
479 |
+
}
|
480 |
+
|
481 |
+
__syncthreads();
|
482 |
+
|
483 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
484 |
+
{
|
485 |
+
if (tid < s) {
|
486 |
+
const unsigned int xid1 = tid << 1;
|
487 |
+
const unsigned int xid2 = (tid + s) << 1;
|
488 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
489 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
490 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
491 |
+
}
|
492 |
+
__syncthreads();
|
493 |
+
}
|
494 |
+
|
495 |
+
if (tid == 0)
|
496 |
+
{
|
497 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
498 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
499 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
500 |
+
}
|
501 |
+
__syncthreads();
|
502 |
+
|
503 |
+
data_weight_ptr += 1;
|
504 |
+
data_loc_w_ptr += 2;
|
505 |
+
grad_attn_weight += grad_weight_stride;
|
506 |
+
grad_sampling_loc += grad_loc_stride;
|
507 |
+
}
|
508 |
+
}
|
509 |
+
}
|
510 |
+
}
|
511 |
+
|
512 |
+
|
513 |
+
template <typename scalar_t>
|
514 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
515 |
+
const scalar_t *grad_col,
|
516 |
+
const scalar_t *data_value,
|
517 |
+
const int64_t *data_spatial_shapes,
|
518 |
+
const int64_t *data_level_start_index,
|
519 |
+
const scalar_t *data_sampling_loc,
|
520 |
+
const scalar_t *data_attn_weight,
|
521 |
+
const int batch_size,
|
522 |
+
const int spatial_size,
|
523 |
+
const int num_heads,
|
524 |
+
const int channels,
|
525 |
+
const int num_levels,
|
526 |
+
const int num_query,
|
527 |
+
const int num_point,
|
528 |
+
scalar_t *grad_value,
|
529 |
+
scalar_t *grad_sampling_loc,
|
530 |
+
scalar_t *grad_attn_weight)
|
531 |
+
{
|
532 |
+
CUDA_KERNEL_LOOP(index, n)
|
533 |
+
{
|
534 |
+
extern __shared__ int _s[];
|
535 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
536 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
537 |
+
unsigned int tid = threadIdx.x;
|
538 |
+
int _temp = index;
|
539 |
+
const int c_col = _temp % channels;
|
540 |
+
_temp /= channels;
|
541 |
+
const int sampling_index = _temp;
|
542 |
+
const int m_col = _temp % num_heads;
|
543 |
+
_temp /= num_heads;
|
544 |
+
const int q_col = _temp % num_query;
|
545 |
+
_temp /= num_query;
|
546 |
+
const int b_col = _temp;
|
547 |
+
|
548 |
+
const scalar_t top_grad = grad_col[index];
|
549 |
+
|
550 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
551 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
552 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
553 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
554 |
+
grad_attn_weight += grad_sampling_ptr;
|
555 |
+
const int grad_weight_stride = 1;
|
556 |
+
const int grad_loc_stride = 2;
|
557 |
+
const int qid_stride = num_heads * channels;
|
558 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
559 |
+
|
560 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
561 |
+
{
|
562 |
+
const int level_start_id = data_level_start_index[l_col];
|
563 |
+
const int spatial_h_ptr = l_col << 1;
|
564 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
565 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
566 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
567 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
568 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
569 |
+
|
570 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
571 |
+
{
|
572 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
573 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
574 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
575 |
+
|
576 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
577 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
578 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
579 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
580 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
581 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
582 |
+
{
|
583 |
+
ms_deform_attn_col2im_bilinear(
|
584 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
585 |
+
top_grad, weight, grad_value_ptr,
|
586 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
587 |
+
}
|
588 |
+
|
589 |
+
__syncthreads();
|
590 |
+
if (tid == 0)
|
591 |
+
{
|
592 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
593 |
+
int sid=2;
|
594 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
595 |
+
{
|
596 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
597 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
598 |
+
_grad_a += cache_grad_attn_weight[tid];
|
599 |
+
sid += 2;
|
600 |
+
}
|
601 |
+
|
602 |
+
|
603 |
+
*grad_sampling_loc = _grad_w;
|
604 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
605 |
+
*grad_attn_weight = _grad_a;
|
606 |
+
}
|
607 |
+
__syncthreads();
|
608 |
+
|
609 |
+
data_weight_ptr += 1;
|
610 |
+
data_loc_w_ptr += 2;
|
611 |
+
grad_attn_weight += grad_weight_stride;
|
612 |
+
grad_sampling_loc += grad_loc_stride;
|
613 |
+
}
|
614 |
+
}
|
615 |
+
}
|
616 |
+
}
|
617 |
+
|
618 |
+
template <typename scalar_t>
|
619 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
620 |
+
const scalar_t *grad_col,
|
621 |
+
const scalar_t *data_value,
|
622 |
+
const int64_t *data_spatial_shapes,
|
623 |
+
const int64_t *data_level_start_index,
|
624 |
+
const scalar_t *data_sampling_loc,
|
625 |
+
const scalar_t *data_attn_weight,
|
626 |
+
const int batch_size,
|
627 |
+
const int spatial_size,
|
628 |
+
const int num_heads,
|
629 |
+
const int channels,
|
630 |
+
const int num_levels,
|
631 |
+
const int num_query,
|
632 |
+
const int num_point,
|
633 |
+
scalar_t *grad_value,
|
634 |
+
scalar_t *grad_sampling_loc,
|
635 |
+
scalar_t *grad_attn_weight)
|
636 |
+
{
|
637 |
+
CUDA_KERNEL_LOOP(index, n)
|
638 |
+
{
|
639 |
+
extern __shared__ int _s[];
|
640 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
641 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
642 |
+
unsigned int tid = threadIdx.x;
|
643 |
+
int _temp = index;
|
644 |
+
const int c_col = _temp % channels;
|
645 |
+
_temp /= channels;
|
646 |
+
const int sampling_index = _temp;
|
647 |
+
const int m_col = _temp % num_heads;
|
648 |
+
_temp /= num_heads;
|
649 |
+
const int q_col = _temp % num_query;
|
650 |
+
_temp /= num_query;
|
651 |
+
const int b_col = _temp;
|
652 |
+
|
653 |
+
const scalar_t top_grad = grad_col[index];
|
654 |
+
|
655 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
656 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
657 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
658 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
659 |
+
grad_attn_weight += grad_sampling_ptr;
|
660 |
+
const int grad_weight_stride = 1;
|
661 |
+
const int grad_loc_stride = 2;
|
662 |
+
const int qid_stride = num_heads * channels;
|
663 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
664 |
+
|
665 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
666 |
+
{
|
667 |
+
const int level_start_id = data_level_start_index[l_col];
|
668 |
+
const int spatial_h_ptr = l_col << 1;
|
669 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
670 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
671 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
672 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
673 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
674 |
+
|
675 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
676 |
+
{
|
677 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
678 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
679 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
680 |
+
|
681 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
682 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
683 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
684 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
685 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
686 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
687 |
+
{
|
688 |
+
ms_deform_attn_col2im_bilinear(
|
689 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
690 |
+
top_grad, weight, grad_value_ptr,
|
691 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
692 |
+
}
|
693 |
+
|
694 |
+
__syncthreads();
|
695 |
+
|
696 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
697 |
+
{
|
698 |
+
if (tid < s) {
|
699 |
+
const unsigned int xid1 = tid << 1;
|
700 |
+
const unsigned int xid2 = (tid + s) << 1;
|
701 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
702 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
703 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
704 |
+
if (tid + (s << 1) < spre)
|
705 |
+
{
|
706 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
707 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
708 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
709 |
+
}
|
710 |
+
}
|
711 |
+
__syncthreads();
|
712 |
+
}
|
713 |
+
|
714 |
+
if (tid == 0)
|
715 |
+
{
|
716 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
717 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
718 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
719 |
+
}
|
720 |
+
__syncthreads();
|
721 |
+
|
722 |
+
data_weight_ptr += 1;
|
723 |
+
data_loc_w_ptr += 2;
|
724 |
+
grad_attn_weight += grad_weight_stride;
|
725 |
+
grad_sampling_loc += grad_loc_stride;
|
726 |
+
}
|
727 |
+
}
|
728 |
+
}
|
729 |
+
}
|
730 |
+
|
731 |
+
template <typename scalar_t>
|
732 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
733 |
+
const scalar_t *grad_col,
|
734 |
+
const scalar_t *data_value,
|
735 |
+
const int64_t *data_spatial_shapes,
|
736 |
+
const int64_t *data_level_start_index,
|
737 |
+
const scalar_t *data_sampling_loc,
|
738 |
+
const scalar_t *data_attn_weight,
|
739 |
+
const int batch_size,
|
740 |
+
const int spatial_size,
|
741 |
+
const int num_heads,
|
742 |
+
const int channels,
|
743 |
+
const int num_levels,
|
744 |
+
const int num_query,
|
745 |
+
const int num_point,
|
746 |
+
scalar_t *grad_value,
|
747 |
+
scalar_t *grad_sampling_loc,
|
748 |
+
scalar_t *grad_attn_weight)
|
749 |
+
{
|
750 |
+
CUDA_KERNEL_LOOP(index, n)
|
751 |
+
{
|
752 |
+
extern __shared__ int _s[];
|
753 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
754 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
755 |
+
unsigned int tid = threadIdx.x;
|
756 |
+
int _temp = index;
|
757 |
+
const int c_col = _temp % channels;
|
758 |
+
_temp /= channels;
|
759 |
+
const int sampling_index = _temp;
|
760 |
+
const int m_col = _temp % num_heads;
|
761 |
+
_temp /= num_heads;
|
762 |
+
const int q_col = _temp % num_query;
|
763 |
+
_temp /= num_query;
|
764 |
+
const int b_col = _temp;
|
765 |
+
|
766 |
+
const scalar_t top_grad = grad_col[index];
|
767 |
+
|
768 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
769 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
770 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
771 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
772 |
+
grad_attn_weight += grad_sampling_ptr;
|
773 |
+
const int grad_weight_stride = 1;
|
774 |
+
const int grad_loc_stride = 2;
|
775 |
+
const int qid_stride = num_heads * channels;
|
776 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
777 |
+
|
778 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
779 |
+
{
|
780 |
+
const int level_start_id = data_level_start_index[l_col];
|
781 |
+
const int spatial_h_ptr = l_col << 1;
|
782 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
783 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
784 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
785 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
786 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
787 |
+
|
788 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
789 |
+
{
|
790 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
791 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
792 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
793 |
+
|
794 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
795 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
796 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
797 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
798 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
799 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
800 |
+
{
|
801 |
+
ms_deform_attn_col2im_bilinear(
|
802 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
803 |
+
top_grad, weight, grad_value_ptr,
|
804 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
805 |
+
}
|
806 |
+
|
807 |
+
__syncthreads();
|
808 |
+
|
809 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
810 |
+
{
|
811 |
+
if (tid < s) {
|
812 |
+
const unsigned int xid1 = tid << 1;
|
813 |
+
const unsigned int xid2 = (tid + s) << 1;
|
814 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
815 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
816 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
817 |
+
if (tid + (s << 1) < spre)
|
818 |
+
{
|
819 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
820 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
821 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
822 |
+
}
|
823 |
+
}
|
824 |
+
__syncthreads();
|
825 |
+
}
|
826 |
+
|
827 |
+
if (tid == 0)
|
828 |
+
{
|
829 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
830 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
831 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
832 |
+
}
|
833 |
+
__syncthreads();
|
834 |
+
|
835 |
+
data_weight_ptr += 1;
|
836 |
+
data_loc_w_ptr += 2;
|
837 |
+
grad_attn_weight += grad_weight_stride;
|
838 |
+
grad_sampling_loc += grad_loc_stride;
|
839 |
+
}
|
840 |
+
}
|
841 |
+
}
|
842 |
+
}
|
843 |
+
|
844 |
+
|
845 |
+
template <typename scalar_t>
|
846 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
847 |
+
const scalar_t *grad_col,
|
848 |
+
const scalar_t *data_value,
|
849 |
+
const int64_t *data_spatial_shapes,
|
850 |
+
const int64_t *data_level_start_index,
|
851 |
+
const scalar_t *data_sampling_loc,
|
852 |
+
const scalar_t *data_attn_weight,
|
853 |
+
const int batch_size,
|
854 |
+
const int spatial_size,
|
855 |
+
const int num_heads,
|
856 |
+
const int channels,
|
857 |
+
const int num_levels,
|
858 |
+
const int num_query,
|
859 |
+
const int num_point,
|
860 |
+
scalar_t *grad_value,
|
861 |
+
scalar_t *grad_sampling_loc,
|
862 |
+
scalar_t *grad_attn_weight)
|
863 |
+
{
|
864 |
+
CUDA_KERNEL_LOOP(index, n)
|
865 |
+
{
|
866 |
+
int _temp = index;
|
867 |
+
const int c_col = _temp % channels;
|
868 |
+
_temp /= channels;
|
869 |
+
const int sampling_index = _temp;
|
870 |
+
const int m_col = _temp % num_heads;
|
871 |
+
_temp /= num_heads;
|
872 |
+
const int q_col = _temp % num_query;
|
873 |
+
_temp /= num_query;
|
874 |
+
const int b_col = _temp;
|
875 |
+
|
876 |
+
const scalar_t top_grad = grad_col[index];
|
877 |
+
|
878 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
879 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
880 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
881 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
882 |
+
grad_attn_weight += grad_sampling_ptr;
|
883 |
+
const int grad_weight_stride = 1;
|
884 |
+
const int grad_loc_stride = 2;
|
885 |
+
const int qid_stride = num_heads * channels;
|
886 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
887 |
+
|
888 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
889 |
+
{
|
890 |
+
const int level_start_id = data_level_start_index[l_col];
|
891 |
+
const int spatial_h_ptr = l_col << 1;
|
892 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
893 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
894 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
895 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
896 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
897 |
+
|
898 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
899 |
+
{
|
900 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
901 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
902 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
903 |
+
|
904 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
905 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
906 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
907 |
+
{
|
908 |
+
ms_deform_attn_col2im_bilinear_gm(
|
909 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
910 |
+
top_grad, weight, grad_value_ptr,
|
911 |
+
grad_sampling_loc, grad_attn_weight);
|
912 |
+
}
|
913 |
+
data_weight_ptr += 1;
|
914 |
+
data_loc_w_ptr += 2;
|
915 |
+
grad_attn_weight += grad_weight_stride;
|
916 |
+
grad_sampling_loc += grad_loc_stride;
|
917 |
+
}
|
918 |
+
}
|
919 |
+
}
|
920 |
+
}
|
921 |
+
|
922 |
+
|
923 |
+
template <typename scalar_t>
|
924 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
925 |
+
const scalar_t* data_value,
|
926 |
+
const int64_t* data_spatial_shapes,
|
927 |
+
const int64_t* data_level_start_index,
|
928 |
+
const scalar_t* data_sampling_loc,
|
929 |
+
const scalar_t* data_attn_weight,
|
930 |
+
const int batch_size,
|
931 |
+
const int spatial_size,
|
932 |
+
const int num_heads,
|
933 |
+
const int channels,
|
934 |
+
const int num_levels,
|
935 |
+
const int num_query,
|
936 |
+
const int num_point,
|
937 |
+
scalar_t* data_col)
|
938 |
+
{
|
939 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
940 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
941 |
+
const int num_threads = CUDA_NUM_THREADS;
|
942 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
943 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
944 |
+
0, stream>>>(
|
945 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
946 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
947 |
+
|
948 |
+
cudaError_t err = cudaGetLastError();
|
949 |
+
if (err != cudaSuccess)
|
950 |
+
{
|
951 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
952 |
+
}
|
953 |
+
|
954 |
+
}
|
955 |
+
|
956 |
+
template <typename scalar_t>
|
957 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
958 |
+
const scalar_t* grad_col,
|
959 |
+
const scalar_t* data_value,
|
960 |
+
const int64_t * data_spatial_shapes,
|
961 |
+
const int64_t * data_level_start_index,
|
962 |
+
const scalar_t * data_sampling_loc,
|
963 |
+
const scalar_t * data_attn_weight,
|
964 |
+
const int batch_size,
|
965 |
+
const int spatial_size,
|
966 |
+
const int num_heads,
|
967 |
+
const int channels,
|
968 |
+
const int num_levels,
|
969 |
+
const int num_query,
|
970 |
+
const int num_point,
|
971 |
+
scalar_t* grad_value,
|
972 |
+
scalar_t* grad_sampling_loc,
|
973 |
+
scalar_t* grad_attn_weight)
|
974 |
+
{
|
975 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
976 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
977 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
978 |
+
if (channels > 1024)
|
979 |
+
{
|
980 |
+
if ((channels & 1023) == 0)
|
981 |
+
{
|
982 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
983 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
984 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
985 |
+
num_kernels,
|
986 |
+
grad_col,
|
987 |
+
data_value,
|
988 |
+
data_spatial_shapes,
|
989 |
+
data_level_start_index,
|
990 |
+
data_sampling_loc,
|
991 |
+
data_attn_weight,
|
992 |
+
batch_size,
|
993 |
+
spatial_size,
|
994 |
+
num_heads,
|
995 |
+
channels,
|
996 |
+
num_levels,
|
997 |
+
num_query,
|
998 |
+
num_point,
|
999 |
+
grad_value,
|
1000 |
+
grad_sampling_loc,
|
1001 |
+
grad_attn_weight);
|
1002 |
+
}
|
1003 |
+
else
|
1004 |
+
{
|
1005 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
1006 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1007 |
+
0, stream>>>(
|
1008 |
+
num_kernels,
|
1009 |
+
grad_col,
|
1010 |
+
data_value,
|
1011 |
+
data_spatial_shapes,
|
1012 |
+
data_level_start_index,
|
1013 |
+
data_sampling_loc,
|
1014 |
+
data_attn_weight,
|
1015 |
+
batch_size,
|
1016 |
+
spatial_size,
|
1017 |
+
num_heads,
|
1018 |
+
channels,
|
1019 |
+
num_levels,
|
1020 |
+
num_query,
|
1021 |
+
num_point,
|
1022 |
+
grad_value,
|
1023 |
+
grad_sampling_loc,
|
1024 |
+
grad_attn_weight);
|
1025 |
+
}
|
1026 |
+
}
|
1027 |
+
else{
|
1028 |
+
switch(channels)
|
1029 |
+
{
|
1030 |
+
case 1:
|
1031 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
1032 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1033 |
+
0, stream>>>(
|
1034 |
+
num_kernels,
|
1035 |
+
grad_col,
|
1036 |
+
data_value,
|
1037 |
+
data_spatial_shapes,
|
1038 |
+
data_level_start_index,
|
1039 |
+
data_sampling_loc,
|
1040 |
+
data_attn_weight,
|
1041 |
+
batch_size,
|
1042 |
+
spatial_size,
|
1043 |
+
num_heads,
|
1044 |
+
channels,
|
1045 |
+
num_levels,
|
1046 |
+
num_query,
|
1047 |
+
num_point,
|
1048 |
+
grad_value,
|
1049 |
+
grad_sampling_loc,
|
1050 |
+
grad_attn_weight);
|
1051 |
+
break;
|
1052 |
+
case 2:
|
1053 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
1054 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1055 |
+
0, stream>>>(
|
1056 |
+
num_kernels,
|
1057 |
+
grad_col,
|
1058 |
+
data_value,
|
1059 |
+
data_spatial_shapes,
|
1060 |
+
data_level_start_index,
|
1061 |
+
data_sampling_loc,
|
1062 |
+
data_attn_weight,
|
1063 |
+
batch_size,
|
1064 |
+
spatial_size,
|
1065 |
+
num_heads,
|
1066 |
+
channels,
|
1067 |
+
num_levels,
|
1068 |
+
num_query,
|
1069 |
+
num_point,
|
1070 |
+
grad_value,
|
1071 |
+
grad_sampling_loc,
|
1072 |
+
grad_attn_weight);
|
1073 |
+
break;
|
1074 |
+
case 4:
|
1075 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
1076 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1077 |
+
0, stream>>>(
|
1078 |
+
num_kernels,
|
1079 |
+
grad_col,
|
1080 |
+
data_value,
|
1081 |
+
data_spatial_shapes,
|
1082 |
+
data_level_start_index,
|
1083 |
+
data_sampling_loc,
|
1084 |
+
data_attn_weight,
|
1085 |
+
batch_size,
|
1086 |
+
spatial_size,
|
1087 |
+
num_heads,
|
1088 |
+
channels,
|
1089 |
+
num_levels,
|
1090 |
+
num_query,
|
1091 |
+
num_point,
|
1092 |
+
grad_value,
|
1093 |
+
grad_sampling_loc,
|
1094 |
+
grad_attn_weight);
|
1095 |
+
break;
|
1096 |
+
case 8:
|
1097 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
1098 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1099 |
+
0, stream>>>(
|
1100 |
+
num_kernels,
|
1101 |
+
grad_col,
|
1102 |
+
data_value,
|
1103 |
+
data_spatial_shapes,
|
1104 |
+
data_level_start_index,
|
1105 |
+
data_sampling_loc,
|
1106 |
+
data_attn_weight,
|
1107 |
+
batch_size,
|
1108 |
+
spatial_size,
|
1109 |
+
num_heads,
|
1110 |
+
channels,
|
1111 |
+
num_levels,
|
1112 |
+
num_query,
|
1113 |
+
num_point,
|
1114 |
+
grad_value,
|
1115 |
+
grad_sampling_loc,
|
1116 |
+
grad_attn_weight);
|
1117 |
+
break;
|
1118 |
+
case 16:
|
1119 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
1120 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1121 |
+
0, stream>>>(
|
1122 |
+
num_kernels,
|
1123 |
+
grad_col,
|
1124 |
+
data_value,
|
1125 |
+
data_spatial_shapes,
|
1126 |
+
data_level_start_index,
|
1127 |
+
data_sampling_loc,
|
1128 |
+
data_attn_weight,
|
1129 |
+
batch_size,
|
1130 |
+
spatial_size,
|
1131 |
+
num_heads,
|
1132 |
+
channels,
|
1133 |
+
num_levels,
|
1134 |
+
num_query,
|
1135 |
+
num_point,
|
1136 |
+
grad_value,
|
1137 |
+
grad_sampling_loc,
|
1138 |
+
grad_attn_weight);
|
1139 |
+
break;
|
1140 |
+
case 32:
|
1141 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
1142 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1143 |
+
0, stream>>>(
|
1144 |
+
num_kernels,
|
1145 |
+
grad_col,
|
1146 |
+
data_value,
|
1147 |
+
data_spatial_shapes,
|
1148 |
+
data_level_start_index,
|
1149 |
+
data_sampling_loc,
|
1150 |
+
data_attn_weight,
|
1151 |
+
batch_size,
|
1152 |
+
spatial_size,
|
1153 |
+
num_heads,
|
1154 |
+
channels,
|
1155 |
+
num_levels,
|
1156 |
+
num_query,
|
1157 |
+
num_point,
|
1158 |
+
grad_value,
|
1159 |
+
grad_sampling_loc,
|
1160 |
+
grad_attn_weight);
|
1161 |
+
break;
|
1162 |
+
case 64:
|
1163 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
1164 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1165 |
+
0, stream>>>(
|
1166 |
+
num_kernels,
|
1167 |
+
grad_col,
|
1168 |
+
data_value,
|
1169 |
+
data_spatial_shapes,
|
1170 |
+
data_level_start_index,
|
1171 |
+
data_sampling_loc,
|
1172 |
+
data_attn_weight,
|
1173 |
+
batch_size,
|
1174 |
+
spatial_size,
|
1175 |
+
num_heads,
|
1176 |
+
channels,
|
1177 |
+
num_levels,
|
1178 |
+
num_query,
|
1179 |
+
num_point,
|
1180 |
+
grad_value,
|
1181 |
+
grad_sampling_loc,
|
1182 |
+
grad_attn_weight);
|
1183 |
+
break;
|
1184 |
+
case 128:
|
1185 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
1186 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1187 |
+
0, stream>>>(
|
1188 |
+
num_kernels,
|
1189 |
+
grad_col,
|
1190 |
+
data_value,
|
1191 |
+
data_spatial_shapes,
|
1192 |
+
data_level_start_index,
|
1193 |
+
data_sampling_loc,
|
1194 |
+
data_attn_weight,
|
1195 |
+
batch_size,
|
1196 |
+
spatial_size,
|
1197 |
+
num_heads,
|
1198 |
+
channels,
|
1199 |
+
num_levels,
|
1200 |
+
num_query,
|
1201 |
+
num_point,
|
1202 |
+
grad_value,
|
1203 |
+
grad_sampling_loc,
|
1204 |
+
grad_attn_weight);
|
1205 |
+
break;
|
1206 |
+
case 256:
|
1207 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
1208 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1209 |
+
0, stream>>>(
|
1210 |
+
num_kernels,
|
1211 |
+
grad_col,
|
1212 |
+
data_value,
|
1213 |
+
data_spatial_shapes,
|
1214 |
+
data_level_start_index,
|
1215 |
+
data_sampling_loc,
|
1216 |
+
data_attn_weight,
|
1217 |
+
batch_size,
|
1218 |
+
spatial_size,
|
1219 |
+
num_heads,
|
1220 |
+
channels,
|
1221 |
+
num_levels,
|
1222 |
+
num_query,
|
1223 |
+
num_point,
|
1224 |
+
grad_value,
|
1225 |
+
grad_sampling_loc,
|
1226 |
+
grad_attn_weight);
|
1227 |
+
break;
|
1228 |
+
case 512:
|
1229 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
1230 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1231 |
+
0, stream>>>(
|
1232 |
+
num_kernels,
|
1233 |
+
grad_col,
|
1234 |
+
data_value,
|
1235 |
+
data_spatial_shapes,
|
1236 |
+
data_level_start_index,
|
1237 |
+
data_sampling_loc,
|
1238 |
+
data_attn_weight,
|
1239 |
+
batch_size,
|
1240 |
+
spatial_size,
|
1241 |
+
num_heads,
|
1242 |
+
channels,
|
1243 |
+
num_levels,
|
1244 |
+
num_query,
|
1245 |
+
num_point,
|
1246 |
+
grad_value,
|
1247 |
+
grad_sampling_loc,
|
1248 |
+
grad_attn_weight);
|
1249 |
+
break;
|
1250 |
+
case 1024:
|
1251 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
1252 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1253 |
+
0, stream>>>(
|
1254 |
+
num_kernels,
|
1255 |
+
grad_col,
|
1256 |
+
data_value,
|
1257 |
+
data_spatial_shapes,
|
1258 |
+
data_level_start_index,
|
1259 |
+
data_sampling_loc,
|
1260 |
+
data_attn_weight,
|
1261 |
+
batch_size,
|
1262 |
+
spatial_size,
|
1263 |
+
num_heads,
|
1264 |
+
channels,
|
1265 |
+
num_levels,
|
1266 |
+
num_query,
|
1267 |
+
num_point,
|
1268 |
+
grad_value,
|
1269 |
+
grad_sampling_loc,
|
1270 |
+
grad_attn_weight);
|
1271 |
+
break;
|
1272 |
+
default:
|
1273 |
+
if (channels < 64)
|
1274 |
+
{
|
1275 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
1276 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1277 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
1278 |
+
num_kernels,
|
1279 |
+
grad_col,
|
1280 |
+
data_value,
|
1281 |
+
data_spatial_shapes,
|
1282 |
+
data_level_start_index,
|
1283 |
+
data_sampling_loc,
|
1284 |
+
data_attn_weight,
|
1285 |
+
batch_size,
|
1286 |
+
spatial_size,
|
1287 |
+
num_heads,
|
1288 |
+
channels,
|
1289 |
+
num_levels,
|
1290 |
+
num_query,
|
1291 |
+
num_point,
|
1292 |
+
grad_value,
|
1293 |
+
grad_sampling_loc,
|
1294 |
+
grad_attn_weight);
|
1295 |
+
}
|
1296 |
+
else
|
1297 |
+
{
|
1298 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
1299 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1300 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
1301 |
+
num_kernels,
|
1302 |
+
grad_col,
|
1303 |
+
data_value,
|
1304 |
+
data_spatial_shapes,
|
1305 |
+
data_level_start_index,
|
1306 |
+
data_sampling_loc,
|
1307 |
+
data_attn_weight,
|
1308 |
+
batch_size,
|
1309 |
+
spatial_size,
|
1310 |
+
num_heads,
|
1311 |
+
channels,
|
1312 |
+
num_levels,
|
1313 |
+
num_query,
|
1314 |
+
num_point,
|
1315 |
+
grad_value,
|
1316 |
+
grad_sampling_loc,
|
1317 |
+
grad_attn_weight);
|
1318 |
+
}
|
1319 |
+
}
|
1320 |
+
}
|
1321 |
+
cudaError_t err = cudaGetLastError();
|
1322 |
+
if (err != cudaSuccess)
|
1323 |
+
{
|
1324 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
1325 |
+
}
|
1326 |
+
|
1327 |
+
}
|
models/ops/src/ms_deform_attn.h
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#pragma once
|
12 |
+
|
13 |
+
#include "cpu/ms_deform_attn_cpu.h"
|
14 |
+
|
15 |
+
#ifdef WITH_CUDA
|
16 |
+
#include "cuda/ms_deform_attn_cuda.h"
|
17 |
+
#endif
|
18 |
+
|
19 |
+
|
20 |
+
at::Tensor
|
21 |
+
ms_deform_attn_forward(
|
22 |
+
const at::Tensor &value,
|
23 |
+
const at::Tensor &spatial_shapes,
|
24 |
+
const at::Tensor &level_start_index,
|
25 |
+
const at::Tensor &sampling_loc,
|
26 |
+
const at::Tensor &attn_weight,
|
27 |
+
const int im2col_step)
|
28 |
+
{
|
29 |
+
if (value.type().is_cuda())
|
30 |
+
{
|
31 |
+
#ifdef WITH_CUDA
|
32 |
+
return ms_deform_attn_cuda_forward(
|
33 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
34 |
+
#else
|
35 |
+
AT_ERROR("Not compiled with GPU support");
|
36 |
+
#endif
|
37 |
+
}
|
38 |
+
AT_ERROR("Not implemented on the CPU");
|
39 |
+
}
|
40 |
+
|
41 |
+
std::vector<at::Tensor>
|
42 |
+
ms_deform_attn_backward(
|
43 |
+
const at::Tensor &value,
|
44 |
+
const at::Tensor &spatial_shapes,
|
45 |
+
const at::Tensor &level_start_index,
|
46 |
+
const at::Tensor &sampling_loc,
|
47 |
+
const at::Tensor &attn_weight,
|
48 |
+
const at::Tensor &grad_output,
|
49 |
+
const int im2col_step)
|
50 |
+
{
|
51 |
+
if (value.type().is_cuda())
|
52 |
+
{
|
53 |
+
#ifdef WITH_CUDA
|
54 |
+
return ms_deform_attn_cuda_backward(
|
55 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
56 |
+
#else
|
57 |
+
AT_ERROR("Not compiled with GPU support");
|
58 |
+
#endif
|
59 |
+
}
|
60 |
+
AT_ERROR("Not implemented on the CPU");
|
61 |
+
}
|
62 |
+
|
models/ops/src/vision.cpp
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
#include "ms_deform_attn.h"
|
12 |
+
|
13 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
14 |
+
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
15 |
+
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
16 |
+
}
|
models/ops/test.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from __future__ import absolute_import
|
10 |
+
from __future__ import print_function
|
11 |
+
from __future__ import division
|
12 |
+
|
13 |
+
import time
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.autograd import gradcheck
|
17 |
+
|
18 |
+
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
19 |
+
|
20 |
+
|
21 |
+
N, M, D = 1, 2, 2
|
22 |
+
Lq, L, P = 2, 2, 2
|
23 |
+
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
24 |
+
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
|
25 |
+
S = sum([(H*W).item() for H, W in shapes])
|
26 |
+
|
27 |
+
|
28 |
+
torch.manual_seed(3)
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def check_forward_equal_with_pytorch_double():
|
33 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
34 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
35 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
36 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
37 |
+
im2col_step = 2
|
38 |
+
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
|
39 |
+
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
|
40 |
+
fwdok = torch.allclose(output_cuda, output_pytorch)
|
41 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
42 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
43 |
+
|
44 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
45 |
+
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
def check_forward_equal_with_pytorch_float():
|
49 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
50 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
51 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
52 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
53 |
+
im2col_step = 2
|
54 |
+
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
|
55 |
+
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
|
56 |
+
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
57 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
58 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
59 |
+
|
60 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
61 |
+
|
62 |
+
|
63 |
+
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
|
64 |
+
|
65 |
+
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
66 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
67 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
68 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
69 |
+
im2col_step = 2
|
70 |
+
func = MSDeformAttnFunction.apply
|
71 |
+
|
72 |
+
value.requires_grad = grad_value
|
73 |
+
sampling_locations.requires_grad = grad_sampling_loc
|
74 |
+
attention_weights.requires_grad = grad_attn_weight
|
75 |
+
|
76 |
+
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
|
77 |
+
|
78 |
+
print(f'* {gradok} check_gradient_numerical(D={channels})')
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
check_forward_equal_with_pytorch_double()
|
83 |
+
check_forward_equal_with_pytorch_float()
|
84 |
+
|
85 |
+
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
86 |
+
check_gradient_numerical(channels, True, True, True)
|
87 |
+
|
88 |
+
|
89 |
+
|
models/resnet.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import models
|
4 |
+
|
5 |
+
|
6 |
+
def convrelu(in_channels, out_channels, kernel, padding):
|
7 |
+
return nn.Sequential(
|
8 |
+
nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
|
9 |
+
nn.ReLU(inplace=True),
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class ResNetBackbone(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
self.base_model = models.resnet50(pretrained=False)
|
17 |
+
self.base_layers = list(self.base_model.children())
|
18 |
+
|
19 |
+
self.conv_original_size0 = convrelu(3, 64, 3, 1)
|
20 |
+
self.conv_original_size1 = convrelu(64, 64, 3, 1)
|
21 |
+
self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
|
22 |
+
self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
|
23 |
+
self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
|
24 |
+
self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
|
25 |
+
self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
|
26 |
+
|
27 |
+
self.strides = [8, 16, 32]
|
28 |
+
self.num_channels = [512, 1024, 2048]
|
29 |
+
|
30 |
+
def forward(self, inputs):
|
31 |
+
x_original = self.conv_original_size0(inputs)
|
32 |
+
x_original = self.conv_original_size1(x_original)
|
33 |
+
layer0 = self.layer0(inputs)
|
34 |
+
layer1 = self.layer1(layer0)
|
35 |
+
layer2 = self.layer2(layer1)
|
36 |
+
layer3 = self.layer3(layer2)
|
37 |
+
layer4 = self.layer4(layer3)
|
38 |
+
|
39 |
+
xs = {"0": layer2, "1": layer3, "2": layer4}
|
40 |
+
all_feats = {'layer0': layer0, 'layer1': layer1, 'layer2': layer2,
|
41 |
+
'layer3': layer3, 'layer4': layer4, 'x_original': x_original}
|
42 |
+
|
43 |
+
mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
|
44 |
+
return xs, mask, all_feats
|
45 |
+
|
46 |
+
def train(self, mode=True):
|
47 |
+
# Override train so that the training mode is set as we want
|
48 |
+
nn.Module.train(self, mode)
|
49 |
+
if mode:
|
50 |
+
# fix all bn layers
|
51 |
+
def set_bn_eval(m):
|
52 |
+
classname = m.__class__.__name__
|
53 |
+
if classname.find('BatchNorm') != -1:
|
54 |
+
m.eval()
|
55 |
+
|
56 |
+
self.apply(set_bn_eval)
|
57 |
+
|
58 |
+
|
59 |
+
class ResNetUNet(nn.Module):
|
60 |
+
def __init__(self, n_class, out_dim=None, ms_feat=False):
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
self.return_ms_feat = ms_feat
|
64 |
+
self.out_dim = out_dim
|
65 |
+
|
66 |
+
self.base_model = models.resnet50(pretrained=True)
|
67 |
+
self.base_layers = list(self.base_model.children())
|
68 |
+
|
69 |
+
self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
|
70 |
+
# self.layer0_1x1 = convrelu(64, 64, 1, 0)
|
71 |
+
self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
|
72 |
+
# self.layer1_1x1 = convrelu(256, 256, 1, 0)
|
73 |
+
self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
|
74 |
+
# self.layer2_1x1 = convrelu(512, 512, 1, 0)
|
75 |
+
self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
|
76 |
+
# self.layer3_1x1 = convrelu(1024, 1024, 1, 0)
|
77 |
+
self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
|
78 |
+
# self.layer4_1x1 = convrelu(2048, 2048, 1, 0)
|
79 |
+
|
80 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
81 |
+
|
82 |
+
self.conv_up3 = convrelu(1024 + 2048, 1024, 3, 1)
|
83 |
+
self.conv_up2 = convrelu(512 + 1024, 512, 3, 1)
|
84 |
+
self.conv_up1 = convrelu(256 + 512, 256, 3, 1)
|
85 |
+
self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
|
86 |
+
# self.conv_up1 = convrelu(512, 256, 3, 1)
|
87 |
+
# self.conv_up0 = convrelu(256, 128, 3, 1)
|
88 |
+
|
89 |
+
self.conv_original_size0 = convrelu(3, 64, 3, 1)
|
90 |
+
self.conv_original_size1 = convrelu(64, 64, 3, 1)
|
91 |
+
self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
|
92 |
+
# self.conv_last = nn.Conv2d(128, n_class, 1)
|
93 |
+
self.conv_last = nn.Conv2d(64, n_class, 1)
|
94 |
+
if out_dim:
|
95 |
+
self.conv_out = nn.Conv2d(64, out_dim, 1)
|
96 |
+
# self.conv_out = nn.Conv2d(128, out_dim, 1)
|
97 |
+
|
98 |
+
# return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
|
99 |
+
self.strides = [8, 16, 32]
|
100 |
+
self.num_channels = [512, 1024, 2048]
|
101 |
+
|
102 |
+
def forward(self, inputs):
|
103 |
+
x_original = self.conv_original_size0(inputs)
|
104 |
+
x_original = self.conv_original_size1(x_original)
|
105 |
+
|
106 |
+
layer0 = self.layer0(inputs)
|
107 |
+
layer1 = self.layer1(layer0)
|
108 |
+
layer2 = self.layer2(layer1)
|
109 |
+
layer3 = self.layer3(layer2)
|
110 |
+
layer4 = self.layer4(layer3)
|
111 |
+
|
112 |
+
# layer4 = self.layer4_1x1(layer4)
|
113 |
+
x = self.upsample(layer4)
|
114 |
+
# layer3 = self.layer3_1x1(layer3)
|
115 |
+
x = torch.cat([x, layer3], dim=1)
|
116 |
+
x = self.conv_up3(x)
|
117 |
+
layer3_up = x
|
118 |
+
|
119 |
+
x = self.upsample(x)
|
120 |
+
# layer2 = self.layer2_1x1(layer2)
|
121 |
+
x = torch.cat([x, layer2], dim=1)
|
122 |
+
x = self.conv_up2(x)
|
123 |
+
layer2_up = x
|
124 |
+
|
125 |
+
x = self.upsample(x)
|
126 |
+
# layer1 = self.layer1_1x1(layer1)
|
127 |
+
x = torch.cat([x, layer1], dim=1)
|
128 |
+
x = self.conv_up1(x)
|
129 |
+
|
130 |
+
x = self.upsample(x)
|
131 |
+
# layer0 = self.layer0_1x1(layer0)
|
132 |
+
x = torch.cat([x, layer0], dim=1)
|
133 |
+
x = self.conv_up0(x)
|
134 |
+
|
135 |
+
x = self.upsample(x)
|
136 |
+
x = torch.cat([x, x_original], dim=1)
|
137 |
+
x = self.conv_original_size2(x)
|
138 |
+
|
139 |
+
out = self.conv_last(x)
|
140 |
+
out = out.sigmoid().squeeze(1)
|
141 |
+
|
142 |
+
# xs = {"0": layer2, "1": layer3, "2": layer4}
|
143 |
+
xs = {"0": layer2_up, "1": layer3_up, "2": layer4}
|
144 |
+
mask = torch.zeros(inputs.shape)[:, 0, :, :].to(layer4.device)
|
145 |
+
# ms_feats = self.ms_feat(xs, mask)
|
146 |
+
|
147 |
+
if self.return_ms_feat:
|
148 |
+
if self.out_dim:
|
149 |
+
out_feat = self.conv_out(x)
|
150 |
+
out_feat = out_feat.permute(0, 2, 3, 1)
|
151 |
+
return xs, mask, out, out_feat
|
152 |
+
else:
|
153 |
+
return xs, mask, out
|
154 |
+
else:
|
155 |
+
return out
|
156 |
+
|
157 |
+
def train(self, mode=True):
|
158 |
+
# Override train so that the training mode is set as we want
|
159 |
+
nn.Module.train(self, mode)
|
160 |
+
if mode:
|
161 |
+
# fix all bn layers
|
162 |
+
def set_bn_eval(m):
|
163 |
+
classname = m.__class__.__name__
|
164 |
+
if classname.find('BatchNorm') != -1:
|
165 |
+
m.eval()
|
166 |
+
|
167 |
+
self.apply(set_bn_eval)
|
models/stacked_hg.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hourglass network inserted in the pre-activated Resnet
|
3 |
+
Use lr=0.01 for current version
|
4 |
+
(c) Nan Xue (HAWP)
|
5 |
+
(c) Yichao Zhou (LCNN)
|
6 |
+
(c) YANG, Wei
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
__all__ = ["HourglassNet", "hg"]
|
13 |
+
|
14 |
+
|
15 |
+
class Bottleneck2D(nn.Module):
|
16 |
+
expansion = 2
|
17 |
+
|
18 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
19 |
+
super(Bottleneck2D, self).__init__()
|
20 |
+
|
21 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
22 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
|
23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
24 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
|
25 |
+
self.bn3 = nn.BatchNorm2d(planes)
|
26 |
+
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = downsample
|
29 |
+
self.stride = stride
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
residual = x
|
33 |
+
|
34 |
+
out = self.bn1(x)
|
35 |
+
out = self.relu(out)
|
36 |
+
out = self.conv1(out)
|
37 |
+
|
38 |
+
out = self.bn2(out)
|
39 |
+
out = self.relu(out)
|
40 |
+
out = self.conv2(out)
|
41 |
+
|
42 |
+
out = self.bn3(out)
|
43 |
+
out = self.relu(out)
|
44 |
+
out = self.conv3(out)
|
45 |
+
|
46 |
+
if self.downsample is not None:
|
47 |
+
residual = self.downsample(x)
|
48 |
+
|
49 |
+
out += residual
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class Hourglass(nn.Module):
|
55 |
+
def __init__(self, block, num_blocks, planes, depth):
|
56 |
+
super(Hourglass, self).__init__()
|
57 |
+
self.depth = depth
|
58 |
+
self.block = block
|
59 |
+
self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
|
60 |
+
|
61 |
+
def _make_residual(self, block, num_blocks, planes):
|
62 |
+
layers = []
|
63 |
+
for i in range(0, num_blocks):
|
64 |
+
layers.append(block(planes * block.expansion, planes))
|
65 |
+
return nn.Sequential(*layers)
|
66 |
+
|
67 |
+
def _make_hour_glass(self, block, num_blocks, planes, depth):
|
68 |
+
hg = []
|
69 |
+
for i in range(depth):
|
70 |
+
res = []
|
71 |
+
for j in range(3):
|
72 |
+
res.append(self._make_residual(block, num_blocks, planes))
|
73 |
+
if i == 0:
|
74 |
+
res.append(self._make_residual(block, num_blocks, planes))
|
75 |
+
hg.append(nn.ModuleList(res))
|
76 |
+
return nn.ModuleList(hg)
|
77 |
+
|
78 |
+
def _hour_glass_forward(self, n, x):
|
79 |
+
up1 = self.hg[n - 1][0](x)
|
80 |
+
low1 = F.max_pool2d(x, 2, stride=2)
|
81 |
+
low1 = self.hg[n - 1][1](low1)
|
82 |
+
|
83 |
+
if n > 1:
|
84 |
+
low2 = self._hour_glass_forward(n - 1, low1)
|
85 |
+
else:
|
86 |
+
low2 = self.hg[n - 1][3](low1)
|
87 |
+
low3 = self.hg[n - 1][2](low2)
|
88 |
+
up2 = F.interpolate(low3, scale_factor=2)
|
89 |
+
out = up1 + up2
|
90 |
+
return out
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
return self._hour_glass_forward(self.depth, x)
|
94 |
+
|
95 |
+
|
96 |
+
class HourglassNet(nn.Module):
|
97 |
+
"""Hourglass model from Newell et al ECCV 2016"""
|
98 |
+
|
99 |
+
def __init__(self, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes):
|
100 |
+
super(HourglassNet, self).__init__()
|
101 |
+
|
102 |
+
self.inplanes = inplanes
|
103 |
+
self.num_feats = num_feats
|
104 |
+
self.num_stacks = num_stacks
|
105 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
|
106 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
107 |
+
self.relu = nn.ReLU(inplace=True)
|
108 |
+
self.layer1 = self._make_residual(block, self.inplanes, 1)
|
109 |
+
self.layer2 = self._make_residual(block, self.inplanes, 1)
|
110 |
+
self.layer3 = self._make_residual(block, self.num_feats, 1)
|
111 |
+
self.maxpool = nn.MaxPool2d(2, stride=2)
|
112 |
+
|
113 |
+
# build hourglass modules
|
114 |
+
ch = self.num_feats * block.expansion
|
115 |
+
# vpts = []
|
116 |
+
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
|
117 |
+
for i in range(num_stacks):
|
118 |
+
hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
|
119 |
+
res.append(self._make_residual(block, self.num_feats, num_blocks))
|
120 |
+
fc.append(self._make_fc(ch, ch))
|
121 |
+
score.append(head(ch, num_classes))
|
122 |
+
# vpts.append(VptsHead(ch))
|
123 |
+
# vpts.append(nn.Linear(ch, 9))
|
124 |
+
# score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
|
125 |
+
# score[i].bias.data[0] += 4.6
|
126 |
+
# score[i].bias.data[2] += 4.6
|
127 |
+
if i < num_stacks - 1:
|
128 |
+
fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
|
129 |
+
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
|
130 |
+
self.hg = nn.ModuleList(hg)
|
131 |
+
self.res = nn.ModuleList(res)
|
132 |
+
self.fc = nn.ModuleList(fc)
|
133 |
+
self.score = nn.ModuleList(score)
|
134 |
+
# self.vpts = nn.ModuleList(vpts)
|
135 |
+
self.fc_ = nn.ModuleList(fc_)
|
136 |
+
self.score_ = nn.ModuleList(score_)
|
137 |
+
|
138 |
+
def _make_residual(self, block, planes, blocks, stride=1):
|
139 |
+
downsample = None
|
140 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
141 |
+
downsample = nn.Sequential(
|
142 |
+
nn.Conv2d(
|
143 |
+
self.inplanes,
|
144 |
+
planes * block.expansion,
|
145 |
+
kernel_size=1,
|
146 |
+
stride=stride,
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
layers = []
|
151 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
152 |
+
self.inplanes = planes * block.expansion
|
153 |
+
for i in range(1, blocks):
|
154 |
+
layers.append(block(self.inplanes, planes))
|
155 |
+
|
156 |
+
return nn.Sequential(*layers)
|
157 |
+
|
158 |
+
def _make_fc(self, inplanes, outplanes):
|
159 |
+
bn = nn.BatchNorm2d(inplanes)
|
160 |
+
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
|
161 |
+
return nn.Sequential(conv, bn, self.relu)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
out = []
|
165 |
+
x = self.conv1(x)
|
166 |
+
x = self.bn1(x)
|
167 |
+
x = self.relu(x)
|
168 |
+
|
169 |
+
x = self.layer1(x)
|
170 |
+
x = self.maxpool(x)
|
171 |
+
x = self.layer2(x)
|
172 |
+
x = self.layer3(x)
|
173 |
+
|
174 |
+
for i in range(self.num_stacks):
|
175 |
+
y = self.hg[i](x)
|
176 |
+
y = self.res[i](y)
|
177 |
+
y = self.fc[i](y)
|
178 |
+
score = self.score[i](y)
|
179 |
+
out.append(score)
|
180 |
+
|
181 |
+
if i < self.num_stacks - 1:
|
182 |
+
fc_ = self.fc_[i](y)
|
183 |
+
score_ = self.score_[i](score)
|
184 |
+
x = x + fc_ + score_
|
185 |
+
|
186 |
+
return out[::-1], y
|
187 |
+
|
188 |
+
def train(self, mode=True):
|
189 |
+
# Override train so that the training mode is set as we want
|
190 |
+
nn.Module.train(self, mode)
|
191 |
+
if mode:
|
192 |
+
# fix all bn layers
|
193 |
+
def set_bn_eval(m):
|
194 |
+
classname = m.__class__.__name__
|
195 |
+
if classname.find('BatchNorm') != -1:
|
196 |
+
m.eval()
|
197 |
+
|
198 |
+
self.apply(set_bn_eval)
|
199 |
+
|
200 |
+
|
201 |
+
class MultitaskHead(nn.Module):
|
202 |
+
def __init__(self, input_channels, num_class, head_size):
|
203 |
+
super(MultitaskHead, self).__init__()
|
204 |
+
|
205 |
+
m = int(input_channels / 4)
|
206 |
+
heads = []
|
207 |
+
for output_channels in sum(head_size, []):
|
208 |
+
heads.append(
|
209 |
+
nn.Sequential(
|
210 |
+
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
|
211 |
+
nn.ReLU(inplace=True),
|
212 |
+
nn.Conv2d(m, output_channels, kernel_size=1),
|
213 |
+
)
|
214 |
+
)
|
215 |
+
self.heads = nn.ModuleList(heads)
|
216 |
+
assert num_class == sum(sum(head_size, []))
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
return torch.cat([head(x) for head in self.heads], dim=1)
|
220 |
+
|
221 |
+
|
222 |
+
def build_hg():
|
223 |
+
inplanes = 64
|
224 |
+
num_feats = 256 //2
|
225 |
+
depth = 4
|
226 |
+
num_stacks = 2
|
227 |
+
num_blocks = 1
|
228 |
+
head_size = [[2], [2]]
|
229 |
+
|
230 |
+
out_feature_channels = 256
|
231 |
+
|
232 |
+
num_class = sum(sum(head_size, []))
|
233 |
+
model = HourglassNet(
|
234 |
+
block=Bottleneck2D,
|
235 |
+
inplanes = inplanes,
|
236 |
+
num_feats= num_feats,
|
237 |
+
depth=depth,
|
238 |
+
head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
|
239 |
+
num_stacks = num_stacks,
|
240 |
+
num_blocks = num_blocks,
|
241 |
+
num_classes = num_class)
|
242 |
+
|
243 |
+
model.out_feature_channels = out_feature_channels
|
244 |
+
|
245 |
+
return model
|
246 |
+
|
predict.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: [egrt]
|
3 |
+
Date: 2022-08-23 13:21:27
|
4 |
+
LastEditors: [egrt]
|
5 |
+
LastEditTime: 2022-08-23 13:45:21
|
6 |
+
Description:
|
7 |
+
'''
|
8 |
+
#--------------------------------------------------------------#
|
9 |
+
# 对单张图片进行预测,运行结果保存在根目录
|
10 |
+
# 默认保存文件为results/predict_out/predict_srgan.png
|
11 |
+
#--------------------------------------------------------------#
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from HEAT import HEAT
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
heat = HEAT()
|
18 |
+
#----------------------------#
|
19 |
+
# 单张图片的保存路径
|
20 |
+
#----------------------------#
|
21 |
+
save_path = "assets/test_out.jpg"
|
22 |
+
|
23 |
+
while True:
|
24 |
+
img = input('Input image filename:')
|
25 |
+
try:
|
26 |
+
image = Image.open(img)
|
27 |
+
except:
|
28 |
+
print('Open Error! Try again!')
|
29 |
+
continue
|
30 |
+
else:
|
31 |
+
r_image = heat.detect_one_image(image)
|
32 |
+
r_image.save(save_path)
|
33 |
+
r_image.show()
|
qualitative_outdoor/generate_html.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
head = '''
|
7 |
+
<html>
|
8 |
+
<head>
|
9 |
+
<style>
|
10 |
+
td {text-align: center;}
|
11 |
+
</style>
|
12 |
+
</head>
|
13 |
+
<p>
|
14 |
+
</p>
|
15 |
+
<br>
|
16 |
+
<table border="1">
|
17 |
+
'''
|
18 |
+
|
19 |
+
end = '''
|
20 |
+
</table>
|
21 |
+
<br>`
|
22 |
+
</html>
|
23 |
+
'''
|
24 |
+
|
25 |
+
def writeHTML(out_path, results_dirs):
|
26 |
+
f = open(out_path, 'w')
|
27 |
+
f.write(head + '\n')
|
28 |
+
f.write('<tr>'
|
29 |
+
'<td style="background-color:#FFFFFF"> ID </td> '
|
30 |
+
'<td style="background-color:#FFFFFF"> Input </td> '
|
31 |
+
'<td style="background-color:#FFFFFF"> ConvMPN </td> '
|
32 |
+
'<td style="background-color:#FFFFFF"> Exp-cls </td> '
|
33 |
+
'<td style="background-color:#FFFFFF"> HAWP </td> '
|
34 |
+
'<td style="background-color:#FFFFFF"> LETR </td> '
|
35 |
+
'<td style="background-color:#FFFFFF"> HEAT (Ours) </td> '
|
36 |
+
'<td style="background-color:#FFFFFF"> G.T. </td> '
|
37 |
+
'</tr>')
|
38 |
+
|
39 |
+
fileids_path = '../data/cities_dataset/valid_list.txt'
|
40 |
+
img_base = '../data/cities_dataset/rgb'
|
41 |
+
with open(fileids_path) as ff:
|
42 |
+
file_ids = ff.readlines()
|
43 |
+
file_ids = file_ids[50:]
|
44 |
+
file_ids = [file_id.strip() for file_id in file_ids]
|
45 |
+
permuted_ids = np.random.permutation(file_ids)
|
46 |
+
file_ids = permuted_ids[:100]
|
47 |
+
|
48 |
+
for file_id in file_ids:
|
49 |
+
row_str = '<tr>'
|
50 |
+
row_str += '<td> {} </td>'.format(file_id)
|
51 |
+
row_str += '<td> <img src="{}" width="180"> </td>'.format(os.path.join(img_base, file_id + '.jpg'))
|
52 |
+
for dir_idx, result_dir in enumerate(results_dirs):
|
53 |
+
pred_filepath = osp.join(result_dir, '{}.png'.format(file_id))
|
54 |
+
row_str += '<td> <img src="{}" width="180"> </td>'.format(pred_filepath)
|
55 |
+
row_str += '</tr>'
|
56 |
+
f.write(row_str + '\n')
|
57 |
+
|
58 |
+
f.write(end + '\n')
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
results_dirs = ['svg_images_256/convmpn', 'svg_images_256/exp_cls', 'svg_images_256/hawp', 'svg_images_256/letr', 'svg_images_256/heat', 'svg_images_256/gt']
|
63 |
+
|
64 |
+
writeHTML(out_path='./outdoor_qual.html', results_dirs=results_dirs)
|
qualitative_outdoor/plot_utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import svgwrite
|
3 |
+
import colorsys
|
4 |
+
|
5 |
+
|
6 |
+
def plot_preds(image, corners, edges):
|
7 |
+
for line in edges:
|
8 |
+
cv2.line(image, tuple(line[:2]), tuple(line[2:]), (255, 255, 0), 2)
|
9 |
+
for c in corners:
|
10 |
+
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
|
11 |
+
return image
|
12 |
+
|
13 |
+
|
14 |
+
def random_colors(N, bright=True, same=False, colors=None):
|
15 |
+
brightness = 1.0 if bright else 0.7
|
16 |
+
if colors is None or same:
|
17 |
+
if same:
|
18 |
+
hsv = [(0, 1, brightness) for i in range(N)]
|
19 |
+
else:
|
20 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
21 |
+
else:
|
22 |
+
hsv = [(colors[i], 1, brightness) for i in range(N)]
|
23 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
24 |
+
return colors
|
25 |
+
|
26 |
+
|
27 |
+
def svg_generate(image_link, corners, edges, name, size=512):
|
28 |
+
dwg = svgwrite.Drawing(name + '.svg', size=('{}'.format(size), '{}'.format(size)))
|
29 |
+
shapes = dwg.add(dwg.g(id='shape', fill='black'))
|
30 |
+
# colors = random_colors(len(edges), same=True)
|
31 |
+
shapes.add(dwg.image(href=image_link, size=(size, size)))
|
32 |
+
|
33 |
+
scale = size / 256
|
34 |
+
for i, edge in enumerate(edges):
|
35 |
+
x = edge[:2] * scale
|
36 |
+
y = edge[2:] * scale
|
37 |
+
shapes.add(dwg.line((int(x[0]), int(x[1])), (int(y[0]), int(y[1])),
|
38 |
+
stroke="#EE6507", stroke_width=3*scale, opacity=0.7))
|
39 |
+
|
40 |
+
for i, corner in enumerate(corners):
|
41 |
+
shapes.add(dwg.circle((int(corners[i][0] * scale), int(corners[i][1]) * scale), r=4*scale,
|
42 |
+
stroke='green', fill='white', stroke_width=2*scale, opacity=0.8))
|
43 |
+
return dwg
|
qualitative_outdoor/visualize_gt.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from plot_utils import plot_preds, svg_generate
|
6 |
+
import cairosvg
|
7 |
+
|
8 |
+
image_base = '../data/cities_dataset/rgb/'
|
9 |
+
annot_base = '../data/cities_dataset/annot/'
|
10 |
+
data_filename = '../data/cities_dataset/valid_list.txt'
|
11 |
+
with open(data_filename) as f:
|
12 |
+
filenames = f.readlines()
|
13 |
+
|
14 |
+
filenames = filenames[50:]
|
15 |
+
filenames = [filename.strip() for filename in filenames]
|
16 |
+
|
17 |
+
|
18 |
+
for filename in filenames:
|
19 |
+
image_path = os.path.join(image_base, filename + '.jpg')
|
20 |
+
# image = cv2.imread(image_path)
|
21 |
+
annot_path = os.path.join(annot_base, filename + '.npy')
|
22 |
+
|
23 |
+
annot = np.load(annot_path, allow_pickle=True, encoding='latin1').tolist()
|
24 |
+
corners = np.array(list(annot.keys())).astype(np.int)
|
25 |
+
|
26 |
+
edges = set()
|
27 |
+
for c, others in annot.items():
|
28 |
+
for other_c in others:
|
29 |
+
edge = (c[0], c[1], other_c[0], other_c[1])
|
30 |
+
edge_2 = (other_c[0], other_c[1], c[0], c[1])
|
31 |
+
if edge not in edges and edge_2 not in edges:
|
32 |
+
edges.add(edge)
|
33 |
+
|
34 |
+
edges = np.array(list(edges)).astype(np.int)
|
35 |
+
|
36 |
+
# image = plot_preds(image, corners, edges)
|
37 |
+
# out_path = os.path.join(out_base, filename + '.png')
|
38 |
+
# cv2.imwrite(out_path, image)
|
39 |
+
|
40 |
+
svg = svg_generate(image_path, corners, edges, name='temp', size=256)
|
41 |
+
svg_path = './svg_results/' + 'tmp.svg'
|
42 |
+
svg.saveas(svg_path)
|
43 |
+
svg_img_path = './svg_images_256/gt/' + '{}.png'.format(filename)
|
44 |
+
cairosvg.svg2png(url=svg_path, write_to=svg_img_path)
|
45 |
+
|
46 |
+
|
qualitative_outdoor/visualize_npy.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import cairosvg
|
6 |
+
from plot_utils import plot_preds, svg_generate
|
7 |
+
|
8 |
+
image_base = '../../data/outdoor/cities_dataset/rgb/'
|
9 |
+
svg_base = './svg_results'
|
10 |
+
|
11 |
+
if not os.path.exists(svg_base):
|
12 |
+
os.makedirs(svg_base)
|
13 |
+
|
14 |
+
data_filename = '../data/outdoor/cities_dataset/valid_list.txt'
|
15 |
+
with open(data_filename) as f:
|
16 |
+
filenames = f.readlines()
|
17 |
+
|
18 |
+
filenames = filenames[50:] # according to previous works, the testing samples are the last 350 samples of the val split
|
19 |
+
filenames = [filename.strip() for filename in filenames]
|
20 |
+
idx_to_filename = {idx: filename for idx, filename in enumerate(filenames)}
|
21 |
+
|
22 |
+
method_name = 'heat'
|
23 |
+
results_base = '../results/npy_outdoor_test_256/'
|
24 |
+
|
25 |
+
svg_method_base = os.path.join(svg_base, method_name)
|
26 |
+
if not os.path.exists(svg_method_base):
|
27 |
+
os.makedirs(svg_method_base)
|
28 |
+
|
29 |
+
for result_filename in sorted(os.listdir(results_base)):
|
30 |
+
file_idx = int(result_filename[:-12])
|
31 |
+
filename = idx_to_filename[file_idx]
|
32 |
+
|
33 |
+
image_path = os.path.join(image_base, filename + '.jpg')
|
34 |
+
|
35 |
+
results_path = os.path.join(results_base, result_filename)
|
36 |
+
results = np.load(results_path, allow_pickle=True).tolist()
|
37 |
+
corners = results['corners'].astype(np.int)
|
38 |
+
edge_ids = results['edges']
|
39 |
+
edges = corners[edge_ids].reshape(edge_ids.shape[0], -1)
|
40 |
+
|
41 |
+
svg = svg_generate(image_path, corners, edges, name='temp', size=256)
|
42 |
+
svg_path = os.path.join(svg_base, 'tmp.svg')
|
43 |
+
svg.saveas(svg_path) # save the svg file temporarily
|
44 |
+
|
45 |
+
svg_img_path = os.path.join(svg_method_base, '{}.png'.format(filename))
|
46 |
+
cairosvg.svg2png(url=svg_path, write_to=svg_img_path)
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Cython==0.29.22
|
2 |
+
defusedxml==0.6.0
|
3 |
+
einops==0.4.1
|
4 |
+
future==0.18.2
|
5 |
+
imageio==2.16.1
|
6 |
+
matplotlib==3.3.4
|
7 |
+
MultiScaleDeformableAttention==1.0
|
8 |
+
numpy==1.20.1
|
9 |
+
opencv-python==4.4.0.44
|
10 |
+
packaging==20.9
|
11 |
+
Pillow==9.0.1
|
12 |
+
prometheus-client==0.9.0
|
13 |
+
prompt-toolkit==3.0.16
|
14 |
+
ptyprocess==0.7.0
|
15 |
+
pycparser==2.20
|
16 |
+
Pygments==2.8.0
|
17 |
+
python-dateutil==2.8.1
|
18 |
+
scikit-image==0.19.2
|
19 |
+
scikit-learn==1.0
|
20 |
+
scipy==1.6.1
|
21 |
+
six==1.15.0
|
22 |
+
torch==1.5.1
|
23 |
+
torchvision==0.6.1
|
24 |
+
cairosvg==2.5.2
|
25 |
+
svgwrite==1.4.2
|
26 |
+
shapely==1.8.2
|
27 |
+
gradio==2.5.3
|
s3d_floorplan_eval/DataRW/DataRW.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class DataRW:
|
3 |
+
def __init__(self, options):
|
4 |
+
pass
|
s3d_floorplan_eval/DataRW/S3DRW.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
from DataRW.DataRW import DataRW
|
8 |
+
from S3DLoader.S3DLoader import S3DLoader
|
9 |
+
|
10 |
+
class S3DRW(DataRW):
|
11 |
+
def __init__(self, options):
|
12 |
+
"""
|
13 |
+
Class for accessing FloorNet dataset related data
|
14 |
+
|
15 |
+
:param options:
|
16 |
+
"""
|
17 |
+
# initialize the base class variables
|
18 |
+
super(DataRW, self).__init__()
|
19 |
+
|
20 |
+
self.options = options
|
21 |
+
|
22 |
+
self.dataset_path = options.dataset_path
|
23 |
+
self.scene_id = options.scene_id
|
24 |
+
|
25 |
+
self.mcts_path = options.mcts_path
|
26 |
+
self.creation_time = int(time.time())
|
27 |
+
|
28 |
+
self.device = torch.device("cpu")
|
29 |
+
|
30 |
+
# mode = "train"
|
31 |
+
# mode = "online_eval"
|
32 |
+
mode = "test"
|
33 |
+
# For validation only
|
34 |
+
# self.loader = S3DLoader(options, 'online_eval').dataset
|
35 |
+
self.loader = S3DLoader(options, mode).dataset
|
36 |
+
|
37 |
+
# gt_sample = iter(floornet_loader.dataset[int(self.scene_id)])
|
38 |
+
# self.gt_sample = floornet_loader.load_sample(list(iter(floornet_loader.dataset))[int(self.scene_id)])
|
39 |
+
|
40 |
+
if mode == "online_eval":
|
41 |
+
scene_ind = int(self.scene_id[6:]) - 3000
|
42 |
+
elif mode == "test":
|
43 |
+
scene_ind = int(self.scene_id[6:]) - 3250
|
44 |
+
elif mode == "train":
|
45 |
+
scene_ind = int(self.scene_id[6:])
|
46 |
+
else:
|
47 |
+
assert False
|
48 |
+
|
49 |
+
# print(len(list(iter(self.s3d_loader.data))))
|
50 |
+
self.gt_sample = gt_sample = self.loader[scene_ind]
|
51 |
+
self.gt_sample["density_map"] = torch.tensor(self.gt_sample["density_map"][None], device=self.device)
|
52 |
+
self.gt_sample["room_map"] = torch.tensor(self.gt_sample["room_map"][None,:,:,None], device=self.device)
|
53 |
+
self.gt_sample["wall_map"] = torch.tensor(self.gt_sample["wall_map"][None,:,:,None], device=self.device)
|
54 |
+
|
55 |
+
|
56 |
+
self.density_map = self.gt_sample['density_map'][:,:,:,None]
|
57 |
+
|
58 |
+
self.h, self.w = self.density_map.shape[1], self.density_map.shape[2]
|
59 |
+
|
60 |
+
self.generate_input_map_from_props = self.generate_input_dict_from_room_props
|
61 |
+
|
62 |
+
def get_gt_solution(self):
|
63 |
+
"""
|
64 |
+
Read top-view density map of the scene
|
65 |
+
|
66 |
+
:return:
|
67 |
+
"""
|
68 |
+
img_path = os.path.join(self.dataset_path, str(self.scene_id) + "_density.png")
|
69 |
+
density_map = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)[:,:, 0][None,:,:,None]
|
70 |
+
|
71 |
+
density_map = torch.from_numpy(density_map).to(self.device)
|
72 |
+
|
73 |
+
dm_min = torch.min(density_map)
|
74 |
+
dm_max = torch.max(density_map)
|
75 |
+
|
76 |
+
density_map = (density_map - dm_min) / (dm_max - dm_min)
|
77 |
+
|
78 |
+
return density_map.type(torch.cuda.FloatTensor)
|
79 |
+
|
80 |
+
def polygonize_mask(self, pm, return_mask=True):
|
81 |
+
pm_np = pm.cpu().numpy()
|
82 |
+
|
83 |
+
room_mask = 255 * (pm_np == 1)
|
84 |
+
room_mask = room_mask.astype(np.uint8)
|
85 |
+
room_mask_inv = 255 - room_mask
|
86 |
+
|
87 |
+
ret, thresh = cv2.threshold(room_mask_inv, 250, 255, cv2.THRESH_BINARY_INV)
|
88 |
+
|
89 |
+
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
90 |
+
|
91 |
+
cnt = contours[0]
|
92 |
+
max_area = cv2.contourArea(cnt)
|
93 |
+
|
94 |
+
for cont in contours:
|
95 |
+
if cv2.contourArea(cont) > max_area:
|
96 |
+
cnt = cont
|
97 |
+
max_area = cv2.contourArea(cont)
|
98 |
+
|
99 |
+
# define main island contour approx. and hull
|
100 |
+
perimeter = cv2.arcLength(cnt, True)
|
101 |
+
epsilon = 0.01 * cv2.arcLength(cnt, True)
|
102 |
+
approx = cv2.approxPolyDP(cnt, epsilon, True)
|
103 |
+
|
104 |
+
# approx = np.concatenate([approx, approx[0][None]], axis=0)
|
105 |
+
approx = approx.astype(np.int32).reshape((1, -1, 2))
|
106 |
+
|
107 |
+
if return_mask:
|
108 |
+
room_filled_map = np.zeros((self.h, self.w))
|
109 |
+
cv2.fillPoly(room_filled_map, approx, color=1.)
|
110 |
+
|
111 |
+
room_filled_map = torch.tensor(room_filled_map[:,:], dtype=torch.float32, device=self.device)
|
112 |
+
|
113 |
+
return room_filled_map
|
114 |
+
else:
|
115 |
+
approx_tensor = torch.tensor(approx, device=self.device)
|
116 |
+
return approx_tensor
|
117 |
+
|
118 |
+
def generate_input_dict_from_room_props(self, room_prop_list, score_function, use_thresh=False):
|
119 |
+
"""
|
120 |
+
|
121 |
+
:param room_prop_list:
|
122 |
+
:type room_prop_list: list of FloorPlanRoomProp
|
123 |
+
:param score_function:
|
124 |
+
:return:
|
125 |
+
"""
|
126 |
+
|
127 |
+
if score_function == "room_maskrcnn_iou":
|
128 |
+
inputs = self.generate_input_dict_for_room_maskrcnn_iou(room_prop_list)
|
129 |
+
elif score_function == "room_iou":
|
130 |
+
inputs = self.generate_input_dict_for_room_iou(room_prop_list, use_thresh=use_thresh)
|
131 |
+
else:
|
132 |
+
assert "generate_input_dict_from_room_props for %s not implemented" % score_function
|
133 |
+
|
134 |
+
return inputs
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
s3d_floorplan_eval/DataRW/wrong_annotatios.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
wrong_s3d_annotations_list = [3261, 3271, 3276, 3296, 3342, 3387, 3398, 3466, 3496]
|
s3d_floorplan_eval/Evaluator/Evaluator.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from scipy.spatial import Delaunay
|
7 |
+
import os
|
8 |
+
import shapely
|
9 |
+
from shapely.geometry import Polygon, MultiPolygon, LineString, MultiLineString
|
10 |
+
|
11 |
+
corner_metric_thresh = 10
|
12 |
+
angle_metric_thresh = 5
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
# colormap_255 = [[i, i, i] for i in range(40)]
|
17 |
+
|
18 |
+
class Evaluator():
|
19 |
+
def __init__(self, data_rw, options):
|
20 |
+
self.data_rw = data_rw
|
21 |
+
self.options = options
|
22 |
+
|
23 |
+
self.device = torch.device("cuda")
|
24 |
+
|
25 |
+
def polygonize_mask(self, mask, degree, return_mask=True):
|
26 |
+
h, w = mask.shape[0], mask.shape[1]
|
27 |
+
mask = mask
|
28 |
+
|
29 |
+
room_mask = 255 * (mask == 1)
|
30 |
+
room_mask = room_mask.astype(np.uint8)
|
31 |
+
room_mask_inv = 255 - room_mask
|
32 |
+
|
33 |
+
ret, thresh = cv2.threshold(room_mask_inv, 250, 255, cv2.THRESH_BINARY_INV)
|
34 |
+
|
35 |
+
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
36 |
+
|
37 |
+
cnt = contours[0]
|
38 |
+
max_area = cv2.contourArea(cnt)
|
39 |
+
|
40 |
+
for cont in contours:
|
41 |
+
if cv2.contourArea(cont) > max_area:
|
42 |
+
cnt = cont
|
43 |
+
max_area = cv2.contourArea(cont)
|
44 |
+
|
45 |
+
perimeter = cv2.arcLength(cnt, True)
|
46 |
+
# epsilon = 0.01 * cv2.arcLength(cnt, True)
|
47 |
+
epsilon = degree * cv2.arcLength(cnt, True)
|
48 |
+
approx = cv2.approxPolyDP(cnt, epsilon, True)
|
49 |
+
|
50 |
+
# approx = np.concatenate([approx, approx[0][None]], axis=0)
|
51 |
+
approx = approx.astype(np.int32).reshape((-1, 2))
|
52 |
+
|
53 |
+
# approx_tensor = torch.tensor(approx, device=self.device)
|
54 |
+
|
55 |
+
# return approx_tensor
|
56 |
+
if return_mask:
|
57 |
+
room_filled_map = np.zeros((h, w))
|
58 |
+
cv2.fillPoly(room_filled_map, [approx], color=1.)
|
59 |
+
|
60 |
+
return approx, room_filled_map
|
61 |
+
else:
|
62 |
+
return approx
|
63 |
+
|
64 |
+
def print_res_str_for_latex(self, quant_result_dict):
|
65 |
+
|
66 |
+
str_fields = ""
|
67 |
+
str_values = ""
|
68 |
+
|
69 |
+
avg_value_prec = 0
|
70 |
+
avg_value_rec = 0
|
71 |
+
for k_ind, k in enumerate(quant_result_dict.keys()):
|
72 |
+
str_fields += " & " + k
|
73 |
+
str_values += " & %.2f " % quant_result_dict[k]
|
74 |
+
|
75 |
+
if k_ind % 2 == 0:
|
76 |
+
avg_value_prec += quant_result_dict[k] / 3
|
77 |
+
else:
|
78 |
+
avg_value_rec += quant_result_dict[k] / 3
|
79 |
+
|
80 |
+
str_fields += "tm_prec & tm_rec"
|
81 |
+
|
82 |
+
str_values += " & %.2f " % avg_value_prec
|
83 |
+
str_values += " & %.2f " % avg_value_rec
|
84 |
+
|
85 |
+
str_fields += " \\\\"
|
86 |
+
str_values += " \\\\"
|
87 |
+
|
88 |
+
print(str_fields)
|
89 |
+
print(str_values)
|
90 |
+
|
91 |
+
def calc_gradient(self, room_map):
|
92 |
+
grad_x = np.abs(room_map[:, 1:] - room_map[:, :-1])
|
93 |
+
grad_y = np.abs(room_map[1:] - room_map[:-1])
|
94 |
+
|
95 |
+
grad_xy = np.zeros_like(room_map)
|
96 |
+
grad_xy[1:] = grad_y
|
97 |
+
grad_xy[:, 1:] = np.maximum(grad_x, grad_xy[:,1:])
|
98 |
+
|
99 |
+
plt.figure()
|
100 |
+
plt.axis("off")
|
101 |
+
plt.imshow(grad_xy, cmap="gray")
|
102 |
+
# plt.show()
|
103 |
+
plt.savefig("grad.png", bbox_inches='tight')
|
104 |
+
|
105 |
+
plt.figure()
|
106 |
+
plt.axis("off")
|
107 |
+
plt.imshow(room_map, cmap="gray")
|
108 |
+
# plt.show()
|
109 |
+
plt.savefig("joint_mask.png", bbox_inches='tight')
|
110 |
+
assert False
|
111 |
+
|
112 |
+
def evaluate_scene(self, room_polys, show=False, name="ours", dataset_type="s3d"):
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
joint_room_map = np.zeros((self.options.height, self.options.width))
|
116 |
+
|
117 |
+
edge_map = np.zeros_like(joint_room_map)
|
118 |
+
room_filled_map = np.ones([joint_room_map.shape[0], joint_room_map.shape[1], 3])
|
119 |
+
|
120 |
+
density_map = self.data_rw.density_map.cpu().numpy()[0]
|
121 |
+
img_size = (density_map.shape[0], density_map.shape[0])
|
122 |
+
|
123 |
+
for room_ind, poly in enumerate(room_polys):
|
124 |
+
cv2.polylines(edge_map, [poly], isClosed=True, color=1.)
|
125 |
+
cv2.fillPoly(joint_room_map, [poly], color=1.)
|
126 |
+
|
127 |
+
joint_room_map_vis = np.ones([joint_room_map.shape[0], joint_room_map.shape[1], 3])
|
128 |
+
|
129 |
+
# Ground Truth
|
130 |
+
|
131 |
+
gt_polys_list = self.data_rw.gt_sample["polygons_list"]
|
132 |
+
gt_polys_list = [np.concatenate([poly, poly[None, 0]]) for poly in gt_polys_list]
|
133 |
+
|
134 |
+
ignore_mask_region = self.data_rw.gt_sample["wall_map"].cpu().numpy()[0, :, :, 0]
|
135 |
+
|
136 |
+
img_size = (joint_room_map.shape[0], joint_room_map.shape[1])
|
137 |
+
quant_result_dict = self.get_quantitative(gt_polys_list, ignore_mask_region, room_polys, img_size, dataset_type=dataset_type)
|
138 |
+
|
139 |
+
return quant_result_dict
|
140 |
+
|
141 |
+
def get_quantitative(self, gt_polys, ignore_mask_region, pred_polys=None, masks_list=None, img_size=(256, 256), dataset_type="s3d"):
|
142 |
+
def get_room_metric():
|
143 |
+
pred_overlaps = [False] * len(pred_room_map_list)
|
144 |
+
|
145 |
+
for pred_ind1 in range(len(pred_room_map_list) - 1):
|
146 |
+
pred_map1 = pred_room_map_list[pred_ind1]
|
147 |
+
|
148 |
+
for pred_ind2 in range(pred_ind1 + 1, len(pred_room_map_list)):
|
149 |
+
pred_map2 = pred_room_map_list[pred_ind2]
|
150 |
+
|
151 |
+
if dataset_type == "s3d":
|
152 |
+
kernel = np.ones((5, 5), np.uint8)
|
153 |
+
else:
|
154 |
+
kernel = np.ones((3, 3), np.uint8)
|
155 |
+
|
156 |
+
# todo: for our method, the rooms share corners and edges, need to check here
|
157 |
+
pred_map1_er = cv2.erode(pred_map1, kernel)
|
158 |
+
pred_map2_er = cv2.erode(pred_map2, kernel)
|
159 |
+
|
160 |
+
intersection = (pred_map1_er + pred_map2_er) == 2
|
161 |
+
# intersection = (pred_map1 + pred_map2) == 2
|
162 |
+
|
163 |
+
intersection_area = np.sum(intersection)
|
164 |
+
|
165 |
+
if intersection_area >= 1:
|
166 |
+
pred_overlaps[pred_ind1] = True
|
167 |
+
pred_overlaps[pred_ind2] = True
|
168 |
+
|
169 |
+
# import pdb; pdb.set_trace()
|
170 |
+
room_metric = [np.bool((1 - pred_overlaps[ind]) * pred2gt_exists[ind]) for ind in range(len(pred_polys))]
|
171 |
+
|
172 |
+
return room_metric
|
173 |
+
|
174 |
+
def get_corner_metric():
|
175 |
+
|
176 |
+
room_corners_metric = []
|
177 |
+
for pred_poly_ind, gt_poly_ind in enumerate(pred2gt_indices):
|
178 |
+
p_poly = pred_polys[pred_poly_ind][:-1] # Last vertex = First vertex
|
179 |
+
|
180 |
+
p_poly_corner_metrics = [False] * p_poly.shape[0]
|
181 |
+
if not room_metric[pred_poly_ind]:
|
182 |
+
room_corners_metric += p_poly_corner_metrics
|
183 |
+
continue
|
184 |
+
|
185 |
+
gt_poly = gt_polys[gt_poly_ind][:-1]
|
186 |
+
|
187 |
+
# for v in p_poly:
|
188 |
+
# v_dists = np.linalg.norm(v[None,:] - gt_poly, axis=1, ord=2)
|
189 |
+
# v_min_dist = np.min(v_dists)
|
190 |
+
#
|
191 |
+
# v_tp = v_min_dist <= 10
|
192 |
+
# room_corners_metric.append(v_tp)
|
193 |
+
|
194 |
+
for v in gt_poly:
|
195 |
+
v_dists = np.linalg.norm(v[None,:] - p_poly, axis=1, ord=2)
|
196 |
+
v_min_dist_ind = np.argmin(v_dists)
|
197 |
+
v_min_dist = v_dists[v_min_dist_ind]
|
198 |
+
|
199 |
+
if not p_poly_corner_metrics[v_min_dist_ind]:
|
200 |
+
v_tp = v_min_dist <= corner_metric_thresh
|
201 |
+
p_poly_corner_metrics[v_min_dist_ind] = v_tp
|
202 |
+
|
203 |
+
room_corners_metric += p_poly_corner_metrics
|
204 |
+
|
205 |
+
return room_corners_metric
|
206 |
+
|
207 |
+
def get_angle_metric():
|
208 |
+
|
209 |
+
def get_line_vector(p1, p2):
|
210 |
+
p1 = np.concatenate((p1, np.array([1])))
|
211 |
+
p2 = np.concatenate((p2, np.array([1])))
|
212 |
+
|
213 |
+
line_vector = -np.cross(p1, p2)
|
214 |
+
|
215 |
+
return line_vector
|
216 |
+
|
217 |
+
def get_poly_orientation(my_poly):
|
218 |
+
angles_sum = 0
|
219 |
+
for v_ind, _ in enumerate(my_poly):
|
220 |
+
if v_ind < len(my_poly) - 1:
|
221 |
+
v_sides = my_poly[[v_ind - 1, v_ind, v_ind, v_ind + 1], :]
|
222 |
+
else:
|
223 |
+
v_sides = my_poly[[v_ind - 1, v_ind, v_ind, 0], :]
|
224 |
+
|
225 |
+
v1_vector = get_line_vector(v_sides[0], v_sides[1])
|
226 |
+
v1_vector = v1_vector / (np.linalg.norm(v1_vector, ord=2) + 1e-4)
|
227 |
+
v2_vector = get_line_vector(v_sides[2], v_sides[3])
|
228 |
+
v2_vector = v2_vector / (np.linalg.norm(v2_vector, ord=2) + 1e-4)
|
229 |
+
|
230 |
+
orientation = (v_sides[1, 1] - v_sides[0, 1]) * (v_sides[3, 0] - v_sides[1, 0]) - (
|
231 |
+
v_sides[3, 1] - v_sides[1, 1]) * (
|
232 |
+
v_sides[1, 0] - v_sides[0, 0])
|
233 |
+
|
234 |
+
v1_vector_2d = v1_vector[:2] / (v1_vector[2] + 1e-4)
|
235 |
+
v2_vector_2d = v2_vector[:2] / (v2_vector[2] + 1e-4)
|
236 |
+
|
237 |
+
v1_vector_2d = v1_vector_2d / (np.linalg.norm(v1_vector_2d, ord=2) + 1e-4)
|
238 |
+
v2_vector_2d = v2_vector_2d / (np.linalg.norm(v2_vector_2d, ord=2) + 1e-4)
|
239 |
+
|
240 |
+
angle_cos = v1_vector_2d.dot(v2_vector_2d)
|
241 |
+
angle_cos = np.clip(angle_cos, -1, 1)
|
242 |
+
|
243 |
+
# G.T. has clockwise orientation, remove minus in the equation
|
244 |
+
|
245 |
+
angle = np.sign(orientation) * np.abs(np.arccos(angle_cos))
|
246 |
+
angle_degree = angle * 180 / np.pi
|
247 |
+
|
248 |
+
angles_sum += angle_degree
|
249 |
+
|
250 |
+
return np.sign(angles_sum)
|
251 |
+
|
252 |
+
def get_angle_v_sides(inp_v_sides, poly_orient):
|
253 |
+
v1_vector = get_line_vector(inp_v_sides[0], inp_v_sides[1])
|
254 |
+
v1_vector = v1_vector / (np.linalg.norm(v1_vector, ord=2) + 1e-4)
|
255 |
+
v2_vector = get_line_vector(inp_v_sides[2], inp_v_sides[3])
|
256 |
+
v2_vector = v2_vector / (np.linalg.norm(v2_vector, ord=2) + 1e-4)
|
257 |
+
|
258 |
+
orientation = (inp_v_sides[1, 1] - inp_v_sides[0, 1]) * (inp_v_sides[3, 0] - inp_v_sides[1, 0]) - (
|
259 |
+
inp_v_sides[3, 1] - inp_v_sides[1, 1]) * (
|
260 |
+
inp_v_sides[1, 0] - inp_v_sides[0, 0])
|
261 |
+
|
262 |
+
v1_vector_2d = v1_vector[:2] / (v1_vector[2]+ 1e-4)
|
263 |
+
v2_vector_2d = v2_vector[:2] / (v2_vector[2]+ 1e-4)
|
264 |
+
|
265 |
+
v1_vector_2d = v1_vector_2d / (np.linalg.norm(v1_vector_2d, ord=2) + 1e-4)
|
266 |
+
v2_vector_2d = v2_vector_2d / (np.linalg.norm(v2_vector_2d, ord=2) + 1e-4)
|
267 |
+
|
268 |
+
angle_cos = v1_vector_2d.dot(v2_vector_2d)
|
269 |
+
angle_cos = np.clip(angle_cos, -1, 1)
|
270 |
+
|
271 |
+
angle = poly_orient * np.sign(orientation) * np.arccos(angle_cos)
|
272 |
+
angle_degree = angle * 180 / np.pi
|
273 |
+
|
274 |
+
return angle_degree
|
275 |
+
|
276 |
+
room_angles_metric = []
|
277 |
+
for pred_poly_ind, gt_poly_ind in enumerate(pred2gt_indices):
|
278 |
+
p_poly = pred_polys[pred_poly_ind][:-1] # Last vertex = First vertex
|
279 |
+
|
280 |
+
p_poly_angle_metrics = [False] * p_poly.shape[0]
|
281 |
+
if not room_metric[pred_poly_ind]:
|
282 |
+
room_angles_metric += p_poly_angle_metrics
|
283 |
+
continue
|
284 |
+
|
285 |
+
gt_poly = gt_polys[gt_poly_ind][:-1]
|
286 |
+
|
287 |
+
# for v in p_poly:
|
288 |
+
# v_dists = np.linalg.norm(v[None,:] - gt_poly, axis=1, ord=2)
|
289 |
+
# v_min_dist = np.min(v_dists)
|
290 |
+
#
|
291 |
+
# v_tp = v_min_dist <= 10
|
292 |
+
# room_corners_metric.append(v_tp)
|
293 |
+
|
294 |
+
gt_poly_orient = get_poly_orientation(gt_poly)
|
295 |
+
p_poly_orient = get_poly_orientation(p_poly)
|
296 |
+
|
297 |
+
for v_gt_ind, v in enumerate(gt_poly):
|
298 |
+
v_dists = np.linalg.norm(v[None,:] - p_poly, axis=1, ord=2)
|
299 |
+
v_ind = np.argmin(v_dists)
|
300 |
+
v_min_dist = v_dists[v_ind]
|
301 |
+
|
302 |
+
if v_min_dist > corner_metric_thresh:
|
303 |
+
# room_angles_metric.append(False)
|
304 |
+
continue
|
305 |
+
|
306 |
+
if v_ind < len(p_poly) - 1:
|
307 |
+
v_sides = p_poly[[v_ind - 1, v_ind, v_ind, v_ind + 1], :]
|
308 |
+
else:
|
309 |
+
v_sides = p_poly[[v_ind - 1, v_ind, v_ind, 0], :]
|
310 |
+
|
311 |
+
v_sides = v_sides.reshape((4,2))
|
312 |
+
pred_angle_degree = get_angle_v_sides(v_sides, p_poly_orient)
|
313 |
+
|
314 |
+
# Note: replacing some variables with values from the g.t. poly
|
315 |
+
|
316 |
+
if v_gt_ind < len(gt_poly) - 1:
|
317 |
+
v_sides = gt_poly[[v_gt_ind - 1, v_gt_ind, v_gt_ind, v_gt_ind + 1], :]
|
318 |
+
else:
|
319 |
+
v_sides = gt_poly[[v_gt_ind - 1, v_gt_ind, v_gt_ind, 0], :]
|
320 |
+
|
321 |
+
v_sides = v_sides.reshape((4, 2))
|
322 |
+
gt_angle_degree = get_angle_v_sides(v_sides, gt_poly_orient)
|
323 |
+
|
324 |
+
angle_metric = np.abs(pred_angle_degree - gt_angle_degree)
|
325 |
+
|
326 |
+
# room_angles_metric.append(angle_metric < 5)
|
327 |
+
p_poly_angle_metrics[v_ind] = angle_metric <= angle_metric_thresh
|
328 |
+
|
329 |
+
# if angle_metric > 5:
|
330 |
+
# print(v_gt_ind, angle_metric)
|
331 |
+
# print(pred_angle_degree, gt_angle_degree)
|
332 |
+
# input("?")
|
333 |
+
|
334 |
+
|
335 |
+
room_angles_metric += p_poly_angle_metrics
|
336 |
+
|
337 |
+
for am, cm in zip(room_angles_metric, corner_metric):
|
338 |
+
assert not (cm == False and am == True), "cm: %d am: %d" %(cm, am)
|
339 |
+
|
340 |
+
return room_angles_metric
|
341 |
+
|
342 |
+
def poly_map_sort_key(x):
|
343 |
+
return np.sum(x[1])
|
344 |
+
|
345 |
+
h, w = img_size
|
346 |
+
|
347 |
+
gt_room_map_list = []
|
348 |
+
for room_ind, poly in enumerate(gt_polys):
|
349 |
+
room_map = np.zeros((h, w))
|
350 |
+
cv2.fillPoly(room_map, [poly], color=1.)
|
351 |
+
|
352 |
+
gt_room_map_list.append(room_map)
|
353 |
+
|
354 |
+
gt_polys_sorted_indcs = [i[0] for i in sorted(enumerate(gt_room_map_list), key=poly_map_sort_key, reverse=True)]
|
355 |
+
|
356 |
+
gt_polys = [gt_polys[ind] for ind in gt_polys_sorted_indcs]
|
357 |
+
gt_room_map_list = [gt_room_map_list[ind] for ind in gt_polys_sorted_indcs]
|
358 |
+
|
359 |
+
if pred_polys is not None:
|
360 |
+
pred_room_map_list = []
|
361 |
+
for room_ind, poly in enumerate(pred_polys):
|
362 |
+
room_map = np.zeros((h, w))
|
363 |
+
cv2.fillPoly(room_map, [poly], color=1.)
|
364 |
+
|
365 |
+
pred_room_map_list.append(room_map)
|
366 |
+
else:
|
367 |
+
pred_room_map_list = masks_list
|
368 |
+
|
369 |
+
gt2pred_indices = [-1] * len(gt_polys)
|
370 |
+
gt2pred_exists = [False] * len(gt_polys)
|
371 |
+
|
372 |
+
for gt_ind, gt_map in enumerate(gt_room_map_list):
|
373 |
+
|
374 |
+
best_iou = 0.
|
375 |
+
best_ind = -1
|
376 |
+
for pred_ind, pred_map in enumerate(pred_room_map_list):
|
377 |
+
|
378 |
+
intersection = (1 - ignore_mask_region) * ((pred_map + gt_map) == 2)
|
379 |
+
union = (1 - ignore_mask_region) * ((pred_map + gt_map) >= 1)
|
380 |
+
|
381 |
+
iou = np.sum(intersection) / (np.sum(union) + 1)
|
382 |
+
|
383 |
+
if iou > best_iou and iou > 0.5:
|
384 |
+
best_iou = iou
|
385 |
+
best_ind = pred_ind
|
386 |
+
|
387 |
+
# plt.figure()
|
388 |
+
# plt.subplot(121)
|
389 |
+
# plt.imshow(pred_map)
|
390 |
+
# plt.subplot(122)
|
391 |
+
# plt.imshow(gt_map)
|
392 |
+
# plt.show()
|
393 |
+
# if best_ind == -1:
|
394 |
+
# plt.figure()
|
395 |
+
# plt.imshow(gt_map)
|
396 |
+
# plt.show()
|
397 |
+
|
398 |
+
gt2pred_indices[gt_ind] = best_ind
|
399 |
+
gt2pred_exists[gt_ind] = best_ind != -1
|
400 |
+
|
401 |
+
# if best_ind == -1:
|
402 |
+
# plt.figure()
|
403 |
+
# plt.imshow(gt_map)
|
404 |
+
# plt.show()
|
405 |
+
|
406 |
+
pred2gt_exists = [True if pred_ind in gt2pred_indices else False for pred_ind, _ in enumerate(pred_polys)]
|
407 |
+
pred2gt_indices = [gt2pred_indices.index(pred_ind) if pred_ind in gt2pred_indices else -1 for pred_ind, _ in enumerate(pred_polys)]
|
408 |
+
|
409 |
+
# print(gt2pred_indices)
|
410 |
+
# print(pred2gt_indices)
|
411 |
+
# assert False
|
412 |
+
|
413 |
+
# import pdb; pdb.set_trace()
|
414 |
+
room_metric = get_room_metric()
|
415 |
+
if len(pred_polys) == 0:
|
416 |
+
room_metric_prec = 0
|
417 |
+
else:
|
418 |
+
room_metric_prec = sum(room_metric) / float(len(pred_polys))
|
419 |
+
room_metric_rec = sum(room_metric) / float(len(gt_polys))
|
420 |
+
|
421 |
+
|
422 |
+
corner_metric = get_corner_metric()
|
423 |
+
pred_corners_n = sum([poly.shape[0] - 1 for poly in pred_polys])
|
424 |
+
gt_corners_n = sum([poly.shape[0] - 1 for poly in gt_polys])
|
425 |
+
|
426 |
+
if pred_corners_n > 0:
|
427 |
+
corner_metric_prec = sum(corner_metric) / float(pred_corners_n)
|
428 |
+
else:
|
429 |
+
corner_metric_prec = 0
|
430 |
+
corner_metric_rec = sum(corner_metric) / float(gt_corners_n)
|
431 |
+
|
432 |
+
|
433 |
+
angles_metric = get_angle_metric()
|
434 |
+
|
435 |
+
if pred_corners_n > 0:
|
436 |
+
angles_metric_prec = sum(angles_metric) / float(pred_corners_n)
|
437 |
+
else:
|
438 |
+
angles_metric_prec = 0
|
439 |
+
angles_metric_rec = sum(angles_metric) / float(gt_corners_n)
|
440 |
+
|
441 |
+
assert room_metric_prec <= 1
|
442 |
+
assert room_metric_rec <= 1
|
443 |
+
assert corner_metric_prec <= 1
|
444 |
+
assert corner_metric_rec <= 1
|
445 |
+
assert angles_metric_prec <= 1
|
446 |
+
assert angles_metric_rec <= 1
|
447 |
+
|
448 |
+
result_dict = {
|
449 |
+
'room_prec': room_metric_prec,
|
450 |
+
'room_rec': room_metric_rec,
|
451 |
+
'corner_prec': corner_metric_prec,
|
452 |
+
'corner_rec': corner_metric_rec,
|
453 |
+
'angles_prec': angles_metric_prec,
|
454 |
+
'angles_rec': angles_metric_rec,
|
455 |
+
}
|
456 |
+
|
457 |
+
return result_dict
|