Upload 6 files
Browse files- app.py +75 -0
- hist_loss.py +208 -0
- ics.jpg +0 -0
- net.py +281 -0
- style_trsfer.py +80 -0
- utils.py +168 -0
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline,UniPCMultistepScheduler
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import gc
|
5 |
+
from style_trsfer import style_transfer_method
|
6 |
+
|
7 |
+
|
8 |
+
def generate(style_image,text, negative_prompts,steps,guidance_scale):
|
9 |
+
pipeline = DiffusionPipeline.from_pretrained("./CCLAP")
|
10 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
11 |
+
pipeline.scheduler.config)
|
12 |
+
device = torch.device(
|
13 |
+
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
14 |
+
if device.type == 'cuda':
|
15 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
16 |
+
pipeline.to(device)
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
gc.collect()
|
19 |
+
content_image = pipeline(text,
|
20 |
+
num_inference_steps=steps,
|
21 |
+
negative_prompt=negative_prompts,
|
22 |
+
guidance_scale=guidance_scale).images[0]
|
23 |
+
result = style_transfer_method(content_image,style_image)
|
24 |
+
return content_image,result
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
|
29 |
+
demo = gr.Interface(title="CCLAP",
|
30 |
+
description = (
|
31 |
+
"This is the demo of CCLAP to generate Chinese landscape painting."
|
32 |
+
),
|
33 |
+
css="",
|
34 |
+
fn=generate,
|
35 |
+
inputs=[gr.Image(label="Style Image",shape=(512,512)),
|
36 |
+
gr.Textbox(lines=3, placeholder="Input the prompt", label="Prompt"),
|
37 |
+
gr.Textbox(lines=3, placeholder="low quality", label="Negative prompt"),
|
38 |
+
gr.Slider(minimum=0, maximum=100, value=20,label='Steps'),
|
39 |
+
gr.Slider(minimum=0, maximum=30, value=7.5,label='Guidance_scale'),
|
40 |
+
],
|
41 |
+
outputs=[gr.Image(label="Content Output",shape=(256,256)),
|
42 |
+
gr.Image(label="Final Output",shape=(256,256))],
|
43 |
+
examples = [
|
44 |
+
[
|
45 |
+
'style_image/style1.jpg',
|
46 |
+
'A Chinese landscape painting of a mountain landscape with trees',
|
47 |
+
'low quality',
|
48 |
+
20,
|
49 |
+
7.5
|
50 |
+
],
|
51 |
+
[
|
52 |
+
'style_image/style2.jpg',
|
53 |
+
'A Chinese landscape painting of a building with trees in front of it',
|
54 |
+
'low quality',
|
55 |
+
20,
|
56 |
+
7.5
|
57 |
+
],
|
58 |
+
[
|
59 |
+
'style_image/style3.jpg',
|
60 |
+
'A Chinese landscape painting of a landscape with mountains in the background',
|
61 |
+
'low quality',
|
62 |
+
20,
|
63 |
+
7.5
|
64 |
+
],
|
65 |
+
[
|
66 |
+
'style_image/style4.jpg',
|
67 |
+
'A Chinese landscape painting of a landscape with mountains and a river',
|
68 |
+
'low quality',
|
69 |
+
20,
|
70 |
+
7.5
|
71 |
+
],
|
72 |
+
],
|
73 |
+
)
|
74 |
+
|
75 |
+
demo.launch()
|
hist_loss.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright 2021 Mahmoud Afifi.
|
3 |
+
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
|
4 |
+
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
|
5 |
+
In CVPR, 2021.
|
6 |
+
|
7 |
+
@inproceedings{afifi2021histogan,
|
8 |
+
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
|
9 |
+
Color Histograms},
|
10 |
+
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
|
11 |
+
booktitle={CVPR},
|
12 |
+
year={2021}
|
13 |
+
}
|
14 |
+
"""
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from PIL import Image
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
EPS = 1e-6
|
25 |
+
|
26 |
+
class RGBuvHistBlock(nn.Module):
|
27 |
+
def __init__(self, h=64, insz=150, resizing='interpolation',
|
28 |
+
method='inverse-quadratic', sigma=0.02, intensity_scale=True,
|
29 |
+
device='cuda'):
|
30 |
+
""" Computes the RGB-uv histogram feature of a given image.
|
31 |
+
Args:
|
32 |
+
h: histogram dimension size (scalar). The default value is 64.
|
33 |
+
insz: maximum size of the input image; if it is larger than this size, the
|
34 |
+
image will be resized (scalar). Default value is 150 (i.e., 150 x 150
|
35 |
+
pixels).
|
36 |
+
resizing: resizing method if applicable. Options are: 'interpolation' or
|
37 |
+
'sampling'. Default is 'interpolation'.
|
38 |
+
method: the method used to count the number of pixels for each bin in the
|
39 |
+
histogram feature. Options are: 'thresholding', 'RBF' (radial basis
|
40 |
+
function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'.
|
41 |
+
sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is
|
42 |
+
the sigma parameter of the kernel function. The default value is 0.02.
|
43 |
+
intensity_scale: boolean variable to use the intensity scale (I_y in
|
44 |
+
Equation 2). Default value is True.
|
45 |
+
|
46 |
+
Methods:
|
47 |
+
forward: accepts input image and returns its histogram feature. Note that
|
48 |
+
unless the method is 'thresholding', this is a differentiable function
|
49 |
+
and can be easily integrated with the loss function. As mentioned in the
|
50 |
+
paper, the 'inverse-quadratic' was found more stable than 'RBF' in our
|
51 |
+
training.
|
52 |
+
"""
|
53 |
+
super(RGBuvHistBlock, self).__init__()
|
54 |
+
self.h = h
|
55 |
+
self.insz = insz
|
56 |
+
self.device = device
|
57 |
+
self.resizing = resizing
|
58 |
+
self.method = method
|
59 |
+
self.intensity_scale = intensity_scale
|
60 |
+
if self.method == 'thresholding':
|
61 |
+
self.eps = 6.0 / h
|
62 |
+
else:
|
63 |
+
self.sigma = sigma
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = torch.clamp(x, 0, 1)
|
67 |
+
if x.shape[2] > self.insz or x.shape[3] > self.insz:
|
68 |
+
if self.resizing == 'interpolation':
|
69 |
+
x_sampled = F.interpolate(x, size=(self.insz, self.insz),
|
70 |
+
mode='bilinear', align_corners=False)
|
71 |
+
elif self.resizing == 'sampling':
|
72 |
+
inds_1 = torch.LongTensor(
|
73 |
+
np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
|
74 |
+
device=self.device)
|
75 |
+
inds_2 = torch.LongTensor(
|
76 |
+
np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
|
77 |
+
device=self.device)
|
78 |
+
x_sampled = x.index_select(2, inds_1)
|
79 |
+
x_sampled = x_sampled.index_select(3, inds_2)
|
80 |
+
else:
|
81 |
+
raise Exception(
|
82 |
+
f'Wrong resizing method. It should be: interpolation or sampling. '
|
83 |
+
f'But the given value is {self.resizing}.')
|
84 |
+
else:
|
85 |
+
x_sampled = x
|
86 |
+
|
87 |
+
L = x_sampled.shape[0] # size of mini-batch
|
88 |
+
if x_sampled.shape[1] > 3:
|
89 |
+
x_sampled = x_sampled[:, :3, :, :]
|
90 |
+
X = torch.unbind(x_sampled, dim=0)
|
91 |
+
hists = torch.zeros((x_sampled.shape[0], 3, self.h, self.h)).to(
|
92 |
+
device=self.device)
|
93 |
+
for l in range(L):
|
94 |
+
I = torch.t(torch.reshape(X[l], (3, -1)))
|
95 |
+
II = torch.pow(I, 2)
|
96 |
+
if self.intensity_scale:
|
97 |
+
Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
|
98 |
+
dim=1)
|
99 |
+
else:
|
100 |
+
Iy = 1
|
101 |
+
|
102 |
+
Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + EPS),
|
103 |
+
dim=1)
|
104 |
+
Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + EPS),
|
105 |
+
dim=1)
|
106 |
+
diff_u0 = abs(
|
107 |
+
Iu0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
108 |
+
dim=0).to(self.device))
|
109 |
+
diff_v0 = abs(
|
110 |
+
Iv0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
111 |
+
dim=0).to(self.device))
|
112 |
+
if self.method == 'thresholding':
|
113 |
+
diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
|
114 |
+
diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
|
115 |
+
elif self.method == 'RBF':
|
116 |
+
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
|
117 |
+
2) / self.sigma ** 2
|
118 |
+
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
|
119 |
+
2) / self.sigma ** 2
|
120 |
+
diff_u0 = torch.exp(-diff_u0) # Radial basis function
|
121 |
+
diff_v0 = torch.exp(-diff_v0)
|
122 |
+
elif self.method == 'inverse-quadratic':
|
123 |
+
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
|
124 |
+
2) / self.sigma ** 2
|
125 |
+
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
|
126 |
+
2) / self.sigma ** 2
|
127 |
+
diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic
|
128 |
+
diff_v0 = 1 / (1 + diff_v0)
|
129 |
+
else:
|
130 |
+
raise Exception(
|
131 |
+
f'Wrong kernel method. It should be either thresholding, RBF,'
|
132 |
+
f' inverse-quadratic. But the given value is {self.method}.')
|
133 |
+
diff_u0 = diff_u0.type(torch.float32)
|
134 |
+
diff_v0 = diff_v0.type(torch.float32)
|
135 |
+
a = torch.t(Iy * diff_u0)
|
136 |
+
hists[l, 0, :, :] = torch.mm(a, diff_v0)
|
137 |
+
|
138 |
+
Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
|
139 |
+
dim=1)
|
140 |
+
Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
|
141 |
+
dim=1)
|
142 |
+
diff_u1 = abs(
|
143 |
+
Iu1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
144 |
+
dim=0).to(self.device))
|
145 |
+
diff_v1 = abs(
|
146 |
+
Iv1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
147 |
+
dim=0).to(self.device))
|
148 |
+
|
149 |
+
if self.method == 'thresholding':
|
150 |
+
diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
|
151 |
+
diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
|
152 |
+
elif self.method == 'RBF':
|
153 |
+
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
|
154 |
+
2) / self.sigma ** 2
|
155 |
+
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
|
156 |
+
2) / self.sigma ** 2
|
157 |
+
diff_u1 = torch.exp(-diff_u1) # Gaussian
|
158 |
+
diff_v1 = torch.exp(-diff_v1)
|
159 |
+
elif self.method == 'inverse-quadratic':
|
160 |
+
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
|
161 |
+
2) / self.sigma ** 2
|
162 |
+
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
|
163 |
+
2) / self.sigma ** 2
|
164 |
+
diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic
|
165 |
+
diff_v1 = 1 / (1 + diff_v1)
|
166 |
+
|
167 |
+
diff_u1 = diff_u1.type(torch.float32)
|
168 |
+
diff_v1 = diff_v1.type(torch.float32)
|
169 |
+
a = torch.t(Iy * diff_u1)
|
170 |
+
hists[l, 1, :, :] = torch.mm(a, diff_v1)
|
171 |
+
|
172 |
+
Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + EPS),
|
173 |
+
dim=1)
|
174 |
+
Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + EPS),
|
175 |
+
dim=1)
|
176 |
+
diff_u2 = abs(
|
177 |
+
Iu2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
178 |
+
dim=0).to(self.device))
|
179 |
+
diff_v2 = abs(
|
180 |
+
Iv2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
181 |
+
dim=0).to(self.device))
|
182 |
+
if self.method == 'thresholding':
|
183 |
+
diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
|
184 |
+
diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
|
185 |
+
elif self.method == 'RBF':
|
186 |
+
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
|
187 |
+
2) / self.sigma ** 2
|
188 |
+
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
|
189 |
+
2) / self.sigma ** 2
|
190 |
+
diff_u2 = torch.exp(-diff_u2) # Gaussian
|
191 |
+
diff_v2 = torch.exp(-diff_v2)
|
192 |
+
elif self.method == 'inverse-quadratic':
|
193 |
+
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
|
194 |
+
2) / self.sigma ** 2
|
195 |
+
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
|
196 |
+
2) / self.sigma ** 2
|
197 |
+
diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic
|
198 |
+
diff_v2 = 1 / (1 + diff_v2)
|
199 |
+
diff_u2 = diff_u2.type(torch.float32)
|
200 |
+
diff_v2 = diff_v2.type(torch.float32)
|
201 |
+
a = torch.t(Iy * diff_u2)
|
202 |
+
hists[l, 2, :, :] = torch.mm(a, diff_v2)
|
203 |
+
|
204 |
+
# normalization
|
205 |
+
hists_normalized = hists / (
|
206 |
+
((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)
|
207 |
+
|
208 |
+
return hists_normalized
|
ics.jpg
ADDED
![]() |
net.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from utils import mean_variance_norm, DEVICE
|
4 |
+
from utils import calc_ss_loss, calc_remd_loss, calc_moment_loss, calc_mse_loss, calc_histogram_loss
|
5 |
+
from hist_loss import RGBuvHistBlock
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class Net(nn.Module):
|
9 |
+
def __init__(self, args):
|
10 |
+
super(Net, self).__init__()
|
11 |
+
self.args = args
|
12 |
+
self.vgg = vgg19[:44]
|
13 |
+
self.vgg.load_state_dict(torch.load('./checkpoints/encoder.pth', map_location='cpu'), strict=False)
|
14 |
+
for param in self.vgg.parameters():
|
15 |
+
param.requires_grad = False
|
16 |
+
|
17 |
+
self.align1 = PAMA(512)
|
18 |
+
self.align2 = PAMA(512)
|
19 |
+
self.align3 = PAMA(512)
|
20 |
+
|
21 |
+
self.decoder = decoder
|
22 |
+
self.hist = RGBuvHistBlock(insz=64, h=256,
|
23 |
+
intensity_scale=True,
|
24 |
+
method='inverse-quadratic',
|
25 |
+
device=DEVICE)
|
26 |
+
|
27 |
+
if args.pretrained == True:
|
28 |
+
self.align1.load_state_dict(torch.load('./checkpoints/PAMA1.pth', map_location='cpu'), strict=True)
|
29 |
+
self.align2.load_state_dict(torch.load('./checkpoints/PAMA2.pth', map_location='cpu'), strict=True)
|
30 |
+
self.align3.load_state_dict(torch.load('./checkpoints/PAMA3.pth', map_location='cpu'), strict=True)
|
31 |
+
self.decoder.load_state_dict(torch.load('./checkpoints/decoder.pth', map_location='cpu'), strict=False)
|
32 |
+
|
33 |
+
if args.requires_grad == False:
|
34 |
+
for param in self.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
|
37 |
+
|
38 |
+
def forward(self, Ic, Is):
|
39 |
+
feat_c = self.forward_vgg(Ic)
|
40 |
+
feat_s = self.forward_vgg(Is)
|
41 |
+
Fc, Fs = feat_c[3], feat_s[3]
|
42 |
+
|
43 |
+
Fcs1 = self.align1(Fc, Fs)
|
44 |
+
Fcs2 = self.align2(Fcs1, Fs)
|
45 |
+
Fcs3 = self.align3(Fcs2, Fs)
|
46 |
+
|
47 |
+
Ics3 = self.decoder(Fcs3)
|
48 |
+
|
49 |
+
if self.args.training == True:
|
50 |
+
Ics1 = self.decoder(Fcs1)
|
51 |
+
Ics2 = self.decoder(Fcs2)
|
52 |
+
Irc = self.decoder(Fc)
|
53 |
+
Irs = self.decoder(Fs)
|
54 |
+
feat_cs1 = self.forward_vgg(Ics1)
|
55 |
+
feat_cs2 = self.forward_vgg(Ics2)
|
56 |
+
feat_cs3 = self.forward_vgg(Ics3)
|
57 |
+
feat_rc = self.forward_vgg(Irc)
|
58 |
+
feat_rs = self.forward_vgg(Irs)
|
59 |
+
|
60 |
+
content_loss1, remd_loss1, moment_loss1, color_loss1 = 0.0, 0.0, 0.0, 0.0
|
61 |
+
content_loss2, remd_loss2, moment_loss2, color_loss2 = 0.0, 0.0, 0.0, 0.0
|
62 |
+
content_loss3, remd_loss3, moment_loss3, color_loss3 = 0.0, 0.0, 0.0, 0.0
|
63 |
+
loss_rec = 0.0
|
64 |
+
|
65 |
+
for l in range(2, 5):
|
66 |
+
content_loss1 += self.args.w_content1 * calc_ss_loss(feat_cs1[l], feat_c[l])
|
67 |
+
remd_loss1 += self.args.w_remd1 * calc_remd_loss(feat_cs1[l], feat_s[l])
|
68 |
+
moment_loss1 += self.args.w_moment1 * calc_moment_loss(feat_cs1[l], feat_s[l])
|
69 |
+
|
70 |
+
content_loss2 += self.args.w_content2 * calc_ss_loss(feat_cs2[l], feat_c[l])
|
71 |
+
remd_loss2 += self.args.w_remd2 * calc_remd_loss(feat_cs2[l], feat_s[l])
|
72 |
+
moment_loss2 += self.args.w_moment2 * calc_moment_loss(feat_cs2[l], feat_s[l])
|
73 |
+
|
74 |
+
content_loss3 += self.args.w_content3 * calc_ss_loss(feat_cs3[l], feat_c[l])
|
75 |
+
remd_loss3 += self.args.w_remd3 * calc_remd_loss(feat_cs3[l], feat_s[l])
|
76 |
+
moment_loss3 += self.args.w_moment3 * calc_moment_loss(feat_cs3[l], feat_s[l])
|
77 |
+
|
78 |
+
loss_rec += 0.5 * calc_mse_loss(feat_rc[l], feat_c[l]) + 0.5 * calc_mse_loss(feat_rs[l], feat_s[l])
|
79 |
+
loss_rec += 25 * calc_mse_loss(Irc, Ic)
|
80 |
+
loss_rec += 25 * calc_mse_loss(Irs, Is)
|
81 |
+
|
82 |
+
if self.args.color_on:
|
83 |
+
color_loss1 += self.args.w_color1 * calc_histogram_loss(Ics1, Is, self.hist)
|
84 |
+
color_loss2 += self.args.w_color2 * calc_histogram_loss(Ics2, Is, self.hist)
|
85 |
+
color_loss3 += self.args.w_color3 * calc_histogram_loss(Ics3, Is, self.hist)
|
86 |
+
|
87 |
+
loss1 = (content_loss1+remd_loss1+moment_loss1+color_loss1)/(self.args.w_content1+self.args.w_remd1+self.args.w_moment1+self.args.w_color1)
|
88 |
+
loss2 = (content_loss2+remd_loss2+moment_loss2+color_loss2)/(self.args.w_content2+self.args.w_remd2+self.args.w_moment2+self.args.w_color2)
|
89 |
+
loss3 = (content_loss3+remd_loss3+moment_loss3+color_loss3)/(self.args.w_content3+self.args.w_remd3+self.args.w_moment3+self.args.w_color3)
|
90 |
+
loss = loss1 + loss2 + loss3 + loss_rec
|
91 |
+
return loss
|
92 |
+
else:
|
93 |
+
return Ics3
|
94 |
+
|
95 |
+
def forward_vgg(self, x):
|
96 |
+
relu1_1 = self.vgg[:4](x)
|
97 |
+
relu2_1 = self.vgg[4:11](relu1_1)
|
98 |
+
relu3_1 = self.vgg[11:18](relu2_1)
|
99 |
+
relu4_1 = self.vgg[18:31](relu3_1)
|
100 |
+
relu5_1 = self.vgg[31:44](relu4_1)
|
101 |
+
return [relu1_1, relu2_1, relu3_1, relu4_1, relu5_1]
|
102 |
+
|
103 |
+
def save_ckpts(self):
|
104 |
+
torch.save(self.align1.state_dict(), "./checkpoints/PAMA1.pth")
|
105 |
+
torch.save(self.align2.state_dict(), "./checkpoints/PAMA2.pth")
|
106 |
+
torch.save(self.align3.state_dict(), "./checkpoints/PAMA3.pth")
|
107 |
+
torch.save(self.decoder.state_dict(), "./checkpoints/decoder.pth")
|
108 |
+
|
109 |
+
#---------------------------------------------------------------------------------------------------------------
|
110 |
+
|
111 |
+
vgg19 = nn.Sequential(
|
112 |
+
nn.Conv2d(3, 3, (1, 1)),
|
113 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
114 |
+
nn.Conv2d(3, 64, (3, 3)),
|
115 |
+
nn.ReLU(), # relu1-1
|
116 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
117 |
+
nn.Conv2d(64, 64, (3, 3)),
|
118 |
+
nn.ReLU(), # relu1-2
|
119 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
120 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
121 |
+
nn.Conv2d(64, 128, (3, 3)),
|
122 |
+
nn.ReLU(), # relu2-1
|
123 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
124 |
+
nn.Conv2d(128, 128, (3, 3)),
|
125 |
+
nn.ReLU(), # relu2-2
|
126 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
127 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
128 |
+
nn.Conv2d(128, 256, (3, 3)),
|
129 |
+
nn.ReLU(), # relu3-1
|
130 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
131 |
+
nn.Conv2d(256, 256, (3, 3)),
|
132 |
+
nn.ReLU(), # relu3-2
|
133 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
134 |
+
nn.Conv2d(256, 256, (3, 3)),
|
135 |
+
nn.ReLU(), # relu3-3
|
136 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
137 |
+
nn.Conv2d(256, 256, (3, 3)),
|
138 |
+
nn.ReLU(), # relu3-4
|
139 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
140 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
141 |
+
nn.Conv2d(256, 512, (3, 3)),
|
142 |
+
nn.ReLU(), # relu4-1,
|
143 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
144 |
+
nn.Conv2d(512, 512, (3, 3)),
|
145 |
+
nn.ReLU(), # relu4-2
|
146 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
147 |
+
nn.Conv2d(512, 512, (3, 3)),
|
148 |
+
nn.ReLU(), # relu4-3
|
149 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
150 |
+
nn.Conv2d(512, 512, (3, 3)),
|
151 |
+
nn.ReLU(), # relu4-4
|
152 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
153 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
154 |
+
nn.Conv2d(512, 512, (3, 3)),
|
155 |
+
nn.ReLU(), # relu5-1
|
156 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
157 |
+
nn.Conv2d(512, 512, (3, 3)),
|
158 |
+
nn.ReLU(), # relu5-2
|
159 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
160 |
+
nn.Conv2d(512, 512, (3, 3)),
|
161 |
+
nn.ReLU(), # relu5-3
|
162 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
163 |
+
nn.Conv2d(512, 512, (3, 3)),
|
164 |
+
nn.ReLU() # relu5-4
|
165 |
+
)
|
166 |
+
|
167 |
+
#---------------------------------------------------------------------------------------------------------------
|
168 |
+
|
169 |
+
decoder = nn.Sequential(
|
170 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
171 |
+
nn.Conv2d(512, 256, (3, 3)),
|
172 |
+
nn.ReLU(), #relu4_1
|
173 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
174 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
175 |
+
nn.Conv2d(256, 256, (3, 3)),
|
176 |
+
nn.ReLU(),
|
177 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
178 |
+
nn.Conv2d(256, 256, (3, 3)),
|
179 |
+
nn.ReLU(),
|
180 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
181 |
+
nn.Conv2d(256, 256, (3, 3)),
|
182 |
+
nn.ReLU(),
|
183 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
184 |
+
nn.Conv2d(256, 128, (3, 3)),
|
185 |
+
nn.ReLU(), #relu3_1
|
186 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
187 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
188 |
+
nn.Conv2d(128, 128, (3, 3)),
|
189 |
+
nn.ReLU(),
|
190 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
191 |
+
nn.Conv2d(128, 64, (3, 3)),
|
192 |
+
nn.ReLU(), #relu2_1
|
193 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
194 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
195 |
+
nn.Conv2d(64, 64, (3, 3)),
|
196 |
+
nn.ReLU(), #relu1_1
|
197 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
198 |
+
nn.Conv2d(64, 3, (3, 3)),
|
199 |
+
)
|
200 |
+
|
201 |
+
#---------------------------------------------------------------------------------------------------------------
|
202 |
+
|
203 |
+
class AttentionUnit(nn.Module):
|
204 |
+
def __init__(self, channels):
|
205 |
+
super(AttentionUnit, self).__init__()
|
206 |
+
self.relu6 = nn.ReLU6()
|
207 |
+
self.f = nn.Conv2d(channels, channels//2, (1, 1))
|
208 |
+
self.g = nn.Conv2d(channels, channels//2, (1, 1))
|
209 |
+
self.h = nn.Conv2d(channels, channels//2, (1, 1))
|
210 |
+
|
211 |
+
self.out_conv = nn.Conv2d(channels//2, channels, (1, 1))
|
212 |
+
self.softmax = nn.Softmax(dim = -1)
|
213 |
+
|
214 |
+
def forward(self, Fc, Fs):
|
215 |
+
B, C, H, W = Fc.shape
|
216 |
+
f_Fc = self.relu6(self.f(mean_variance_norm(Fc)))
|
217 |
+
g_Fs = self.relu6(self.g(mean_variance_norm(Fs)))
|
218 |
+
h_Fs = self.relu6(self.h(Fs))
|
219 |
+
f_Fc = f_Fc.view(f_Fc.shape[0], f_Fc.shape[1], -1).permute(0, 2, 1)
|
220 |
+
g_Fs = g_Fs.view(g_Fs.shape[0], g_Fs.shape[1], -1)
|
221 |
+
|
222 |
+
Attention = self.softmax(torch.bmm(f_Fc, g_Fs))
|
223 |
+
|
224 |
+
h_Fs = h_Fs.view(h_Fs.shape[0], h_Fs.shape[1], -1)
|
225 |
+
|
226 |
+
Fcs = torch.bmm(h_Fs, Attention.permute(0, 2, 1))
|
227 |
+
Fcs = Fcs.view(B, C//2, H, W)
|
228 |
+
Fcs = self.relu6(self.out_conv(Fcs))
|
229 |
+
|
230 |
+
return Fcs
|
231 |
+
|
232 |
+
class FuseUnit(nn.Module):
|
233 |
+
def __init__(self, channels):
|
234 |
+
super(FuseUnit, self).__init__()
|
235 |
+
self.proj1 = nn.Conv2d(2*channels, channels, (1, 1))
|
236 |
+
self.proj2 = nn.Conv2d(channels, channels, (1, 1))
|
237 |
+
self.proj3 = nn.Conv2d(channels, channels, (1, 1))
|
238 |
+
|
239 |
+
self.fuse1x = nn.Conv2d(channels, 1, (1, 1), stride = 1)
|
240 |
+
self.fuse3x = nn.Conv2d(channels, 1, (3, 3), stride = 1)
|
241 |
+
self.fuse5x = nn.Conv2d(channels, 1, (5, 5), stride = 1)
|
242 |
+
|
243 |
+
self.pad3x = nn.ReflectionPad2d((1, 1, 1, 1))
|
244 |
+
self.pad5x = nn.ReflectionPad2d((2, 2, 2, 2))
|
245 |
+
self.sigmoid = nn.Sigmoid()
|
246 |
+
|
247 |
+
def forward(self, F1, F2):
|
248 |
+
Fcat = self.proj1(torch.cat((F1, F2), dim=1))
|
249 |
+
F1 = self.proj2(F1)
|
250 |
+
F2 = self.proj3(F2)
|
251 |
+
|
252 |
+
fusion1 = self.sigmoid(self.fuse1x(Fcat))
|
253 |
+
fusion3 = self.sigmoid(self.fuse3x(self.pad3x(Fcat)))
|
254 |
+
fusion5 = self.sigmoid(self.fuse5x(self.pad5x(Fcat)))
|
255 |
+
fusion = (fusion1 + fusion3 + fusion5) / 3
|
256 |
+
|
257 |
+
return torch.clamp(fusion, min=0, max=1.0)*F1 + torch.clamp(1 - fusion, min=0, max=1.0)*F2
|
258 |
+
|
259 |
+
class PAMA(nn.Module):
|
260 |
+
def __init__(self, channels):
|
261 |
+
super(PAMA, self).__init__()
|
262 |
+
self.conv_in = nn.Conv2d(channels, channels, (3, 3), stride=1)
|
263 |
+
self.attn = AttentionUnit(channels)
|
264 |
+
self.fuse = FuseUnit(channels)
|
265 |
+
self.conv_out = nn.Conv2d(channels, channels, (3, 3), stride=1)
|
266 |
+
|
267 |
+
self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
|
268 |
+
self.relu6 = nn.ReLU6()
|
269 |
+
|
270 |
+
def forward(self, Fc, Fs):
|
271 |
+
Fc = self.relu6(self.conv_in(self.pad(Fc)))
|
272 |
+
Fs = self.relu6(self.conv_in(self.pad(Fs)))
|
273 |
+
Fcs = self.attn(Fc, Fs)
|
274 |
+
Fcs = self.relu6(self.conv_out(self.pad(Fcs)))
|
275 |
+
Fcs = self.fuse(Fc, Fcs)
|
276 |
+
|
277 |
+
return Fcs
|
278 |
+
|
279 |
+
#---------------------------------------------------------------------------------------------------------------
|
280 |
+
|
281 |
+
|
style_trsfer.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from torchvision.utils import make_grid
|
4 |
+
from PIL import Image, ImageFile
|
5 |
+
from net import Net
|
6 |
+
from utils import DEVICE, test_transform
|
7 |
+
Image.MAX_IMAGE_PIXELS = None
|
8 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def style_transfer_method(content_image,style_img):
|
13 |
+
main_parser = argparse.ArgumentParser(description="main parser")
|
14 |
+
subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand")
|
15 |
+
|
16 |
+
main_parser.add_argument("--pretrained", type=bool, default=True,
|
17 |
+
help="whether to use the pre-trained checkpoints")
|
18 |
+
main_parser.add_argument("--requires_grad", type=bool, default=True,
|
19 |
+
help="set to True if the model requires model gradient")
|
20 |
+
|
21 |
+
train_parser = subparsers.add_parser("train", help="training mode parser")
|
22 |
+
train_parser.add_argument("--training", type=bool, default=True)
|
23 |
+
train_parser.add_argument("--iterations", type=int, default=60000,
|
24 |
+
help="total training epochs (default: 160000)")
|
25 |
+
train_parser.add_argument("--batch_size", type=int, default=2,
|
26 |
+
help="training batch size (default: 8)")
|
27 |
+
train_parser.add_argument("--num_workers", type=int, default=2,
|
28 |
+
help="iterator threads (default: 8)")
|
29 |
+
train_parser.add_argument("--lr", type=float, default=1e-4, help="the learning rate during training (default: 1e-4)")
|
30 |
+
train_parser.add_argument("--content_folder", type=str, required = True,
|
31 |
+
help="the root of content images, the path should point to a folder")
|
32 |
+
train_parser.add_argument("--style_folder", type=str, required = True,
|
33 |
+
help="the root of style images, the path should point to a folder")
|
34 |
+
train_parser.add_argument("--log_interval", type=int, default=10000,
|
35 |
+
help="number of images after which the training loss is logged (default: 20000)")
|
36 |
+
|
37 |
+
train_parser.add_argument("--w_content1", type=float, default=12, help="the stage1 content loss weight")
|
38 |
+
train_parser.add_argument("--w_content2", type=float, default=9, help="the stage2 content loss weight")
|
39 |
+
train_parser.add_argument("--w_content3", type=float, default=7, help="the stage3 content loss weight")
|
40 |
+
train_parser.add_argument("--w_remd1", type=float, default=2, help="the stage1 remd loss weight")
|
41 |
+
train_parser.add_argument("--w_remd2", type=float, default=2, help="the stage2 remd loss weight")
|
42 |
+
train_parser.add_argument("--w_remd3", type=float, default=2, help="the stage3 remd loss weight")
|
43 |
+
train_parser.add_argument("--w_moment1", type=float, default=2, help="the stage1 moment loss weight")
|
44 |
+
train_parser.add_argument("--w_moment2", type=float, default=2, help="the stage2 moment loss weight")
|
45 |
+
train_parser.add_argument("--w_moment3", type=float, default=2, help="the stage3 moment loss weight")
|
46 |
+
train_parser.add_argument("--color_on", type=str, default=True, help="turn on the color loss")
|
47 |
+
train_parser.add_argument("--w_color1", type=float, default=0.25, help="the stage1 color loss weight")
|
48 |
+
train_parser.add_argument("--w_color2", type=float, default=0.5, help="the stage2 color loss weight")
|
49 |
+
train_parser.add_argument("--w_color3", type=float, default=1, help="the stage3 color loss weight")
|
50 |
+
|
51 |
+
|
52 |
+
eval_parser = subparsers.add_parser("eval", help="evaluation mode parser")
|
53 |
+
eval_parser.add_argument("--training", type=bool, default=False)
|
54 |
+
eval_parser.add_argument("--run_folder", type=bool, default=False)
|
55 |
+
|
56 |
+
args = main_parser.parse_args()
|
57 |
+
|
58 |
+
args.training = False
|
59 |
+
|
60 |
+
model = Net(args)
|
61 |
+
model.eval()
|
62 |
+
model = model.to(DEVICE)
|
63 |
+
|
64 |
+
tf = test_transform()
|
65 |
+
|
66 |
+
Ic = tf(content_image).to(DEVICE)
|
67 |
+
Is = tf(Image.fromarray(style_img)).to(DEVICE)
|
68 |
+
|
69 |
+
Ic = Ic.unsqueeze(dim=0)
|
70 |
+
Is = Is.unsqueeze(dim=0)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
Ics = model(Ic, Is)
|
74 |
+
|
75 |
+
grid = make_grid(Ics[0])
|
76 |
+
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
|
77 |
+
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
78 |
+
im = Image.fromarray(ndarr)
|
79 |
+
|
80 |
+
return im
|
utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.data as data
|
7 |
+
from torchvision import transforms
|
8 |
+
import PIL.Image as Image
|
9 |
+
|
10 |
+
DEVICE = 'cuda'
|
11 |
+
mse = nn.MSELoss()
|
12 |
+
|
13 |
+
|
14 |
+
def calc_histogram_loss(A, B, histogram_block):
|
15 |
+
input_hist = histogram_block(A)
|
16 |
+
target_hist = histogram_block(B)
|
17 |
+
histogram_loss = (1/np.sqrt(2.0) * (torch.sqrt(torch.sum(
|
18 |
+
torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) /
|
19 |
+
input_hist.shape[0])
|
20 |
+
|
21 |
+
return histogram_loss
|
22 |
+
|
23 |
+
# B, C, H, W; mean var on HW
|
24 |
+
def calc_mean_std(feat, eps=1e-5):
|
25 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
26 |
+
size = feat.size()
|
27 |
+
assert (len(size) == 4)
|
28 |
+
N, C = size[:2]
|
29 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
30 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
31 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
32 |
+
return feat_mean, feat_std
|
33 |
+
|
34 |
+
def mean_variance_norm(feat):
|
35 |
+
size = feat.size()
|
36 |
+
mean, std = calc_mean_std(feat)
|
37 |
+
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
|
38 |
+
return normalized_feat
|
39 |
+
|
40 |
+
def train_transform():
|
41 |
+
transform_list = [
|
42 |
+
transforms.Resize(size=512),
|
43 |
+
transforms.RandomCrop(256),
|
44 |
+
transforms.ToTensor()
|
45 |
+
]
|
46 |
+
return transforms.Compose(transform_list)
|
47 |
+
|
48 |
+
def test_transform():
|
49 |
+
transform_list = []
|
50 |
+
transform_list.append(transforms.Resize(size=(512)))
|
51 |
+
transform_list.append(transforms.ToTensor())
|
52 |
+
transform = transforms.Compose(transform_list)
|
53 |
+
return transform
|
54 |
+
|
55 |
+
# https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/7
|
56 |
+
def plot_grad_flow(named_parameters):
|
57 |
+
'''Plots the gradients flowing through different layers in the net during training.
|
58 |
+
Can be used for checking for possible gradient vanishing / exploding problems.
|
59 |
+
|
60 |
+
Usage: Plug this function in Trainer class after loss.backwards() as
|
61 |
+
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
|
62 |
+
ave_grads = []
|
63 |
+
max_grads= []
|
64 |
+
layers = []
|
65 |
+
for n, p in named_parameters:
|
66 |
+
if(p.requires_grad) and ("bias" not in n):
|
67 |
+
layers.append(n)
|
68 |
+
ave_grads.append(p.grad.abs().mean())
|
69 |
+
max_grads.append(p.grad.abs().max())
|
70 |
+
print('-'*82)
|
71 |
+
print(n, p.grad.abs().mean(), p.grad.abs().max())
|
72 |
+
print('-'*82)
|
73 |
+
|
74 |
+
def InfiniteSampler(n):
|
75 |
+
# i = 0
|
76 |
+
i = n - 1
|
77 |
+
order = np.random.permutation(n)
|
78 |
+
while True:
|
79 |
+
yield order[i]
|
80 |
+
i += 1
|
81 |
+
if i >= n:
|
82 |
+
np.random.seed()
|
83 |
+
order = np.random.permutation(n)
|
84 |
+
i = 0
|
85 |
+
|
86 |
+
class InfiniteSamplerWrapper(data.sampler.Sampler):
|
87 |
+
def __init__(self, data_source):
|
88 |
+
self.num_samples = len(data_source)
|
89 |
+
|
90 |
+
def __iter__(self):
|
91 |
+
return iter(InfiniteSampler(self.num_samples))
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return 2 ** 31
|
95 |
+
|
96 |
+
class FlatFolderDataset(data.Dataset):
|
97 |
+
def __init__(self, root, transform):
|
98 |
+
super(FlatFolderDataset, self).__init__()
|
99 |
+
self.root = root
|
100 |
+
self.paths = os.listdir(self.root)
|
101 |
+
self.transform = transform
|
102 |
+
|
103 |
+
def __getitem__(self, index):
|
104 |
+
path = self.paths[index]
|
105 |
+
img = Image.open(os.path.join(self.root, path)).convert('RGB')
|
106 |
+
img = self.transform(img)
|
107 |
+
return img
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.paths)
|
111 |
+
|
112 |
+
def name(self):
|
113 |
+
return 'FlatFolderDataset'
|
114 |
+
|
115 |
+
def adjust_learning_rate(optimizer, iteration_count, args):
|
116 |
+
"""Imitating the original implementation"""
|
117 |
+
lr = args.lr / (1.0 + 5e-5 * iteration_count)
|
118 |
+
for param_group in optimizer.param_groups:
|
119 |
+
param_group['lr'] = lr
|
120 |
+
|
121 |
+
def cosine_dismat(A, B):
|
122 |
+
A = A.view(A.shape[0], A.shape[1], -1)
|
123 |
+
B = B.view(B.shape[0], B.shape[1], -1)
|
124 |
+
|
125 |
+
A_norm = torch.sqrt((A**2).sum(1))
|
126 |
+
B_norm = torch.sqrt((B**2).sum(1))
|
127 |
+
|
128 |
+
A = (A/A_norm.unsqueeze(dim=1).expand(A.shape)).permute(0,2,1)
|
129 |
+
B = (B/B_norm.unsqueeze(dim=1).expand(B.shape))
|
130 |
+
dismat = 1.-torch.bmm(A, B)
|
131 |
+
|
132 |
+
return dismat
|
133 |
+
|
134 |
+
def calc_remd_loss(A, B):
|
135 |
+
C = cosine_dismat(A, B)
|
136 |
+
m1, _ = C.min(1)
|
137 |
+
m2, _ = C.min(2)
|
138 |
+
|
139 |
+
remd = torch.max(m1.mean(), m2.mean())
|
140 |
+
|
141 |
+
return remd
|
142 |
+
|
143 |
+
def calc_ss_loss(A, B):
|
144 |
+
MA = cosine_dismat(A, A)
|
145 |
+
MB = cosine_dismat(B, B)
|
146 |
+
Lself_similarity = torch.abs(MA-MB).mean()
|
147 |
+
|
148 |
+
return Lself_similarity
|
149 |
+
|
150 |
+
def calc_moment_loss(A, B):
|
151 |
+
A = A.view(A.shape[0], A.shape[1], -1)
|
152 |
+
B = B.view(B.shape[0], B.shape[1], -1)
|
153 |
+
|
154 |
+
mu_a = torch.mean(A, 1, keepdim=True)
|
155 |
+
mu_b = torch.mean(B, 1, keepdim=True)
|
156 |
+
mu_d = torch.abs(mu_a - mu_b).mean()
|
157 |
+
|
158 |
+
A_c = A - mu_a
|
159 |
+
B_c = B - mu_b
|
160 |
+
cov_a = torch.bmm(A_c, A_c.permute(0,2,1)) / (A.shape[2]-1)
|
161 |
+
cov_b = torch.bmm(B_c, B_c.permute(0,2,1)) / (B.shape[2]-1)
|
162 |
+
cov_d = torch.abs(cov_a - cov_b).mean()
|
163 |
+
loss = mu_d + cov_d
|
164 |
+
return loss
|
165 |
+
|
166 |
+
def calc_mse_loss(A, B):
|
167 |
+
return mse(A, B)
|
168 |
+
|