Spaces:
Sleeping
Sleeping
File size: 9,786 Bytes
427d150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""
ALPModule
"""
import torch
import time
import math
from torch import nn
from torch.nn import functional as F
import numpy as np
from pdb import set_trace
import matplotlib.pyplot as plt
# for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm
def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
x_norm = torch.norm(x, p = p, dim = dim) # .detach()
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
x = x.div(x_norm.unsqueeze(1).expand_as(x))
return x
class MultiProtoAsConv(nn.Module):
def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'):
"""
ALPModule
Args:
proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2
feature_hw: Spatial size of input feature map
"""
super(MultiProtoAsConv, self).__init__()
self.feature_hw = feature_hw
self.proto_grid = proto_grid
self.upsample_mode = upsample_mode
kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ]
self.kernel_size = kernel_size
print(f"MultiProtoAsConv: kernel_size: {kernel_size}")
self.avg_pool_op = nn.AvgPool2d( kernel_size )
if use_attention:
self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
self.fg_mask_projection = nn.Sequential(
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
)
self.bg_mask_projection = nn.Sequential(
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
)
def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ):
if mode == 'mask':
pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w]
# incase there are more than one prototypes in the same location, take the max
pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0)
vis_dict = {'proto_assign': pred_mask} # things to visualize
if vis_sim:
vis_dict['raw_local_sims'] = pred_mask
return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w]
elif mode == 'gridconv':
dists = F.conv2d(query, prototypes[..., None, None]) * 20
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
debug_assign = dists.argmax(dim = 1).float().detach()
vis_dict = {'proto_assign': debug_assign} # things to visualize
if vis_sim: # return the similarity for visualization
vis_dict['raw_local_sims'] = dists.clone().detach()
return pred_grid, [debug_assign], vis_dict
elif mode == 'gridconv+':
dists = F.conv2d(query, prototypes[..., None, None]) * 20
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
# raw_local_sims = dists.det ach()
debug_assign = dists.argmax(dim = 1).float()
vis_dict = {'proto_assign': debug_assign}
if vis_sim:
vis_dict['raw_local_sims'] = dists.clone().detach()
return pred_grid, [debug_assign], vis_dict
else:
raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.")
def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False):
if mode == 'mask':
proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
/ (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C
pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything
pro_n = proto
proto_grid = sup_y.clone().detach() # a single prototype for the whole image
resized_proto_grid = proto_grid
non_zero = torch.nonzero(proto_grid)
elif mode == 'gridconv':
nch = sup_x.shape[1]
sup_nshot = sup_x.shape[0]
# if len(sup_x.shape) > 4:
# sup_x = sup_x.squeeze()
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
# get a grid of prototypes
proto_grid = sup_y_g.clone().detach()
proto_grid[proto_grid < thresh] = 0
# interpolate the grid to the original size
non_zero = torch.nonzero(proto_grid)
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
for index in non_zero:
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
protos = n_sup_x[sup_y_g > thresh, :] # npro, nc
pro_n = safe_norm(protos)
elif mode == 'gridconv+':
nch = sup_x.shape[1]
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
sup_nshot = sup_x.shape[0]
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0)
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
# get a grid of prototypes
proto_grid = sup_y_g.clone().detach()
proto_grid[proto_grid < thresh] = 0
non_zero = torch.nonzero(proto_grid)
for i, idx in enumerate(non_zero):
proto_grid[0, idx[1], idx[2], idx[3]] = i + 1
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
for index in non_zero:
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
protos = n_sup_x[sup_y_g > thresh, :]
glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
/ (sup_y.sum(dim=(-1, -2)) + 1e-5)
pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 ))
return pro_n, resized_proto_grid, non_zero
def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs):
"""
Now supports
Args:
mode: 'mask'/ 'grid'. if mask, works as original prototyping
qry: [way(1), nc, h, w]
sup_x: [nb, nc, h, w]
sup_y: [nb, 1, h, w]
vis_sim: visualize raw similarities or not
New
mode: 'mask'/ 'grid'. if mask, works as original prototyping
qry: [way(1), nb(1), nc, h, w]
sup_x: [way(1), shot, nb(1), nc, h, w]
sup_y: [way(1), shot, nb(1), h, w]
vis_sim: visualize raw similarities or not
"""
qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w]
sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w]
sup_y = sup_y.squeeze(0) # [nshot, 1, h, w]
def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
x_norm = torch.norm(x, p = p, dim = dim) # .detach()
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
x = x.div(x_norm.unsqueeze(1).expand_as(x))
return x
if val_wsize is None:
val_wsize = self.avg_pool_op.kernel_size
if isinstance(val_wsize, (tuple, list)):
val_wsize = val_wsize[0]
sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1])
pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval)
if 0 in pro_n.shape:
print("failed to find prototypes")
qry_n = qry if mode == 'mask' else safe_norm(qry)
pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim)
return pred_grid, debug_assign, vis_dict, proto_grid
|