Spaces:
Sleeping
Sleeping
File size: 21,702 Bytes
427d150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 |
"""
ALPNet
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .alpmodule import MultiProtoAsConv
from .backbone.torchvision_backbones import TVDeeplabRes101Encoder
from util.consts import DEFAULT_FEATURE_SIZE
from util.lora import inject_trainable_lora
# from util.utils import load_config_from_url, plot_dinov2_fts
import math
# Specify a local path to the repository (or use installed package instead)
FG_PROT_MODE = 'gridconv+' # using both local and global prototype
# FG_PROT_MODE = 'mask'
# using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet)
BG_PROT_MODE = 'gridconv'
# thresholds for deciding class of prototypes
FG_THRESH = 0.95
BG_THRESH = 0.95
class FewShotSeg(nn.Module):
"""
ALPNet
Args:
in_channels: Number of input channels
cfg: Model configurations
"""
def __init__(self, image_size, pretrained_path=None, cfg=None):
super(FewShotSeg, self).__init__()
self.image_size = image_size
self.pretrained_path = pretrained_path
print(f'###### Pre-trained path: {self.pretrained_path} ######')
self.config = cfg or {
'align': False, 'debug': False}
self.get_encoder()
self.get_cls()
if self.pretrained_path:
self.load_state_dict(torch.load(self.pretrained_path), strict=True)
print(
f'###### Pre-trained model f{self.pretrained_path} has been loaded ######')
def get_encoder(self):
self.config['feature_hw'] = [DEFAULT_FEATURE_SIZE,
DEFAULT_FEATURE_SIZE] # default feature map size
if self.config['which_model'] == 'dlfcn_res101' or self.config['which_model'] == 'default':
use_coco_init = self.config['use_coco_init']
self.encoder = TVDeeplabRes101Encoder(use_coco_init)
self.config['feature_hw'] = [
math.ceil(self.image_size/8), math.ceil(self.image_size/8)]
elif self.config['which_model'] == 'dinov2_l14':
self.encoder = torch.hub.load(
'facebookresearch/dinov2', 'dinov2_vitl14')
self.config['feature_hw'] = [max(
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
elif self.config['which_model'] == 'dinov2_l14_reg':
try:
self.encoder = torch.hub.load(
'facebookresearch/dinov2', 'dinov2_vitl14_reg')
except RuntimeError as e:
self.encoder = torch.hub.load(
'facebookresearch/dino', 'dinov2_vitl14_reg', force_reload=True)
self.config['feature_hw'] = [max(
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
elif self.config['which_model'] == 'dinov2_b14':
self.encoder = torch.hub.load(
'facebookresearch/dinov2', 'dinov2_vitb14')
self.config['feature_hw'] = [max(
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
else:
raise NotImplementedError(
f'Backbone network {self.config["which_model"]} not implemented')
if self.config['lora'] > 0:
self.encoder.requires_grad_(False)
print(f'Injecting LoRA with rank:{self.config["lora"]}')
encoder_lora_params = inject_trainable_lora(
self.encoder, r=self.config['lora'])
def get_features(self, imgs_concat):
if self.config['which_model'] == 'dlfcn_res101':
img_fts = self.encoder(imgs_concat, low_level=False)
elif 'dino' in self.config['which_model']:
# resize imgs_concat to the closest size that is divisble by 14
imgs_concat = F.interpolate(imgs_concat, size=(
self.image_size // 14 * 14, self.image_size // 14 * 14), mode='bilinear')
dino_fts = self.encoder.forward_features(imgs_concat)
img_fts = dino_fts["x_norm_patchtokens"] # B, HW, C
img_fts = img_fts.permute(0, 2, 1) # B, C, HW
C, HW = img_fts.shape[-2:]
img_fts = img_fts.view(-1, C, int(HW**0.5),
int(HW**0.5)) # B, C, H, W
if HW < DEFAULT_FEATURE_SIZE ** 2:
img_fts = F.interpolate(img_fts, size=(
DEFAULT_FEATURE_SIZE, DEFAULT_FEATURE_SIZE), mode='bilinear') # this is if h,w < (32,32)
else:
raise NotImplementedError(
f'Backbone network {self.config["which_model"]} not implemented')
return img_fts
def get_cls(self):
"""
Obtain the similarity-based classifier
"""
proto_hw = self.config["proto_grid_size"]
if self.config['cls_name'] == 'grid_proto':
embed_dim = 256
if 'dinov2_b14' in self.config['which_model']:
embed_dim = 768
elif 'dinov2_l14' in self.config['which_model']:
embed_dim = 1024
self.cls_unit = MultiProtoAsConv(proto_grid=[proto_hw, proto_hw], feature_hw=self.config["feature_hw"], embed_dim=embed_dim) # when treating it as ordinary prototype
print(f"cls unit feature hw: {self.cls_unit.feature_hw}")
else:
raise NotImplementedError(
f'Classifier {self.config["cls_name"]} not implemented')
def forward_resolutions(self, resolutions, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
predictions = []
for res in resolutions:
supp_imgs_resized = [[F.interpolate(supp_img[0], size=(
res, res), mode='bilinear') for supp_img in supp_imgs]] if supp_imgs[0][0].shape[-1] != res else supp_imgs
fore_mask_resized = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != res else fore_mask
back_mask_resized = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != res else back_mask
qry_imgs_resized = [F.interpolate(qry_img, size=(res, res), mode='bilinear')
for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != res else qry_imgs
pred = self.forward(supp_imgs_resized, fore_mask_resized, back_mask_resized,
qry_imgs_resized, isval, val_wsize, show_viz, supp_fts)[0]
predictions.append(pred)
def resize_inputs_to_image_size(self, supp_imgs, fore_mask, back_mask, qry_imgs):
supp_imgs = [[F.interpolate(supp_img, size=(
self.image_size, self.image_size), mode='bilinear') for supp_img in supp_imgs_way] for supp_imgs_way in supp_imgs]
fore_mask = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != self.image_size else fore_mask
back_mask = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != self.image_size else back_mask
qry_imgs = [F.interpolate(qry_img, size=(self.image_size, self.image_size), mode='bilinear')
for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != self.image_size else qry_imgs
return supp_imgs, fore_mask, back_mask, qry_imgs
def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
"""
Args:
supp_imgs: support images
way x shot x [B x 3 x H x W], list of lists of tensors
fore_mask: foreground masks for support images
way x shot x [B x H x W], list of lists of tensors
back_mask: background masks for support images
way x shot x [B x H x W], list of lists of tensors
qry_imgs: query images
N x [B x 3 x H x W], list of tensors
show_viz: return the visualization dictionary
"""
# ('Please go through this piece of code carefully')
# supp_imgs, fore_mask, back_mask, qry_imgs = self.resize_inputs_to_image_size(
# supp_imgs, fore_mask, back_mask, qry_imgs)
n_ways = len(supp_imgs)
n_shots = len(supp_imgs[0])
n_queries = len(qry_imgs)
# NOTE: actual shot in support goes in batch dimension
assert n_ways == 1, "Multi-shot has not been implemented yet"
assert n_queries == 1
sup_bsize = supp_imgs[0][0].shape[0]
img_size = supp_imgs[0][0].shape[-2:]
if self.config["cls_name"] == 'grid_proto_3d':
img_size = supp_imgs[0][0].shape[-3:]
qry_bsize = qry_imgs[0].shape[0]
imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs]
+ [torch.cat(qry_imgs, dim=0),], dim=0)
img_fts = self.get_features(imgs_concat)
if len(img_fts.shape) == 5: # for 3D
fts_size = img_fts.shape[-3:]
else:
fts_size = img_fts.shape[-2:]
if supp_fts is None:
supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view(
n_ways, n_shots, sup_bsize, -1, *fts_size) # wa x sh x b x c x h' x w'
qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view(
n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W'
else:
# N x B x C x H' x W'
qry_fts = img_fts.view(n_queries, qry_bsize, -1, *fts_size)
fore_mask = torch.stack([torch.stack(way, dim=0)
for way in fore_mask], dim=0) # Wa x Sh x B x H' x W'
fore_mask = torch.autograd.Variable(fore_mask, requires_grad=True)
back_mask = torch.stack([torch.stack(way, dim=0)
for way in back_mask], dim=0) # Wa x Sh x B x H' x W'
###### Compute loss ######
align_loss = 0
outputs = []
visualizes = [] # the buffer for visualization
for epi in range(1): # batch dimension, fixed to 1
fg_masks = [] # keep the way part
'''
for way in range(n_ways):
# note: index of n_ways starts from 0
mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this
mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear')
fg_masks.append( mean_sup_msk )
mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W]
'''
# re-interpolate support mask to the same size as support feature
if len(fts_size) == 3: # TODO make more generic
res_fg_msk = torch.stack([F.interpolate(fore_mask[0][0].unsqueeze(
0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
res_bg_msk = torch.stack([F.interpolate(back_mask[0][0].unsqueeze(
0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
else:
res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size=fts_size, mode='nearest')
for fore_mask_w in fore_mask], dim=0) # [nway, ns, nb, nh', nw']
res_bg_msk = torch.stack([F.interpolate(back_mask_w, size=fts_size, mode='nearest')
for back_mask_w in back_mask], dim=0) # [nway, ns, nb, nh', nw']
scores = []
assign_maps = []
bg_sim_maps = []
fg_sim_maps = []
bg_mode = BG_PROT_MODE
_raw_score, _, aux_attr, _ = self.cls_unit(
qry_fts, supp_fts, res_bg_msk, mode=bg_mode, thresh=BG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
scores.append(_raw_score)
assign_maps.append(aux_attr['proto_assign'])
for way, _msks in enumerate(res_fg_msk):
raw_scores = []
for i, _msk in enumerate(_msks):
_msk = _msk.unsqueeze(0)
supp_ft = supp_fts[:, i].unsqueeze(0)
if self.config["cls_name"] == 'grid_proto_3d': # 3D
k_size = self.cls_unit.kernel_size
fg_mode = FG_PROT_MODE if F.avg_pool3d(_msk, k_size).max(
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' # TODO figure out kernel size
else:
k_size = self.cls_unit.kernel_size
fg_mode = FG_PROT_MODE if F.avg_pool2d(_msk, k_size).max(
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
# TODO figure out kernel size
_raw_score, _, aux_attr, proto_grid = self.cls_unit(qry_fts, supp_ft, _msk.unsqueeze(
0), mode=fg_mode, thresh=FG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
raw_scores.append(_raw_score)
# create a score where each feature is the max of the raw_score
_raw_score = torch.stack(raw_scores, dim=1).max(dim=1)[
0]
scores.append(_raw_score)
assign_maps.append(aux_attr['proto_assign'])
if show_viz:
fg_sim_maps.append(aux_attr['raw_local_sims'])
# print(f"Time for fg: {time.time() - start_time}")
pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
interpolate_mode = 'bilinear'
outputs.append(F.interpolate(
pred, size=img_size, mode=interpolate_mode))
###### Prototype alignment loss ######
if self.config['align'] and self.training:
align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi],
fore_mask[:, :, epi], back_mask[:, :, epi])
align_loss += align_loss_epi
output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W
grid_shape = output.shape[2:]
if self.config["cls_name"] == 'grid_proto_3d':
grid_shape = output.shape[2:]
output = output.view(-1, *grid_shape)
assign_maps = torch.stack(assign_maps, dim=1) if show_viz else None
bg_sim_maps = torch.stack(bg_sim_maps, dim=1) if show_viz else None
fg_sim_maps = torch.stack(fg_sim_maps, dim=1) if show_viz else None
return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps, proto_grid, supp_fts, qry_fts
def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask):
"""
Compute the loss for the prototype alignment branch
Args:
qry_fts: embedding features for query images
expect shape: N x C x H' x W'
pred: predicted segmentation score
expect shape: N x (1 + Wa) x H x W
supp_fts: embedding fatures for support images
expect shape: Wa x Sh x C x H' x W'
fore_mask: foreground masks for support images
expect shape: way x shot x H x W
back_mask: background masks for support images
expect shape: way x shot x H x W
"""
n_ways, n_shots = len(fore_mask), len(fore_mask[0])
# Masks for getting query prototype
pred_mask = pred.argmax(dim=1).unsqueeze(0) # 1 x N x H' x W'
binary_masks = [pred_mask == i for i in range(1 + n_ways)]
# skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
# FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness
skip_ways = []
# added for matching dimensions to the new data format
qry_fts = qry_fts.unsqueeze(0).unsqueeze(
2) # added to nway(1) and nb(1)
# end of added part
loss = []
for way in range(n_ways):
if way in skip_ways:
continue
# Get the query prototypes
for shot in range(n_shots):
# actual local query [way(1), nb(1, nb is now nshot), nc, h, w]
img_fts = supp_fts[way: way + 1, shot: shot + 1]
size = img_fts.shape[-2:]
mode = 'bilinear'
if self.config["cls_name"] == 'grid_proto_3d':
size = img_fts.shape[-3:]
mode = 'trilinear'
qry_pred_fg_msk = F.interpolate(
binary_masks[way + 1].float(), size=size, mode=mode) # [1 (way), n (shot), h, w]
# background
qry_pred_bg_msk = F.interpolate(
binary_masks[0].float(), size=size, mode=mode) # 1, n, h ,w
scores = []
bg_mode = BG_PROT_MODE
_raw_score_bg, _, _, _ = self.cls_unit(
qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_bg_msk.unsqueeze(-3), mode=bg_mode, thresh=BG_THRESH)
scores.append(_raw_score_bg)
if self.config["cls_name"] == 'grid_proto_3d':
fg_mode = FG_PROT_MODE if F.avg_pool3d(qry_pred_fg_msk, 4).max(
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
else:
fg_mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max(
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
_raw_score_fg, _, _, _ = self.cls_unit(
qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_fg_msk.unsqueeze(2), mode=fg_mode, thresh=FG_THRESH)
scores.append(_raw_score_fg)
supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
size = fore_mask.shape[-2:]
if self.config["cls_name"] == 'grid_proto_3d':
size = fore_mask.shape[-3:]
supp_pred = F.interpolate(supp_pred, size=size, mode=mode)
# Construct the support Ground-Truth segmentation
supp_label = torch.full_like(fore_mask[way, shot], 255,
device=img_fts.device).long()
supp_label[fore_mask[way, shot] == 1] = 1
supp_label[back_mask[way, shot] == 1] = 0
# Compute Loss
loss.append(F.cross_entropy(
supp_pred.float(), supp_label[None, ...], ignore_index=255) / n_shots / n_ways)
return torch.sum(torch.stack(loss))
def dino_cls_loss(self, teacher_cls_tokens, student_cls_tokens):
cls_loss_weight = 0.1
student_temp = 1
teacher_cls_tokens = self.sinkhorn_knopp_teacher(teacher_cls_tokens)
lsm = F.log_softmax(student_cls_tokens / student_temp, dim=-1)
cls_loss = torch.sum(teacher_cls_tokens * lsm, dim=-1)
return -cls_loss.mean() * cls_loss_weight
@torch.no_grad()
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp=1, n_iterations=3):
teacher_output = teacher_output.float()
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# Q is K-by-B for consistency with notations from our paper
Q = torch.exp(teacher_output / teacher_temp).t()
# B = Q.shape[1] * world_size # number of samples to assign
B = Q.shape[1]
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
Q /= sum_Q
for it in range(n_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
Q /= sum_of_rows
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the columns must sum to 1 so that Q is an assignment
return Q.t()
def dino_patch_loss(self, features, masked_features, masks):
# for both supp and query features perform the patch wise loss
loss = 0.0
weight = 0.1
B = features.shape[0]
for (f, mf, mask) in zip(features, masked_features, masks):
# TODO sinkhorn knopp center features
f = f[mask]
f = self.sinkhorn_knopp_teacher(f)
mf = mf[mask]
loss += torch.sum(f * F.log_softmax(mf / 1,
dim=-1), dim=-1) / mask.sum()
return -loss.sum() * weight / B
|