File size: 3,460 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
from torch.nn import functional as F

import os

from losses.style.custom_loss import custom_loss, prepare_mask
from losses.style.vgg_activations import VGG16_Activations, VGG19_Activations, Vgg_face_dag


class StyleLoss(nn.Module):
    def __init__(self, VGG16_ACTIVATIONS_LIST=[21], normalize=False, distance="l2"):

        super(StyleLoss, self).__init__()

        self.vgg16_act = VGG16_Activations(VGG16_ACTIVATIONS_LIST)
        self.vgg16_act.eval()

        self.normalize = normalize
        self.distance = distance

    def get_features(self, model, x):

        return model(x)

    def mask_features(self, x, mask):

        mask = prepare_mask(x, mask)
        return x * mask

    def gram_matrix(self, x):
        """
        :x is an activation tensor
        """
        N, C, H, W = x.shape
        x = x.view(N * C, H * W)
        G = torch.mm(x, x.t())

        return G.div(N * H * W * C)

    def cal_style(self, model, x, x_hat, mask1=None, mask2=None):
        # Get features from the model for x and x_hat
        with torch.no_grad():
            act_x = self.get_features(model, x)
        for layer in range(0, len(act_x)):
            act_x[layer].detach_()

        act_x_hat = self.get_features(model, x_hat)

        loss = 0.0
        for layer in range(0, len(act_x)):

            # mask features if present
            if mask1 is not None:
                feat_x = self.mask_features(act_x[layer], mask1)
            else:
                feat_x = act_x[layer]
            if mask2 is not None:
                feat_x_hat = self.mask_features(act_x_hat[layer], mask2)
            else:
                feat_x_hat = act_x_hat[layer]

            """
            import ipdb; ipdb.set_trace()
            fx = feat_x[0, ...].detach().cpu().numpy()
            fx = (fx - fx.min()) / (fx.max() - fx.min())
            fx = fx * 255.
            fxhat = feat_x_hat[0, ...].detach().cpu().numpy()
            fxhat = (fxhat - fxhat.min()) / (fxhat.max() - fxhat.min())
            fxhat = fxhat * 255
            from PIL import Image
            import numpy as np
            for idx, img in enumerate(fx):
                img = fx[idx, ...]
                img = img.astype(np.uint8)
                img = Image.fromarray(img)
                img.save('plot/feat_x/{}.png'.format(str(idx)))
                img = fxhat[idx, ...]
                img = img.astype(np.uint8)
                img = Image.fromarray(img)
                img.save('plot/feat_x_hat/{}.png'.format(str(idx)))
            import ipdb; ipdb.set_trace()
            """

            # compute Gram matrix for x and x_hat
            G_x = self.gram_matrix(feat_x)
            G_x_hat = self.gram_matrix(feat_x_hat)

            # compute layer wise loss and aggregate
            loss += custom_loss(
                G_x, G_x_hat, mask=None, loss_type=self.distance, include_bkgd=True
            )

        loss = loss / len(act_x)

        return loss

    def forward(self, x, x_hat, mask1=None, mask2=None):
        x = x.cuda()
        x_hat = x_hat.cuda()

        # resize images to 256px resolution
        N, C, H, W = x.shape
        upsample2d = nn.Upsample(
            scale_factor=256 / H, mode="bilinear", align_corners=True
        )

        x = upsample2d(x)
        x_hat = upsample2d(x_hat)

        loss = self.cal_style(self.vgg16_act, x, x_hat, mask1=mask1, mask2=mask2)

        return loss