Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
class TaikoLoss(nn.Module): | |
def __init__( | |
self, | |
reduction="mean", | |
nps_penalty_weight_alpha=0.3, | |
nps_penalty_weight_beta=1.0, | |
): | |
super().__init__() | |
self.mse_loss = nn.MSELoss(reduction="none") | |
self.reduction = reduction | |
self.nps_penalty_weight_alpha = nps_penalty_weight_alpha | |
self.nps_penalty_weight_beta = nps_penalty_weight_beta | |
def forward(self, outputs, batch): | |
""" | |
Calculates the MSE loss for energy-based predictions, with a two-level penalty | |
based on sliding NPS values. | |
- A heavier penalty if sliding_nps is 0. | |
- A continuous penalty if sliding_nps > 0. | |
Args: | |
outputs (dict): Model output, containing 'presence' tensor. | |
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies. | |
batch (dict): Batch data from collate_fn, containing true labels, lengths, | |
and sliding_nps_labels. | |
batch['sliding_nps_labels'] shape: (B, T) | |
Returns: | |
torch.Tensor: The calculated loss. | |
""" | |
pred_energies = outputs["presence"] # (B, T, 3) | |
true_don = batch["don_labels"] # (B, T) | |
true_ka = batch["ka_labels"] # (B, T) | |
true_drumroll = batch["drumroll_labels"] # (B, T) | |
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2).to( | |
pred_energies.device | |
) # (B, T, 3) | |
B, T, _ = pred_energies.shape | |
# Create a mask based on batch['lengths'] to ignore padded parts of sequences | |
# batch['lengths'] gives the actual length of each sequence in the batch | |
# mask shape: (B, T) | |
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[ | |
"lengths" | |
].to(pred_energies.device).unsqueeze(1) | |
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels | |
mask_3d = mask_2d.unsqueeze(2) # (B, T, 1) | |
# Calculate element-wise MSE loss | |
mse_loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3) | |
# Calculate two-level Sliding NPS penalty | |
sliding_nps = batch["sliding_nps_labels"].to(pred_energies.device) # (B, T) | |
penalty_coefficients = torch.zeros_like(sliding_nps) # (B, T) | |
is_zero_nps = sliding_nps == 0.0 | |
is_not_zero_nps = ~is_zero_nps | |
# Apply heavy penalty where sliding_nps is 0 | |
penalty_coefficients[is_zero_nps] = self.nps_penalty_weight_beta | |
# Apply continuous penalty where sliding_nps > 0 | |
penalty_coefficients[is_not_zero_nps] = self.nps_penalty_weight_alpha * ( | |
1 - sliding_nps[is_not_zero_nps] | |
) | |
# Apply penalty factor to the MSE loss | |
loss_elementwise = mse_loss_elementwise * ( | |
1 + penalty_coefficients.unsqueeze(2) | |
) | |
# Apply the mask to the combined loss | |
masked_loss = loss_elementwise * mask_3d | |
if self.reduction == "mean": | |
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements | |
total_loss = masked_loss.sum() | |
num_valid_elements = mask_3d.sum() # Total number of unmasked float values | |
if num_valid_elements > 0: | |
return total_loss / num_valid_elements | |
else: | |
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0) | |
return torch.tensor( | |
0.0, device=pred_energies.device, requires_grad=True | |
) | |
elif self.reduction == "sum": | |
return masked_loss.sum() | |
else: # 'none' or any other case | |
return masked_loss | |