LoGoSAM_demo / models /alpmodule.py
quandn2003's picture
Upload folder using huggingface_hub
427d150 verified
"""
ALPModule
"""
import torch
import time
import math
from torch import nn
from torch.nn import functional as F
import numpy as np
from pdb import set_trace
import matplotlib.pyplot as plt
# for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm
def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
x_norm = torch.norm(x, p = p, dim = dim) # .detach()
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
x = x.div(x_norm.unsqueeze(1).expand_as(x))
return x
class MultiProtoAsConv(nn.Module):
def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'):
"""
ALPModule
Args:
proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2
feature_hw: Spatial size of input feature map
"""
super(MultiProtoAsConv, self).__init__()
self.feature_hw = feature_hw
self.proto_grid = proto_grid
self.upsample_mode = upsample_mode
kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ]
self.kernel_size = kernel_size
print(f"MultiProtoAsConv: kernel_size: {kernel_size}")
self.avg_pool_op = nn.AvgPool2d( kernel_size )
if use_attention:
self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
self.fg_mask_projection = nn.Sequential(
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
)
self.bg_mask_projection = nn.Sequential(
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
)
def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ):
if mode == 'mask':
pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w]
# incase there are more than one prototypes in the same location, take the max
pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0)
vis_dict = {'proto_assign': pred_mask} # things to visualize
if vis_sim:
vis_dict['raw_local_sims'] = pred_mask
return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w]
elif mode == 'gridconv':
dists = F.conv2d(query, prototypes[..., None, None]) * 20
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
debug_assign = dists.argmax(dim = 1).float().detach()
vis_dict = {'proto_assign': debug_assign} # things to visualize
if vis_sim: # return the similarity for visualization
vis_dict['raw_local_sims'] = dists.clone().detach()
return pred_grid, [debug_assign], vis_dict
elif mode == 'gridconv+':
dists = F.conv2d(query, prototypes[..., None, None]) * 20
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
# raw_local_sims = dists.det ach()
debug_assign = dists.argmax(dim = 1).float()
vis_dict = {'proto_assign': debug_assign}
if vis_sim:
vis_dict['raw_local_sims'] = dists.clone().detach()
return pred_grid, [debug_assign], vis_dict
else:
raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.")
def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False):
if mode == 'mask':
proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
/ (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C
pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything
pro_n = proto
proto_grid = sup_y.clone().detach() # a single prototype for the whole image
resized_proto_grid = proto_grid
non_zero = torch.nonzero(proto_grid)
elif mode == 'gridconv':
nch = sup_x.shape[1]
sup_nshot = sup_x.shape[0]
# if len(sup_x.shape) > 4:
# sup_x = sup_x.squeeze()
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
# get a grid of prototypes
proto_grid = sup_y_g.clone().detach()
proto_grid[proto_grid < thresh] = 0
# interpolate the grid to the original size
non_zero = torch.nonzero(proto_grid)
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
for index in non_zero:
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
protos = n_sup_x[sup_y_g > thresh, :] # npro, nc
pro_n = safe_norm(protos)
elif mode == 'gridconv+':
nch = sup_x.shape[1]
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
sup_nshot = sup_x.shape[0]
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0)
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
# get a grid of prototypes
proto_grid = sup_y_g.clone().detach()
proto_grid[proto_grid < thresh] = 0
non_zero = torch.nonzero(proto_grid)
for i, idx in enumerate(non_zero):
proto_grid[0, idx[1], idx[2], idx[3]] = i + 1
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
for index in non_zero:
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
protos = n_sup_x[sup_y_g > thresh, :]
glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
/ (sup_y.sum(dim=(-1, -2)) + 1e-5)
pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 ))
return pro_n, resized_proto_grid, non_zero
def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs):
"""
Now supports
Args:
mode: 'mask'/ 'grid'. if mask, works as original prototyping
qry: [way(1), nc, h, w]
sup_x: [nb, nc, h, w]
sup_y: [nb, 1, h, w]
vis_sim: visualize raw similarities or not
New
mode: 'mask'/ 'grid'. if mask, works as original prototyping
qry: [way(1), nb(1), nc, h, w]
sup_x: [way(1), shot, nb(1), nc, h, w]
sup_y: [way(1), shot, nb(1), h, w]
vis_sim: visualize raw similarities or not
"""
qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w]
sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w]
sup_y = sup_y.squeeze(0) # [nshot, 1, h, w]
def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
x_norm = torch.norm(x, p = p, dim = dim) # .detach()
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
x = x.div(x_norm.unsqueeze(1).expand_as(x))
return x
if val_wsize is None:
val_wsize = self.avg_pool_op.kernel_size
if isinstance(val_wsize, (tuple, list)):
val_wsize = val_wsize[0]
sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1])
pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval)
if 0 in pro_n.shape:
print("failed to find prototypes")
qry_n = qry if mode == 'mask' else safe_norm(qry)
pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim)
return pred_grid, debug_assign, vis_dict, proto_grid