Spaces:
Sleeping
Sleeping
""" | |
ALPModule | |
""" | |
import torch | |
import time | |
import math | |
from torch import nn | |
from torch.nn import functional as F | |
import numpy as np | |
from pdb import set_trace | |
import matplotlib.pyplot as plt | |
# for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm | |
def safe_norm(x, p = 2, dim = 1, eps = 1e-4): | |
x_norm = torch.norm(x, p = p, dim = dim) # .detach() | |
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps) | |
x = x.div(x_norm.unsqueeze(1).expand_as(x)) | |
return x | |
class MultiProtoAsConv(nn.Module): | |
def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'): | |
""" | |
ALPModule | |
Args: | |
proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2 | |
feature_hw: Spatial size of input feature map | |
""" | |
super(MultiProtoAsConv, self).__init__() | |
self.feature_hw = feature_hw | |
self.proto_grid = proto_grid | |
self.upsample_mode = upsample_mode | |
kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ] | |
self.kernel_size = kernel_size | |
print(f"MultiProtoAsConv: kernel_size: {kernel_size}") | |
self.avg_pool_op = nn.AvgPool2d( kernel_size ) | |
if use_attention: | |
self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True) | |
self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True) | |
self.fg_mask_projection = nn.Sequential( | |
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
) | |
self.bg_mask_projection = nn.Sequential( | |
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
) | |
def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ): | |
if mode == 'mask': | |
pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w] | |
# incase there are more than one prototypes in the same location, take the max | |
pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0) | |
vis_dict = {'proto_assign': pred_mask} # things to visualize | |
if vis_sim: | |
vis_dict['raw_local_sims'] = pred_mask | |
return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w] | |
elif mode == 'gridconv': | |
dists = F.conv2d(query, prototypes[..., None, None]) * 20 | |
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) | |
debug_assign = dists.argmax(dim = 1).float().detach() | |
vis_dict = {'proto_assign': debug_assign} # things to visualize | |
if vis_sim: # return the similarity for visualization | |
vis_dict['raw_local_sims'] = dists.clone().detach() | |
return pred_grid, [debug_assign], vis_dict | |
elif mode == 'gridconv+': | |
dists = F.conv2d(query, prototypes[..., None, None]) * 20 | |
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True) | |
# raw_local_sims = dists.det ach() | |
debug_assign = dists.argmax(dim = 1).float() | |
vis_dict = {'proto_assign': debug_assign} | |
if vis_sim: | |
vis_dict['raw_local_sims'] = dists.clone().detach() | |
return pred_grid, [debug_assign], vis_dict | |
else: | |
raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.") | |
def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False): | |
if mode == 'mask': | |
proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ | |
/ (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C | |
pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything | |
pro_n = proto | |
proto_grid = sup_y.clone().detach() # a single prototype for the whole image | |
resized_proto_grid = proto_grid | |
non_zero = torch.nonzero(proto_grid) | |
elif mode == 'gridconv': | |
nch = sup_x.shape[1] | |
sup_nshot = sup_x.shape[0] | |
# if len(sup_x.shape) > 4: | |
# sup_x = sup_x.squeeze() | |
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) | |
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc | |
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) | |
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) | |
# get a grid of prototypes | |
proto_grid = sup_y_g.clone().detach() | |
proto_grid[proto_grid < thresh] = 0 | |
# interpolate the grid to the original size | |
non_zero = torch.nonzero(proto_grid) | |
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize]) | |
for index in non_zero: | |
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]] | |
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) | |
protos = n_sup_x[sup_y_g > thresh, :] # npro, nc | |
pro_n = safe_norm(protos) | |
elif mode == 'gridconv+': | |
nch = sup_x.shape[1] | |
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x ) | |
sup_nshot = sup_x.shape[0] | |
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) | |
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0) | |
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y) | |
# get a grid of prototypes | |
proto_grid = sup_y_g.clone().detach() | |
proto_grid[proto_grid < thresh] = 0 | |
non_zero = torch.nonzero(proto_grid) | |
for i, idx in enumerate(non_zero): | |
proto_grid[0, idx[1], idx[2], idx[3]] = i + 1 | |
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize]) | |
for index in non_zero: | |
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]] | |
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0) | |
protos = n_sup_x[sup_y_g > thresh, :] | |
glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \ | |
/ (sup_y.sum(dim=(-1, -2)) + 1e-5) | |
pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 )) | |
return pro_n, resized_proto_grid, non_zero | |
def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs): | |
""" | |
Now supports | |
Args: | |
mode: 'mask'/ 'grid'. if mask, works as original prototyping | |
qry: [way(1), nc, h, w] | |
sup_x: [nb, nc, h, w] | |
sup_y: [nb, 1, h, w] | |
vis_sim: visualize raw similarities or not | |
New | |
mode: 'mask'/ 'grid'. if mask, works as original prototyping | |
qry: [way(1), nb(1), nc, h, w] | |
sup_x: [way(1), shot, nb(1), nc, h, w] | |
sup_y: [way(1), shot, nb(1), h, w] | |
vis_sim: visualize raw similarities or not | |
""" | |
qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w] | |
sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w] | |
sup_y = sup_y.squeeze(0) # [nshot, 1, h, w] | |
def safe_norm(x, p = 2, dim = 1, eps = 1e-4): | |
x_norm = torch.norm(x, p = p, dim = dim) # .detach() | |
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps) | |
x = x.div(x_norm.unsqueeze(1).expand_as(x)) | |
return x | |
if val_wsize is None: | |
val_wsize = self.avg_pool_op.kernel_size | |
if isinstance(val_wsize, (tuple, list)): | |
val_wsize = val_wsize[0] | |
sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1]) | |
pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval) | |
if 0 in pro_n.shape: | |
print("failed to find prototypes") | |
qry_n = qry if mode == 'mask' else safe_norm(qry) | |
pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim) | |
return pred_grid, debug_assign, vis_dict, proto_grid | |