|
|
|
import numpy as np |
|
from typing import Callable, Dict, List, Union |
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import configurable |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.layers import Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm |
|
from detectron2.modeling import ( |
|
META_ARCH_REGISTRY, |
|
SEM_SEG_HEADS_REGISTRY, |
|
build_backbone, |
|
build_sem_seg_head, |
|
) |
|
from detectron2.modeling.postprocessing import sem_seg_postprocess |
|
from detectron2.projects.deeplab import DeepLabV3PlusHead |
|
from detectron2.projects.deeplab.loss import DeepLabCE |
|
from detectron2.structures import BitMasks, ImageList, Instances |
|
from detectron2.utils.registry import Registry |
|
|
|
from .post_processing import get_panoptic_segmentation |
|
|
|
__all__ = ["PanopticDeepLab", "INS_EMBED_BRANCHES_REGISTRY", "build_ins_embed_branch"] |
|
|
|
|
|
INS_EMBED_BRANCHES_REGISTRY = Registry("INS_EMBED_BRANCHES") |
|
INS_EMBED_BRANCHES_REGISTRY.__doc__ = """ |
|
Registry for instance embedding branches, which make instance embedding |
|
predictions from feature maps. |
|
""" |
|
|
|
|
|
@META_ARCH_REGISTRY.register() |
|
class PanopticDeepLab(nn.Module): |
|
""" |
|
Main class for panoptic segmentation architectures. |
|
""" |
|
|
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.backbone = build_backbone(cfg) |
|
self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape()) |
|
self.ins_embed_head = build_ins_embed_branch(cfg, self.backbone.output_shape()) |
|
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) |
|
self.meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) |
|
self.stuff_area = cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA |
|
self.threshold = cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD |
|
self.nms_kernel = cfg.MODEL.PANOPTIC_DEEPLAB.NMS_KERNEL |
|
self.top_k = cfg.MODEL.PANOPTIC_DEEPLAB.TOP_K_INSTANCE |
|
self.predict_instances = cfg.MODEL.PANOPTIC_DEEPLAB.PREDICT_INSTANCES |
|
self.use_depthwise_separable_conv = cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV |
|
assert ( |
|
cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV |
|
== cfg.MODEL.PANOPTIC_DEEPLAB.USE_DEPTHWISE_SEPARABLE_CONV |
|
) |
|
self.size_divisibility = cfg.MODEL.PANOPTIC_DEEPLAB.SIZE_DIVISIBILITY |
|
self.benchmark_network_speed = cfg.MODEL.PANOPTIC_DEEPLAB.BENCHMARK_NETWORK_SPEED |
|
|
|
@property |
|
def device(self): |
|
return self.pixel_mean.device |
|
|
|
def forward(self, batched_inputs): |
|
""" |
|
Args: |
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper`. |
|
Each item in the list contains the inputs for one image. |
|
For now, each item in the list is a dict that contains: |
|
* "image": Tensor, image in (C, H, W) format. |
|
* "sem_seg": semantic segmentation ground truth |
|
* "center": center points heatmap ground truth |
|
* "offset": pixel offsets to center points ground truth |
|
* Other information that's included in the original dicts, such as: |
|
"height", "width" (int): the output resolution of the model (may be different |
|
from input resolution), used in inference. |
|
Returns: |
|
list[dict]: |
|
each dict is the results for one image. The dict contains the following keys: |
|
|
|
* "panoptic_seg", "sem_seg": see documentation |
|
:doc:`/tutorials/models` for the standard output format |
|
* "instances": available if ``predict_instances is True``. see documentation |
|
:doc:`/tutorials/models` for the standard output format |
|
""" |
|
images = [x["image"].to(self.device) for x in batched_inputs] |
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images] |
|
|
|
size_divisibility = ( |
|
self.size_divisibility |
|
if self.size_divisibility > 0 |
|
else self.backbone.size_divisibility |
|
) |
|
images = ImageList.from_tensors(images, size_divisibility) |
|
|
|
features = self.backbone(images.tensor) |
|
|
|
losses = {} |
|
if "sem_seg" in batched_inputs[0]: |
|
targets = [x["sem_seg"].to(self.device) for x in batched_inputs] |
|
targets = ImageList.from_tensors( |
|
targets, size_divisibility, self.sem_seg_head.ignore_value |
|
).tensor |
|
if "sem_seg_weights" in batched_inputs[0]: |
|
|
|
|
|
weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs] |
|
weights = ImageList.from_tensors(weights, size_divisibility).tensor |
|
else: |
|
weights = None |
|
else: |
|
targets = None |
|
weights = None |
|
sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights) |
|
losses.update(sem_seg_losses) |
|
|
|
if "center" in batched_inputs[0] and "offset" in batched_inputs[0]: |
|
center_targets = [x["center"].to(self.device) for x in batched_inputs] |
|
center_targets = ImageList.from_tensors( |
|
center_targets, size_divisibility |
|
).tensor.unsqueeze(1) |
|
center_weights = [x["center_weights"].to(self.device) for x in batched_inputs] |
|
center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor |
|
|
|
offset_targets = [x["offset"].to(self.device) for x in batched_inputs] |
|
offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor |
|
offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs] |
|
offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor |
|
else: |
|
center_targets = None |
|
center_weights = None |
|
|
|
offset_targets = None |
|
offset_weights = None |
|
|
|
center_results, offset_results, center_losses, offset_losses = self.ins_embed_head( |
|
features, center_targets, center_weights, offset_targets, offset_weights |
|
) |
|
losses.update(center_losses) |
|
losses.update(offset_losses) |
|
|
|
if self.training: |
|
return losses |
|
|
|
if self.benchmark_network_speed: |
|
return [] |
|
|
|
processed_results = [] |
|
for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip( |
|
sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes |
|
): |
|
height = input_per_image.get("height") |
|
width = input_per_image.get("width") |
|
r = sem_seg_postprocess(sem_seg_result, image_size, height, width) |
|
c = sem_seg_postprocess(center_result, image_size, height, width) |
|
o = sem_seg_postprocess(offset_result, image_size, height, width) |
|
|
|
panoptic_image, _ = get_panoptic_segmentation( |
|
r.argmax(dim=0, keepdim=True), |
|
c, |
|
o, |
|
thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(), |
|
label_divisor=self.meta.label_divisor, |
|
stuff_area=self.stuff_area, |
|
void_label=-1, |
|
threshold=self.threshold, |
|
nms_kernel=self.nms_kernel, |
|
top_k=self.top_k, |
|
) |
|
|
|
processed_results.append({"sem_seg": r}) |
|
panoptic_image = panoptic_image.squeeze(0) |
|
semantic_prob = F.softmax(r, dim=0) |
|
|
|
processed_results[-1]["panoptic_seg"] = (panoptic_image, None) |
|
|
|
if self.predict_instances: |
|
instances = [] |
|
panoptic_image_cpu = panoptic_image.cpu().numpy() |
|
for panoptic_label in np.unique(panoptic_image_cpu): |
|
if panoptic_label == -1: |
|
continue |
|
pred_class = panoptic_label // self.meta.label_divisor |
|
isthing = pred_class in list( |
|
self.meta.thing_dataset_id_to_contiguous_id.values() |
|
) |
|
|
|
if isthing: |
|
instance = Instances((height, width)) |
|
|
|
instance.pred_classes = torch.tensor( |
|
[pred_class], device=panoptic_image.device |
|
) |
|
mask = panoptic_image == panoptic_label |
|
instance.pred_masks = mask.unsqueeze(0) |
|
|
|
sem_scores = semantic_prob[pred_class, ...] |
|
sem_scores = torch.mean(sem_scores[mask]) |
|
|
|
mask_indices = torch.nonzero(mask).float() |
|
center_y, center_x = ( |
|
torch.mean(mask_indices[:, 0]), |
|
torch.mean(mask_indices[:, 1]), |
|
) |
|
center_scores = c[0, int(center_y.item()), int(center_x.item())] |
|
|
|
instance.scores = torch.tensor( |
|
[sem_scores * center_scores], device=panoptic_image.device |
|
) |
|
|
|
instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes() |
|
instances.append(instance) |
|
if len(instances) > 0: |
|
processed_results[-1]["instances"] = Instances.cat(instances) |
|
|
|
return processed_results |
|
|
|
|
|
@SEM_SEG_HEADS_REGISTRY.register() |
|
class PanopticDeepLabSemSegHead(DeepLabV3PlusHead): |
|
""" |
|
A semantic segmentation head described in :paper:`Panoptic-DeepLab`. |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
input_shape: Dict[str, ShapeSpec], |
|
*, |
|
decoder_channels: List[int], |
|
norm: Union[str, Callable], |
|
head_channels: int, |
|
loss_weight: float, |
|
loss_type: str, |
|
loss_top_k: float, |
|
ignore_value: int, |
|
num_classes: int, |
|
**kwargs, |
|
): |
|
""" |
|
NOTE: this interface is experimental. |
|
|
|
Args: |
|
input_shape (ShapeSpec): shape of the input feature |
|
decoder_channels (list[int]): a list of output channels of each |
|
decoder stage. It should have the same length as "input_shape" |
|
(each element in "input_shape" corresponds to one decoder stage). |
|
norm (str or callable): normalization for all conv layers. |
|
head_channels (int): the output channels of extra convolutions |
|
between decoder and predictor. |
|
loss_weight (float): loss weight. |
|
loss_top_k: (float): setting the top k% hardest pixels for |
|
"hard_pixel_mining" loss. |
|
loss_type, ignore_value, num_classes: the same as the base class. |
|
""" |
|
super().__init__( |
|
input_shape, |
|
decoder_channels=decoder_channels, |
|
norm=norm, |
|
ignore_value=ignore_value, |
|
**kwargs, |
|
) |
|
assert self.decoder_only |
|
|
|
self.loss_weight = loss_weight |
|
use_bias = norm == "" |
|
|
|
if self.use_depthwise_separable_conv: |
|
|
|
|
|
self.head = DepthwiseSeparableConv2d( |
|
decoder_channels[0], |
|
head_channels, |
|
kernel_size=5, |
|
padding=2, |
|
norm1=norm, |
|
activation1=F.relu, |
|
norm2=norm, |
|
activation2=F.relu, |
|
) |
|
else: |
|
self.head = nn.Sequential( |
|
Conv2d( |
|
decoder_channels[0], |
|
decoder_channels[0], |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, decoder_channels[0]), |
|
activation=F.relu, |
|
), |
|
Conv2d( |
|
decoder_channels[0], |
|
head_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, head_channels), |
|
activation=F.relu, |
|
), |
|
) |
|
weight_init.c2_xavier_fill(self.head[0]) |
|
weight_init.c2_xavier_fill(self.head[1]) |
|
self.predictor = Conv2d(head_channels, num_classes, kernel_size=1) |
|
nn.init.normal_(self.predictor.weight, 0, 0.001) |
|
nn.init.constant_(self.predictor.bias, 0) |
|
|
|
if loss_type == "cross_entropy": |
|
self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_value) |
|
elif loss_type == "hard_pixel_mining": |
|
self.loss = DeepLabCE(ignore_label=ignore_value, top_k_percent_pixels=loss_top_k) |
|
else: |
|
raise ValueError("Unexpected loss type: %s" % loss_type) |
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
ret = super().from_config(cfg, input_shape) |
|
ret["head_channels"] = cfg.MODEL.SEM_SEG_HEAD.HEAD_CHANNELS |
|
ret["loss_top_k"] = cfg.MODEL.SEM_SEG_HEAD.LOSS_TOP_K |
|
return ret |
|
|
|
def forward(self, features, targets=None, weights=None): |
|
""" |
|
Returns: |
|
In training, returns (None, dict of losses) |
|
In inference, returns (CxHxW logits, {}) |
|
""" |
|
y = self.layers(features) |
|
if self.training: |
|
return None, self.losses(y, targets, weights) |
|
else: |
|
y = F.interpolate( |
|
y, scale_factor=self.common_stride, mode="bilinear", align_corners=False |
|
) |
|
return y, {} |
|
|
|
def layers(self, features): |
|
assert self.decoder_only |
|
y = super().layers(features) |
|
y = self.head(y) |
|
y = self.predictor(y) |
|
return y |
|
|
|
def losses(self, predictions, targets, weights=None): |
|
predictions = F.interpolate( |
|
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False |
|
) |
|
loss = self.loss(predictions, targets, weights) |
|
losses = {"loss_sem_seg": loss * self.loss_weight} |
|
return losses |
|
|
|
|
|
def build_ins_embed_branch(cfg, input_shape): |
|
""" |
|
Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. |
|
""" |
|
name = cfg.MODEL.INS_EMBED_HEAD.NAME |
|
return INS_EMBED_BRANCHES_REGISTRY.get(name)(cfg, input_shape) |
|
|
|
|
|
@INS_EMBED_BRANCHES_REGISTRY.register() |
|
class PanopticDeepLabInsEmbedHead(DeepLabV3PlusHead): |
|
""" |
|
A instance embedding head described in :paper:`Panoptic-DeepLab`. |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
input_shape: Dict[str, ShapeSpec], |
|
*, |
|
decoder_channels: List[int], |
|
norm: Union[str, Callable], |
|
head_channels: int, |
|
center_loss_weight: float, |
|
offset_loss_weight: float, |
|
**kwargs, |
|
): |
|
""" |
|
NOTE: this interface is experimental. |
|
|
|
Args: |
|
input_shape (ShapeSpec): shape of the input feature |
|
decoder_channels (list[int]): a list of output channels of each |
|
decoder stage. It should have the same length as "input_shape" |
|
(each element in "input_shape" corresponds to one decoder stage). |
|
norm (str or callable): normalization for all conv layers. |
|
head_channels (int): the output channels of extra convolutions |
|
between decoder and predictor. |
|
center_loss_weight (float): loss weight for center point prediction. |
|
offset_loss_weight (float): loss weight for center offset prediction. |
|
""" |
|
super().__init__(input_shape, decoder_channels=decoder_channels, norm=norm, **kwargs) |
|
assert self.decoder_only |
|
|
|
self.center_loss_weight = center_loss_weight |
|
self.offset_loss_weight = offset_loss_weight |
|
use_bias = norm == "" |
|
|
|
|
|
self.center_head = nn.Sequential( |
|
Conv2d( |
|
decoder_channels[0], |
|
decoder_channels[0], |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, decoder_channels[0]), |
|
activation=F.relu, |
|
), |
|
Conv2d( |
|
decoder_channels[0], |
|
head_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, head_channels), |
|
activation=F.relu, |
|
), |
|
) |
|
weight_init.c2_xavier_fill(self.center_head[0]) |
|
weight_init.c2_xavier_fill(self.center_head[1]) |
|
self.center_predictor = Conv2d(head_channels, 1, kernel_size=1) |
|
nn.init.normal_(self.center_predictor.weight, 0, 0.001) |
|
nn.init.constant_(self.center_predictor.bias, 0) |
|
|
|
|
|
|
|
if self.use_depthwise_separable_conv: |
|
|
|
|
|
self.offset_head = DepthwiseSeparableConv2d( |
|
decoder_channels[0], |
|
head_channels, |
|
kernel_size=5, |
|
padding=2, |
|
norm1=norm, |
|
activation1=F.relu, |
|
norm2=norm, |
|
activation2=F.relu, |
|
) |
|
else: |
|
self.offset_head = nn.Sequential( |
|
Conv2d( |
|
decoder_channels[0], |
|
decoder_channels[0], |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, decoder_channels[0]), |
|
activation=F.relu, |
|
), |
|
Conv2d( |
|
decoder_channels[0], |
|
head_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, head_channels), |
|
activation=F.relu, |
|
), |
|
) |
|
weight_init.c2_xavier_fill(self.offset_head[0]) |
|
weight_init.c2_xavier_fill(self.offset_head[1]) |
|
self.offset_predictor = Conv2d(head_channels, 2, kernel_size=1) |
|
nn.init.normal_(self.offset_predictor.weight, 0, 0.001) |
|
nn.init.constant_(self.offset_predictor.bias, 0) |
|
|
|
self.center_loss = nn.MSELoss(reduction="none") |
|
self.offset_loss = nn.L1Loss(reduction="none") |
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
if cfg.INPUT.CROP.ENABLED: |
|
assert cfg.INPUT.CROP.TYPE == "absolute" |
|
train_size = cfg.INPUT.CROP.SIZE |
|
else: |
|
train_size = None |
|
decoder_channels = [cfg.MODEL.INS_EMBED_HEAD.CONVS_DIM] * ( |
|
len(cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES) - 1 |
|
) + [cfg.MODEL.INS_EMBED_HEAD.ASPP_CHANNELS] |
|
ret = dict( |
|
input_shape={ |
|
k: v for k, v in input_shape.items() if k in cfg.MODEL.INS_EMBED_HEAD.IN_FEATURES |
|
}, |
|
project_channels=cfg.MODEL.INS_EMBED_HEAD.PROJECT_CHANNELS, |
|
aspp_dilations=cfg.MODEL.INS_EMBED_HEAD.ASPP_DILATIONS, |
|
aspp_dropout=cfg.MODEL.INS_EMBED_HEAD.ASPP_DROPOUT, |
|
decoder_channels=decoder_channels, |
|
common_stride=cfg.MODEL.INS_EMBED_HEAD.COMMON_STRIDE, |
|
norm=cfg.MODEL.INS_EMBED_HEAD.NORM, |
|
train_size=train_size, |
|
head_channels=cfg.MODEL.INS_EMBED_HEAD.HEAD_CHANNELS, |
|
center_loss_weight=cfg.MODEL.INS_EMBED_HEAD.CENTER_LOSS_WEIGHT, |
|
offset_loss_weight=cfg.MODEL.INS_EMBED_HEAD.OFFSET_LOSS_WEIGHT, |
|
use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV, |
|
) |
|
return ret |
|
|
|
def forward( |
|
self, |
|
features, |
|
center_targets=None, |
|
center_weights=None, |
|
offset_targets=None, |
|
offset_weights=None, |
|
): |
|
""" |
|
Returns: |
|
In training, returns (None, dict of losses) |
|
In inference, returns (CxHxW logits, {}) |
|
""" |
|
center, offset = self.layers(features) |
|
if self.training: |
|
return ( |
|
None, |
|
None, |
|
self.center_losses(center, center_targets, center_weights), |
|
self.offset_losses(offset, offset_targets, offset_weights), |
|
) |
|
else: |
|
center = F.interpolate( |
|
center, scale_factor=self.common_stride, mode="bilinear", align_corners=False |
|
) |
|
offset = ( |
|
F.interpolate( |
|
offset, scale_factor=self.common_stride, mode="bilinear", align_corners=False |
|
) |
|
* self.common_stride |
|
) |
|
return center, offset, {}, {} |
|
|
|
def layers(self, features): |
|
assert self.decoder_only |
|
y = super().layers(features) |
|
|
|
center = self.center_head(y) |
|
center = self.center_predictor(center) |
|
|
|
offset = self.offset_head(y) |
|
offset = self.offset_predictor(offset) |
|
return center, offset |
|
|
|
def center_losses(self, predictions, targets, weights): |
|
predictions = F.interpolate( |
|
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False |
|
) |
|
loss = self.center_loss(predictions, targets) * weights |
|
if weights.sum() > 0: |
|
loss = loss.sum() / weights.sum() |
|
else: |
|
loss = loss.sum() * 0 |
|
losses = {"loss_center": loss * self.center_loss_weight} |
|
return losses |
|
|
|
def offset_losses(self, predictions, targets, weights): |
|
predictions = ( |
|
F.interpolate( |
|
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False |
|
) |
|
* self.common_stride |
|
) |
|
loss = self.offset_loss(predictions, targets) * weights |
|
if weights.sum() > 0: |
|
loss = loss.sum() / weights.sum() |
|
else: |
|
loss = loss.sum() * 0 |
|
losses = {"loss_offset": loss * self.offset_loss_weight} |
|
return losses |
|
|