File size: 3,833 Bytes
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn


class TaikoLoss(nn.Module):
    def __init__(
        self,
        reduction="mean",
        nps_penalty_weight_alpha=0.3,
        nps_penalty_weight_beta=1.0,
    ):
        super().__init__()
        self.mse_loss = nn.MSELoss(reduction="none")
        self.reduction = reduction
        self.nps_penalty_weight_alpha = nps_penalty_weight_alpha
        self.nps_penalty_weight_beta = nps_penalty_weight_beta

    def forward(self, outputs, batch):
        """
        Calculates the MSE loss for energy-based predictions, with a two-level penalty
        based on sliding NPS values.
        - A heavier penalty if sliding_nps is 0.
        - A continuous penalty if sliding_nps > 0.

        Args:
            outputs (dict): Model output, containing 'presence' tensor.
                            outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
            batch (dict): Batch data from collate_fn, containing true labels, lengths,
                          and sliding_nps_labels.
                          batch['sliding_nps_labels'] shape: (B, T)
        Returns:
            torch.Tensor: The calculated loss.
        """
        pred_energies = outputs["presence"]  # (B, T, 3)
        true_don = batch["don_labels"]  # (B, T)
        true_ka = batch["ka_labels"]  # (B, T)
        true_drumroll = batch["drumroll_labels"]  # (B, T)
        true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2).to(
            pred_energies.device
        )  # (B, T, 3)

        B, T, _ = pred_energies.shape

        # Create a mask based on batch['lengths'] to ignore padded parts of sequences
        # batch['lengths'] gives the actual length of each sequence in the batch
        # mask shape: (B, T)
        mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
            "lengths"
        ].to(pred_energies.device).unsqueeze(1)
        # Expand mask to (B, T, 1) to broadcast across the 3 energy channels
        mask_3d = mask_2d.unsqueeze(2)  # (B, T, 1)

        # Calculate element-wise MSE loss
        mse_loss_elementwise = self.mse_loss(pred_energies, true_energies)  # (B, T, 3)

        # Calculate two-level Sliding NPS penalty
        sliding_nps = batch["sliding_nps_labels"].to(pred_energies.device)  # (B, T)

        penalty_coefficients = torch.zeros_like(sliding_nps)  # (B, T)

        is_zero_nps = sliding_nps == 0.0
        is_not_zero_nps = ~is_zero_nps

        # Apply heavy penalty where sliding_nps is 0
        penalty_coefficients[is_zero_nps] = self.nps_penalty_weight_beta

        # Apply continuous penalty where sliding_nps > 0
        penalty_coefficients[is_not_zero_nps] = self.nps_penalty_weight_alpha * (
            1 - sliding_nps[is_not_zero_nps]
        )

        # Apply penalty factor to the MSE loss
        loss_elementwise = mse_loss_elementwise * (
            1 + penalty_coefficients.unsqueeze(2)
        )

        # Apply the mask to the combined loss
        masked_loss = loss_elementwise * mask_3d

        if self.reduction == "mean":
            # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
            total_loss = masked_loss.sum()
            num_valid_elements = mask_3d.sum()  # Total number of unmasked float values
            if num_valid_elements > 0:
                return total_loss / num_valid_elements
            else:
                # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
                return torch.tensor(
                    0.0, device=pred_energies.device, requires_grad=True
                )
        elif self.reduction == "sum":
            return masked_loss.sum()
        else:  # 'none' or any other case
            return masked_loss