RobinWZQ commited on
Commit
c8c90c7
·
1 Parent(s): aa31f9f

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +75 -0
  2. hist_loss.py +208 -0
  3. ics.jpg +0 -0
  4. net.py +281 -0
  5. style_trsfer.py +80 -0
  6. utils.py +168 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline,UniPCMultistepScheduler
2
+ import gradio as gr
3
+ import torch
4
+ import gc
5
+ from style_trsfer import style_transfer_method
6
+
7
+
8
+ def generate(style_image,text, negative_prompts,steps,guidance_scale):
9
+ pipeline = DiffusionPipeline.from_pretrained("./CCLAP")
10
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(
11
+ pipeline.scheduler.config)
12
+ device = torch.device(
13
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
14
+ if device.type == 'cuda':
15
+ pipeline.enable_xformers_memory_efficient_attention()
16
+ pipeline.to(device)
17
+ torch.cuda.empty_cache()
18
+ gc.collect()
19
+ content_image = pipeline(text,
20
+ num_inference_steps=steps,
21
+ negative_prompt=negative_prompts,
22
+ guidance_scale=guidance_scale).images[0]
23
+ result = style_transfer_method(content_image,style_image)
24
+ return content_image,result
25
+
26
+
27
+ if __name__ == '__main__':
28
+
29
+ demo = gr.Interface(title="CCLAP",
30
+ description = (
31
+ "This is the demo of CCLAP to generate Chinese landscape painting."
32
+ ),
33
+ css="",
34
+ fn=generate,
35
+ inputs=[gr.Image(label="Style Image",shape=(512,512)),
36
+ gr.Textbox(lines=3, placeholder="Input the prompt", label="Prompt"),
37
+ gr.Textbox(lines=3, placeholder="low quality", label="Negative prompt"),
38
+ gr.Slider(minimum=0, maximum=100, value=20,label='Steps'),
39
+ gr.Slider(minimum=0, maximum=30, value=7.5,label='Guidance_scale'),
40
+ ],
41
+ outputs=[gr.Image(label="Content Output",shape=(256,256)),
42
+ gr.Image(label="Final Output",shape=(256,256))],
43
+ examples = [
44
+ [
45
+ 'style_image/style1.jpg',
46
+ 'A Chinese landscape painting of a mountain landscape with trees',
47
+ 'low quality',
48
+ 20,
49
+ 7.5
50
+ ],
51
+ [
52
+ 'style_image/style2.jpg',
53
+ 'A Chinese landscape painting of a building with trees in front of it',
54
+ 'low quality',
55
+ 20,
56
+ 7.5
57
+ ],
58
+ [
59
+ 'style_image/style3.jpg',
60
+ 'A Chinese landscape painting of a landscape with mountains in the background',
61
+ 'low quality',
62
+ 20,
63
+ 7.5
64
+ ],
65
+ [
66
+ 'style_image/style4.jpg',
67
+ 'A Chinese landscape painting of a landscape with mountains and a river',
68
+ 'low quality',
69
+ 20,
70
+ 7.5
71
+ ],
72
+ ],
73
+ )
74
+
75
+ demo.launch()
hist_loss.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2021 Mahmoud Afifi.
3
+ Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
4
+ Controlling Colors of GAN-Generated and Real Images via Color Histograms."
5
+ In CVPR, 2021.
6
+
7
+ @inproceedings{afifi2021histogan,
8
+ title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
9
+ Color Histograms},
10
+ author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
11
+ booktitle={CVPR},
12
+ year={2021}
13
+ }
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from PIL import Image
19
+ import matplotlib.pyplot as plt
20
+ import torch.nn.functional as F
21
+ import torchvision.transforms as transforms
22
+ import numpy as np
23
+
24
+ EPS = 1e-6
25
+
26
+ class RGBuvHistBlock(nn.Module):
27
+ def __init__(self, h=64, insz=150, resizing='interpolation',
28
+ method='inverse-quadratic', sigma=0.02, intensity_scale=True,
29
+ device='cuda'):
30
+ """ Computes the RGB-uv histogram feature of a given image.
31
+ Args:
32
+ h: histogram dimension size (scalar). The default value is 64.
33
+ insz: maximum size of the input image; if it is larger than this size, the
34
+ image will be resized (scalar). Default value is 150 (i.e., 150 x 150
35
+ pixels).
36
+ resizing: resizing method if applicable. Options are: 'interpolation' or
37
+ 'sampling'. Default is 'interpolation'.
38
+ method: the method used to count the number of pixels for each bin in the
39
+ histogram feature. Options are: 'thresholding', 'RBF' (radial basis
40
+ function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'.
41
+ sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is
42
+ the sigma parameter of the kernel function. The default value is 0.02.
43
+ intensity_scale: boolean variable to use the intensity scale (I_y in
44
+ Equation 2). Default value is True.
45
+
46
+ Methods:
47
+ forward: accepts input image and returns its histogram feature. Note that
48
+ unless the method is 'thresholding', this is a differentiable function
49
+ and can be easily integrated with the loss function. As mentioned in the
50
+ paper, the 'inverse-quadratic' was found more stable than 'RBF' in our
51
+ training.
52
+ """
53
+ super(RGBuvHistBlock, self).__init__()
54
+ self.h = h
55
+ self.insz = insz
56
+ self.device = device
57
+ self.resizing = resizing
58
+ self.method = method
59
+ self.intensity_scale = intensity_scale
60
+ if self.method == 'thresholding':
61
+ self.eps = 6.0 / h
62
+ else:
63
+ self.sigma = sigma
64
+
65
+ def forward(self, x):
66
+ x = torch.clamp(x, 0, 1)
67
+ if x.shape[2] > self.insz or x.shape[3] > self.insz:
68
+ if self.resizing == 'interpolation':
69
+ x_sampled = F.interpolate(x, size=(self.insz, self.insz),
70
+ mode='bilinear', align_corners=False)
71
+ elif self.resizing == 'sampling':
72
+ inds_1 = torch.LongTensor(
73
+ np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
74
+ device=self.device)
75
+ inds_2 = torch.LongTensor(
76
+ np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
77
+ device=self.device)
78
+ x_sampled = x.index_select(2, inds_1)
79
+ x_sampled = x_sampled.index_select(3, inds_2)
80
+ else:
81
+ raise Exception(
82
+ f'Wrong resizing method. It should be: interpolation or sampling. '
83
+ f'But the given value is {self.resizing}.')
84
+ else:
85
+ x_sampled = x
86
+
87
+ L = x_sampled.shape[0] # size of mini-batch
88
+ if x_sampled.shape[1] > 3:
89
+ x_sampled = x_sampled[:, :3, :, :]
90
+ X = torch.unbind(x_sampled, dim=0)
91
+ hists = torch.zeros((x_sampled.shape[0], 3, self.h, self.h)).to(
92
+ device=self.device)
93
+ for l in range(L):
94
+ I = torch.t(torch.reshape(X[l], (3, -1)))
95
+ II = torch.pow(I, 2)
96
+ if self.intensity_scale:
97
+ Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
98
+ dim=1)
99
+ else:
100
+ Iy = 1
101
+
102
+ Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + EPS),
103
+ dim=1)
104
+ Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + EPS),
105
+ dim=1)
106
+ diff_u0 = abs(
107
+ Iu0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
108
+ dim=0).to(self.device))
109
+ diff_v0 = abs(
110
+ Iv0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
111
+ dim=0).to(self.device))
112
+ if self.method == 'thresholding':
113
+ diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
114
+ diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
115
+ elif self.method == 'RBF':
116
+ diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
117
+ 2) / self.sigma ** 2
118
+ diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
119
+ 2) / self.sigma ** 2
120
+ diff_u0 = torch.exp(-diff_u0) # Radial basis function
121
+ diff_v0 = torch.exp(-diff_v0)
122
+ elif self.method == 'inverse-quadratic':
123
+ diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
124
+ 2) / self.sigma ** 2
125
+ diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
126
+ 2) / self.sigma ** 2
127
+ diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic
128
+ diff_v0 = 1 / (1 + diff_v0)
129
+ else:
130
+ raise Exception(
131
+ f'Wrong kernel method. It should be either thresholding, RBF,'
132
+ f' inverse-quadratic. But the given value is {self.method}.')
133
+ diff_u0 = diff_u0.type(torch.float32)
134
+ diff_v0 = diff_v0.type(torch.float32)
135
+ a = torch.t(Iy * diff_u0)
136
+ hists[l, 0, :, :] = torch.mm(a, diff_v0)
137
+
138
+ Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
139
+ dim=1)
140
+ Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
141
+ dim=1)
142
+ diff_u1 = abs(
143
+ Iu1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
144
+ dim=0).to(self.device))
145
+ diff_v1 = abs(
146
+ Iv1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
147
+ dim=0).to(self.device))
148
+
149
+ if self.method == 'thresholding':
150
+ diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
151
+ diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
152
+ elif self.method == 'RBF':
153
+ diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
154
+ 2) / self.sigma ** 2
155
+ diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
156
+ 2) / self.sigma ** 2
157
+ diff_u1 = torch.exp(-diff_u1) # Gaussian
158
+ diff_v1 = torch.exp(-diff_v1)
159
+ elif self.method == 'inverse-quadratic':
160
+ diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
161
+ 2) / self.sigma ** 2
162
+ diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
163
+ 2) / self.sigma ** 2
164
+ diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic
165
+ diff_v1 = 1 / (1 + diff_v1)
166
+
167
+ diff_u1 = diff_u1.type(torch.float32)
168
+ diff_v1 = diff_v1.type(torch.float32)
169
+ a = torch.t(Iy * diff_u1)
170
+ hists[l, 1, :, :] = torch.mm(a, diff_v1)
171
+
172
+ Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + EPS),
173
+ dim=1)
174
+ Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + EPS),
175
+ dim=1)
176
+ diff_u2 = abs(
177
+ Iu2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
178
+ dim=0).to(self.device))
179
+ diff_v2 = abs(
180
+ Iv2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
181
+ dim=0).to(self.device))
182
+ if self.method == 'thresholding':
183
+ diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
184
+ diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
185
+ elif self.method == 'RBF':
186
+ diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
187
+ 2) / self.sigma ** 2
188
+ diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
189
+ 2) / self.sigma ** 2
190
+ diff_u2 = torch.exp(-diff_u2) # Gaussian
191
+ diff_v2 = torch.exp(-diff_v2)
192
+ elif self.method == 'inverse-quadratic':
193
+ diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
194
+ 2) / self.sigma ** 2
195
+ diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
196
+ 2) / self.sigma ** 2
197
+ diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic
198
+ diff_v2 = 1 / (1 + diff_v2)
199
+ diff_u2 = diff_u2.type(torch.float32)
200
+ diff_v2 = diff_v2.type(torch.float32)
201
+ a = torch.t(Iy * diff_u2)
202
+ hists[l, 2, :, :] = torch.mm(a, diff_v2)
203
+
204
+ # normalization
205
+ hists_normalized = hists / (
206
+ ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)
207
+
208
+ return hists_normalized
ics.jpg ADDED
net.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils import mean_variance_norm, DEVICE
4
+ from utils import calc_ss_loss, calc_remd_loss, calc_moment_loss, calc_mse_loss, calc_histogram_loss
5
+ from hist_loss import RGBuvHistBlock
6
+ import torch
7
+
8
+ class Net(nn.Module):
9
+ def __init__(self, args):
10
+ super(Net, self).__init__()
11
+ self.args = args
12
+ self.vgg = vgg19[:44]
13
+ self.vgg.load_state_dict(torch.load('./checkpoints/encoder.pth', map_location='cpu'), strict=False)
14
+ for param in self.vgg.parameters():
15
+ param.requires_grad = False
16
+
17
+ self.align1 = PAMA(512)
18
+ self.align2 = PAMA(512)
19
+ self.align3 = PAMA(512)
20
+
21
+ self.decoder = decoder
22
+ self.hist = RGBuvHistBlock(insz=64, h=256,
23
+ intensity_scale=True,
24
+ method='inverse-quadratic',
25
+ device=DEVICE)
26
+
27
+ if args.pretrained == True:
28
+ self.align1.load_state_dict(torch.load('./checkpoints/PAMA1.pth', map_location='cpu'), strict=True)
29
+ self.align2.load_state_dict(torch.load('./checkpoints/PAMA2.pth', map_location='cpu'), strict=True)
30
+ self.align3.load_state_dict(torch.load('./checkpoints/PAMA3.pth', map_location='cpu'), strict=True)
31
+ self.decoder.load_state_dict(torch.load('./checkpoints/decoder.pth', map_location='cpu'), strict=False)
32
+
33
+ if args.requires_grad == False:
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+
38
+ def forward(self, Ic, Is):
39
+ feat_c = self.forward_vgg(Ic)
40
+ feat_s = self.forward_vgg(Is)
41
+ Fc, Fs = feat_c[3], feat_s[3]
42
+
43
+ Fcs1 = self.align1(Fc, Fs)
44
+ Fcs2 = self.align2(Fcs1, Fs)
45
+ Fcs3 = self.align3(Fcs2, Fs)
46
+
47
+ Ics3 = self.decoder(Fcs3)
48
+
49
+ if self.args.training == True:
50
+ Ics1 = self.decoder(Fcs1)
51
+ Ics2 = self.decoder(Fcs2)
52
+ Irc = self.decoder(Fc)
53
+ Irs = self.decoder(Fs)
54
+ feat_cs1 = self.forward_vgg(Ics1)
55
+ feat_cs2 = self.forward_vgg(Ics2)
56
+ feat_cs3 = self.forward_vgg(Ics3)
57
+ feat_rc = self.forward_vgg(Irc)
58
+ feat_rs = self.forward_vgg(Irs)
59
+
60
+ content_loss1, remd_loss1, moment_loss1, color_loss1 = 0.0, 0.0, 0.0, 0.0
61
+ content_loss2, remd_loss2, moment_loss2, color_loss2 = 0.0, 0.0, 0.0, 0.0
62
+ content_loss3, remd_loss3, moment_loss3, color_loss3 = 0.0, 0.0, 0.0, 0.0
63
+ loss_rec = 0.0
64
+
65
+ for l in range(2, 5):
66
+ content_loss1 += self.args.w_content1 * calc_ss_loss(feat_cs1[l], feat_c[l])
67
+ remd_loss1 += self.args.w_remd1 * calc_remd_loss(feat_cs1[l], feat_s[l])
68
+ moment_loss1 += self.args.w_moment1 * calc_moment_loss(feat_cs1[l], feat_s[l])
69
+
70
+ content_loss2 += self.args.w_content2 * calc_ss_loss(feat_cs2[l], feat_c[l])
71
+ remd_loss2 += self.args.w_remd2 * calc_remd_loss(feat_cs2[l], feat_s[l])
72
+ moment_loss2 += self.args.w_moment2 * calc_moment_loss(feat_cs2[l], feat_s[l])
73
+
74
+ content_loss3 += self.args.w_content3 * calc_ss_loss(feat_cs3[l], feat_c[l])
75
+ remd_loss3 += self.args.w_remd3 * calc_remd_loss(feat_cs3[l], feat_s[l])
76
+ moment_loss3 += self.args.w_moment3 * calc_moment_loss(feat_cs3[l], feat_s[l])
77
+
78
+ loss_rec += 0.5 * calc_mse_loss(feat_rc[l], feat_c[l]) + 0.5 * calc_mse_loss(feat_rs[l], feat_s[l])
79
+ loss_rec += 25 * calc_mse_loss(Irc, Ic)
80
+ loss_rec += 25 * calc_mse_loss(Irs, Is)
81
+
82
+ if self.args.color_on:
83
+ color_loss1 += self.args.w_color1 * calc_histogram_loss(Ics1, Is, self.hist)
84
+ color_loss2 += self.args.w_color2 * calc_histogram_loss(Ics2, Is, self.hist)
85
+ color_loss3 += self.args.w_color3 * calc_histogram_loss(Ics3, Is, self.hist)
86
+
87
+ loss1 = (content_loss1+remd_loss1+moment_loss1+color_loss1)/(self.args.w_content1+self.args.w_remd1+self.args.w_moment1+self.args.w_color1)
88
+ loss2 = (content_loss2+remd_loss2+moment_loss2+color_loss2)/(self.args.w_content2+self.args.w_remd2+self.args.w_moment2+self.args.w_color2)
89
+ loss3 = (content_loss3+remd_loss3+moment_loss3+color_loss3)/(self.args.w_content3+self.args.w_remd3+self.args.w_moment3+self.args.w_color3)
90
+ loss = loss1 + loss2 + loss3 + loss_rec
91
+ return loss
92
+ else:
93
+ return Ics3
94
+
95
+ def forward_vgg(self, x):
96
+ relu1_1 = self.vgg[:4](x)
97
+ relu2_1 = self.vgg[4:11](relu1_1)
98
+ relu3_1 = self.vgg[11:18](relu2_1)
99
+ relu4_1 = self.vgg[18:31](relu3_1)
100
+ relu5_1 = self.vgg[31:44](relu4_1)
101
+ return [relu1_1, relu2_1, relu3_1, relu4_1, relu5_1]
102
+
103
+ def save_ckpts(self):
104
+ torch.save(self.align1.state_dict(), "./checkpoints/PAMA1.pth")
105
+ torch.save(self.align2.state_dict(), "./checkpoints/PAMA2.pth")
106
+ torch.save(self.align3.state_dict(), "./checkpoints/PAMA3.pth")
107
+ torch.save(self.decoder.state_dict(), "./checkpoints/decoder.pth")
108
+
109
+ #---------------------------------------------------------------------------------------------------------------
110
+
111
+ vgg19 = nn.Sequential(
112
+ nn.Conv2d(3, 3, (1, 1)),
113
+ nn.ReflectionPad2d((1, 1, 1, 1)),
114
+ nn.Conv2d(3, 64, (3, 3)),
115
+ nn.ReLU(), # relu1-1
116
+ nn.ReflectionPad2d((1, 1, 1, 1)),
117
+ nn.Conv2d(64, 64, (3, 3)),
118
+ nn.ReLU(), # relu1-2
119
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
120
+ nn.ReflectionPad2d((1, 1, 1, 1)),
121
+ nn.Conv2d(64, 128, (3, 3)),
122
+ nn.ReLU(), # relu2-1
123
+ nn.ReflectionPad2d((1, 1, 1, 1)),
124
+ nn.Conv2d(128, 128, (3, 3)),
125
+ nn.ReLU(), # relu2-2
126
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
127
+ nn.ReflectionPad2d((1, 1, 1, 1)),
128
+ nn.Conv2d(128, 256, (3, 3)),
129
+ nn.ReLU(), # relu3-1
130
+ nn.ReflectionPad2d((1, 1, 1, 1)),
131
+ nn.Conv2d(256, 256, (3, 3)),
132
+ nn.ReLU(), # relu3-2
133
+ nn.ReflectionPad2d((1, 1, 1, 1)),
134
+ nn.Conv2d(256, 256, (3, 3)),
135
+ nn.ReLU(), # relu3-3
136
+ nn.ReflectionPad2d((1, 1, 1, 1)),
137
+ nn.Conv2d(256, 256, (3, 3)),
138
+ nn.ReLU(), # relu3-4
139
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
140
+ nn.ReflectionPad2d((1, 1, 1, 1)),
141
+ nn.Conv2d(256, 512, (3, 3)),
142
+ nn.ReLU(), # relu4-1,
143
+ nn.ReflectionPad2d((1, 1, 1, 1)),
144
+ nn.Conv2d(512, 512, (3, 3)),
145
+ nn.ReLU(), # relu4-2
146
+ nn.ReflectionPad2d((1, 1, 1, 1)),
147
+ nn.Conv2d(512, 512, (3, 3)),
148
+ nn.ReLU(), # relu4-3
149
+ nn.ReflectionPad2d((1, 1, 1, 1)),
150
+ nn.Conv2d(512, 512, (3, 3)),
151
+ nn.ReLU(), # relu4-4
152
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
153
+ nn.ReflectionPad2d((1, 1, 1, 1)),
154
+ nn.Conv2d(512, 512, (3, 3)),
155
+ nn.ReLU(), # relu5-1
156
+ nn.ReflectionPad2d((1, 1, 1, 1)),
157
+ nn.Conv2d(512, 512, (3, 3)),
158
+ nn.ReLU(), # relu5-2
159
+ nn.ReflectionPad2d((1, 1, 1, 1)),
160
+ nn.Conv2d(512, 512, (3, 3)),
161
+ nn.ReLU(), # relu5-3
162
+ nn.ReflectionPad2d((1, 1, 1, 1)),
163
+ nn.Conv2d(512, 512, (3, 3)),
164
+ nn.ReLU() # relu5-4
165
+ )
166
+
167
+ #---------------------------------------------------------------------------------------------------------------
168
+
169
+ decoder = nn.Sequential(
170
+ nn.ReflectionPad2d((1, 1, 1, 1)),
171
+ nn.Conv2d(512, 256, (3, 3)),
172
+ nn.ReLU(), #relu4_1
173
+ nn.Upsample(scale_factor=2, mode='nearest'),
174
+ nn.ReflectionPad2d((1, 1, 1, 1)),
175
+ nn.Conv2d(256, 256, (3, 3)),
176
+ nn.ReLU(),
177
+ nn.ReflectionPad2d((1, 1, 1, 1)),
178
+ nn.Conv2d(256, 256, (3, 3)),
179
+ nn.ReLU(),
180
+ nn.ReflectionPad2d((1, 1, 1, 1)),
181
+ nn.Conv2d(256, 256, (3, 3)),
182
+ nn.ReLU(),
183
+ nn.ReflectionPad2d((1, 1, 1, 1)),
184
+ nn.Conv2d(256, 128, (3, 3)),
185
+ nn.ReLU(), #relu3_1
186
+ nn.Upsample(scale_factor=2, mode='nearest'),
187
+ nn.ReflectionPad2d((1, 1, 1, 1)),
188
+ nn.Conv2d(128, 128, (3, 3)),
189
+ nn.ReLU(),
190
+ nn.ReflectionPad2d((1, 1, 1, 1)),
191
+ nn.Conv2d(128, 64, (3, 3)),
192
+ nn.ReLU(), #relu2_1
193
+ nn.Upsample(scale_factor=2, mode='nearest'),
194
+ nn.ReflectionPad2d((1, 1, 1, 1)),
195
+ nn.Conv2d(64, 64, (3, 3)),
196
+ nn.ReLU(), #relu1_1
197
+ nn.ReflectionPad2d((1, 1, 1, 1)),
198
+ nn.Conv2d(64, 3, (3, 3)),
199
+ )
200
+
201
+ #---------------------------------------------------------------------------------------------------------------
202
+
203
+ class AttentionUnit(nn.Module):
204
+ def __init__(self, channels):
205
+ super(AttentionUnit, self).__init__()
206
+ self.relu6 = nn.ReLU6()
207
+ self.f = nn.Conv2d(channels, channels//2, (1, 1))
208
+ self.g = nn.Conv2d(channels, channels//2, (1, 1))
209
+ self.h = nn.Conv2d(channels, channels//2, (1, 1))
210
+
211
+ self.out_conv = nn.Conv2d(channels//2, channels, (1, 1))
212
+ self.softmax = nn.Softmax(dim = -1)
213
+
214
+ def forward(self, Fc, Fs):
215
+ B, C, H, W = Fc.shape
216
+ f_Fc = self.relu6(self.f(mean_variance_norm(Fc)))
217
+ g_Fs = self.relu6(self.g(mean_variance_norm(Fs)))
218
+ h_Fs = self.relu6(self.h(Fs))
219
+ f_Fc = f_Fc.view(f_Fc.shape[0], f_Fc.shape[1], -1).permute(0, 2, 1)
220
+ g_Fs = g_Fs.view(g_Fs.shape[0], g_Fs.shape[1], -1)
221
+
222
+ Attention = self.softmax(torch.bmm(f_Fc, g_Fs))
223
+
224
+ h_Fs = h_Fs.view(h_Fs.shape[0], h_Fs.shape[1], -1)
225
+
226
+ Fcs = torch.bmm(h_Fs, Attention.permute(0, 2, 1))
227
+ Fcs = Fcs.view(B, C//2, H, W)
228
+ Fcs = self.relu6(self.out_conv(Fcs))
229
+
230
+ return Fcs
231
+
232
+ class FuseUnit(nn.Module):
233
+ def __init__(self, channels):
234
+ super(FuseUnit, self).__init__()
235
+ self.proj1 = nn.Conv2d(2*channels, channels, (1, 1))
236
+ self.proj2 = nn.Conv2d(channels, channels, (1, 1))
237
+ self.proj3 = nn.Conv2d(channels, channels, (1, 1))
238
+
239
+ self.fuse1x = nn.Conv2d(channels, 1, (1, 1), stride = 1)
240
+ self.fuse3x = nn.Conv2d(channels, 1, (3, 3), stride = 1)
241
+ self.fuse5x = nn.Conv2d(channels, 1, (5, 5), stride = 1)
242
+
243
+ self.pad3x = nn.ReflectionPad2d((1, 1, 1, 1))
244
+ self.pad5x = nn.ReflectionPad2d((2, 2, 2, 2))
245
+ self.sigmoid = nn.Sigmoid()
246
+
247
+ def forward(self, F1, F2):
248
+ Fcat = self.proj1(torch.cat((F1, F2), dim=1))
249
+ F1 = self.proj2(F1)
250
+ F2 = self.proj3(F2)
251
+
252
+ fusion1 = self.sigmoid(self.fuse1x(Fcat))
253
+ fusion3 = self.sigmoid(self.fuse3x(self.pad3x(Fcat)))
254
+ fusion5 = self.sigmoid(self.fuse5x(self.pad5x(Fcat)))
255
+ fusion = (fusion1 + fusion3 + fusion5) / 3
256
+
257
+ return torch.clamp(fusion, min=0, max=1.0)*F1 + torch.clamp(1 - fusion, min=0, max=1.0)*F2
258
+
259
+ class PAMA(nn.Module):
260
+ def __init__(self, channels):
261
+ super(PAMA, self).__init__()
262
+ self.conv_in = nn.Conv2d(channels, channels, (3, 3), stride=1)
263
+ self.attn = AttentionUnit(channels)
264
+ self.fuse = FuseUnit(channels)
265
+ self.conv_out = nn.Conv2d(channels, channels, (3, 3), stride=1)
266
+
267
+ self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
268
+ self.relu6 = nn.ReLU6()
269
+
270
+ def forward(self, Fc, Fs):
271
+ Fc = self.relu6(self.conv_in(self.pad(Fc)))
272
+ Fs = self.relu6(self.conv_in(self.pad(Fs)))
273
+ Fcs = self.attn(Fc, Fs)
274
+ Fcs = self.relu6(self.conv_out(self.pad(Fcs)))
275
+ Fcs = self.fuse(Fc, Fcs)
276
+
277
+ return Fcs
278
+
279
+ #---------------------------------------------------------------------------------------------------------------
280
+
281
+
style_trsfer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ from PIL import Image, ImageFile
5
+ from net import Net
6
+ from utils import DEVICE, test_transform
7
+ Image.MAX_IMAGE_PIXELS = None
8
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
9
+
10
+
11
+
12
+ def style_transfer_method(content_image,style_img):
13
+ main_parser = argparse.ArgumentParser(description="main parser")
14
+ subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand")
15
+
16
+ main_parser.add_argument("--pretrained", type=bool, default=True,
17
+ help="whether to use the pre-trained checkpoints")
18
+ main_parser.add_argument("--requires_grad", type=bool, default=True,
19
+ help="set to True if the model requires model gradient")
20
+
21
+ train_parser = subparsers.add_parser("train", help="training mode parser")
22
+ train_parser.add_argument("--training", type=bool, default=True)
23
+ train_parser.add_argument("--iterations", type=int, default=60000,
24
+ help="total training epochs (default: 160000)")
25
+ train_parser.add_argument("--batch_size", type=int, default=2,
26
+ help="training batch size (default: 8)")
27
+ train_parser.add_argument("--num_workers", type=int, default=2,
28
+ help="iterator threads (default: 8)")
29
+ train_parser.add_argument("--lr", type=float, default=1e-4, help="the learning rate during training (default: 1e-4)")
30
+ train_parser.add_argument("--content_folder", type=str, required = True,
31
+ help="the root of content images, the path should point to a folder")
32
+ train_parser.add_argument("--style_folder", type=str, required = True,
33
+ help="the root of style images, the path should point to a folder")
34
+ train_parser.add_argument("--log_interval", type=int, default=10000,
35
+ help="number of images after which the training loss is logged (default: 20000)")
36
+
37
+ train_parser.add_argument("--w_content1", type=float, default=12, help="the stage1 content loss weight")
38
+ train_parser.add_argument("--w_content2", type=float, default=9, help="the stage2 content loss weight")
39
+ train_parser.add_argument("--w_content3", type=float, default=7, help="the stage3 content loss weight")
40
+ train_parser.add_argument("--w_remd1", type=float, default=2, help="the stage1 remd loss weight")
41
+ train_parser.add_argument("--w_remd2", type=float, default=2, help="the stage2 remd loss weight")
42
+ train_parser.add_argument("--w_remd3", type=float, default=2, help="the stage3 remd loss weight")
43
+ train_parser.add_argument("--w_moment1", type=float, default=2, help="the stage1 moment loss weight")
44
+ train_parser.add_argument("--w_moment2", type=float, default=2, help="the stage2 moment loss weight")
45
+ train_parser.add_argument("--w_moment3", type=float, default=2, help="the stage3 moment loss weight")
46
+ train_parser.add_argument("--color_on", type=str, default=True, help="turn on the color loss")
47
+ train_parser.add_argument("--w_color1", type=float, default=0.25, help="the stage1 color loss weight")
48
+ train_parser.add_argument("--w_color2", type=float, default=0.5, help="the stage2 color loss weight")
49
+ train_parser.add_argument("--w_color3", type=float, default=1, help="the stage3 color loss weight")
50
+
51
+
52
+ eval_parser = subparsers.add_parser("eval", help="evaluation mode parser")
53
+ eval_parser.add_argument("--training", type=bool, default=False)
54
+ eval_parser.add_argument("--run_folder", type=bool, default=False)
55
+
56
+ args = main_parser.parse_args()
57
+
58
+ args.training = False
59
+
60
+ model = Net(args)
61
+ model.eval()
62
+ model = model.to(DEVICE)
63
+
64
+ tf = test_transform()
65
+
66
+ Ic = tf(content_image).to(DEVICE)
67
+ Is = tf(Image.fromarray(style_img)).to(DEVICE)
68
+
69
+ Ic = Ic.unsqueeze(dim=0)
70
+ Is = Is.unsqueeze(dim=0)
71
+
72
+ with torch.no_grad():
73
+ Ics = model(Ic, Is)
74
+
75
+ grid = make_grid(Ics[0])
76
+ # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
77
+ ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
78
+ im = Image.fromarray(ndarr)
79
+
80
+ return im
utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.data as data
7
+ from torchvision import transforms
8
+ import PIL.Image as Image
9
+
10
+ DEVICE = 'cuda'
11
+ mse = nn.MSELoss()
12
+
13
+
14
+ def calc_histogram_loss(A, B, histogram_block):
15
+ input_hist = histogram_block(A)
16
+ target_hist = histogram_block(B)
17
+ histogram_loss = (1/np.sqrt(2.0) * (torch.sqrt(torch.sum(
18
+ torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) /
19
+ input_hist.shape[0])
20
+
21
+ return histogram_loss
22
+
23
+ # B, C, H, W; mean var on HW
24
+ def calc_mean_std(feat, eps=1e-5):
25
+ # eps is a small value added to the variance to avoid divide-by-zero.
26
+ size = feat.size()
27
+ assert (len(size) == 4)
28
+ N, C = size[:2]
29
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
30
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
31
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
32
+ return feat_mean, feat_std
33
+
34
+ def mean_variance_norm(feat):
35
+ size = feat.size()
36
+ mean, std = calc_mean_std(feat)
37
+ normalized_feat = (feat - mean.expand(size)) / std.expand(size)
38
+ return normalized_feat
39
+
40
+ def train_transform():
41
+ transform_list = [
42
+ transforms.Resize(size=512),
43
+ transforms.RandomCrop(256),
44
+ transforms.ToTensor()
45
+ ]
46
+ return transforms.Compose(transform_list)
47
+
48
+ def test_transform():
49
+ transform_list = []
50
+ transform_list.append(transforms.Resize(size=(512)))
51
+ transform_list.append(transforms.ToTensor())
52
+ transform = transforms.Compose(transform_list)
53
+ return transform
54
+
55
+ # https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/7
56
+ def plot_grad_flow(named_parameters):
57
+ '''Plots the gradients flowing through different layers in the net during training.
58
+ Can be used for checking for possible gradient vanishing / exploding problems.
59
+
60
+ Usage: Plug this function in Trainer class after loss.backwards() as
61
+ "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
62
+ ave_grads = []
63
+ max_grads= []
64
+ layers = []
65
+ for n, p in named_parameters:
66
+ if(p.requires_grad) and ("bias" not in n):
67
+ layers.append(n)
68
+ ave_grads.append(p.grad.abs().mean())
69
+ max_grads.append(p.grad.abs().max())
70
+ print('-'*82)
71
+ print(n, p.grad.abs().mean(), p.grad.abs().max())
72
+ print('-'*82)
73
+
74
+ def InfiniteSampler(n):
75
+ # i = 0
76
+ i = n - 1
77
+ order = np.random.permutation(n)
78
+ while True:
79
+ yield order[i]
80
+ i += 1
81
+ if i >= n:
82
+ np.random.seed()
83
+ order = np.random.permutation(n)
84
+ i = 0
85
+
86
+ class InfiniteSamplerWrapper(data.sampler.Sampler):
87
+ def __init__(self, data_source):
88
+ self.num_samples = len(data_source)
89
+
90
+ def __iter__(self):
91
+ return iter(InfiniteSampler(self.num_samples))
92
+
93
+ def __len__(self):
94
+ return 2 ** 31
95
+
96
+ class FlatFolderDataset(data.Dataset):
97
+ def __init__(self, root, transform):
98
+ super(FlatFolderDataset, self).__init__()
99
+ self.root = root
100
+ self.paths = os.listdir(self.root)
101
+ self.transform = transform
102
+
103
+ def __getitem__(self, index):
104
+ path = self.paths[index]
105
+ img = Image.open(os.path.join(self.root, path)).convert('RGB')
106
+ img = self.transform(img)
107
+ return img
108
+
109
+ def __len__(self):
110
+ return len(self.paths)
111
+
112
+ def name(self):
113
+ return 'FlatFolderDataset'
114
+
115
+ def adjust_learning_rate(optimizer, iteration_count, args):
116
+ """Imitating the original implementation"""
117
+ lr = args.lr / (1.0 + 5e-5 * iteration_count)
118
+ for param_group in optimizer.param_groups:
119
+ param_group['lr'] = lr
120
+
121
+ def cosine_dismat(A, B):
122
+ A = A.view(A.shape[0], A.shape[1], -1)
123
+ B = B.view(B.shape[0], B.shape[1], -1)
124
+
125
+ A_norm = torch.sqrt((A**2).sum(1))
126
+ B_norm = torch.sqrt((B**2).sum(1))
127
+
128
+ A = (A/A_norm.unsqueeze(dim=1).expand(A.shape)).permute(0,2,1)
129
+ B = (B/B_norm.unsqueeze(dim=1).expand(B.shape))
130
+ dismat = 1.-torch.bmm(A, B)
131
+
132
+ return dismat
133
+
134
+ def calc_remd_loss(A, B):
135
+ C = cosine_dismat(A, B)
136
+ m1, _ = C.min(1)
137
+ m2, _ = C.min(2)
138
+
139
+ remd = torch.max(m1.mean(), m2.mean())
140
+
141
+ return remd
142
+
143
+ def calc_ss_loss(A, B):
144
+ MA = cosine_dismat(A, A)
145
+ MB = cosine_dismat(B, B)
146
+ Lself_similarity = torch.abs(MA-MB).mean()
147
+
148
+ return Lself_similarity
149
+
150
+ def calc_moment_loss(A, B):
151
+ A = A.view(A.shape[0], A.shape[1], -1)
152
+ B = B.view(B.shape[0], B.shape[1], -1)
153
+
154
+ mu_a = torch.mean(A, 1, keepdim=True)
155
+ mu_b = torch.mean(B, 1, keepdim=True)
156
+ mu_d = torch.abs(mu_a - mu_b).mean()
157
+
158
+ A_c = A - mu_a
159
+ B_c = B - mu_b
160
+ cov_a = torch.bmm(A_c, A_c.permute(0,2,1)) / (A.shape[2]-1)
161
+ cov_b = torch.bmm(B_c, B_c.permute(0,2,1)) / (B.shape[2]-1)
162
+ cov_d = torch.abs(cov_a - cov_b).mean()
163
+ loss = mu_d + cov_d
164
+ return loss
165
+
166
+ def calc_mse_loss(A, B):
167
+ return mse(A, B)
168
+