import torch import torch.nn as nn class TaikoLoss(nn.Module): def __init__( self, reduction="mean", nps_penalty_weight_alpha=0.3, nps_penalty_weight_beta=1.0, ): super().__init__() self.mse_loss = nn.MSELoss(reduction="none") self.reduction = reduction self.nps_penalty_weight_alpha = nps_penalty_weight_alpha self.nps_penalty_weight_beta = nps_penalty_weight_beta def forward(self, outputs, batch): """ Calculates the MSE loss for energy-based predictions, with a two-level penalty based on sliding NPS values. - A heavier penalty if sliding_nps is 0. - A continuous penalty if sliding_nps > 0. Args: outputs (dict): Model output, containing 'presence' tensor. outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies. batch (dict): Batch data from collate_fn, containing true labels, lengths, and sliding_nps_labels. batch['sliding_nps_labels'] shape: (B, T) Returns: torch.Tensor: The calculated loss. """ pred_energies = outputs["presence"] # (B, T, 3) true_don = batch["don_labels"] # (B, T) true_ka = batch["ka_labels"] # (B, T) true_drumroll = batch["drumroll_labels"] # (B, T) true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2).to( pred_energies.device ) # (B, T, 3) B, T, _ = pred_energies.shape # Create a mask based on batch['lengths'] to ignore padded parts of sequences # batch['lengths'] gives the actual length of each sequence in the batch # mask shape: (B, T) mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[ "lengths" ].to(pred_energies.device).unsqueeze(1) # Expand mask to (B, T, 1) to broadcast across the 3 energy channels mask_3d = mask_2d.unsqueeze(2) # (B, T, 1) # Calculate element-wise MSE loss mse_loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3) # Calculate two-level Sliding NPS penalty sliding_nps = batch["sliding_nps_labels"].to(pred_energies.device) # (B, T) penalty_coefficients = torch.zeros_like(sliding_nps) # (B, T) is_zero_nps = sliding_nps == 0.0 is_not_zero_nps = ~is_zero_nps # Apply heavy penalty where sliding_nps is 0 penalty_coefficients[is_zero_nps] = self.nps_penalty_weight_beta # Apply continuous penalty where sliding_nps > 0 penalty_coefficients[is_not_zero_nps] = self.nps_penalty_weight_alpha * ( 1 - sliding_nps[is_not_zero_nps] ) # Apply penalty factor to the MSE loss loss_elementwise = mse_loss_elementwise * ( 1 + penalty_coefficients.unsqueeze(2) ) # Apply the mask to the combined loss masked_loss = loss_elementwise * mask_3d if self.reduction == "mean": # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements total_loss = masked_loss.sum() num_valid_elements = mask_3d.sum() # Total number of unmasked float values if num_valid_elements > 0: return total_loss / num_valid_elements else: # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0) return torch.tensor( 0.0, device=pred_energies.device, requires_grad=True ) elif self.reduction == "sum": return masked_loss.sum() else: # 'none' or any other case return masked_loss