Spaces:
Runtime error
Runtime error
""" Binary Cross Entropy w/ a few extras | |
Hacked together by / Copyright 2021 Ross Wightman | |
""" | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BinaryCrossEntropy(nn.Module): | |
""" BCE with optional one-hot from dense targets, label smoothing, thresholding | |
NOTE for experiments comparing CE to BCE /w label smoothing, may remove | |
""" | |
def __init__( | |
self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, | |
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): | |
super(BinaryCrossEntropy, self).__init__() | |
assert 0. <= smoothing < 1.0 | |
self.smoothing = smoothing | |
self.target_threshold = target_threshold | |
self.reduction = reduction | |
self.register_buffer('weight', weight) | |
self.register_buffer('pos_weight', pos_weight) | |
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
assert x.shape[0] == target.shape[0] | |
if target.shape != x.shape: | |
# NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse | |
num_classes = x.shape[-1] | |
# FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ | |
off_value = self.smoothing / num_classes | |
on_value = 1. - self.smoothing + off_value | |
target = target.long().view(-1, 1) | |
target = torch.full( | |
(target.size()[0], num_classes), | |
off_value, | |
device=x.device, dtype=x.dtype).scatter_(1, target, on_value) | |
if self.target_threshold is not None: | |
# Make target 0, or 1 if threshold set | |
target = target.gt(self.target_threshold).to(dtype=target.dtype) | |
return F.binary_cross_entropy_with_logits( | |
x, target, | |
self.weight, | |
pos_weight=self.pos_weight, | |
reduction=self.reduction) | |