|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This file contains the loss functions for MaX-DeepLab models. |
|
|
|
Reference: |
|
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
""" |
|
from typing import Text, Dict, Tuple, List |
|
|
|
import tensorflow as tf |
|
from deeplab2 import common |
|
from deeplab2 import config_pb2 |
|
from deeplab2.model import utils |
|
from deeplab2.model.loss import base_loss |
|
from deeplab2.model.loss import matchers_ops |
|
|
|
|
|
|
|
_MATCHING_NEGATIVE_CONSTANT = -999.0 |
|
_MATCHING_POSITIVE_CONSTANT = 999.0 |
|
|
|
|
|
_SOFTMAX_MASKING_CONSTANT = -99999.0 |
|
|
|
_GT_KEY = 'gt_key' |
|
_PRED_KEY = 'pred_key' |
|
_WEIGHT_KEY = 'weight_key' |
|
|
|
|
|
def _generate_mask_slot_semantic_one_hot( |
|
matched_mask_slot_indices: tf.Tensor, |
|
mask_gt_semantic_map: tf.Tensor, |
|
num_mask_slots: int, |
|
thing_stuff_class_ids: List[int]): |
|
"""Generates the ground truth for transformer_class_logits. |
|
|
|
This function generates a pseudo ground truth that we will use to train the |
|
transformer class head logits. The input tensors, matched_mask_slot_indices |
|
and mask_gt_semantic_map, are obtained by (hungarian) matching the ground |
|
truth masks with the predicted masks. Note that this function generates the |
|
positive one hot encodings only, i.e., the void class is not included in the |
|
output tensor but will be generated outside the function. |
|
|
|
Args: |
|
matched_mask_slot_indices: An int32 tf.Tensor of shape [batch_size, |
|
num_ground_truth_masks] that encodes the matched mask slot id for each |
|
ground truth mask. |
|
mask_gt_semantic_map: An int32 tf.Tensor of shape [batch_size, |
|
num_ground_truth_masks] that encodes the semantic label for each ground |
|
truth mask. A padded mask (or void, or no object) will have the label -1. |
|
num_mask_slots: An integer, the number of mask slots for the MaX-DeepLab |
|
model. |
|
thing_stuff_class_ids: A list of integers of length [num_thing_classes + |
|
num_stuff_classes] that encodes the class IDs for all thing and stuff |
|
classes. It is a concatenation of the thing_class_ids list and the |
|
stuff_class_ids list. |
|
|
|
Returns: |
|
mask_slot_semantic_one_hot: An output tf.Tensor with shape [batch_size, |
|
num_mask_slots, num_thing_classes + num_stuff_classes]. |
|
""" |
|
semantic_map_shape = mask_gt_semantic_map.get_shape().as_list() |
|
batch_size = semantic_map_shape[0] |
|
num_ground_truth_masks = semantic_map_shape[-1] |
|
|
|
|
|
|
|
batch_indices = tf.expand_dims(tf.range(batch_size), axis=-1) |
|
batch_indices = tf.tile(batch_indices, [1, num_ground_truth_masks]) |
|
batch_indices = tf.reshape(batch_indices, [-1, 1]) |
|
matched_mask_slot_indices = tf.reshape(matched_mask_slot_indices, [-1, 1]) |
|
|
|
|
|
semantic_indices = tf.reshape(mask_gt_semantic_map, [-1, 1]) + 1 |
|
indices = tf.concat([batch_indices, |
|
matched_mask_slot_indices, |
|
semantic_indices], axis=-1) |
|
|
|
|
|
|
|
updates = tf.ones([batch_size * num_ground_truth_masks], dtype=tf.float32) |
|
mask_slot_semantic_one_hot = tf.scatter_nd( |
|
indices, updates, |
|
shape=[batch_size, num_mask_slots, max(thing_stuff_class_ids) + 2]) |
|
|
|
|
|
thing_stuff_tensor = tf.cast(thing_stuff_class_ids, tf.int32) |
|
|
|
|
|
mask_slot_semantic_one_hot = tf.gather(mask_slot_semantic_one_hot, |
|
thing_stuff_tensor + 1, axis=2) |
|
return mask_slot_semantic_one_hot |
|
|
|
|
|
def nonsquare_hungarian_matching( |
|
weights: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: |
|
"""Hungarian matching with arbitrary shape. |
|
|
|
The matchers_ops.hungarian_matching supports only squared weight matrices. |
|
This function generalizes the hungarian matching to nonsquare cases by padding |
|
the weights to a square and running the square version matching. The property |
|
of hungarian matching ensures that the solutions are equivalent for the padded |
|
square problem and the original nonsquare problem. |
|
|
|
Args: |
|
weights: A [batch, shape1, shape2] float32 tf.Tensor. |
|
|
|
Returns: |
|
square_permutation: A [batch, max(shape1, shape2), max(shape1, shape2)] |
|
float32 tf.Tensor that is the permutation matrix that achieves the minimum |
|
total weight. Note that a permutation matrix contains only value 0.0 and |
|
1.0, with each row and each column sums to 1.0. |
|
nonsquare_permutation: A [batch, shape1, shape2] float32 tf.Tensor. The |
|
nonsquare part of the permutation matrix. |
|
""" |
|
_, height, width = weights.get_shape().as_list() |
|
max_height_width = max(height, width) |
|
|
|
weights = tf.pad(weights, |
|
[[0, 0], |
|
[0, max_height_width - height], |
|
[0, max_height_width - width]], |
|
constant_values=_MATCHING_NEGATIVE_CONSTANT) |
|
square_permutation = matchers_ops.hungarian_matching(weights) |
|
|
|
square_permutation = tf.cast(square_permutation, tf.float32) |
|
return square_permutation, square_permutation[:, :height, :width] |
|
|
|
|
|
def _mask_similarity(gt_mask: tf.Tensor, pred_mask: tf.Tensor, |
|
metric: str = 'dice') -> tf.Tensor: |
|
"""Computes mask similarity between gt_masks and pred_masks. |
|
|
|
Args: |
|
gt_mask: A [batch, height * width, num_gt_masks] float32 tf.Tensor, that |
|
contains only value 0.0 and 1.0. Each 1.0 indicates that the pixel belongs |
|
to the ground truth mask. Note that panoptic segmentation enforces that |
|
ground truth masks do not overlap. |
|
pred_mask: A [batch, height * width, num_pred_masks] float32 tf.Tensor, that |
|
is positive. For each batch_id and pixel_id, the [num_pred_masks] vector |
|
encodes whether each pixel belongs to each mask. The sum of each vector is |
|
less than or equal to one. |
|
metric: A string, the mask similarity metric that we will compute. Supports |
|
'dice' (default), 'iou', 'intersection_over_ground_truth', and |
|
'intersection_over_prediction'. |
|
|
|
Returns: |
|
mask_similarity: A float32 [batch, num_gt_masks, num_pred_masks] tf.Tensor |
|
that contains the mask similarity between all ground truth masks and all |
|
predicted masks. |
|
|
|
Raises: |
|
ValueError: If the mask similarity metric is not one of 'dice', 'iou', |
|
'intersection_over_ground_truth', or 'intersection_over_prediction'. |
|
""" |
|
denominator_epsilon = 1e-5 |
|
intersection = tf.einsum('bpi,bpj->bij', gt_mask, pred_mask) |
|
if metric.lower() == 'dice': |
|
denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) + |
|
tf.reduce_sum(pred_mask, axis=1, keepdims=True)) / 2 |
|
elif metric.lower() == 'iou': |
|
denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) + |
|
tf.reduce_sum(pred_mask, axis=1, keepdims=True) - |
|
intersection) |
|
elif metric.lower() == 'intersection_over_ground_truth': |
|
denominator = tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) |
|
elif metric.lower() == 'intersection_over_prediction': |
|
denominator = tf.reduce_sum(pred_mask, axis=1, keepdims=True) |
|
else: |
|
raise ValueError('The mask similarity metric is not supported.') |
|
return intersection / (denominator + denominator_epsilon) |
|
|
|
|
|
class MaXDeepLabLoss(tf.keras.layers.Layer): |
|
"""This class contains code for MaX-DeepLab losses.""" |
|
|
|
def __init__(self, |
|
loss_options: config_pb2.LossOptions, |
|
ignore_label: int, |
|
thing_class_ids: Tuple[int], |
|
focal_loss_alpha: float = 0.75, |
|
instance_discrimination_temperature: float = 0.3): |
|
"""Initializes a MaX-DeepLab loss. |
|
|
|
This class supports PQ-style loss, mask id cross entropy loss, and instance |
|
discrimination loss, proposed in MaX-DeepLab. The PQ-style loss can be |
|
further decomposed in to a classification term and a mask dice term. |
|
|
|
Reference: |
|
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
|
|
Args: |
|
loss_options: Loss options as defined by config_pb2.LossOptions. |
|
ignore_label: An integer specifying the ignore label. |
|
thing_class_ids: A tuple of length [N] containing N thing indices. |
|
focal_loss_alpha: An optional float specifying the coefficient that |
|
weights between positive (matched) and negative (unmatched) masks in |
|
focal loss. The positives are weighted by alpha, while the negatives |
|
are weighted by (1. - alpha). Note that we do not use a focal loss |
|
gamma here, i.e., the gamma is set to zero which is equivalent to the |
|
normal cross-entropy loss, except for the alpha weighting. Default to |
|
0.75. |
|
instance_discrimination_temperature: An optional float specifying the |
|
temperature for the instance discrimination loss. |
|
""" |
|
super(MaXDeepLabLoss, self).__init__(name='MaXDeepLabLoss') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.loss_terms = [] |
|
|
|
|
|
self._pq_style_loss_weight = 0.0 |
|
if loss_options.HasField(common.PQ_STYLE_LOSS): |
|
self._pq_style_loss_weight = loss_options.pq_style_loss.weight |
|
self.loss_terms.append(common.PQ_STYLE_LOSS_CLASS_TERM) |
|
self.loss_terms.append(common.PQ_STYLE_LOSS_MASK_DICE_TERM) |
|
|
|
|
|
self._mask_id_cross_entropy_loss_weight = 0.0 |
|
if loss_options.HasField(common.MASK_ID_CROSS_ENTROPY_LOSS): |
|
self._mask_id_cross_entropy_loss_weight = ( |
|
loss_options.mask_id_cross_entropy_loss.weight) |
|
self.loss_terms.append(common.MASK_ID_CROSS_ENTROPY_LOSS) |
|
|
|
|
|
self._instance_discrimination_loss_weight = 0.0 |
|
if loss_options.HasField(common.INSTANCE_DISCRIMINATION_LOSS): |
|
self._instance_discrimination_loss_weight = ( |
|
loss_options.instance_discrimination_loss.weight) |
|
self.loss_terms.append(common.INSTANCE_DISCRIMINATION_LOSS) |
|
|
|
self._ignore_label = ignore_label |
|
self._thing_class_ids = list(thing_class_ids) |
|
self._focal_loss_alpha = focal_loss_alpha |
|
self._instance_discrimination_temperature = ( |
|
instance_discrimination_temperature) |
|
|
|
|
|
self._pq_style_loss_class_term = base_loss.FocalCrossEntropyLoss( |
|
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, |
|
|
|
|
|
num_classes=None, ignore_label=None, |
|
focal_loss_alpha=focal_loss_alpha, |
|
focal_loss_gamma=0.0, background_channel_index=-1, |
|
dynamic_weight=True) |
|
self._pq_style_loss_mask_dice_term = base_loss.MaskDiceLoss( |
|
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, |
|
prediction_activation='softmax') |
|
self._mask_id_cross_entropy_loss = base_loss.TopKCrossEntropyLoss( |
|
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, |
|
|
|
|
|
num_classes=None, ignore_label=None, |
|
top_k_percent_pixels=1.0, dynamic_weight=True) |
|
self._instance_discrimination_loss = base_loss.TopKCrossEntropyLoss( |
|
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, |
|
|
|
|
|
num_classes=None, ignore_label=None, |
|
top_k_percent_pixels=1.0, dynamic_weight=True) |
|
|
|
def build(self, |
|
input_shapes: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]): |
|
"""Extracts useful constants that depend on the input shapes.""" |
|
y_true_shapes = input_shapes[0] |
|
self._max_thing_id = int(y_true_shapes[common.GT_THING_ID_CLASS_KEY][-1]) |
|
y_pred_shapes = input_shapes[1] |
|
transformer_class_logits_shape = y_pred_shapes[ |
|
common.PRED_TRANSFORMER_CLASS_LOGITS_KEY] |
|
self._num_mask_slots = int(transformer_class_logits_shape[1]) |
|
|
|
|
|
|
|
self._num_thing_stuff_classes = int(transformer_class_logits_shape[2]) - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
self._mask_dice_term_modifier = ( |
|
self._focal_loss_alpha / self._num_mask_slots) |
|
|
|
self._stuff_class_ids = utils.get_stuff_class_ids( |
|
self._num_thing_stuff_classes, |
|
self._thing_class_ids, |
|
self._ignore_label) |
|
self._num_stuff_classes = len(self._stuff_class_ids) |
|
self._thing_stuff_class_ids = self._thing_class_ids + self._stuff_class_ids |
|
self._pixel_gt_num_mask_id = self._max_thing_id + self._num_stuff_classes |
|
|
|
def _pre_process_ground_truth( |
|
self, y_true: Dict[Text, tf.Tensor], output_height: int, output_width: int |
|
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, |
|
tf.Tensor]: |
|
"""Pre-processes the ground truth before we compute the losses. |
|
|
|
This function generates tensors that do not depend on the prediction of the |
|
model, but are useful to the calculation of the losses. The function mainly |
|
downsamples the pixel space ground truth to the model output resolution, and |
|
combines (or concatenates) the thing masks and the stuff masks. The output |
|
shape pixel_gt_num_mask_id = max_thing_id + num_stuff_classes, which means |
|
the output masks contain both thing masks and stuff masks. |
|
|
|
Args: |
|
y_true: A dict of tensors providing ground-truth information, containing |
|
- common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the |
|
semantic label map. |
|
- common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32 tf.Tensor. |
|
It assigns each non-crowd thing instance a unique mask-ID label, |
|
starting from 0. Unassigned pixels are set to -1. |
|
- common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32 tf.Tensor. |
|
It contains semantic ID of each instance assigned to thing_id_mask. The |
|
remaining (max_thing_id - num_things) elements are set to -1. |
|
output_height: An integer, the height of the model output. |
|
output_width: An integer, the width of the model output. |
|
|
|
Returns: |
|
pixel_gt_thing_mask: A [batch, output_height * output_width] float32 |
|
tensor, with value 0.0 and 1.0 only, indicating whether a pixel belongs |
|
to a 'thing' class. |
|
pixel_gt_non_void_mask: A [batch, output_height * output_width] float32 |
|
tensor, with value 0.0 and 1.0 only, indicating if a pixel does not |
|
belong to the void class. |
|
pixel_gt_mask_id_one_hot: A [batch, output_height * output_width, |
|
pixel_gt_num_mask_id] float32 tensor, with value 0.0 and 1.0 only, |
|
indicating the mask id each pixel belongs to. |
|
mask_gt_semantic_map: A [batch, pixel_gt_num_mask_id] int32 tensor, the |
|
semantic class of each ground truth mask. |
|
mask_gt_non_void_mask: A [batch, pixel_gt_num_mask_id] int32 tensor, with |
|
value 0.0 and 1.0 only, indicating if the ground truth mask is a valid |
|
mask, not a padded mask. The masks are padded because TPU does not |
|
support dynamic shapes except in the batch axis. We pad all ground truth |
|
thing masks to a large enough constant max_thing_id. Similarly, stuff |
|
classes that do not present in the current image will be set to a void |
|
mask too. |
|
mask_gt_semantic_one_hot: A [batch, pixel_gt_num_mask_id, |
|
num_thing_stuff_classes] float32 tensor, with value 0.0 and 1.0 only, |
|
containing the one hot encodings of the ground truth mask classes. The |
|
last dimension contains concatenated thing classes and stuff classes, |
|
which is different from the dataset class IDs in mask_gt_semantic_map. |
|
mask_gt_area: A [batch, pixel_gt_num_mask_id] float32 tensor, the area of |
|
each ground truth mask. Padded masks have an area of 0.0. |
|
""" |
|
|
|
|
|
|
|
one_hot_depth = max(self._thing_stuff_class_ids) + 1 |
|
batch_size = y_true[common.GT_SEMANTIC_KEY].get_shape().as_list()[0] |
|
|
|
|
|
|
|
pixel_gt_semantic_map = utils.strided_downsample( |
|
y_true[common.GT_SEMANTIC_KEY], |
|
target_size=[output_height, output_width]) |
|
pixel_gt_semantic_map = tf.reshape( |
|
pixel_gt_semantic_map, |
|
[batch_size, output_height * output_width]) |
|
|
|
|
|
pixel_gt_non_void_mask = tf.cast( |
|
tf.not_equal(pixel_gt_semantic_map, self._ignore_label), tf.float32) |
|
pixel_gt_non_void_mask = tf.ensure_shape( |
|
pixel_gt_non_void_mask, |
|
[batch_size, output_height * output_width]) |
|
|
|
|
|
|
|
pixel_gt_semantic_one_hot = tf.one_hot(pixel_gt_semantic_map, one_hot_depth) |
|
|
|
|
|
pixel_gt_stuff_id_one_hot = tf.gather(pixel_gt_semantic_one_hot, |
|
self._stuff_class_ids, axis=-1) |
|
pixel_gt_stuff_id_one_hot = tf.ensure_shape( |
|
pixel_gt_stuff_id_one_hot, |
|
[batch_size, output_height * output_width, self._num_stuff_classes]) |
|
|
|
|
|
pixel_gt_thing_id_map = utils.strided_downsample( |
|
y_true[common.GT_THING_ID_MASK_KEY], |
|
target_size=[output_height, output_width]) |
|
pixel_gt_thing_id_map = tf.reshape( |
|
pixel_gt_thing_id_map, shape=[batch_size, output_height * output_width]) |
|
|
|
|
|
pixel_gt_thing_mask = tf.cast( |
|
tf.not_equal(pixel_gt_thing_id_map, -1), tf.float32) |
|
pixel_gt_thing_id_one_hot = tf.one_hot(pixel_gt_thing_id_map, |
|
self._max_thing_id) |
|
|
|
|
|
pixel_gt_mask_id_one_hot = tf.concat([pixel_gt_thing_id_one_hot, |
|
pixel_gt_stuff_id_one_hot], axis=-1) |
|
pixel_gt_mask_id_one_hot = tf.ensure_shape( |
|
pixel_gt_mask_id_one_hot, |
|
[batch_size, output_height * output_width, self._pixel_gt_num_mask_id]) |
|
|
|
|
|
mask_gt_area = tf.expand_dims( |
|
tf.reduce_sum(pixel_gt_mask_id_one_hot, axis=1), axis=-1) |
|
|
|
|
|
|
|
|
|
mask_gt_area_mask = tf.reshape(mask_gt_area > 0.5, |
|
[batch_size, self._pixel_gt_num_mask_id]) |
|
|
|
|
|
thing_id_gt_semantic_map = tf.reshape( |
|
tf.cast(y_true[common.GT_THING_ID_CLASS_KEY], tf.int32), |
|
[batch_size, self._max_thing_id]) |
|
|
|
stuff_id_gt_semantic_map = tf.tile( |
|
tf.reshape( |
|
tf.cast(self._stuff_class_ids, tf.int32), |
|
[1, self._num_stuff_classes]), [batch_size, 1]) |
|
mask_gt_semantic_map = tf.concat( |
|
[thing_id_gt_semantic_map, stuff_id_gt_semantic_map], axis=-1) |
|
|
|
|
|
|
|
mask_gt_semantic_map = ( |
|
(mask_gt_semantic_map + 1) * tf.cast(mask_gt_area_mask, tf.int32) - 1) |
|
|
|
mask_gt_semantic_one_hot = tf.one_hot(mask_gt_semantic_map, one_hot_depth) |
|
mask_gt_semantic_one_hot = tf.gather( |
|
mask_gt_semantic_one_hot, self._thing_stuff_class_ids, axis=-1) |
|
|
|
|
|
|
|
|
|
mask_gt_non_void_mask = tf.cast(mask_gt_semantic_map > -1, tf.float32) |
|
mask_gt_non_void_mask = tf.ensure_shape( |
|
mask_gt_non_void_mask, [batch_size, self._pixel_gt_num_mask_id]) |
|
|
|
return (pixel_gt_thing_mask, pixel_gt_non_void_mask, |
|
pixel_gt_mask_id_one_hot, mask_gt_semantic_map, |
|
mask_gt_non_void_mask, mask_gt_semantic_one_hot, mask_gt_area) |
|
|
|
def call( |
|
self, inputs: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]] |
|
) -> Dict[Text, tf.Tensor]: |
|
"""Computes the MaX-DeepLab losses. |
|
|
|
Args: |
|
inputs: A tuple of two dicts (y_true, y_pred): |
|
- y_true: A dict of tensors providing ground-truth information, containing |
|
- common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the |
|
semantic label map. |
|
- common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32 |
|
tf.Tensor. It assigns each non-crowd thing instance a unique mask-ID |
|
label, starting from 0. Unassigned pixels are set to -1. |
|
- common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32 |
|
tf.Tensor. It contains semantic ID of each instance assigned to |
|
thing_id_mask. The remaining (max_thing_id - num_things) elements are |
|
set to -1. |
|
- y_pred: A dict of tensors providing predictions. |
|
- common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY: A [batch_size, |
|
output_height, output_width, channels] float32 tensor. |
|
- common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY: A [batch_size, |
|
output_height, output_width, num_mask_slots] float32 tensor, the |
|
logits that a pixel belongs to a mask slot. |
|
- common.PRED_TRANSFORMER_CLASS_LOGITS_KEY: A [batch_size, |
|
num_mask_slots, num_thing_stuff_classes + 1] float32 tensor, the |
|
logits that a mask belongs to a semantic class (including thing, |
|
stuff, and void) |
|
|
|
Returns: |
|
The loss as a dict of tf.Tensor, optionally containing the following: |
|
- common.PQ_STYLE_LOSS_CLASS_TERM: [batch]. |
|
- common.PQ_STYLE_LOSS_MASK_DICE_TERM: [batch]. |
|
- common.MASK_ID_CROSS_ENTROPY_LOSS: [batch]. |
|
- common.INSTANCE_DISCRIMINATION_LOSS: [batch]. |
|
""" |
|
y_true, y_pred = inputs |
|
resulting_dict = {} |
|
|
|
pixel_feature = y_pred[common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY] |
|
batch_size, output_height, output_width, _ = ( |
|
pixel_feature.get_shape().as_list()) |
|
|
|
|
|
(pixel_gt_thing_mask, pixel_gt_non_void_mask, pixel_gt_mask_id_one_hot, |
|
mask_gt_semantic_map, mask_gt_non_void_mask, mask_gt_semantic_one_hot, |
|
mask_gt_area) = self._pre_process_ground_truth(y_true, |
|
output_height, output_width) |
|
pixel_gt_non_void_mask_expanded = tf.expand_dims( |
|
pixel_gt_non_void_mask, axis=-1) |
|
|
|
|
|
pixel_feature = tf.reshape( |
|
pixel_feature, [batch_size, output_height * output_width, -1]) |
|
mask_average_feature = tf.einsum( |
|
'bpd,bpi->bid', |
|
pixel_feature, |
|
pixel_gt_mask_id_one_hot) / tf.maximum(mask_gt_area, 1.0) |
|
|
|
|
|
mask_average_feature = tf.math.l2_normalize(mask_average_feature, axis=-1) |
|
|
|
|
|
|
|
instance_discrimination_similarity = tf.einsum( |
|
'bpd,bid->bpi', pixel_feature, mask_average_feature) |
|
instance_discrimination_similarity /= ( |
|
self._instance_discrimination_temperature) |
|
mask_gt_non_void_mask_expanded_1 = tf.expand_dims( |
|
mask_gt_non_void_mask, axis=1) |
|
|
|
|
|
instance_discrimination_similarity = ( |
|
mask_gt_non_void_mask_expanded_1 * instance_discrimination_similarity + |
|
(1.0 - mask_gt_non_void_mask_expanded_1) * _SOFTMAX_MASKING_CONSTANT) |
|
|
|
|
|
if self._instance_discrimination_loss_weight > 0.0: |
|
resulting_dict[common.INSTANCE_DISCRIMINATION_LOSS] = ( |
|
self._instance_discrimination_loss( |
|
{_GT_KEY: pixel_gt_mask_id_one_hot}, |
|
{_PRED_KEY: instance_discrimination_similarity, |
|
_WEIGHT_KEY: pixel_gt_thing_mask}) * |
|
self._instance_discrimination_loss_weight) |
|
|
|
|
|
pixel_space_mask_logits = y_pred[common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY] |
|
pixel_space_mask_logits = tf.reshape( |
|
pixel_space_mask_logits, |
|
[batch_size, output_height * output_width, self._num_mask_slots]) |
|
pixel_space_mask_probs = tf.nn.softmax(pixel_space_mask_logits, axis=-1) |
|
|
|
|
|
|
|
mask_similarity = _mask_similarity( |
|
pixel_gt_mask_id_one_hot, |
|
pixel_space_mask_probs * pixel_gt_non_void_mask_expanded, |
|
metric='dice') |
|
|
|
|
|
|
|
|
|
transformer_class_logits = y_pred[common.PRED_TRANSFORMER_CLASS_LOGITS_KEY] |
|
transformer_class_probs = tf.nn.softmax( |
|
transformer_class_logits, axis=-1)[:, :, :-1] |
|
class_similarity = tf.einsum( |
|
'bij,bkj->bik', mask_gt_semantic_one_hot, transformer_class_probs) |
|
|
|
|
|
|
|
|
|
hungarian_weights = - mask_similarity * class_similarity |
|
mask_gt_non_void_mask_expanded_2 = tf.expand_dims( |
|
mask_gt_non_void_mask, axis=2) |
|
|
|
|
|
|
|
if self._num_mask_slots >= self._pixel_gt_num_mask_id: |
|
|
|
|
|
|
|
|
|
|
|
|
|
hungarian_weights = ( |
|
hungarian_weights * mask_gt_non_void_mask_expanded_2 + |
|
(1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_NEGATIVE_CONSTANT) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hungarian_weights = ( |
|
hungarian_weights * mask_gt_non_void_mask_expanded_2 + |
|
(1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_POSITIVE_CONSTANT) |
|
|
|
|
|
full_permutation, nonsquare_permutation = ( |
|
nonsquare_hungarian_matching(hungarian_weights)) |
|
|
|
|
|
|
|
matched_permutation = ( |
|
nonsquare_permutation * mask_gt_non_void_mask_expanded_2) |
|
|
|
|
|
matched_mask_dice = tf.reduce_max( |
|
mask_similarity * matched_permutation, axis=-2) |
|
matched_mask_dice = tf.stop_gradient(matched_mask_dice) |
|
|
|
|
|
|
|
|
|
matched_class_prob = tf.reduce_max( |
|
class_similarity * matched_permutation, axis=-1) |
|
matched_class_prob = tf.stop_gradient(matched_class_prob) |
|
|
|
|
|
matched_mask_slot_indices = tf.math.argmax( |
|
nonsquare_permutation, axis=-1, output_type=tf.dtypes.int32) |
|
|
|
full_num_mask_slots = full_permutation.get_shape().as_list()[-1] |
|
|
|
|
|
full_pixel_space_mask_logits = tf.pad( |
|
pixel_space_mask_logits, |
|
[[0, 0], [0, 0], [0, full_num_mask_slots - self._num_mask_slots]], |
|
constant_values=_SOFTMAX_MASKING_CONSTANT) |
|
|
|
|
|
|
|
permuted_full_pixel_space_mask_logits = tf.einsum( |
|
'bpi,bji->bpj', full_pixel_space_mask_logits, full_permutation) |
|
|
|
|
|
full_matched_class_prob = tf.pad( |
|
matched_class_prob, |
|
[[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]]) |
|
|
|
mask_dice_term_loss_weight = tf.pad( |
|
mask_gt_non_void_mask, |
|
[[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]]) |
|
|
|
|
|
|
|
|
|
|
|
mask_dice_term_loss_weight *= tf.maximum(full_matched_class_prob, 1e-5) |
|
|
|
|
|
full_pixel_gt_mask_id_one_hot = tf.pad( |
|
pixel_gt_mask_id_one_hot, |
|
[[0, 0], [0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]]) |
|
|
|
if self._pq_style_loss_weight > 0.0: |
|
|
|
|
|
resulting_dict[common.PQ_STYLE_LOSS_MASK_DICE_TERM] = ( |
|
self._pq_style_loss_mask_dice_term( |
|
{_GT_KEY: full_pixel_gt_mask_id_one_hot}, |
|
{_PRED_KEY: permuted_full_pixel_space_mask_logits, |
|
_WEIGHT_KEY: mask_dice_term_loss_weight}) * |
|
(self._pq_style_loss_weight * self._mask_dice_term_modifier)) |
|
|
|
|
|
|
|
if self._mask_id_cross_entropy_loss_weight > 0.0: |
|
resulting_dict[common.MASK_ID_CROSS_ENTROPY_LOSS] = ( |
|
self._mask_id_cross_entropy_loss( |
|
{_GT_KEY: full_pixel_gt_mask_id_one_hot}, |
|
{_PRED_KEY: permuted_full_pixel_space_mask_logits, |
|
_WEIGHT_KEY: pixel_gt_non_void_mask}) * |
|
self._mask_id_cross_entropy_loss_weight) |
|
|
|
|
|
mask_slot_semantic_one_hot = _generate_mask_slot_semantic_one_hot( |
|
matched_mask_slot_indices, mask_gt_semantic_map, |
|
self._num_mask_slots, self._thing_stuff_class_ids) |
|
|
|
|
|
mask_slot_positive_mask = tf.cast(tf.equal(tf.reduce_max( |
|
mask_slot_semantic_one_hot, axis=-1), 1.0), tf.float32) |
|
mask_slot_negative_mask = 1.0 - mask_slot_positive_mask |
|
|
|
|
|
|
|
mask_void_ratio = tf.stop_gradient(_mask_similarity( |
|
1.0 - pixel_gt_non_void_mask_expanded, |
|
pixel_space_mask_probs, |
|
'intersection_over_prediction')) |
|
mask_void_ratio = tf.squeeze(mask_void_ratio, axis=1) |
|
|
|
|
|
|
|
|
|
transformer_class_loss_weight = ( |
|
mask_slot_positive_mask * tf.maximum(matched_mask_dice, 1e-5) + |
|
mask_slot_negative_mask * tf.maximum(mask_void_ratio, 1e-5)) |
|
|
|
|
|
|
|
transformer_class_one_hot = tf.concat( |
|
[mask_slot_semantic_one_hot, |
|
tf.expand_dims(mask_slot_negative_mask, axis=-1)], axis=-1) |
|
|
|
|
|
if self._pq_style_loss_weight > 0.0: |
|
resulting_dict[common.PQ_STYLE_LOSS_CLASS_TERM] = ( |
|
self._pq_style_loss_class_term( |
|
{_GT_KEY: transformer_class_one_hot}, |
|
{_PRED_KEY: transformer_class_logits, |
|
_WEIGHT_KEY: transformer_class_loss_weight}) * |
|
self._pq_style_loss_weight) |
|
|
|
return resulting_dict |
|
|