Spaces:
Running
on
Zero
Running
on
Zero
| EPS = 1e-7 | |
| import kornia | |
| from typing import Dict, Iterator, List, Optional, Tuple, Union | |
| import torchvision | |
| from guided_diffusion import dist_util, logger | |
| from pdb import set_trace as st | |
| from torch.nn import functional as F | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import lpips | |
| from . import * | |
| from .sdfstudio_losses import ScaleAndShiftInvariantLoss | |
| from ldm.util import default, instantiate_from_config | |
| from .vqperceptual import hinge_d_loss, vanilla_d_loss | |
| from torch.autograd import Variable | |
| from math import exp | |
| def gaussian(window_size, sigma): | |
| gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
| return gauss / gauss.sum() | |
| def create_window(window_size, channel): | |
| _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
| _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
| window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
| return window | |
| def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
| mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
| mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
| mu1_sq = mu1.pow(2) | |
| mu2_sq = mu2.pow(2) | |
| mu1_mu2 = mu1 * mu2 | |
| sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
| sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
| sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
| if size_average: | |
| return ssim_map.mean() | |
| else: | |
| return ssim_map.mean(1).mean(1).mean(1) | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| nn.init.normal_(m.weight.data, 0.0, 0.02) | |
| elif classname.find("BatchNorm") != -1: | |
| nn.init.normal_(m.weight.data, 1.0, 0.02) | |
| nn.init.constant_(m.bias.data, 0) | |
| # Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) | |
| def extract_key(prediction, key): | |
| if isinstance(prediction, dict): | |
| return prediction[key] | |
| return prediction | |
| class SILogLoss(nn.Module): | |
| """SILog loss (pixel-wise)""" | |
| def __init__(self, beta=0.15): | |
| super(SILogLoss, self).__init__() | |
| self.name = 'SILog' | |
| self.beta = beta | |
| def forward(self, | |
| input, | |
| target, | |
| mask=None, | |
| interpolate=True, | |
| return_interpolated=False): | |
| # input = extract_key(input, KEY_OUTPUT) | |
| if input.shape[-1] != target.shape[-1] and interpolate: | |
| input = nn.functional.interpolate(input, | |
| target.shape[-2:], | |
| mode='bilinear', | |
| align_corners=True) | |
| intr_input = input | |
| else: | |
| intr_input = input | |
| if target.ndim == 3: | |
| target = target.unsqueeze(1) | |
| if mask is not None: | |
| if mask.ndim == 3: | |
| mask = mask.unsqueeze(1) | |
| input = input[mask] | |
| target = target[mask] | |
| # with torch.amp.autocast(enabled=False): # amp causes NaNs in this loss function | |
| alpha = 1e-7 | |
| g = torch.log(input + alpha) - torch.log(target + alpha) | |
| # n, c, h, w = g.shape | |
| # norm = 1/(h*w) | |
| # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 | |
| Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) | |
| loss = 10 * torch.sqrt(Dg) | |
| if torch.isnan(loss): | |
| print("Nan SILog loss") | |
| print("input:", input.shape) | |
| print("target:", target.shape) | |
| print("G", torch.sum(torch.isnan(g))) | |
| print("Input min max", torch.min(input), torch.max(input)) | |
| print("Target min max", torch.min(target), torch.max(target)) | |
| print("Dg", torch.isnan(Dg)) | |
| print("loss", torch.isnan(loss)) | |
| if not return_interpolated: | |
| return loss | |
| return loss, intr_input | |
| def get_outnorm(x: torch.Tensor, out_norm: str = '') -> torch.Tensor: | |
| """ Common function to get a loss normalization value. Can | |
| normalize by either the batch size ('b'), the number of | |
| channels ('c'), the image size ('i') or combinations | |
| ('bi', 'bci', etc) | |
| """ | |
| # b, c, h, w = x.size() | |
| img_shape = x.shape | |
| if not out_norm: | |
| return 1 | |
| norm = 1 | |
| if 'b' in out_norm: | |
| # normalize by batch size | |
| # norm /= b | |
| norm /= img_shape[0] | |
| if 'c' in out_norm: | |
| # normalize by the number of channels | |
| # norm /= c | |
| norm /= img_shape[-3] | |
| if 'i' in out_norm: | |
| # normalize by image/map size | |
| # norm /= h*w | |
| norm /= img_shape[-1] * img_shape[-2] | |
| return norm | |
| class CharbonnierLoss(torch.nn.Module): | |
| """Charbonnier Loss (L1)""" | |
| def __init__(self, eps=1e-6, out_norm: str = 'bci'): | |
| super(CharbonnierLoss, self).__init__() | |
| self.eps = eps | |
| self.out_norm = out_norm | |
| def forward(self, x, y): | |
| norm = get_outnorm(x, self.out_norm) | |
| loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2)) | |
| return loss * norm | |
| def feature_vae_loss(feature): | |
| # kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) | |
| # feature dim: B C H W | |
| mu = feature.mean(1) | |
| var = feature.var(1) | |
| log_var = torch.log(var) | |
| kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu**2 - var, dim=1), dim=0) | |
| return kld | |
| def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff): | |
| # return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) | |
| kl_lambda = max( | |
| min( | |
| min_kl_coeff + (max_kl_coeff - min_kl_coeff) * | |
| (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) | |
| return torch.tensor(kl_lambda, device=dist_util.dev()) | |
| def depth_smoothness_loss(alpha_pred, depth_pred): | |
| # from PesonNeRF paper. | |
| # all Tensor shape B 1 H W | |
| geom_loss = ( | |
| alpha_pred[..., :-1] * alpha_pred[..., 1:] * ( | |
| depth_pred[..., :-1] - depth_pred[..., 1:] # W dim | |
| ).square()).mean() # mean of ([8, 1, 64, 63]) | |
| geom_loss += (alpha_pred[..., :-1, :] * alpha_pred[..., 1:, :] * | |
| (depth_pred[..., :-1, :] - depth_pred[..., 1:, :]).square() | |
| ).mean() # H dim, ([8, 1, 63, 64]) | |
| return geom_loss | |
| # https://github.com/elliottwu/unsup3d/blob/master/unsup3d/networks.py#L140 | |
| class LPIPSLoss(torch.nn.Module): | |
| def __init__( | |
| self, | |
| loss_weight=1.0, | |
| use_input_norm=True, | |
| range_norm=True, | |
| # n1p1_input=True, | |
| ): | |
| super(LPIPSLoss, self).__init__() | |
| self.perceptual = lpips.LPIPS(net="alex", spatial=False).eval() | |
| self.loss_weight = loss_weight | |
| self.use_input_norm = use_input_norm | |
| self.range_norm = range_norm | |
| # if self.use_input_norm: | |
| # # the mean is for image with range [0, 1] | |
| # self.register_buffer( | |
| # 'mean', | |
| # torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
| # # the std is for image with range [0, 1] | |
| # self.register_buffer( | |
| # 'std', | |
| # torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
| def forward(self, pred, target, conf_sigma_percl=None): | |
| # st() | |
| lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) | |
| return self.loss_weight * lpips_loss.mean() | |
| # mask-aware perceptual loss | |
| class PerceptualLoss(nn.Module): | |
| def __init__(self, requires_grad=False): | |
| super(PerceptualLoss, self).__init__() | |
| mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) | |
| std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) | |
| self.register_buffer('mean_rgb', mean_rgb) | |
| self.register_buffer('std_rgb', std_rgb) | |
| vgg_pretrained_features = torchvision.models.vgg16( | |
| pretrained=True).features | |
| self.slice1 = nn.Sequential() | |
| self.slice2 = nn.Sequential() | |
| self.slice3 = nn.Sequential() | |
| self.slice4 = nn.Sequential() | |
| for x in range(4): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(4, 9): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(9, 16): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(16, 23): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def normalize(self, x): | |
| out = x / 2 + 0.5 | |
| out = (out - self.mean_rgb.view(1, 3, 1, 1)) / self.std_rgb.view( | |
| 1, 3, 1, 1) | |
| return out | |
| def __call__(self, im1, im2, mask=None, conf_sigma=None): | |
| im = torch.cat([im1, im2], 0) | |
| im = self.normalize(im) # normalize input | |
| ## compute features | |
| feats = [] | |
| f = self.slice1(im) | |
| feats += [torch.chunk(f, 2, dim=0)] | |
| f = self.slice2(f) | |
| feats += [torch.chunk(f, 2, dim=0)] | |
| f = self.slice3(f) | |
| feats += [torch.chunk(f, 2, dim=0)] | |
| f = self.slice4(f) | |
| feats += [torch.chunk(f, 2, dim=0)] | |
| losses = [] | |
| for f1, f2 in feats[2:3]: # use relu3_3 features only | |
| loss = (f1 - f2)**2 | |
| if conf_sigma is not None: | |
| loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + | |
| EPS).log() | |
| if mask is not None: | |
| b, c, h, w = loss.shape | |
| _, _, hm, wm = mask.shape | |
| sh, sw = hm // h, wm // w | |
| mask0 = nn.functional.avg_pool2d(mask, | |
| kernel_size=(sh, sw), | |
| stride=(sh, | |
| sw)).expand_as(loss) | |
| loss = (loss * mask0).sum() / mask0.sum() | |
| else: | |
| loss = loss.mean() | |
| losses += [loss] | |
| return sum(losses) | |
| # add confidence support, unsup3d version | |
| def photometric_loss_laplace(im1, im2, mask=None, conf_sigma=None): | |
| loss = (im1 - im2).abs() | |
| # loss = (im1 - im2).square() | |
| if conf_sigma is not None: | |
| loss = loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + EPS).log() | |
| if mask is not None: | |
| mask = mask.expand_as(loss) | |
| loss = (loss * mask).sum() / mask.sum() | |
| else: | |
| loss = loss.mean() | |
| return loss | |
| # gaussian likelihood version, What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? | |
| # also used in the mask-aware vgg loss | |
| def photometric_loss(im1, im2, mask=None, conf_sigma=None): | |
| # loss = torch.nn.functional.mse_loss(im1, im2, reduce='none') | |
| loss = (im1 - im2).square() | |
| if conf_sigma is not None: | |
| loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + EPS).log() | |
| if mask is not None: | |
| mask = mask.expand_as(loss) | |
| loss = (loss * mask).sum() / mask.sum() | |
| else: | |
| loss = loss.mean() | |
| return loss | |
| class E3DGELossClass(torch.nn.Module): | |
| def __init__(self, device, opt) -> None: | |
| super().__init__() | |
| self.opt = opt | |
| self.device = device | |
| self.criterionImg = { | |
| 'mse': torch.nn.MSELoss(), | |
| 'l1': torch.nn.L1Loss(), | |
| 'charbonnier': CharbonnierLoss(), | |
| }[opt.color_criterion] | |
| self.criterion_latent = { | |
| 'mse': torch.nn.MSELoss(), | |
| 'l1': torch.nn.L1Loss(), | |
| 'vae': feature_vae_loss | |
| }[opt.latent_criterion] | |
| # self.criterionLPIPS = LPIPS(net_type='alex', device=device).eval() | |
| if opt.lpips_lambda > 0: | |
| self.criterionLPIPS = LPIPSLoss(loss_weight=opt.lpips_lambda) | |
| # self.criterionLPIPS = torch.nn.MSELoss() | |
| if opt.id_lambda > 0: | |
| self.criterionID = IDLoss(device=device).eval() | |
| self.id_loss_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| # define 3d rec loss, for occupancy | |
| # self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none') | |
| # self.criterion_alpha = torch.nn.SmoothL1Loss() | |
| # self.criterion3d_rec = torch.nn.MSELoss(reduction='none') | |
| self.criterion_alpha = torch.nn.L1Loss() | |
| if self.opt.depth_lambda > 0: | |
| self.criterion3d_rec = ScaleAndShiftInvariantLoss(alpha=0.5, | |
| scales=1) | |
| else: | |
| self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none') | |
| # self.silog_loss = SILogLoss() | |
| logger.log('init loss class finished', ) | |
| def calc_scale_invariant_depth_loss(self, pred_depth: torch.Tensor, | |
| gt_depth: torch.Tensor, | |
| gt_depth_mask: torch.Tensor): | |
| """apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss | |
| """ | |
| shape_loss_dict = {} | |
| assert gt_depth_mask is not None | |
| shape_loss = self.criterion3d_rec(pred_depth, gt_depth, gt_depth_mask) | |
| if shape_loss > 0.2: # hinge loss, avoid ood gradient | |
| shape_loss = torch.zeros_like(shape_loss) | |
| else: | |
| shape_loss *= self.opt.depth_lambda | |
| shape_loss_dict['loss_depth'] = shape_loss | |
| shape_loss_dict['depth_fgratio'] = gt_depth_mask.mean() | |
| # return l_si, shape_loss_dict | |
| return shape_loss, shape_loss_dict | |
| def calc_depth_loss(self, pred_depth: torch.Tensor, gt_depth: torch.Tensor, | |
| gt_depth_mask: torch.Tensor): | |
| """apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss | |
| """ | |
| shape_loss_dict = {} | |
| shape_loss = self.criterion3d_rec(pred_depth, gt_depth) | |
| if gt_depth_mask is not None: | |
| # pred_depth *= gt_depth_mask | |
| # gt_depth *= gt_depth_mask | |
| shape_loss *= gt_depth_mask | |
| shape_loss = shape_loss.sum() / gt_depth_mask.sum() | |
| # else: | |
| # shape_loss /= pred_depth.numel() | |
| # l_si = self.silog_loss(pred_depth, gt_depth, mask=None, interpolate=True, return_interpolated=False) | |
| # l_si *= self.opt.depth_lambda | |
| # shape_loss_dict['loss_depth'] = l_si | |
| shape_loss_dict['loss_depth'] = shape_loss.clamp( | |
| min=0, max=0.1) * self.opt.depth_lambda | |
| # return l_si, shape_loss_dict | |
| return shape_loss, shape_loss_dict | |
| def calc_alpha_loss(self, pred_alpha, gt_depth_mask): | |
| # return self.criterionImg(alpha, gt_depth_mask.float()) | |
| if gt_depth_mask.ndim == 3: | |
| gt_depth_mask = gt_depth_mask.unsqueeze(1) | |
| if gt_depth_mask.shape[1] == 3: | |
| gt_depth_mask = gt_depth_mask[:, 0:1, ...] # B 1 H W | |
| assert pred_alpha.shape == gt_depth_mask.shape | |
| alpha_loss = self.criterion_alpha(pred_alpha, gt_depth_mask) | |
| return alpha_loss | |
| def calc_mask_mse_loss( | |
| self, | |
| input, | |
| gt, | |
| gt_depth_mask, | |
| # conf_sigma=None, | |
| conf_sigma_l1=None, | |
| # conf_sigma_percl=None, | |
| use_fg_ratio=False): | |
| if gt_depth_mask.ndim == 3: | |
| gt_depth_mask = gt_depth_mask.unsqueeze(1).repeat_interleave(3, 1) | |
| else: | |
| assert gt_depth_mask.shape == input.shape | |
| gt_depth_mask = gt_depth_mask.float() | |
| if conf_sigma_l1 is None: | |
| rec_loss = torch.nn.functional.mse_loss( | |
| input.float(), gt.float(), | |
| reduction='none') # 'sum' already divide by batch size n | |
| else: | |
| rec_loss = photometric_loss( | |
| input, gt, gt_depth_mask, conf_sigma_l1 | |
| ) # ! only cauclate laplace on the foreground, or bg confidence low, large gradient. | |
| return rec_loss | |
| # rec_loss = torch.nn.functional.l1_loss( # for laplace loss | |
| # input.float(), gt.float(), | |
| # reduction='none') # 'sum' already divide by batch size n | |
| # gt_depth_mask = torch.ones_like(gt_depth_mask) # ! DEBUGGING | |
| # if conf_sigma is not None: # from unsup3d, but a L2 version | |
| # rec_loss = rec_loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + | |
| # EPS).log() | |
| # return rec_loss.mean() | |
| # rec_loss = torch.exp(-(rec_loss * 2**0.5 / (conf_sigma + EPS))) * 1/(conf_sigma + | |
| # EPS) / (2**0.5) | |
| fg_size = gt_depth_mask.sum() | |
| # fg_ratio = fg_size / torch.ones_like(gt_depth_mask).sum() if use_fg_ratio else 1 | |
| fg_loss = rec_loss * gt_depth_mask | |
| fg_loss = fg_loss.sum() / fg_size # * fg_ratio | |
| if self.opt.bg_lamdba > 0: | |
| bg_loss = rec_loss * (1 - gt_depth_mask) | |
| bg_loss = bg_loss.sum() / (1 - gt_depth_mask).sum() | |
| rec_loss = fg_loss + bg_loss * self.opt.bg_lamdba | |
| else: | |
| rec_loss = fg_loss | |
| return rec_loss | |
| def calc_2d_rec_loss( | |
| self, | |
| input, | |
| gt, | |
| depth_fg_mask, | |
| test_mode=True, | |
| step=1, | |
| ignore_lpips=False, | |
| # conf_sigma=None, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None, | |
| ): | |
| opt = self.opt | |
| loss_dict = {} | |
| # logger.log(test_mode) | |
| # logger.log(input.min(), input.max(), gt.min(), gt.max()) | |
| if test_mode or not opt.fg_mse: | |
| rec_loss = self.criterionImg(input, gt) | |
| else: | |
| rec_loss = self.calc_mask_mse_loss( | |
| input, | |
| gt, | |
| depth_fg_mask, | |
| conf_sigma_l1=conf_sigma_l1, | |
| ) | |
| # conf_sigma_percl=conf_sigma_percl) | |
| # conf_sigma) | |
| # if step == 300: | |
| # st() | |
| if opt.lpips_lambda > 0 and step >= opt.lpips_delay_iter and not ignore_lpips: # tricky solution to avoid NAN in LPIPS loss | |
| # with torch.autocast(device_type='cuda', | |
| # dtype=torch.float16, | |
| # enabled=False): | |
| # if test_mode or not opt.fg_mse: # no need to calculate background lpips for ease of computation | |
| lpips_loss = self.criterionLPIPS( | |
| input, | |
| gt, | |
| conf_sigma_percl=conf_sigma_percl, | |
| ) | |
| # else: # fg lpips | |
| # assert depth_fg_mask.shape == input.shape | |
| # lpips_loss = self.criterionLPIPS( | |
| # input.contiguous() * depth_fg_mask, | |
| # gt.contiguous() * depth_fg_mask).mean() | |
| else: | |
| lpips_loss = torch.tensor(0., device=input.device) | |
| if opt.ssim_lambda > 0: | |
| loss_ssim = self.ssim_loss(input, gt) #? | |
| else: | |
| loss_ssim = torch.tensor(0., device=input.device) | |
| loss_psnr = self.psnr((input / 2 + 0.5), (gt / 2 + 0.5), 1.0) | |
| if opt.id_lambda > 0: | |
| loss_id = self._calc_loss_id(input, gt) | |
| else: | |
| loss_id = torch.tensor(0., device=input.device) | |
| if opt.l1_lambda > 0: | |
| loss_l1 = F.l1_loss(input, gt) | |
| else: | |
| loss_l1 = torch.tensor(0., device=input.device) | |
| # loss = rec_loss * opt.l2_lambda + lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda | |
| loss = rec_loss * opt.l2_lambda + lpips_loss + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda + opt.l1_lambda * loss_l1 | |
| # if return_dict: | |
| loss_dict['loss_l2'] = rec_loss | |
| loss_dict['loss_id'] = loss_id | |
| loss_dict['loss_lpips'] = lpips_loss | |
| loss_dict['loss'] = loss | |
| loss_dict['loss_ssim'] = loss_ssim | |
| # metrics to report, not involved in training | |
| loss_dict['mae'] = loss_l1 | |
| loss_dict['PSNR'] = loss_psnr | |
| loss_dict['SSIM'] = 1 - loss_ssim # Todo | |
| loss_dict['ID_SIM'] = 1 - loss_id | |
| return loss, loss_dict | |
| def calc_shape_rec_loss( | |
| self, | |
| pred_shape: dict, | |
| gt_shape: dict, | |
| device, | |
| ): | |
| """apply 3d shape reconstruction supervision. Basically supervise the densities with L1 loss | |
| Args: | |
| pred_shape (dict): dict contains reconstructed shape information | |
| gt_shape (dict): dict contains gt shape information | |
| supervise_sdf (bool, optional): whether supervise sdf rec. Defaults to True. | |
| supervise_surface_normal (bool, optional): whether supervise surface rec. Defaults to False. | |
| Returns: | |
| dict: shape reconstruction loss | |
| """ | |
| shape_loss_dict = {} | |
| shape_loss = 0 | |
| # assert supervise_sdf or supervise_surface_normal, 'should at least supervise one types of shape reconstruction' | |
| # todo, add weights | |
| if self.opt.shape_uniform_lambda > 0: | |
| shape_loss_dict['coarse'] = self.criterion3d_rec( | |
| pred_shape['coarse_densities'].squeeze(), | |
| gt_shape['coarse_densities'].squeeze()) | |
| shape_loss += shape_loss_dict[ | |
| 'coarse'] * self.opt.shape_uniform_lambda | |
| if self.opt.shape_importance_lambda > 0: | |
| shape_loss_dict['fine'] = self.criterion3d_rec( | |
| pred_shape['fine_densities'].squeeze(), # ? how to supervise | |
| gt_shape['fine_densities'].squeeze()) | |
| shape_loss += shape_loss_dict[ | |
| 'fine'] * self.opt.shape_importance_lambda | |
| loss_depth = self.criterion_alpha(pred_shape['image_depth'], | |
| gt_shape['image_depth']) | |
| shape_loss += loss_depth * self.opt.shape_depth_lambda | |
| shape_loss_dict.update(dict(loss_depth=loss_depth)) | |
| # TODO, add on surface pts supervision ? | |
| return shape_loss, shape_loss_dict | |
| def psnr(self, input, target, max_val): | |
| return kornia.metrics.psnr(input, target, max_val) | |
| # @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) | |
| def ssim_loss(self, img1, img2, window_size=11, size_average=True): | |
| channel = img1.size(-3) | |
| window = create_window(window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| return 1 - _ssim(img1, img2, window, window_size, channel, size_average) | |
| def forward(self, | |
| pred, | |
| gt, | |
| test_mode=True, | |
| step=1, | |
| return_fg_mask=False, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None, | |
| *args, | |
| **kwargs): | |
| with torch.autocast(device_type='cuda', | |
| dtype=torch.float16, | |
| enabled=False): | |
| loss = torch.tensor(0., device=self.device) | |
| loss_dict = {} | |
| # balance rec_loss with logvar | |
| # if 'depth_mask' in gt: | |
| if self.opt.online_mask: | |
| # https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L193 | |
| margin = (self.opt.max_depth - self.opt.min_depth) / 2 | |
| fg_mask = (pred['image_depth'] | |
| < self.opt.max_depth + margin).float() # B 1 H W | |
| fg_mask = fg_mask.repeat_interleave(3, 1).float() | |
| else: | |
| if 'depth_mask' in gt: | |
| fg_mask = gt['depth_mask'].unsqueeze(1).repeat_interleave( | |
| 3, 1).float() | |
| else: | |
| fg_mask = None | |
| loss_2d, loss_2d_dict = self.calc_2d_rec_loss( | |
| pred['image_raw'], | |
| gt['img'], | |
| fg_mask, | |
| test_mode=test_mode, | |
| step=step, | |
| ignore_lpips=False, | |
| conf_sigma_l1=conf_sigma_l1, | |
| conf_sigma_percl=conf_sigma_percl) | |
| # ignore_lpips=self.opt.fg_mse) | |
| if self.opt.kl_lambda > 0: | |
| # assert 'posterior' in pred, 'logvar' in pred | |
| assert 'posterior' in pred | |
| kl_loss = pred['posterior'].kl() | |
| kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] | |
| if self.opt.kl_anneal: | |
| kl_lambda = kl_coeff( | |
| step=step, | |
| constant_step=5e3, # 1w steps | |
| total_step=25e3, # 5w steps in total | |
| min_kl_coeff=max(1e-9, self.opt.kl_lambda / 1e4), | |
| max_kl_coeff=self.opt.kl_lambda) | |
| loss_dict['kl_lambda'] = kl_lambda | |
| else: | |
| loss_dict['kl_lambda'] = torch.tensor( | |
| self.opt.kl_lambda, device=dist_util.dev()) | |
| # loss_dict['kl_loss'] = kl_loss * kl_lambda | |
| # loss_dict['kl_loss'] = kl_loss * kl_lambda | |
| loss_dict['kl_loss'] = kl_loss * loss_dict['kl_lambda'] | |
| loss += loss_dict['kl_loss'] | |
| # nll_loss = loss_2d / torch.exp(pred['logvar']) + pred['logvar'] # nll_loss | |
| nll_loss = loss_2d | |
| loss += nll_loss | |
| loss_dict.update(dict(nll_loss=nll_loss)) | |
| # loss_dict['latent_mu'] = pred['latent_normalized'].mean() | |
| # loss_dict['latent_max'] = pred['latent_normalized'].max() | |
| # loss_dict['latent_min'] = pred['latent_normalized'].min() | |
| # loss_dict['latent_std'] = pred['latent_normalized'].std() | |
| loss_dict['latent_mu'] = pred[ | |
| 'latent_normalized_2Ddiffusion'].mean() | |
| loss_dict['latent_max'] = pred[ | |
| 'latent_normalized_2Ddiffusion'].max() | |
| loss_dict['latent_min'] = pred[ | |
| 'latent_normalized_2Ddiffusion'].min() | |
| loss_dict['latent_std'] = pred[ | |
| 'latent_normalized_2Ddiffusion'].std() | |
| else: | |
| loss += loss_2d | |
| # if 'image_sr' in pred and pred['image_sr'].shape==gt['img_sr']: | |
| if 'image_sr' in pred: | |
| if 'depth_mask_sr' in gt: | |
| depth_mask_sr = gt['depth_mask_sr'].unsqueeze( | |
| 1).repeat_interleave(3, 1).float() | |
| else: | |
| depth_mask_sr = None | |
| loss_sr, loss_sr_dict = self.calc_2d_rec_loss( | |
| pred['image_sr'], | |
| gt['img_sr'], | |
| depth_fg_mask=depth_mask_sr, | |
| # test_mode=test_mode, | |
| test_mode=True, | |
| step=step) | |
| loss_sr_lambda = 1 | |
| if step < self.opt.sr_delay_iter: | |
| loss_sr_lambda = 0 | |
| loss += loss_sr * loss_sr_lambda | |
| for k, v in loss_sr_dict.items(): | |
| loss_dict['sr_' + k] = v * loss_sr_lambda | |
| if self.opt.depth_lambda > 0: # TODO, switch to scale-agnostic depth loss | |
| assert 'depth' in gt | |
| pred_depth = pred['image_depth'] | |
| if pred_depth.ndim == 4: | |
| pred_depth = pred_depth.squeeze(1) # B H W | |
| # loss_3d, shape_loss_dict = self.calc_depth_loss( | |
| # pred_depth, gt['depth'], fg_mask[:, 0, ...]) | |
| _, shape_loss_dict = self.calc_scale_invariant_depth_loss( | |
| pred_depth, gt['depth'], fg_mask[:, 0, ...]) | |
| loss += shape_loss_dict['loss_depth'] | |
| loss_dict.update(shape_loss_dict) | |
| # if self.opt.latent_lambda > 0: # make sure the latent suits diffusion learning | |
| # latent_mu = pred['latent'].mean() | |
| # loss_latent = self.criterion_latent( | |
| # latent_mu, torch.zeros_like( | |
| # latent_mu)) # only regularize the mean value here | |
| # loss_dict['loss_latent'] = loss_latent | |
| # loss += loss_latent * self.opt.latent_lambda | |
| if 'image_mask' in pred: | |
| pred_alpha = pred['image_mask'] # B 1 H W | |
| else: | |
| N, _, H, W = pred['image_depth'].shape | |
| pred_alpha = pred['weights_samples'].permute(0, 2, 1).reshape( | |
| N, 1, H, W) | |
| if self.opt.alpha_lambda > 0 and 'image_depth' in pred: | |
| loss_alpha = self.calc_alpha_loss(pred_alpha, fg_mask) | |
| loss_dict['loss_alpha'] = loss_alpha * self.opt.alpha_lambda | |
| loss += loss_alpha * self.opt.alpha_lambda | |
| if self.opt.depth_smoothness_lambda > 0: | |
| loss_depth_smoothness = depth_smoothness_loss( | |
| pred_alpha, | |
| pred['image_depth']) * self.opt.depth_smoothness_lambda | |
| loss_dict['loss_depth_smoothness'] = loss_depth_smoothness | |
| loss += loss_depth_smoothness | |
| loss_2d_dict['all_loss'] = loss | |
| loss_dict.update(loss_2d_dict) | |
| # if return_fg_mask: | |
| return loss, loss_dict, fg_mask | |
| # else: | |
| # return loss, loss_dict | |
| def _calc_loss_id(self, input, gt): | |
| if input.shape[-1] != 256: | |
| arcface_input = self.id_loss_pool(input) | |
| id_loss_gt = self.id_loss_pool(gt) | |
| else: | |
| arcface_input = input | |
| id_loss_gt = gt | |
| loss_id, _, _ = self.criterionID(arcface_input, id_loss_gt, id_loss_gt) | |
| return loss_id | |
| def calc_2d_rec_loss_misaligned(self, input, gt): | |
| """id loss + vgg loss | |
| Args: | |
| input (_type_): _description_ | |
| gt (_type_): _description_ | |
| depth_mask (_type_): _description_ | |
| test_mode (bool, optional): _description_. Defaults to True. | |
| """ | |
| opt = self.opt | |
| loss_dict = {} | |
| if opt.lpips_lambda > 0: | |
| with torch.autocast( | |
| device_type='cuda', dtype=torch.float16, | |
| enabled=False): # close AMP for lpips to avoid nan | |
| lpips_loss = self.criterionLPIPS(input, gt) | |
| else: | |
| lpips_loss = torch.tensor(0., device=input.device) | |
| if opt.id_lambda > 0: | |
| loss_id = self._calc_loss_id(input, gt) | |
| else: | |
| loss_id = torch.tensor(0., device=input.device) | |
| loss_dict['loss_id_real'] = loss_id | |
| loss_dict['loss_lpips_real'] = lpips_loss | |
| loss = lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda | |
| return loss, loss_dict | |
| class E3DGE_with_AdvLoss(E3DGELossClass): | |
| # adapted from sgm/modules/autoencoding/losses/discriminator_loss.py | |
| def __init__( | |
| self, | |
| device, | |
| opt, | |
| discriminator_config: Optional[Dict] = None, | |
| disc_num_layers: int = 3, | |
| disc_in_channels: int = 3, | |
| disc_start: int = 0, | |
| disc_loss: str = "hinge", | |
| disc_factor: float = 1.0, | |
| disc_weight: float = 1.0, | |
| regularization_weights: Union[None, Dict[str, float]] = None, | |
| # additional_log_keys: Optional[List[str]] = None, | |
| ) -> None: | |
| super().__init__( | |
| device, | |
| opt, | |
| ) | |
| # ! initialize GAN loss | |
| discriminator_config = default( | |
| discriminator_config, | |
| { | |
| "target": | |
| "nsr.losses.disc.NLayerDiscriminator", | |
| "params": { | |
| "input_nc": disc_in_channels, | |
| "n_layers": disc_num_layers, | |
| "use_actnorm": False, | |
| }, | |
| }, | |
| ) | |
| self.discriminator = instantiate_from_config( | |
| discriminator_config).apply(weights_init) | |
| self.discriminator_iter_start = disc_start | |
| self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss | |
| self.disc_factor = disc_factor | |
| self.discriminator_weight = disc_weight # self.regularization_weights = default(regularization_weights, {}) | |
| # self.forward_keys = [ | |
| # "optimizer_idx", | |
| # "global_step", | |
| # "last_layer", | |
| # "split", | |
| # "regularization_log", | |
| # ] | |
| # self.additional_log_keys = set(default(additional_log_keys, [])) | |
| # self.additional_log_keys.update(set( | |
| # self.regularization_weights.keys())) | |
| def get_trainable_parameters(self) -> Iterator[nn.Parameter]: | |
| return self.discriminator.parameters() | |
| def forward(self, | |
| pred, | |
| gt, | |
| behaviour: str, | |
| test_mode=True, | |
| step=1, | |
| return_fg_mask=False, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None, | |
| *args, | |
| **kwargs): | |
| # now the GAN part | |
| reconstructions = pred['image_raw'] | |
| inputs = gt['img'] | |
| if behaviour == 'g_step': | |
| nll_loss, loss_dict, fg_mask = super().forward( | |
| pred, | |
| gt, | |
| test_mode, | |
| step, | |
| return_fg_mask, | |
| conf_sigma_l1, | |
| conf_sigma_percl, | |
| *args, | |
| **kwargs) | |
| # generator update | |
| if step >= self.discriminator_iter_start or not self.training: | |
| logits_fake = self.discriminator(reconstructions.contiguous()) | |
| g_loss = -torch.mean(logits_fake) | |
| if self.training: | |
| d_weight = torch.tensor(self.discriminator_weight) | |
| # d_weight = self.calculate_adaptive_weight( | |
| # nll_loss, g_loss, last_layer=last_layer) | |
| else: | |
| d_weight = torch.tensor(1.0) | |
| else: | |
| d_weight = torch.tensor(0.0) | |
| g_loss = torch.tensor(0.0, requires_grad=True) | |
| loss = nll_loss + d_weight * self.disc_factor * g_loss | |
| # TODO | |
| loss_dict.update({ | |
| f"loss/g": g_loss.detach().mean(), | |
| }) | |
| # return loss, log | |
| return loss, loss_dict, fg_mask | |
| elif behaviour == 'd_step': | |
| # second pass for discriminator update | |
| logits_real = self.discriminator(inputs.contiguous().detach()) | |
| logits_fake = self.discriminator( | |
| reconstructions.contiguous().detach()) | |
| if step >= self.discriminator_iter_start or not self.training: | |
| d_loss = self.disc_factor * self.disc_loss( | |
| logits_real, logits_fake) | |
| else: | |
| d_loss = torch.tensor(0.0, requires_grad=True) | |
| loss_dict = {} | |
| loss_dict.update({ | |
| "loss/disc": d_loss.clone().detach().mean(), | |
| "logits/real": logits_real.detach().mean(), | |
| "logits/fake": logits_fake.detach().mean(), | |
| }) | |
| return d_loss, loss_dict, None | |
| else: | |
| raise NotImplementedError(f"Unknown optimizer behaviour {behaviour}") | |