Spaces:
Sleeping
Sleeping
""" | |
ALPNet | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .alpmodule import MultiProtoAsConv | |
from .backbone.torchvision_backbones import TVDeeplabRes101Encoder | |
from util.consts import DEFAULT_FEATURE_SIZE | |
from util.lora import inject_trainable_lora | |
# from util.utils import load_config_from_url, plot_dinov2_fts | |
import math | |
# Specify a local path to the repository (or use installed package instead) | |
FG_PROT_MODE = 'gridconv+' # using both local and global prototype | |
# FG_PROT_MODE = 'mask' | |
# using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet) | |
BG_PROT_MODE = 'gridconv' | |
# thresholds for deciding class of prototypes | |
FG_THRESH = 0.95 | |
BG_THRESH = 0.95 | |
class FewShotSeg(nn.Module): | |
""" | |
ALPNet | |
Args: | |
in_channels: Number of input channels | |
cfg: Model configurations | |
""" | |
def __init__(self, image_size, pretrained_path=None, cfg=None): | |
super(FewShotSeg, self).__init__() | |
self.image_size = image_size | |
self.pretrained_path = pretrained_path | |
print(f'###### Pre-trained path: {self.pretrained_path} ######') | |
self.config = cfg or { | |
'align': False, 'debug': False} | |
self.get_encoder() | |
self.get_cls() | |
if self.pretrained_path: | |
self.load_state_dict(torch.load(self.pretrained_path), strict=True) | |
print( | |
f'###### Pre-trained model f{self.pretrained_path} has been loaded ######') | |
def get_encoder(self): | |
self.config['feature_hw'] = [DEFAULT_FEATURE_SIZE, | |
DEFAULT_FEATURE_SIZE] # default feature map size | |
if self.config['which_model'] == 'dlfcn_res101' or self.config['which_model'] == 'default': | |
use_coco_init = self.config['use_coco_init'] | |
self.encoder = TVDeeplabRes101Encoder(use_coco_init) | |
self.config['feature_hw'] = [ | |
math.ceil(self.image_size/8), math.ceil(self.image_size/8)] | |
elif self.config['which_model'] == 'dinov2_l14': | |
self.encoder = torch.hub.load( | |
'facebookresearch/dinov2', 'dinov2_vitl14') | |
self.config['feature_hw'] = [max( | |
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)] | |
elif self.config['which_model'] == 'dinov2_l14_reg': | |
try: | |
self.encoder = torch.hub.load( | |
'facebookresearch/dinov2', 'dinov2_vitl14_reg') | |
except RuntimeError as e: | |
self.encoder = torch.hub.load( | |
'facebookresearch/dino', 'dinov2_vitl14_reg', force_reload=True) | |
self.config['feature_hw'] = [max( | |
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)] | |
elif self.config['which_model'] == 'dinov2_b14': | |
self.encoder = torch.hub.load( | |
'facebookresearch/dinov2', 'dinov2_vitb14') | |
self.config['feature_hw'] = [max( | |
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)] | |
else: | |
raise NotImplementedError( | |
f'Backbone network {self.config["which_model"]} not implemented') | |
if self.config['lora'] > 0: | |
self.encoder.requires_grad_(False) | |
print(f'Injecting LoRA with rank:{self.config["lora"]}') | |
encoder_lora_params = inject_trainable_lora( | |
self.encoder, r=self.config['lora']) | |
def get_features(self, imgs_concat): | |
if self.config['which_model'] == 'dlfcn_res101': | |
img_fts = self.encoder(imgs_concat, low_level=False) | |
elif 'dino' in self.config['which_model']: | |
# resize imgs_concat to the closest size that is divisble by 14 | |
imgs_concat = F.interpolate(imgs_concat, size=( | |
self.image_size // 14 * 14, self.image_size // 14 * 14), mode='bilinear') | |
dino_fts = self.encoder.forward_features(imgs_concat) | |
img_fts = dino_fts["x_norm_patchtokens"] # B, HW, C | |
img_fts = img_fts.permute(0, 2, 1) # B, C, HW | |
C, HW = img_fts.shape[-2:] | |
img_fts = img_fts.view(-1, C, int(HW**0.5), | |
int(HW**0.5)) # B, C, H, W | |
if HW < DEFAULT_FEATURE_SIZE ** 2: | |
img_fts = F.interpolate(img_fts, size=( | |
DEFAULT_FEATURE_SIZE, DEFAULT_FEATURE_SIZE), mode='bilinear') # this is if h,w < (32,32) | |
else: | |
raise NotImplementedError( | |
f'Backbone network {self.config["which_model"]} not implemented') | |
return img_fts | |
def get_cls(self): | |
""" | |
Obtain the similarity-based classifier | |
""" | |
proto_hw = self.config["proto_grid_size"] | |
if self.config['cls_name'] == 'grid_proto': | |
embed_dim = 256 | |
if 'dinov2_b14' in self.config['which_model']: | |
embed_dim = 768 | |
elif 'dinov2_l14' in self.config['which_model']: | |
embed_dim = 1024 | |
self.cls_unit = MultiProtoAsConv(proto_grid=[proto_hw, proto_hw], feature_hw=self.config["feature_hw"], embed_dim=embed_dim) # when treating it as ordinary prototype | |
print(f"cls unit feature hw: {self.cls_unit.feature_hw}") | |
else: | |
raise NotImplementedError( | |
f'Classifier {self.config["cls_name"]} not implemented') | |
def forward_resolutions(self, resolutions, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None): | |
predictions = [] | |
for res in resolutions: | |
supp_imgs_resized = [[F.interpolate(supp_img[0], size=( | |
res, res), mode='bilinear') for supp_img in supp_imgs]] if supp_imgs[0][0].shape[-1] != res else supp_imgs | |
fore_mask_resized = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[ | |
0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != res else fore_mask | |
back_mask_resized = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[ | |
0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != res else back_mask | |
qry_imgs_resized = [F.interpolate(qry_img, size=(res, res), mode='bilinear') | |
for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != res else qry_imgs | |
pred = self.forward(supp_imgs_resized, fore_mask_resized, back_mask_resized, | |
qry_imgs_resized, isval, val_wsize, show_viz, supp_fts)[0] | |
predictions.append(pred) | |
def resize_inputs_to_image_size(self, supp_imgs, fore_mask, back_mask, qry_imgs): | |
supp_imgs = [[F.interpolate(supp_img, size=( | |
self.image_size, self.image_size), mode='bilinear') for supp_img in supp_imgs_way] for supp_imgs_way in supp_imgs] | |
fore_mask = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[ | |
0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != self.image_size else fore_mask | |
back_mask = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[ | |
0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != self.image_size else back_mask | |
qry_imgs = [F.interpolate(qry_img, size=(self.image_size, self.image_size), mode='bilinear') | |
for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != self.image_size else qry_imgs | |
return supp_imgs, fore_mask, back_mask, qry_imgs | |
def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None): | |
""" | |
Args: | |
supp_imgs: support images | |
way x shot x [B x 3 x H x W], list of lists of tensors | |
fore_mask: foreground masks for support images | |
way x shot x [B x H x W], list of lists of tensors | |
back_mask: background masks for support images | |
way x shot x [B x H x W], list of lists of tensors | |
qry_imgs: query images | |
N x [B x 3 x H x W], list of tensors | |
show_viz: return the visualization dictionary | |
""" | |
# ('Please go through this piece of code carefully') | |
# supp_imgs, fore_mask, back_mask, qry_imgs = self.resize_inputs_to_image_size( | |
# supp_imgs, fore_mask, back_mask, qry_imgs) | |
n_ways = len(supp_imgs) | |
n_shots = len(supp_imgs[0]) | |
n_queries = len(qry_imgs) | |
# NOTE: actual shot in support goes in batch dimension | |
assert n_ways == 1, "Multi-shot has not been implemented yet" | |
assert n_queries == 1 | |
sup_bsize = supp_imgs[0][0].shape[0] | |
img_size = supp_imgs[0][0].shape[-2:] | |
if self.config["cls_name"] == 'grid_proto_3d': | |
img_size = supp_imgs[0][0].shape[-3:] | |
qry_bsize = qry_imgs[0].shape[0] | |
imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs] | |
+ [torch.cat(qry_imgs, dim=0),], dim=0) | |
img_fts = self.get_features(imgs_concat) | |
if len(img_fts.shape) == 5: # for 3D | |
fts_size = img_fts.shape[-3:] | |
else: | |
fts_size = img_fts.shape[-2:] | |
if supp_fts is None: | |
supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view( | |
n_ways, n_shots, sup_bsize, -1, *fts_size) # wa x sh x b x c x h' x w' | |
qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view( | |
n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W' | |
else: | |
# N x B x C x H' x W' | |
qry_fts = img_fts.view(n_queries, qry_bsize, -1, *fts_size) | |
fore_mask = torch.stack([torch.stack(way, dim=0) | |
for way in fore_mask], dim=0) # Wa x Sh x B x H' x W' | |
fore_mask = torch.autograd.Variable(fore_mask, requires_grad=True) | |
back_mask = torch.stack([torch.stack(way, dim=0) | |
for way in back_mask], dim=0) # Wa x Sh x B x H' x W' | |
###### Compute loss ###### | |
align_loss = 0 | |
outputs = [] | |
visualizes = [] # the buffer for visualization | |
for epi in range(1): # batch dimension, fixed to 1 | |
fg_masks = [] # keep the way part | |
''' | |
for way in range(n_ways): | |
# note: index of n_ways starts from 0 | |
mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this | |
mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') | |
fg_masks.append( mean_sup_msk ) | |
mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W] | |
''' | |
# re-interpolate support mask to the same size as support feature | |
if len(fts_size) == 3: # TODO make more generic | |
res_fg_msk = torch.stack([F.interpolate(fore_mask[0][0].unsqueeze( | |
0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw']) | |
res_bg_msk = torch.stack([F.interpolate(back_mask[0][0].unsqueeze( | |
0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw']) | |
else: | |
res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size=fts_size, mode='nearest') | |
for fore_mask_w in fore_mask], dim=0) # [nway, ns, nb, nh', nw'] | |
res_bg_msk = torch.stack([F.interpolate(back_mask_w, size=fts_size, mode='nearest') | |
for back_mask_w in back_mask], dim=0) # [nway, ns, nb, nh', nw'] | |
scores = [] | |
assign_maps = [] | |
bg_sim_maps = [] | |
fg_sim_maps = [] | |
bg_mode = BG_PROT_MODE | |
_raw_score, _, aux_attr, _ = self.cls_unit( | |
qry_fts, supp_fts, res_bg_msk, mode=bg_mode, thresh=BG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz) | |
scores.append(_raw_score) | |
assign_maps.append(aux_attr['proto_assign']) | |
for way, _msks in enumerate(res_fg_msk): | |
raw_scores = [] | |
for i, _msk in enumerate(_msks): | |
_msk = _msk.unsqueeze(0) | |
supp_ft = supp_fts[:, i].unsqueeze(0) | |
if self.config["cls_name"] == 'grid_proto_3d': # 3D | |
k_size = self.cls_unit.kernel_size | |
fg_mode = FG_PROT_MODE if F.avg_pool3d(_msk, k_size).max( | |
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' # TODO figure out kernel size | |
else: | |
k_size = self.cls_unit.kernel_size | |
fg_mode = FG_PROT_MODE if F.avg_pool2d(_msk, k_size).max( | |
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' | |
# TODO figure out kernel size | |
_raw_score, _, aux_attr, proto_grid = self.cls_unit(qry_fts, supp_ft, _msk.unsqueeze( | |
0), mode=fg_mode, thresh=FG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz) | |
raw_scores.append(_raw_score) | |
# create a score where each feature is the max of the raw_score | |
_raw_score = torch.stack(raw_scores, dim=1).max(dim=1)[ | |
0] | |
scores.append(_raw_score) | |
assign_maps.append(aux_attr['proto_assign']) | |
if show_viz: | |
fg_sim_maps.append(aux_attr['raw_local_sims']) | |
# print(f"Time for fg: {time.time() - start_time}") | |
pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' | |
interpolate_mode = 'bilinear' | |
outputs.append(F.interpolate( | |
pred, size=img_size, mode=interpolate_mode)) | |
###### Prototype alignment loss ###### | |
if self.config['align'] and self.training: | |
align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi], | |
fore_mask[:, :, epi], back_mask[:, :, epi]) | |
align_loss += align_loss_epi | |
output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W | |
grid_shape = output.shape[2:] | |
if self.config["cls_name"] == 'grid_proto_3d': | |
grid_shape = output.shape[2:] | |
output = output.view(-1, *grid_shape) | |
assign_maps = torch.stack(assign_maps, dim=1) if show_viz else None | |
bg_sim_maps = torch.stack(bg_sim_maps, dim=1) if show_viz else None | |
fg_sim_maps = torch.stack(fg_sim_maps, dim=1) if show_viz else None | |
return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps, proto_grid, supp_fts, qry_fts | |
def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask): | |
""" | |
Compute the loss for the prototype alignment branch | |
Args: | |
qry_fts: embedding features for query images | |
expect shape: N x C x H' x W' | |
pred: predicted segmentation score | |
expect shape: N x (1 + Wa) x H x W | |
supp_fts: embedding fatures for support images | |
expect shape: Wa x Sh x C x H' x W' | |
fore_mask: foreground masks for support images | |
expect shape: way x shot x H x W | |
back_mask: background masks for support images | |
expect shape: way x shot x H x W | |
""" | |
n_ways, n_shots = len(fore_mask), len(fore_mask[0]) | |
# Masks for getting query prototype | |
pred_mask = pred.argmax(dim=1).unsqueeze(0) # 1 x N x H' x W' | |
binary_masks = [pred_mask == i for i in range(1 + n_ways)] | |
# skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0] | |
# FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness | |
skip_ways = [] | |
# added for matching dimensions to the new data format | |
qry_fts = qry_fts.unsqueeze(0).unsqueeze( | |
2) # added to nway(1) and nb(1) | |
# end of added part | |
loss = [] | |
for way in range(n_ways): | |
if way in skip_ways: | |
continue | |
# Get the query prototypes | |
for shot in range(n_shots): | |
# actual local query [way(1), nb(1, nb is now nshot), nc, h, w] | |
img_fts = supp_fts[way: way + 1, shot: shot + 1] | |
size = img_fts.shape[-2:] | |
mode = 'bilinear' | |
if self.config["cls_name"] == 'grid_proto_3d': | |
size = img_fts.shape[-3:] | |
mode = 'trilinear' | |
qry_pred_fg_msk = F.interpolate( | |
binary_masks[way + 1].float(), size=size, mode=mode) # [1 (way), n (shot), h, w] | |
# background | |
qry_pred_bg_msk = F.interpolate( | |
binary_masks[0].float(), size=size, mode=mode) # 1, n, h ,w | |
scores = [] | |
bg_mode = BG_PROT_MODE | |
_raw_score_bg, _, _, _ = self.cls_unit( | |
qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_bg_msk.unsqueeze(-3), mode=bg_mode, thresh=BG_THRESH) | |
scores.append(_raw_score_bg) | |
if self.config["cls_name"] == 'grid_proto_3d': | |
fg_mode = FG_PROT_MODE if F.avg_pool3d(qry_pred_fg_msk, 4).max( | |
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' | |
else: | |
fg_mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max( | |
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' | |
_raw_score_fg, _, _, _ = self.cls_unit( | |
qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_fg_msk.unsqueeze(2), mode=fg_mode, thresh=FG_THRESH) | |
scores.append(_raw_score_fg) | |
supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W' | |
size = fore_mask.shape[-2:] | |
if self.config["cls_name"] == 'grid_proto_3d': | |
size = fore_mask.shape[-3:] | |
supp_pred = F.interpolate(supp_pred, size=size, mode=mode) | |
# Construct the support Ground-Truth segmentation | |
supp_label = torch.full_like(fore_mask[way, shot], 255, | |
device=img_fts.device).long() | |
supp_label[fore_mask[way, shot] == 1] = 1 | |
supp_label[back_mask[way, shot] == 1] = 0 | |
# Compute Loss | |
loss.append(F.cross_entropy( | |
supp_pred.float(), supp_label[None, ...], ignore_index=255) / n_shots / n_ways) | |
return torch.sum(torch.stack(loss)) | |
def dino_cls_loss(self, teacher_cls_tokens, student_cls_tokens): | |
cls_loss_weight = 0.1 | |
student_temp = 1 | |
teacher_cls_tokens = self.sinkhorn_knopp_teacher(teacher_cls_tokens) | |
lsm = F.log_softmax(student_cls_tokens / student_temp, dim=-1) | |
cls_loss = torch.sum(teacher_cls_tokens * lsm, dim=-1) | |
return -cls_loss.mean() * cls_loss_weight | |
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp=1, n_iterations=3): | |
teacher_output = teacher_output.float() | |
# world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
# Q is K-by-B for consistency with notations from our paper | |
Q = torch.exp(teacher_output / teacher_temp).t() | |
# B = Q.shape[1] * world_size # number of samples to assign | |
B = Q.shape[1] | |
K = Q.shape[0] # how many prototypes | |
# make the matrix sums to 1 | |
sum_Q = torch.sum(Q) | |
Q /= sum_Q | |
for it in range(n_iterations): | |
# normalize each row: total weight per prototype must be 1/K | |
sum_of_rows = torch.sum(Q, dim=1, keepdim=True) | |
Q /= sum_of_rows | |
Q /= K | |
# normalize each column: total weight per sample must be 1/B | |
Q /= torch.sum(Q, dim=0, keepdim=True) | |
Q /= B | |
Q *= B # the columns must sum to 1 so that Q is an assignment | |
return Q.t() | |
def dino_patch_loss(self, features, masked_features, masks): | |
# for both supp and query features perform the patch wise loss | |
loss = 0.0 | |
weight = 0.1 | |
B = features.shape[0] | |
for (f, mf, mask) in zip(features, masked_features, masks): | |
# TODO sinkhorn knopp center features | |
f = f[mask] | |
f = self.sinkhorn_knopp_teacher(f) | |
mf = mf[mask] | |
loss += torch.sum(f * F.log_softmax(mf / 1, | |
dim=-1), dim=-1) / mask.sum() | |
return -loss.sum() * weight / B | |