File size: 1,482 Bytes
960b1a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F

class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights=None):
        """
        Инициализация класса для кросс-энтропийной потери с возможностью взвешивания классов.

        :param class_weights: Вектор весов для классов (опционально)
        """
        super(WeightedCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights

    def forward(self, y_pred, y_true):
        """
        Вычисление кросс-энтропийной потери с (или без) взвешиванием классов.

        :param y_true: Точные метки классов (вектор или одна метка)
        :param y_pred: Вероятностный вектор предсказаний
        :return: Значение потери
        """

        y_true = y_true.to(torch.long)  # Приводим метки к типу Long
        y_pred = y_pred.to(torch.float32)  # Приводим предсказания к типу Float32

        if self.class_weights is not None:
            class_weights = torch.tensor(self.class_weights).float().to(y_true.device)
            loss = F.cross_entropy(y_pred, y_true, weight=class_weights)
        else:
            loss = F.cross_entropy(y_pred, y_true)
        
        return loss