|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class LabelSmoothingCrossEntropy(nn.Module): |
|
""" |
|
NLL loss with label smoothing. |
|
""" |
|
def __init__(self, smoothing=0.1): |
|
""" |
|
Constructor for the LabelSmoothing module. |
|
:param smoothing: label smoothing factor |
|
""" |
|
super(LabelSmoothingCrossEntropy, self).__init__() |
|
assert smoothing < 1.0 |
|
self.smoothing = smoothing |
|
self.confidence = 1. - smoothing |
|
|
|
def forward(self, x, target): |
|
logprobs = F.log_softmax(x, dim=-1) |
|
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) |
|
nll_loss = nll_loss.squeeze(1) |
|
smooth_loss = -logprobs.mean(dim=-1) |
|
loss = self.confidence * nll_loss + self.smoothing * smooth_loss |
|
return loss.mean() |
|
|
|
|
|
class SoftTargetCrossEntropy(nn.Module): |
|
|
|
def __init__(self): |
|
super(SoftTargetCrossEntropy, self).__init__() |
|
|
|
def forward(self, x, target): |
|
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) |
|
return loss.mean() |
|
|