Spaces:
Build error
Build error
File size: 1,745 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import torch.nn as nn
from torch.nn import functional as F
mse_loss = nn.MSELoss(reduction="mean")
def custom_loss(x, y, mask=None, loss_type="l2", include_bkgd=True):
"""
x, y: [N, C, H, W]
Computes L1/L2 loss
if include_bkgd is True:
use traditional MSE and L1 loss
else:
mask out background info using :mask
normalize loss with #1's in mask
"""
if include_bkgd:
# perform simple mse or l1 loss
if loss_type == "l2":
loss_rec = mse_loss(x, y)
elif loss_type == "l1":
loss_rec = F.l1_loss(x, y)
return loss_rec
Nx, Cx, Hx, Wx = x.shape
Nm, Cm, Hm, Wm = mask.shape
mask = prepare_mask(x, mask)
x_reshape = torch.reshape(x, [Nx, -1])
y_reshape = torch.reshape(y, [Nx, -1])
mask_reshape = torch.reshape(mask, [Nx, -1])
if loss_type == "l2":
diff = (x_reshape - y_reshape) ** 2
elif loss_type == "l1":
diff = torch.abs(x_reshape - y_reshape)
# diff: [N, Cx * Hx * Wx]
# set elements in diff to 0 using mask
masked_diff = diff * mask_reshape
sum_diff = torch.sum(masked_diff, axis=-1)
# count non-zero elements; add :mask_reshape elements
norm_count = torch.sum(mask_reshape, axis=-1)
diff_norm = sum_diff / (norm_count + 1.0)
loss_rec = torch.mean(diff_norm)
return loss_rec
def prepare_mask(x, mask):
"""
Make mask similar to x.
Mask contains values in [0, 1].
Adjust channels and spatial dimensions.
"""
Nx, Cx, Hx, Wx = x.shape
Nm, Cm, Hm, Wm = mask.shape
if Cm == 1:
mask = mask.repeat(1, Cx, 1, 1)
mask = F.interpolate(mask, scale_factor=Hx / Hm, mode="nearest")
return mask
|