Spaces:
Running
on
Zero
Running
on
Zero
| # Description: RDD model | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| import numpy as np | |
| from .utils import NestedTensor, nested_tensor_from_tensor_list, to_pixel_coords, read_config | |
| from .models.detector import build_detector | |
| from .models.descriptor import build_descriptor | |
| from .models.soft_detect import SoftDetect | |
| from .models.interpolator import InterpolateSparse2d | |
| class RDD(nn.Module): | |
| def __init__(self, detector, descriptor, detection_threshold=0.5, top_k=4096, train_detector=False, device='cuda'): | |
| super().__init__() | |
| self.detector = detector | |
| self.descriptor = descriptor | |
| self.interpolator = InterpolateSparse2d('bicubic') | |
| self.detection_threshold = detection_threshold | |
| self.top_k = top_k | |
| self.device = device | |
| if train_detector: | |
| for p in self.detector.parameters(): | |
| p.requires_grad = True | |
| for p in self.descriptor.parameters(): | |
| p.requires_grad = False | |
| else: | |
| for p in self.detector.parameters(): | |
| p.requires_grad = False | |
| for p in self.descriptor.parameters(): | |
| p.requires_grad = True | |
| self.softdetect = None | |
| self.stride = descriptor.stride | |
| def train(self, mode=True): | |
| super().train(mode) | |
| self.set_softdetect(top_k=500, scores_th=0.2) | |
| def eval(self): | |
| super().eval() | |
| self.set_softdetect(top_k=self.top_k, scores_th=0.01) | |
| def forward(self, samples: NestedTensor): | |
| if not isinstance(samples, NestedTensor): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| scoremap = self.detector(samples) | |
| feats, matchibility = self.descriptor(samples) | |
| return feats, scoremap, matchibility | |
| def set_softdetect(self, top_k=4096, scores_th=0.01): | |
| self.softdetect = SoftDetect(radius=2, top_k=top_k, scores_th=scores_th) | |
| def filter(self, matchibility): | |
| # Filter out keypoints on the border | |
| B, _, H, W = matchibility.shape | |
| frame = torch.zeros(B, H, W, device=matchibility.device) | |
| frame[:, self.stride:-self.stride, self.stride:-self.stride] = 1 | |
| matchibility = matchibility * frame | |
| return matchibility | |
| def extract(self, x): | |
| if self.softdetect is None: | |
| self.eval() | |
| x, rh1, rw1 = self.preprocess_tensor(x) | |
| x = x.to(self.device).float() | |
| B, _, _H1, _W1 = x.shape | |
| M1, K1, H1 = self.forward(x) | |
| M1 = F.normalize(M1, dim=1) | |
| keypoints, kptscores, scoredispersitys = self.softdetect(K1) | |
| keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)]) | |
| kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)]) | |
| keypoints = to_pixel_coords(keypoints, _H1, _W1) | |
| feats = self.interpolator(M1, keypoints, H = _H1, W = _W1) | |
| feats = F.normalize(feats, dim=-1) | |
| # Correct kpt scale | |
| keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1) | |
| valid = kptscores > self.detection_threshold | |
| return [ | |
| {'keypoints': keypoints[b][valid[b]], | |
| 'scores': kptscores[b][valid[b]], | |
| 'descriptors': feats[b][valid[b]]} for b in range(B) | |
| ] | |
| def extract_3rd_party(self, x, model='aliked'): | |
| """ | |
| one image per batch | |
| """ | |
| x, rh1, rw1 = self.preprocess_tensor(x) | |
| B, _, _H1, _W1 = x.shape | |
| if model == 'aliked': | |
| from third_party import extract_aliked_kpts | |
| img = x | |
| mkpts, scores = extract_aliked_kpts(img, self.device) | |
| else: | |
| raise ValueError('Unknown model') | |
| M1, _ = self.descriptor(x) | |
| M1 = F.normalize(M1, dim=1) | |
| if mkpts.shape[1] > self.top_k: | |
| idx = torch.argsort(scores, descending=True)[0][:self.top_k] | |
| mkpts = mkpts[:,idx] | |
| scores = scores[:,idx] | |
| feats = self.interpolator(M1, mkpts, H = _H1, W = _W1) | |
| feats = F.normalize(feats, dim=-1) | |
| mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1) | |
| return [ | |
| {'keypoints': mkpts[b], | |
| 'scores': scores[b], | |
| 'descriptors': feats[b]} for b in range(B) | |
| ] | |
| def extract_dense(self, x, n_limit=30000, thr=0.01): | |
| self.set_softdetect(top_k=n_limit, scores_th=-1) | |
| x, rh1, rw1 = self.preprocess_tensor(x) | |
| B, _, _H1, _W1 = x.shape | |
| M1, K1, H1 = self.forward(x) | |
| M1 = F.normalize(M1, dim=1) | |
| keypoints, kptscores, scoredispersitys = self.softdetect(K1) | |
| keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)]) | |
| kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)]) | |
| keypoints = to_pixel_coords(keypoints, _H1, _W1) | |
| feats = self.interpolator(M1, keypoints, H = _H1, W = _W1) | |
| feats = F.normalize(feats, dim=-1) | |
| H1 = self.filter(H1) | |
| dense_kpts, dense_scores, inds = self.sample_dense_kpts(H1, n_limit=n_limit) | |
| dense_keypoints = to_pixel_coords(dense_kpts, _H1, _W1) | |
| dense_feats = self.interpolator(M1, dense_keypoints, H = _H1, W = _W1) | |
| dense_feats = F.normalize(dense_feats, dim=-1) | |
| keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1) | |
| dense_keypoints = dense_keypoints * torch.tensor([rw1,rh1], device=dense_keypoints.device).view(1, -1) | |
| valid = kptscores > self.detection_threshold | |
| valid_dense = dense_scores > thr | |
| return [ | |
| {'keypoints': keypoints[b][valid[b]], | |
| 'scores': kptscores[b][valid[b]], | |
| 'descriptors': feats[b][valid[b]], | |
| 'keypoints_dense': dense_keypoints[b][valid_dense[b]], | |
| 'scores_dense': dense_scores[b][valid_dense[b]], | |
| 'descriptors_dense': dense_feats[b][valid_dense[b]]} for b in range(B) | |
| ] | |
| def sample_dense_kpts(self, keypoint_logits, threshold=0.01, n_limit=30000, force_kpts = True): | |
| B, K, H, W = keypoint_logits.shape | |
| if n_limit < 0 or n_limit > H*W: | |
| n_limit = min(H*W - 1, n_limit) | |
| scoremap = keypoint_logits.permute(0,2,3,1) | |
| scoremap = scoremap.reshape(B, H, W) | |
| frame = torch.zeros(B, H, W, device=keypoint_logits.device) | |
| frame[:, 1:-1, 1:-1] = 1 | |
| scoremap = scoremap * frame | |
| scoremap = scoremap.reshape(B, H*W) | |
| grid = self.get_grid(B, H, W, device = keypoint_logits.device) | |
| inds = torch.topk(scoremap, n_limit, dim=1).indices | |
| # inds = torch.multinomial(scoremap, top_k, replacement=False) | |
| kpts = torch.gather(grid, 1, inds[..., None].expand(B, n_limit, 2)) | |
| scoremap = torch.gather(scoremap, 1, inds) | |
| if force_kpts: | |
| valid = scoremap > threshold | |
| kpts = kpts[valid][None] | |
| scoremap = scoremap[valid][None] | |
| return kpts, scoremap, inds | |
| def preprocess_tensor(self, x): | |
| """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ | |
| if isinstance(x, np.ndarray) and len(x.shape) == 3: | |
| x = torch.tensor(x).permute(2,0,1)[None] | |
| x = x.to(self.device).float() | |
| H, W = x.shape[-2:] | |
| _H, _W = (H//32) * 32, (W//32) * 32 | |
| rh, rw = H/_H, W/_W | |
| x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) | |
| return x, rh, rw | |
| def get_grid(self, B, H, W, device = None): | |
| x1_n = torch.meshgrid( | |
| *[ | |
| torch.linspace( | |
| -1 + 1 / n, 1 - 1 / n, n, device=device | |
| ) | |
| for n in (B, H, W) | |
| ] | |
| ) | |
| x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) | |
| return x1_n | |
| def build(config=None, weights=None): | |
| if config is None: | |
| config = read_config('./configs/default.yaml') | |
| if weights is not None: | |
| config['weights'] = weights | |
| device = torch.device(config['device']) | |
| print('config', config) | |
| detector = build_detector(config) | |
| descriptor = build_descriptor(config) | |
| model = RDD( | |
| detector, | |
| descriptor, | |
| detection_threshold=config['detection_threshold'], | |
| top_k=config['top_k'], | |
| train_detector=config['train_detector'], | |
| device=device | |
| ) | |
| if 'weights' in config and config['weights'] is not None: | |
| model.load_state_dict(torch.load(config['weights'], map_location='cpu')) | |
| model.to(device) | |
| return model |