Spaces:
Running
Running
| from pathlib import Path | |
| from types import SimpleNamespace | |
| import warnings | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from typing import Optional, List, Callable | |
| try: | |
| from flash_attn.modules.mha import FlashCrossAttention | |
| except ModuleNotFoundError: | |
| FlashCrossAttention = None | |
| if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"): | |
| FLASH_AVAILABLE = True | |
| else: | |
| FLASH_AVAILABLE = False | |
| torch.backends.cudnn.deterministic = True | |
| def normalize_keypoints(kpts: torch.Tensor, size: torch.Tensor) -> torch.Tensor: | |
| if isinstance(size, torch.Size): | |
| size = torch.tensor(size)[None] | |
| shift = size.float().to(kpts) / 2 | |
| scale = size.max(1).values.float().to(kpts) / 2 | |
| kpts = (kpts - shift[:, None]) / scale[:, None, None] | |
| return kpts | |
| def rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x = x.unflatten(-1, (-1, 2)) | |
| x1, x2 = x.unbind(dim=-1) | |
| return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) | |
| def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
| return (t * freqs[0]) + (rotate_half(t) * freqs[1]) | |
| class LearnableFourierPositionalEncoding(nn.Module): | |
| def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: | |
| super().__init__() | |
| F_dim = F_dim if F_dim is not None else dim | |
| self.gamma = gamma | |
| self.Wr = nn.Linear(M, F_dim // 2, bias=False) | |
| nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """encode position vector""" | |
| projected = self.Wr(x) | |
| cosines, sines = torch.cos(projected), torch.sin(projected) | |
| emb = torch.stack([cosines, sines], 0).unsqueeze(-3) | |
| return emb.repeat_interleave(2, dim=-1) | |
| class TokenConfidence(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) | |
| def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): | |
| """get confidence tokens""" | |
| return ( | |
| self.token(desc0.detach().float()).squeeze(-1), | |
| self.token(desc1.detach().float()).squeeze(-1), | |
| ) | |
| class Attention(nn.Module): | |
| def __init__(self, allow_flash: bool) -> None: | |
| super().__init__() | |
| if allow_flash and not FLASH_AVAILABLE: | |
| warnings.warn( | |
| "FlashAttention is not available. For optimal speed, " | |
| "consider installing torch >= 2.0 or flash-attn.", | |
| stacklevel=2, | |
| ) | |
| self.enable_flash = allow_flash and FLASH_AVAILABLE | |
| if allow_flash and FlashCrossAttention: | |
| self.flash_ = FlashCrossAttention() | |
| def forward(self, q, k, v) -> torch.Tensor: | |
| if self.enable_flash and q.device.type == "cuda": | |
| if FlashCrossAttention: | |
| q, k, v = [x.transpose(-2, -3) for x in [q, k, v]] | |
| m = self.flash_(q.half(), torch.stack([k, v], 2).half()) | |
| return m.transpose(-2, -3).to(q.dtype) | |
| else: # use torch 2.0 scaled_dot_product_attention with flash | |
| args = [x.half().contiguous() for x in [q, k, v]] | |
| with torch.backends.cuda.sdp_kernel(enable_flash=True): | |
| return F.scaled_dot_product_attention(*args).to(q.dtype) | |
| elif hasattr(F, "scaled_dot_product_attention"): | |
| args = [x.contiguous() for x in [q, k, v]] | |
| return F.scaled_dot_product_attention(*args).to(q.dtype) | |
| else: | |
| s = q.shape[-1] ** -0.5 | |
| attn = F.softmax(torch.einsum("...id,...jd->...ij", q, k) * s, -1) | |
| return torch.einsum("...ij,...jd->...id", attn, v) | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True | |
| ) -> None: | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| assert self.embed_dim % num_heads == 0 | |
| self.head_dim = self.embed_dim // num_heads | |
| self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) | |
| self.inner_attn = Attention(flash) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(2 * embed_dim, 2 * embed_dim), | |
| nn.LayerNorm(2 * embed_dim, elementwise_affine=True), | |
| nn.GELU(), | |
| nn.Linear(2 * embed_dim, embed_dim), | |
| ) | |
| def _forward(self, x: torch.Tensor, encoding: Optional[torch.Tensor] = None): | |
| qkv = self.Wqkv(x) | |
| qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) | |
| q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] | |
| if encoding is not None: | |
| q = apply_cached_rotary_emb(encoding, q) | |
| k = apply_cached_rotary_emb(encoding, k) | |
| context = self.inner_attn(q, k, v) | |
| message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) | |
| return x + self.ffn(torch.cat([x, message], -1)) | |
| def forward(self, x0, x1, encoding0=None, encoding1=None): | |
| return self._forward(x0, encoding0), self._forward(x1, encoding1) | |
| class CrossTransformer(nn.Module): | |
| def __init__( | |
| self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True | |
| ) -> None: | |
| super().__init__() | |
| self.heads = num_heads | |
| dim_head = embed_dim // num_heads | |
| self.scale = dim_head**-0.5 | |
| inner_dim = dim_head * num_heads | |
| self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) | |
| self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) | |
| self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(2 * embed_dim, 2 * embed_dim), | |
| nn.LayerNorm(2 * embed_dim, elementwise_affine=True), | |
| nn.GELU(), | |
| nn.Linear(2 * embed_dim, embed_dim), | |
| ) | |
| if flash and FLASH_AVAILABLE: | |
| self.flash = Attention(True) | |
| else: | |
| self.flash = None | |
| def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): | |
| return func(x0), func(x1) | |
| def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]: | |
| qk0, qk1 = self.map_(self.to_qk, x0, x1) | |
| v0, v1 = self.map_(self.to_v, x0, x1) | |
| qk0, qk1, v0, v1 = map( | |
| lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), | |
| (qk0, qk1, v0, v1), | |
| ) | |
| if self.flash is not None: | |
| m0 = self.flash(qk0, qk1, v1) | |
| m1 = self.flash(qk1, qk0, v0) | |
| else: | |
| qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 | |
| sim = torch.einsum("b h i d, b h j d -> b h i j", qk0, qk1) | |
| attn01 = F.softmax(sim, dim=-1) | |
| attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) | |
| m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) | |
| m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) | |
| m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) | |
| m0, m1 = self.map_(self.to_out, m0, m1) | |
| x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) | |
| x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) | |
| return x0, x1 | |
| def sigmoid_log_double_softmax( | |
| sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor | |
| ) -> torch.Tensor: | |
| """create the log assignment matrix from logits and similarity""" | |
| b, m, n = sim.shape | |
| certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) | |
| scores0 = F.log_softmax(sim, 2) | |
| scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) | |
| scores = sim.new_full((b, m + 1, n + 1), 0) | |
| scores[:, :m, :n] = scores0 + scores1 + certainties | |
| scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) | |
| scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) | |
| return scores | |
| class MatchAssignment(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.matchability = nn.Linear(dim, 1, bias=True) | |
| self.final_proj = nn.Linear(dim, dim, bias=True) | |
| def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): | |
| """build assignment matrix from descriptors""" | |
| mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) | |
| _, _, d = mdesc0.shape | |
| mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 | |
| sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) | |
| z0 = self.matchability(desc0) | |
| z1 = self.matchability(desc1) | |
| scores = sigmoid_log_double_softmax(sim, z0, z1) | |
| return scores, sim | |
| def scores(self, desc0: torch.Tensor, desc1: torch.Tensor): | |
| m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1) | |
| m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1) | |
| return m0, m1 | |
| def filter_matches(scores: torch.Tensor, th: float): | |
| """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" | |
| max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) | |
| m0, m1 = max0.indices, max1.indices | |
| mutual0 = torch.arange(m0.shape[1]).to(m0)[None] == m1.gather(1, m0) | |
| mutual1 = torch.arange(m1.shape[1]).to(m1)[None] == m0.gather(1, m1) | |
| max0_exp = max0.values.exp() | |
| zero = max0_exp.new_tensor(0) | |
| mscores0 = torch.where(mutual0, max0_exp, zero) | |
| mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) | |
| if th is not None: | |
| valid0 = mutual0 & (mscores0 > th) | |
| else: | |
| valid0 = mutual0 | |
| valid1 = mutual1 & valid0.gather(1, m1) | |
| m0 = torch.where(valid0, m0, m0.new_tensor(-1)) | |
| m1 = torch.where(valid1, m1, m1.new_tensor(-1)) | |
| return m0, m1, mscores0, mscores1 | |
| class LightGlue(nn.Module): | |
| default_conf = { | |
| "name": "lightglue", # just for interfacing | |
| "input_dim": 256, # input descriptor dimension (autoselected from weights) | |
| "descriptor_dim": 256, | |
| "n_layers": 9, | |
| "num_heads": 4, | |
| "flash": True, # enable FlashAttention if available. | |
| "mp": False, # enable mixed precision | |
| "depth_confidence": 0.95, # early stopping, disable with -1 | |
| "width_confidence": 0.99, # point pruning, disable with -1 | |
| "filter_threshold": 0.1, # match threshold | |
| "weights": None, | |
| } | |
| required_data_keys = ["image0", "image1"] | |
| version = "v0.1_arxiv" | |
| url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" | |
| features = { | |
| "superpoint": ("superpoint_lightglue", 256), | |
| "disk": ("disk_lightglue", 128), | |
| } | |
| def __init__(self, features="superpoint", **conf) -> None: | |
| super().__init__() | |
| self.conf = {**self.default_conf, **conf} | |
| if features is not None: | |
| assert features in list(self.features.keys()) | |
| self.conf["weights"], self.conf["input_dim"] = self.features[features] | |
| self.conf = conf = SimpleNamespace(**self.conf) | |
| if conf.input_dim != conf.descriptor_dim: | |
| self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) | |
| else: | |
| self.input_proj = nn.Identity() | |
| head_dim = conf.descriptor_dim // conf.num_heads | |
| self.posenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) | |
| h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim | |
| self.self_attn = nn.ModuleList( | |
| [Transformer(d, h, conf.flash) for _ in range(n)] | |
| ) | |
| self.cross_attn = nn.ModuleList( | |
| [CrossTransformer(d, h, conf.flash) for _ in range(n)] | |
| ) | |
| self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) | |
| self.token_confidence = nn.ModuleList( | |
| [TokenConfidence(d) for _ in range(n - 1)] | |
| ) | |
| if features is not None: | |
| fname = f"{conf.weights}_{self.version}.pth".replace(".", "-") | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| self.url.format(self.version, features), file_name=fname | |
| ) | |
| self.load_state_dict(state_dict, strict=False) | |
| elif conf.weights is not None: | |
| path = Path(__file__).parent | |
| path = path / "weights/{}.pth".format(self.conf.weights) | |
| state_dict = torch.load(str(path), map_location="cpu") | |
| self.load_state_dict(state_dict, strict=False) | |
| print("Loaded LightGlue model") | |
| def forward(self, data: dict) -> dict: | |
| """ | |
| Match keypoints and descriptors between two images | |
| Input (dict): | |
| image0: dict | |
| keypoints: [B x M x 2] | |
| descriptors: [B x M x D] | |
| image: [B x C x H x W] or image_size: [B x 2] | |
| image1: dict | |
| keypoints: [B x N x 2] | |
| descriptors: [B x N x D] | |
| image: [B x C x H x W] or image_size: [B x 2] | |
| Output (dict): | |
| log_assignment: [B x M+1 x N+1] | |
| matches0: [B x M] | |
| matching_scores0: [B x M] | |
| matches1: [B x N] | |
| matching_scores1: [B x N] | |
| matches: List[[Si x 2]], scores: List[[Si]] | |
| """ | |
| with torch.autocast(enabled=self.conf.mp, device_type="cuda"): | |
| return self._forward(data) | |
| def _forward(self, data: dict) -> dict: | |
| for key in self.required_data_keys: | |
| assert key in data, f"Missing key {key} in data" | |
| data0, data1 = data["image0"], data["image1"] | |
| kpts0_, kpts1_ = data0["keypoints"], data1["keypoints"] | |
| b, m, _ = kpts0_.shape | |
| b, n, _ = kpts1_.shape | |
| size0, size1 = data0.get("image_size"), data1.get("image_size") | |
| size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1] | |
| size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1] | |
| kpts0 = normalize_keypoints(kpts0_, size=size0) | |
| kpts1 = normalize_keypoints(kpts1_, size=size1) | |
| assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1) | |
| assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1) | |
| desc0 = data0["descriptors"].detach() | |
| desc1 = data1["descriptors"].detach() | |
| assert desc0.shape[-1] == self.conf.input_dim | |
| assert desc1.shape[-1] == self.conf.input_dim | |
| if torch.is_autocast_enabled(): | |
| desc0 = desc0.half() | |
| desc1 = desc1.half() | |
| desc0 = self.input_proj(desc0) | |
| desc1 = self.input_proj(desc1) | |
| # cache positional embeddings | |
| encoding0 = self.posenc(kpts0) | |
| encoding1 = self.posenc(kpts1) | |
| # GNN + final_proj + assignment | |
| ind0 = torch.arange(0, m).to(device=kpts0.device)[None] | |
| ind1 = torch.arange(0, n).to(device=kpts0.device)[None] | |
| prune0 = torch.ones_like(ind0) # store layer where pruning is detected | |
| prune1 = torch.ones_like(ind1) | |
| dec, wic = self.conf.depth_confidence, self.conf.width_confidence | |
| token0, token1 = None, None | |
| for i in range(self.conf.n_layers): | |
| # self+cross attention | |
| desc0, desc1 = self.self_attn[i](desc0, desc1, encoding0, encoding1) | |
| desc0, desc1 = self.cross_attn[i](desc0, desc1) | |
| if i == self.conf.n_layers - 1: | |
| continue # no early stopping or adaptive width at last layer | |
| if dec > 0: # early stopping | |
| token0, token1 = self.token_confidence[i](desc0, desc1) | |
| if self.stop(token0, token1, self.conf_th(i), dec, m + n): | |
| break | |
| if wic > 0: # point pruning | |
| match0, match1 = self.log_assignment[i].scores(desc0, desc1) | |
| mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic) | |
| mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic) | |
| ind0, ind1 = ind0[mask0][None], ind1[mask1][None] | |
| desc0, desc1 = desc0[mask0][None], desc1[mask1][None] | |
| if desc0.shape[-2] == 0 or desc1.shape[-2] == 0: | |
| break | |
| encoding0 = encoding0[:, :, mask0][:, None] | |
| encoding1 = encoding1[:, :, mask1][:, None] | |
| prune0[:, ind0] += 1 | |
| prune1[:, ind1] += 1 | |
| if wic > 0: # scatter with indices after pruning | |
| scores_, _ = self.log_assignment[i](desc0, desc1) | |
| dt, dev = scores_.dtype, scores_.device | |
| scores = torch.zeros(b, m + 1, n + 1, dtype=dt, device=dev) | |
| scores[:, :-1, :-1] = -torch.inf | |
| scores[:, ind0[0], -1] = scores_[:, :-1, -1] | |
| scores[:, -1, ind1[0]] = scores_[:, -1, :-1] | |
| x, y = torch.meshgrid(ind0[0], ind1[0], indexing="ij") | |
| scores[:, x, y] = scores_[:, :-1, :-1] | |
| else: | |
| scores, _ = self.log_assignment[i](desc0, desc1) | |
| m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) | |
| matches, mscores = [], [] | |
| for k in range(b): | |
| valid = m0[k] > -1 | |
| matches.append(torch.stack([torch.where(valid)[0], m0[k][valid]], -1)) | |
| mscores.append(mscores0[k][valid]) | |
| return { | |
| "log_assignment": scores, | |
| "matches0": m0, | |
| "matches1": m1, | |
| "matching_scores0": mscores0, | |
| "matching_scores1": mscores1, | |
| "stop": i + 1, | |
| "prune0": prune0, | |
| "prune1": prune1, | |
| "matches": matches, | |
| "scores": mscores, | |
| } | |
| def conf_th(self, i: int) -> float: | |
| """scaled confidence threshold""" | |
| return np.clip(0.8 + 0.1 * np.exp(-4.0 * i / self.conf.n_layers), 0, 1) | |
| def get_mask( | |
| self, | |
| confidence: torch.Tensor, | |
| match: torch.Tensor, | |
| conf_th: float, | |
| match_th: float, | |
| ) -> torch.Tensor: | |
| """mask points which should be removed""" | |
| if conf_th and confidence is not None: | |
| mask = ( | |
| torch.where(confidence > conf_th, match, match.new_tensor(1.0)) | |
| > match_th | |
| ) | |
| else: | |
| mask = match > match_th | |
| return mask | |
| def stop( | |
| self, | |
| token0: torch.Tensor, | |
| token1: torch.Tensor, | |
| conf_th: float, | |
| inl_th: float, | |
| seql: int, | |
| ) -> torch.Tensor: | |
| """evaluate stopping condition""" | |
| tokens = torch.cat([token0, token1], -1) | |
| if conf_th: | |
| pos = 1.0 - (tokens < conf_th).float().sum() / seql | |
| return pos > inl_th | |
| else: | |
| return tokens.mean() > inl_th | |