# Copyright (c) OpenMMLab. All rights reserved. import copy import math from typing import Callable, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import Scale from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d from mmengine.config import ConfigDict from mmengine.model import BaseModel from mmengine.structures import InstanceData from torch import Tensor try: from transformers import BertConfig except ImportError: BertConfig = None from mmdet.registry import MODELS from mmdet.structures.bbox import cat_boxes from mmdet.utils import InstanceList, OptInstanceList, reduce_mean from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk, permute_and_flatten, select_single_mlvl, unpack_gt_instances) from ..utils.vlfuse_helper import MAX_CLAMP_VALUE from .atss_head import ATSSHead def convert_grounding_to_cls_scores(logits: Tensor, positive_maps: List[dict]) -> Tensor: """Convert logits to class scores.""" assert len(positive_maps) == logits.shape[0] # batch size scores = torch.zeros(logits.shape[0], logits.shape[1], len(positive_maps[0])).to(logits.device) if positive_maps is not None: if all(x == positive_maps[0] for x in positive_maps): # only need to compute once positive_map = positive_maps[0] for label_j in positive_map: scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j] )].mean(-1) else: for i, positive_map in enumerate(positive_maps): for label_j in positive_map: scores[i, :, label_j - 1] = logits[ i, :, torch.LongTensor(positive_map[label_j])].mean(-1) return scores class Conv3x3Norm(nn.Module): """Conv3x3 and norm.""" def __init__(self, in_channels: int, out_channels: int, stride: int, groups: int = 1, use_dcn: bool = False, norm_type: Optional[Union[Sequence, str]] = None): super().__init__() if use_dcn: self.conv = ModulatedDeformConv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups) else: self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups) if isinstance(norm_type, Sequence): assert len(norm_type) == 2 assert norm_type[0] == 'gn' gn_group = norm_type[1] norm_type = norm_type[0] if norm_type == 'bn': bn_op = nn.BatchNorm2d(out_channels) elif norm_type == 'gn': bn_op = nn.GroupNorm( num_groups=gn_group, num_channels=out_channels) if norm_type is not None: self.bn = bn_op else: self.bn = None def forward(self, x, **kwargs): x = self.conv(x, **kwargs) if self.bn: x = self.bn(x) return x class DyReLU(nn.Module): """Dynamic ReLU.""" def __init__(self, in_channels: int, out_channels: int, expand_ratio: int = 4): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.expand_ratio = expand_ratio self.out_channels = out_channels self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // expand_ratio), nn.ReLU(inplace=True), nn.Linear(in_channels // expand_ratio, out_channels * self.expand_ratio), nn.Hardsigmoid(inplace=True)) def forward(self, x) -> Tensor: x_out = x b, c, h, w = x.size() x = self.avg_pool(x).view(b, c) x = self.fc(x).view(b, -1, 1, 1) a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1) a1 = (a1 - 0.5) * 2 + 1.0 a2 = (a2 - 0.5) * 2 b1 = b1 - 0.5 b2 = b2 - 0.5 out = torch.max(x_out * a1 + b1, x_out * a2 + b2) return out class DyConv(nn.Module): """Dynamic Convolution.""" def __init__(self, conv_func: Callable, in_channels: int, out_channels: int, use_dyfuse: bool = True, use_dyrelu: bool = False, use_dcn: bool = False): super().__init__() self.dyconvs = nn.ModuleList() self.dyconvs.append(conv_func(in_channels, out_channels, 1)) self.dyconvs.append(conv_func(in_channels, out_channels, 1)) self.dyconvs.append(conv_func(in_channels, out_channels, 2)) if use_dyfuse: self.attnconv = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, 1, kernel_size=1), nn.ReLU(inplace=True)) self.h_sigmoid = nn.Hardsigmoid(inplace=True) else: self.attnconv = None if use_dyrelu: self.relu = DyReLU(in_channels, out_channels) else: self.relu = nn.ReLU() if use_dcn: self.offset = nn.Conv2d( in_channels, 27, kernel_size=3, stride=1, padding=1) else: self.offset = None self.init_weights() def init_weights(self): for m in self.dyconvs.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0, 0.01) if m.bias is not None: m.bias.data.zero_() if self.attnconv is not None: for m in self.attnconv.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0, 0.01) if m.bias is not None: m.bias.data.zero_() def forward(self, inputs: dict) -> dict: visual_feats = inputs['visual'] out_vis_feats = [] for level, feature in enumerate(visual_feats): offset_conv_args = {} if self.offset is not None: offset_mask = self.offset(feature) offset = offset_mask[:, :18, :, :] mask = offset_mask[:, 18:, :, :].sigmoid() offset_conv_args = dict(offset=offset, mask=mask) temp_feats = [self.dyconvs[1](feature, **offset_conv_args)] if level > 0: temp_feats.append(self.dyconvs[2](visual_feats[level - 1], **offset_conv_args)) if level < len(visual_feats) - 1: temp_feats.append( F.upsample_bilinear( self.dyconvs[0](visual_feats[level + 1], **offset_conv_args), size=[feature.size(2), feature.size(3)])) mean_feats = torch.mean( torch.stack(temp_feats), dim=0, keepdim=False) if self.attnconv is not None: attn_feat = [] res_feat = [] for feat in temp_feats: res_feat.append(feat) attn_feat.append(self.attnconv(feat)) res_feat = torch.stack(res_feat) spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat)) mean_feats = torch.mean( res_feat * spa_pyr_attn, dim=0, keepdim=False) out_vis_feats.append(mean_feats) out_vis_feats = [self.relu(item) for item in out_vis_feats] features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']} return features_dict class VLFusionModule(BaseModel): """Visual-lang Fusion Module.""" def __init__(self, in_channels: int, feat_channels: int, num_base_priors: int, early_fuse: bool = False, num_dyhead_blocks: int = 6, lang_model_name: str = 'bert-base-uncased', use_dyrelu: bool = True, use_dyfuse: bool = True, use_dcn: bool = True, use_checkpoint: bool = False, **kwargs) -> None: super().__init__(**kwargs) if BertConfig is None: raise RuntimeError( 'transformers is not installed, please install it by: ' 'pip install transformers.') self.in_channels = in_channels self.feat_channels = feat_channels self.num_base_priors = num_base_priors self.early_fuse = early_fuse self.num_dyhead_blocks = num_dyhead_blocks self.use_dyrelu = use_dyrelu self.use_dyfuse = use_dyfuse self.use_dcn = use_dcn self.use_checkpoint = use_checkpoint self.lang_cfg = BertConfig.from_pretrained(lang_model_name) self.lang_dim = self.lang_cfg.hidden_size self._init_layers() def _init_layers(self) -> None: """Initialize layers of the model.""" bias_value = -math.log((1 - 0.01) / 0.01) dyhead_tower = [] for i in range(self.num_dyhead_blocks): if self.early_fuse: # cross-modality fusion dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint)) # lang branch dyhead_tower.append( BertEncoderLayer( self.lang_cfg, clamp_min_for_underflow=True, clamp_max_for_overflow=True)) # vision branch dyhead_tower.append( DyConv( lambda i, o, s: Conv3x3Norm( i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]), self.in_channels if i == 0 else self.feat_channels, self.feat_channels, use_dyrelu=(self.use_dyrelu and self.in_channels == self.feat_channels) if i == 0 else self.use_dyrelu, use_dyfuse=(self.use_dyfuse and self.in_channels == self.feat_channels) if i == 0 else self.use_dyfuse, use_dcn=(self.use_dcn and self.in_channels == self.feat_channels) if i == 0 else self.use_dcn, )) self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) self.bbox_pred = nn.Conv2d( self.feat_channels, self.num_base_priors * 4, kernel_size=1) self.centerness = nn.Conv2d( self.feat_channels, self.num_base_priors * 1, kernel_size=1) self.dot_product_projection_text = nn.Linear( self.lang_dim, self.num_base_priors * self.feat_channels, bias=True) self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True) self.bias_lang = nn.Parameter( torch.zeros(self.lang_dim), requires_grad=True) self.bias0 = nn.Parameter( torch.Tensor([bias_value]), requires_grad=True) self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)]) def forward(self, visual_feats: Tuple[Tensor], language_feats: dict) -> Tuple: feat_inputs = {'visual': visual_feats, 'lang': language_feats} dyhead_tower = self.dyhead_tower(feat_inputs) if self.early_fuse: embedding = dyhead_tower['lang']['hidden'] else: embedding = language_feats['embedded'] embedding = F.normalize(embedding, p=2, dim=-1) dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0) dot_product_proj_tokens_bias = torch.matmul( embedding, self.bias_lang) + self.bias0 bbox_preds = [] centerness = [] cls_logits = [] for i, feature in enumerate(visual_feats): visual = dyhead_tower['visual'][i] B, C, H, W = visual.shape bbox_pred = self.scales[i](self.bbox_pred(visual)) bbox_preds.append(bbox_pred) centerness.append(self.centerness(visual)) dot_product_proj_queries = permute_and_flatten( visual, B, self.num_base_priors, C, H, W) bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat( 1, self.num_base_priors, 1) dot_product_logit = ( torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp()) + bias dot_product_logit = torch.clamp( dot_product_logit, max=MAX_CLAMP_VALUE) dot_product_logit = torch.clamp( dot_product_logit, min=-MAX_CLAMP_VALUE) cls_logits.append(dot_product_logit) return bbox_preds, centerness, cls_logits @MODELS.register_module() class ATSSVLFusionHead(ATSSHead): """ATSS head with visual-language fusion module. Args: early_fuse (bool): Whether to fuse visual and language features Defaults to False. use_checkpoint (bool): Whether to use checkpoint. Defaults to False. num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6. lang_model_name (str): Name of the language model. Defaults to 'bert-base-uncased'. """ def __init__(self, *args, early_fuse: bool = False, use_checkpoint: bool = False, num_dyhead_blocks: int = 6, lang_model_name: str = 'bert-base-uncased', init_cfg=None, **kwargs): super().__init__(*args, **kwargs, init_cfg=init_cfg) self.head = VLFusionModule( in_channels=self.in_channels, feat_channels=self.feat_channels, num_base_priors=self.num_base_priors, early_fuse=early_fuse, use_checkpoint=use_checkpoint, num_dyhead_blocks=num_dyhead_blocks, lang_model_name=lang_model_name) self.text_masks = None def _init_layers(self) -> None: """No need to initialize the ATSS head layer.""" pass def forward(self, visual_feats: Tuple[Tensor], language_feats: dict) -> Tuple[Tensor]: """Forward function.""" bbox_preds, centerness, cls_logits = self.head(visual_feats, language_feats) return cls_logits, bbox_preds, centerness def loss(self, visual_feats: Tuple[Tensor], language_feats: dict, batch_data_samples): outputs = unpack_gt_instances(batch_data_samples) (batch_gt_instances, batch_gt_instances_ignore, batch_img_metas) = outputs outs = self(visual_feats, language_feats) self.text_masks = language_feats['masks'] loss_inputs = outs + (batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) losses = self.loss_by_feat(*loss_inputs) return losses def loss_by_feat( self, cls_scores: List[Tensor], bbox_preds: List[Tensor], centernesses: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None) -> dict: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W) centernesses (list[Tensor]): Centerness for each scale level with shape (N, num_anchors * 1, H, W) batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, batch_img_metas, device=device) cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore) (anchor_list, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, avg_factor) = cls_reg_targets avg_factor = reduce_mean( torch.tensor(avg_factor, dtype=torch.float, device=device)).item() anchors = torch.cat(anchor_list, dim=1) labels = torch.cat(labels_list, dim=1) label_weights = torch.cat(label_weights_list, dim=1) bbox_targets = torch.cat(bbox_targets_list, dim=1) cls_scores = torch.cat(cls_scores, dim=1) centernesses_ = [] bbox_preds_ = [] for bbox_pred, centerness in zip(bbox_preds, centernesses): centernesses_.append( centerness.permute(0, 2, 3, 1).reshape(cls_scores.size(0), -1, 1)) bbox_preds_.append( bbox_pred.permute(0, 2, 3, 1).reshape(cls_scores.size(0), -1, 4)) bbox_preds = torch.cat(bbox_preds_, dim=1) centernesses = torch.cat(centernesses_, dim=1) losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = \ self._loss_by_feat( anchors, cls_scores, bbox_preds, centernesses, labels, label_weights, bbox_targets, avg_factor=avg_factor) bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() losses_bbox = losses_bbox / bbox_avg_factor return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_centerness=loss_centerness) def _loss_by_feat(self, anchors: Tensor, cls_score: Tensor, bbox_pred: Tensor, centerness: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, avg_factor: float) -> dict: """Calculate the loss of all scale level based on the features extracted by the detection head. Returns: dict[str, Tensor]: A dictionary of loss components. """ anchors = anchors.reshape(-1, 4) # ===== this change ===== pos_inds = (labels.sum(-1) > 0).reshape(-1) # Loss is not computed for the padded regions of the text. assert (self.text_masks.dim() == 2) text_mask = (self.text_masks > 0).unsqueeze(1) text_mask = text_mask.repeat(1, cls_score.size(1), 1) cls_score = torch.masked_select(cls_score, text_mask).contiguous() labels = torch.masked_select(labels, text_mask) label_weights = label_weights[..., None].repeat(1, 1, text_mask.size(-1)) label_weights = torch.masked_select(label_weights, text_mask) bbox_pred = bbox_pred.reshape(-1, 4) centerness = centerness.reshape(-1) bbox_targets = bbox_targets.reshape(-1, 4) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) # classification loss loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor) if pos_inds.sum() > 0: pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] pos_anchors = anchors[pos_inds] pos_centerness = centerness[pos_inds] centerness_targets = self.centerness_target( pos_anchors, pos_bbox_targets) if torch.isnan(centerness_targets).any(): print('=====Centerness includes NaN=====') mask = ~torch.isnan(centerness_targets) centerness_targets = centerness_targets[mask] pos_centerness = pos_centerness[mask] pos_anchors = pos_anchors[mask] pos_bbox_targets = pos_bbox_targets[mask] pos_bbox_pred = pos_bbox_pred[mask] if pos_bbox_targets.shape[0] == 0: loss_bbox = bbox_pred.sum() * 0 loss_centerness = centerness.sum() * 0 centerness_targets = bbox_targets.new_tensor(0.) return loss_cls, loss_bbox, loss_centerness, \ centerness_targets.sum() # The decoding process takes the offset into consideration. pos_anchors[:, 2:] += 1 pos_decode_bbox_pred = self.bbox_coder.decode( pos_anchors, pos_bbox_pred) # regression loss loss_bbox = self.loss_bbox( pos_decode_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) # centerness loss loss_centerness = self.loss_centerness( pos_centerness, centerness_targets, avg_factor=avg_factor) else: loss_bbox = bbox_pred.sum() * 0 loss_centerness = centerness.sum() * 0 centerness_targets = bbox_targets.new_tensor(0.) return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum() def _get_targets_single(self, flat_anchors: Tensor, valid_flags: Tensor, num_level_anchors: List[int], gt_instances: InstanceData, img_meta: dict, gt_instances_ignore: Optional[InstanceData] = None, unmap_outputs: bool = True) -> tuple: """Compute regression, classification targets for anchors in a single image. Args: flat_anchors (Tensor): Multi-level anchors of the image, which are concatenated into a single tensor of shape (num_anchors ,4) valid_flags (Tensor): Multi level valid flags of the image, which are concatenated into a single tensor of shape (num_anchors,). num_level_anchors (List[int]): Number of anchors of each scale level. gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It usually includes ``bboxes`` and ``labels`` attributes. img_meta (dict): Meta information for current image. gt_instances_ignore (:obj:`InstanceData`, optional): Instances to be ignored during training. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple: N is the number of total anchors in the image. labels (Tensor): Labels of all anchors in the image with shape (N,). label_weights (Tensor): Label weights of all anchor in the image with shape (N,). bbox_targets (Tensor): BBox targets of all anchors in the image with shape (N, 4). bbox_weights (Tensor): BBox weights of all anchors in the image with shape (N, 4) pos_inds (Tensor): Indices of positive anchor with shape (num_pos,). neg_inds (Tensor): Indices of negative anchor with shape (num_neg,). sampling_result (:obj:`SamplingResult`): Sampling results. """ anchors = flat_anchors # Align the official implementation anchors[:, 2:] -= 1 num_level_anchors_inside = num_level_anchors pred_instances = InstanceData(priors=anchors) assign_result = self.assigner.assign(pred_instances, num_level_anchors_inside, gt_instances, gt_instances_ignore) sampling_result = self.sampler.sample(assign_result, pred_instances, gt_instances) num_valid_anchors = anchors.shape[0] bbox_targets = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors) # ===== this change ===== labels = anchors.new_full((num_valid_anchors, self.feat_channels), 0, dtype=torch.float32) label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: if self.reg_decoded_bbox: pos_bbox_targets = sampling_result.pos_gt_bboxes else: pos_bbox_targets = self.bbox_coder.encode( sampling_result.pos_priors, sampling_result.pos_gt_bboxes) bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 # ===== this change ===== labels[pos_inds] = gt_instances.positive_maps[ sampling_result.pos_assigned_gt_inds] if self.train_cfg['pos_weight'] <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = self.train_cfg['pos_weight'] if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 return (anchors, labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds, sampling_result) def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor: """Calculate the centerness between anchors and gts. Only calculate pos centerness targets, otherwise there may be nan. Args: anchors (Tensor): Anchors with shape (N, 4), "xyxy" format. gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format. Returns: Tensor: Centerness between anchors and gts. """ anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 l_ = anchors_cx - gts[:, 0] t_ = anchors_cy - gts[:, 1] r_ = gts[:, 2] - anchors_cx b_ = gts[:, 3] - anchors_cy left_right = torch.stack([l_, r_], dim=1) top_bottom = torch.stack([t_, b_], dim=1) centerness = torch.sqrt( (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) # assert not torch.isnan(centerness).any() return centerness def predict(self, visual_feats: Tuple[Tensor], language_feats: dict, batch_data_samples, rescale: bool = True): """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. Args: visual_feats (tuple[Tensor]): Multi-level visual features from the upstream network, each is a 4D-tensor. language_feats (dict): Language features from the upstream network. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[obj:`InstanceData`]: Detection results of each image after the post process. """ batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] batch_token_positive_maps = [ data_samples.token_positive_map for data_samples in batch_data_samples ] outs = self(visual_feats, language_feats) predictions = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, batch_token_positive_maps=batch_token_positive_maps, rescale=rescale) return predictions def predict_by_feat(self, cls_logits: List[Tensor], bbox_preds: List[Tensor], score_factors: List[Tensor], batch_img_metas: Optional[List[dict]] = None, batch_token_positive_maps: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = False, with_nms: bool = True) -> InstanceList: """Transform a batch of output features extracted from the head into bbox results. Note: When score_factors is not None, the cls_scores are usually multiplied by it then obtain the real score used in NMS, such as CenterNess in FCOS, IoU branch in ATSS. Args: cls_logits (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). score_factors (list[Tensor], optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, num_priors * 1, H, W). Defaults to None. batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. batch_token_positive_maps (list[dict], Optional): Batch token positive map. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: list[:obj:`InstanceData`]: Object detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ assert len(bbox_preds) == len(score_factors) num_levels = len(bbox_preds) featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)] mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) result_list = [] for img_id in range(len(batch_img_metas)): img_meta = batch_img_metas[img_id] token_positive_maps = batch_token_positive_maps[img_id] bbox_pred_list = select_single_mlvl( bbox_preds, img_id, detach=True) score_factor_list = select_single_mlvl( score_factors, img_id, detach=True) cls_logit_list = select_single_mlvl( cls_logits, img_id, detach=True) results = self._predict_by_feat_single( bbox_pred_list=bbox_pred_list, score_factor_list=score_factor_list, cls_logit_list=cls_logit_list, mlvl_priors=mlvl_priors, token_positive_maps=token_positive_maps, img_meta=img_meta, cfg=cfg, rescale=rescale, with_nms=with_nms) result_list.append(results) return result_list def _predict_by_feat_single(self, bbox_pred_list: List[Tensor], score_factor_list: List[Tensor], cls_logit_list: List[Tensor], mlvl_priors: List[Tensor], token_positive_maps: dict, img_meta: dict, cfg: ConfigDict, rescale: bool = True, with_nms: bool = True) -> InstanceData: """Transform a single image's features extracted from the head into bbox results. Args: bbox_pred_list (list[Tensor]): Box energies / deltas from all scale levels of a single image, each item has shape (num_priors * 4, H, W). score_factor_list (list[Tensor]): Score factor from all scale levels of a single image, each item has shape (num_priors * 1, H, W). cls_logit_list (list[Tensor]): Box scores from all scale levels of a single image, each item has shape (num_priors * num_classes, H, W). mlvl_priors (list[Tensor]): Each element in the list is the priors of a single level in feature pyramid. In all anchor-based methods, it has shape (num_priors, 4). In all anchor-free methods, it has shape (num_priors, 2) when `with_stride=True`, otherwise it still has shape (num_priors, 4). token_positive_maps (dict): Token positive map. img_meta (dict): Image meta info. cfg (mmengine.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: :obj:`InstanceData`: Detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) img_shape = img_meta['img_shape'] nms_pre = cfg.get('nms_pre', -1) score_thr = cfg.get('score_thr', 0) mlvl_bbox_preds = [] mlvl_valid_priors = [] mlvl_scores = [] mlvl_labels = [] for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \ enumerate(zip(bbox_pred_list, score_factor_list, cls_logit_list, mlvl_priors)): bbox_pred = bbox_pred.permute(1, 2, 0).reshape( -1, self.bbox_coder.encode_size) score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid() scores = convert_grounding_to_cls_scores( logits=cls_logit.sigmoid()[None], positive_maps=[token_positive_maps])[0] results = filter_scores_and_topk( scores, score_thr, nms_pre, dict(bbox_pred=bbox_pred, priors=priors)) scores, labels, keep_idxs, filtered_results = results bbox_pred = filtered_results['bbox_pred'] priors = filtered_results['priors'] score_factor = score_factor[keep_idxs] scores = torch.sqrt(scores * score_factor) mlvl_bbox_preds.append(bbox_pred) mlvl_valid_priors.append(priors) mlvl_scores.append(scores) mlvl_labels.append(labels) bbox_pred = torch.cat(mlvl_bbox_preds) priors = cat_boxes(mlvl_valid_priors) bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) results = InstanceData() results.bboxes = bboxes results.scores = torch.cat(mlvl_scores) results.labels = torch.cat(mlvl_labels) predictions = self._bbox_post_process( results=results, cfg=cfg, rescale=rescale, with_nms=with_nms, img_meta=img_meta) if len(predictions) > 0: # Note: GLIP adopts a very strange bbox decoder logic, # and if 1 is not added here, it will not align with # the official mAP. predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1 return predictions