Spaces:
Runtime error
Runtime error
Commit
·
504f20d
1
Parent(s):
d54f4b1
code cleanup
Browse files
test.py
DELETED
@@ -1,252 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from pathlib import Path
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from PIL import Image
|
7 |
-
from torchvision import transforms
|
8 |
-
from torchvision.utils import save_image
|
9 |
-
import time
|
10 |
-
import net
|
11 |
-
from function import adaptive_instance_normalization, coral
|
12 |
-
from function import adaptive_mean_normalization
|
13 |
-
from function import adaptive_std_normalization
|
14 |
-
from function import exact_feature_distribution_matching, histogram_matching
|
15 |
-
|
16 |
-
def test_transform(size, crop):
|
17 |
-
transform_list = []
|
18 |
-
if size != 0:
|
19 |
-
transform_list.append(transforms.Resize(size))
|
20 |
-
if crop:
|
21 |
-
transform_list.append(transforms.CenterCrop(size))
|
22 |
-
transform_list.append(transforms.ToTensor())
|
23 |
-
transform = transforms.Compose(transform_list)
|
24 |
-
return transform
|
25 |
-
|
26 |
-
|
27 |
-
def style_transfer(vgg, decoder, content, style, alpha=1.0,
|
28 |
-
interpolation_weights=None, style_type='adain'):
|
29 |
-
assert (0.0 <= alpha <= 1.0)
|
30 |
-
content_f = vgg(content)
|
31 |
-
style_f = vgg(style)
|
32 |
-
if interpolation_weights:
|
33 |
-
_, C, H, W = content_f.size()
|
34 |
-
feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
|
35 |
-
if style_type == 'adain':
|
36 |
-
base_feat = adaptive_instance_normalization(content_f, style_f)
|
37 |
-
elif style_type == 'adamean':
|
38 |
-
base_feat = adaptive_mean_normalization(content_f, style_f)
|
39 |
-
elif style_type == 'adastd':
|
40 |
-
base_feat = adaptive_std_normalization(content_f, style_f)
|
41 |
-
elif style_type == 'efdm':
|
42 |
-
base_feat = exact_feature_distribution_matching(content_f, style_f)
|
43 |
-
elif style_type == 'hm':
|
44 |
-
feat = histogram_matching(content_f, style_f)
|
45 |
-
else:
|
46 |
-
raise NotImplementedError
|
47 |
-
for i, w in enumerate(interpolation_weights):
|
48 |
-
feat = feat + w * base_feat[i:i + 1]
|
49 |
-
content_f = content_f[0:1]
|
50 |
-
else:
|
51 |
-
if style_type == 'adain':
|
52 |
-
feat = adaptive_instance_normalization(content_f, style_f)
|
53 |
-
elif style_type == 'adamean':
|
54 |
-
feat = adaptive_mean_normalization(content_f, style_f)
|
55 |
-
elif style_type == 'adastd':
|
56 |
-
feat = adaptive_std_normalization(content_f, style_f)
|
57 |
-
elif style_type == 'efdm':
|
58 |
-
feat = exact_feature_distribution_matching(content_f, style_f)
|
59 |
-
elif style_type == 'hm':
|
60 |
-
feat = histogram_matching(content_f, style_f)
|
61 |
-
else:
|
62 |
-
raise NotImplementedError
|
63 |
-
feat = feat * alpha + content_f * (1 - alpha)
|
64 |
-
return decoder(feat)
|
65 |
-
|
66 |
-
|
67 |
-
parser = argparse.ArgumentParser()
|
68 |
-
# Basic options
|
69 |
-
parser.add_argument('--content', type=str,
|
70 |
-
help='File path to the content image')
|
71 |
-
parser.add_argument('--content_dir', type=str,
|
72 |
-
help='Directory path to a batch of content images')
|
73 |
-
parser.add_argument('--style', type=str,
|
74 |
-
help='File path to the style image, or multiple style \
|
75 |
-
images separated by commas if you want to do style \
|
76 |
-
interpolation or spatial control')
|
77 |
-
parser.add_argument('--style_dir', type=str,
|
78 |
-
help='Directory path to a batch of style images')
|
79 |
-
parser.add_argument('--vgg', type=str, default='pretrained/vgg_normalised.pth')
|
80 |
-
parser.add_argument('--decoder', type=str, default='pretrained/efdm_decoder_iter_160000.pth.tar')
|
81 |
-
parser.add_argument('--style_type', type=str, default='adain', help='adain | adamean | adastd | efdm')
|
82 |
-
parser.add_argument('--test_style_type', type=str, default='', help='adain | adamean | adastd | efdm')
|
83 |
-
# Additional options
|
84 |
-
parser.add_argument('--content_size', type=int, default=512,
|
85 |
-
help='New (minimum) size for the content image, \
|
86 |
-
keeping the original size if set to 0')
|
87 |
-
parser.add_argument('--style_size', type=int, default=512,
|
88 |
-
help='New (minimum) size for the style image, \
|
89 |
-
keeping the original size if set to 0')
|
90 |
-
parser.add_argument('--crop', action='store_true',
|
91 |
-
help='do center crop to create squared image')
|
92 |
-
parser.add_argument('--save_ext', default='.jpg',
|
93 |
-
help='The extension name of the output image')
|
94 |
-
parser.add_argument('--output', type=str, default='output',
|
95 |
-
help='Directory to save the output image(s)')
|
96 |
-
parser.add_argument('--photo', action='store_true',
|
97 |
-
help='apply on the photo style transfer')
|
98 |
-
# Advanced options
|
99 |
-
parser.add_argument('--preserve_color', action='store_true',
|
100 |
-
help='If specified, preserve color of the content image')
|
101 |
-
parser.add_argument('--alpha', type=float, default=1.0,
|
102 |
-
help='The weight that controls the degree of \
|
103 |
-
stylization. Should be between 0 and 1')
|
104 |
-
parser.add_argument(
|
105 |
-
'--style_interpolation_weights', type=str, default='',
|
106 |
-
help='The weight for blending the style of multiple style images')
|
107 |
-
|
108 |
-
args = parser.parse_args()
|
109 |
-
if not args.test_style_type:
|
110 |
-
args.test_style_type = args.style_type
|
111 |
-
|
112 |
-
print('Note: the style type: %s and the pre-trained model: %s should be consistent' % (args.style_type, args.decoder))
|
113 |
-
print('The test style type is:', args.test_style_type)
|
114 |
-
|
115 |
-
do_interpolation = False
|
116 |
-
|
117 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
118 |
-
|
119 |
-
output_dir = Path(args.output + '_' + args.style_type + '_' + args.test_style_type)
|
120 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
121 |
-
|
122 |
-
# Either --content or --contentDir should be given.
|
123 |
-
assert (args.content or args.content_dir)
|
124 |
-
if args.content:
|
125 |
-
content_paths = [Path(args.content)]
|
126 |
-
else:
|
127 |
-
content_dir = Path(args.content_dir)
|
128 |
-
content_paths = [f for f in content_dir.glob('*')]
|
129 |
-
|
130 |
-
# Either --style or --styleDir should be given.
|
131 |
-
assert (args.style or args.style_dir)
|
132 |
-
if args.style:
|
133 |
-
style_paths = args.style.split(',')
|
134 |
-
if len(style_paths) == 1:
|
135 |
-
style_paths = [Path(args.style)]
|
136 |
-
else:
|
137 |
-
do_interpolation = True
|
138 |
-
# assert (args.style_interpolation_weights != ''), \
|
139 |
-
# 'Please specify interpolation weights'
|
140 |
-
# weights = [int(i) for i in args.style_interpolation_weights.split(',')]
|
141 |
-
# interpolation_weights = [w / sum(weights) for w in weights]
|
142 |
-
else:
|
143 |
-
style_dir = Path(args.style_dir)
|
144 |
-
style_paths = [f for f in style_dir.glob('*')]
|
145 |
-
|
146 |
-
decoder = net.decoder
|
147 |
-
vgg = net.vgg
|
148 |
-
|
149 |
-
decoder.eval()
|
150 |
-
vgg.eval()
|
151 |
-
|
152 |
-
decoder.load_state_dict(torch.load(args.decoder))
|
153 |
-
vgg.load_state_dict(torch.load(args.vgg))
|
154 |
-
vgg = nn.Sequential(*list(vgg.children())[:31])
|
155 |
-
|
156 |
-
vgg.to(device)
|
157 |
-
decoder.to(device)
|
158 |
-
|
159 |
-
content_tf = test_transform(args.content_size, args.crop)
|
160 |
-
style_tf = test_transform(args.style_size, args.crop)
|
161 |
-
|
162 |
-
timer = []
|
163 |
-
for content_path in content_paths:
|
164 |
-
if do_interpolation:
|
165 |
-
# one content image, 4 style image
|
166 |
-
style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
|
167 |
-
content = content_tf(Image.open(str(content_path))) \
|
168 |
-
.unsqueeze(0).expand_as(style)
|
169 |
-
style = style.to(device)
|
170 |
-
content = content.to(device)
|
171 |
-
list = []
|
172 |
-
steps = [1, 0.75, 0.5, 0.25, 0]
|
173 |
-
for i in steps:
|
174 |
-
for j in steps:
|
175 |
-
list.append([i*j, i*(1-j), (1-i)*j, (1-i)*(1-j)])
|
176 |
-
count = 1
|
177 |
-
for interpolation_weights in list:
|
178 |
-
with torch.no_grad():
|
179 |
-
output = style_transfer(vgg, decoder, content, style,
|
180 |
-
args.alpha, interpolation_weights, style_type=args.test_style_type)
|
181 |
-
output = output.cpu()
|
182 |
-
output_name = output_dir / '{:s}_interpolate_{:s}_{:s}'.format(
|
183 |
-
content_path.stem, str(count), args.save_ext)
|
184 |
-
save_image(output, str(output_name))
|
185 |
-
count+=1
|
186 |
-
|
187 |
-
#### content & style trade-off.
|
188 |
-
# alpha = [0.0, 0.25, 0.5, 0.75, 1.0]
|
189 |
-
# for style_path in style_paths:
|
190 |
-
# content = content_tf(Image.open(str(content_path)))
|
191 |
-
# style = style_tf(Image.open(str(style_path)))
|
192 |
-
# if args.preserve_color:
|
193 |
-
# style = coral(style, content)
|
194 |
-
# style = style.to(device).unsqueeze(0)
|
195 |
-
# content = content.to(device).unsqueeze(0)
|
196 |
-
# ## replace the style image with Gaussian noise
|
197 |
-
# # style.normal_(0,1)
|
198 |
-
# # style = torch.rand(style.size()).to(device)
|
199 |
-
# ### for paired images.
|
200 |
-
# if args.photo:
|
201 |
-
# if content_path.stem[2:] == style_path.stem[3:]:
|
202 |
-
# for sample_alpha in alpha:
|
203 |
-
# with torch.no_grad():
|
204 |
-
# output = style_transfer(vgg, decoder, content, style,
|
205 |
-
# sample_alpha, style_type=args.test_style_type)
|
206 |
-
# output = output.cpu()
|
207 |
-
# output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
|
208 |
-
# content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
|
209 |
-
# save_image(output, str(output_name))
|
210 |
-
# else:
|
211 |
-
# for sample_alpha in alpha:
|
212 |
-
# with torch.no_grad():
|
213 |
-
# output = style_transfer(vgg, decoder, content, style,
|
214 |
-
# sample_alpha, style_type=args.test_style_type)
|
215 |
-
# output = output.cpu()
|
216 |
-
# output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
|
217 |
-
# content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
|
218 |
-
# save_image(output, str(output_name))
|
219 |
-
else: # process one content and one style
|
220 |
-
for style_path in style_paths:
|
221 |
-
content = content_tf(Image.open(str(content_path)))
|
222 |
-
style = style_tf(Image.open(str(style_path)))
|
223 |
-
if args.preserve_color:
|
224 |
-
style = coral(style, content)
|
225 |
-
style = style.to(device).unsqueeze(0)
|
226 |
-
content = content.to(device).unsqueeze(0)
|
227 |
-
## replace the style image with Gaussian noise
|
228 |
-
# style.normal_(0,1)
|
229 |
-
# style = torch.rand(style.size()).to(device)
|
230 |
-
### for paired images.
|
231 |
-
if args.photo:
|
232 |
-
if content_path.stem[2:] == style_path.stem[3:]:
|
233 |
-
with torch.no_grad():
|
234 |
-
start_time = time.time()
|
235 |
-
output = style_transfer(vgg, decoder, content, style,
|
236 |
-
args.alpha, style_type=args.test_style_type)
|
237 |
-
timer.append(time.time() - start_time)
|
238 |
-
print(timer)
|
239 |
-
|
240 |
-
output = output.cpu()
|
241 |
-
output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
|
242 |
-
content_path.stem, style_path.stem, args.save_ext)
|
243 |
-
save_image(output, str(output_name))
|
244 |
-
else:
|
245 |
-
with torch.no_grad():
|
246 |
-
output = style_transfer(vgg, decoder, content, style,
|
247 |
-
args.alpha, style_type=args.test_style_type)
|
248 |
-
output = output.cpu()
|
249 |
-
output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
|
250 |
-
content_path.stem, style_path.stem, args.save_ext)
|
251 |
-
save_image(output, str(output_name))
|
252 |
-
print(torch.FloatTensor(timer).mean())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|