Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class WeightedCrossEntropyLoss(nn.Module): | |
def __init__(self, class_weights=None): | |
""" | |
Инициализация класса для кросс-энтропийной потери с возможностью взвешивания классов. | |
:param class_weights: Вектор весов для классов (опционально) | |
""" | |
super(WeightedCrossEntropyLoss, self).__init__() | |
self.class_weights = class_weights | |
def forward(self, y_pred, y_true): | |
""" | |
Вычисление кросс-энтропийной потери с (или без) взвешиванием классов. | |
:param y_true: Точные метки классов (вектор или одна метка) | |
:param y_pred: Вероятностный вектор предсказаний | |
:return: Значение потери | |
""" | |
y_true = y_true.to(torch.long) # Приводим метки к типу Long | |
y_pred = y_pred.to(torch.float32) # Приводим предсказания к типу Float32 | |
if self.class_weights is not None: | |
class_weights = torch.tensor(self.class_weights).float().to(y_true.device) | |
loss = F.cross_entropy(y_pred, y_true, weight=class_weights) | |
else: | |
loss = F.cross_entropy(y_pred, y_true) | |
return loss |