|
|
|
import numpy as np |
|
from typing import Any, List |
|
|
|
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY |
|
from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, mask_rcnn_inference |
|
from detectron2.projects.point_rend import ImplicitPointRendMaskHead |
|
from detectron2.projects.point_rend.point_features import point_sample |
|
from detectron2.projects.point_rend.point_head import roi_mask_point_loss |
|
from detectron2.structures import Instances |
|
|
|
from .point_utils import get_point_coords_from_point_annotation |
|
|
|
__all__ = [ |
|
"ImplicitPointRendPointSupHead", |
|
"MaskRCNNConvUpsamplePointSupHead", |
|
] |
|
|
|
|
|
@ROI_MASK_HEAD_REGISTRY.register() |
|
class MaskRCNNConvUpsamplePointSupHead(MaskRCNNConvUpsampleHead): |
|
""" |
|
A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). |
|
Predictions are made with a final 1x1 conv layer. |
|
|
|
The difference with `MaskRCNNConvUpsampleHead` is that this head is trained |
|
with point supervision. Please use the `MaskRCNNConvUpsampleHead` if you want |
|
to train the model with mask supervision. |
|
""" |
|
|
|
def forward(self, x, instances: List[Instances]) -> Any: |
|
""" |
|
Args: |
|
x: input region feature(s) provided by :class:`ROIHeads`. |
|
instances (list[Instances]): contains the boxes & labels corresponding |
|
to the input features. |
|
Exact format is up to its caller to decide. |
|
Typically, this is the foreground instances in training, with |
|
"proposal_boxes" field and other gt annotations. |
|
In inference, it contains boxes that are already predicted. |
|
Returns: |
|
A dict of losses in training. The predicted "instances" in inference. |
|
""" |
|
x = self.layers(x) |
|
if self.training: |
|
N, C, H, W = x.shape |
|
assert H == W |
|
|
|
proposal_boxes = [x.proposal_boxes for x in instances] |
|
assert N == np.sum(len(x) for x in proposal_boxes) |
|
|
|
if N == 0: |
|
return {"loss_mask": x.sum() * 0} |
|
|
|
|
|
point_coords, point_labels = get_point_coords_from_point_annotation(instances) |
|
|
|
mask_logits = point_sample( |
|
x, |
|
point_coords, |
|
align_corners=False, |
|
) |
|
|
|
return {"loss_mask": roi_mask_point_loss(mask_logits, instances, point_labels)} |
|
else: |
|
mask_rcnn_inference(x, instances) |
|
return instances |
|
|
|
|
|
@ROI_MASK_HEAD_REGISTRY.register() |
|
class ImplicitPointRendPointSupHead(ImplicitPointRendMaskHead): |
|
def _uniform_sample_train_points(self, instances): |
|
assert self.training |
|
|
|
point_coords, point_labels = get_point_coords_from_point_annotation(instances) |
|
|
|
return point_coords, point_labels |
|
|