File size: 1,512 Bytes
9ace58a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn.functional as F


def bbox_size_loss(pred_size, gt_size):
    """
    Bounding box size loss. Only compute loss where there is a bounding box.
    """
    gt_size_mask = (gt_size > 0).float()
    return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum')  / (gt_size_mask.sum() + 1e-5))


def focal_loss(pred, gt, weights=None, valid_mask=None):
    """
    Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
    pred  (batch x c x h x w)
    gt    (batch x c x h x w)
    """
    eps = 1e-5
    beta = 4
    alpha = 2

    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()

    pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
    neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds

    if weights is not None:
        pos_loss = pos_loss*weights
        #neg_loss = neg_loss*weights

    if valid_mask is not None:
        pos_loss = pos_loss*valid_mask
        neg_loss = neg_loss*valid_mask

    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    num_pos  = pos_inds.float().sum()
    if num_pos == 0:
        loss = -neg_loss
    else:
        loss = -(pos_loss + neg_loss) / num_pos
    return loss


def mse_loss(pred, gt, weights=None, valid_mask=None):
    """
    Mean squared error loss. 
    """
    if valid_mask is None:
        op = ((gt-pred)**2).mean()
    else:
        op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum()
    return op