File size: 2,843 Bytes
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn


class TaikoEnergyLoss(nn.Module):
    def __init__(self, reduction="mean"):
        super().__init__()
        # Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
        self.mse_loss = nn.MSELoss(reduction="none")
        self.reduction = reduction

    def forward(self, outputs, batch):
        """
        Calculates the MSE loss for energy-based predictions.

        Args:
            outputs (dict): Model output, containing 'presence' tensor.
                            outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
            batch (dict): Batch data from collate_fn, containing true labels and lengths.
                          batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
                          batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
        Returns:
            torch.Tensor: The calculated loss.
        """
        pred_energies = outputs["presence"]  # (B, T, 3)

        true_don = batch["don_labels"]  # (B, T)
        true_ka = batch["ka_labels"]  # (B, T)
        true_drumroll = batch["drumroll_labels"]  # (B, T)

        # Stack true labels to match the structure of pred_energies (B, T, 3)
        true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)

        B, T, _ = pred_energies.shape

        # Create a mask based on batch['lengths'] to ignore padded parts of sequences
        # batch['lengths'] gives the actual length of each sequence in the batch
        # mask shape: (B, T)
        mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
            "lengths"
        ].unsqueeze(1)
        # Expand mask to (B, T, 1) to broadcast across the 3 energy channels
        mask_3d = mask_2d.unsqueeze(2)

        # Calculate element-wise MSE loss
        loss_elementwise = self.mse_loss(pred_energies, true_energies)  # (B, T, 3)

        # Apply the mask to the loss
        masked_loss = loss_elementwise * mask_3d

        if self.reduction == "mean":
            # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
            total_loss = masked_loss.sum()
            num_valid_elements = mask_3d.sum()  # Total number of unmasked float values
            if num_valid_elements > 0:
                return total_loss / num_valid_elements
            else:
                # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
                return torch.tensor(
                    0.0, device=pred_energies.device, requires_grad=True
                )
        elif self.reduction == "sum":
            return masked_loss.sum()
        else:  # 'none' or any other case
            return masked_loss