biubiubiiu commited on
Commit
504f20d
·
1 Parent(s): d54f4b1

code cleanup

Browse files
Files changed (1) hide show
  1. test.py +0 -252
test.py DELETED
@@ -1,252 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
-
4
- import torch
5
- import torch.nn as nn
6
- from PIL import Image
7
- from torchvision import transforms
8
- from torchvision.utils import save_image
9
- import time
10
- import net
11
- from function import adaptive_instance_normalization, coral
12
- from function import adaptive_mean_normalization
13
- from function import adaptive_std_normalization
14
- from function import exact_feature_distribution_matching, histogram_matching
15
-
16
- def test_transform(size, crop):
17
- transform_list = []
18
- if size != 0:
19
- transform_list.append(transforms.Resize(size))
20
- if crop:
21
- transform_list.append(transforms.CenterCrop(size))
22
- transform_list.append(transforms.ToTensor())
23
- transform = transforms.Compose(transform_list)
24
- return transform
25
-
26
-
27
- def style_transfer(vgg, decoder, content, style, alpha=1.0,
28
- interpolation_weights=None, style_type='adain'):
29
- assert (0.0 <= alpha <= 1.0)
30
- content_f = vgg(content)
31
- style_f = vgg(style)
32
- if interpolation_weights:
33
- _, C, H, W = content_f.size()
34
- feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
35
- if style_type == 'adain':
36
- base_feat = adaptive_instance_normalization(content_f, style_f)
37
- elif style_type == 'adamean':
38
- base_feat = adaptive_mean_normalization(content_f, style_f)
39
- elif style_type == 'adastd':
40
- base_feat = adaptive_std_normalization(content_f, style_f)
41
- elif style_type == 'efdm':
42
- base_feat = exact_feature_distribution_matching(content_f, style_f)
43
- elif style_type == 'hm':
44
- feat = histogram_matching(content_f, style_f)
45
- else:
46
- raise NotImplementedError
47
- for i, w in enumerate(interpolation_weights):
48
- feat = feat + w * base_feat[i:i + 1]
49
- content_f = content_f[0:1]
50
- else:
51
- if style_type == 'adain':
52
- feat = adaptive_instance_normalization(content_f, style_f)
53
- elif style_type == 'adamean':
54
- feat = adaptive_mean_normalization(content_f, style_f)
55
- elif style_type == 'adastd':
56
- feat = adaptive_std_normalization(content_f, style_f)
57
- elif style_type == 'efdm':
58
- feat = exact_feature_distribution_matching(content_f, style_f)
59
- elif style_type == 'hm':
60
- feat = histogram_matching(content_f, style_f)
61
- else:
62
- raise NotImplementedError
63
- feat = feat * alpha + content_f * (1 - alpha)
64
- return decoder(feat)
65
-
66
-
67
- parser = argparse.ArgumentParser()
68
- # Basic options
69
- parser.add_argument('--content', type=str,
70
- help='File path to the content image')
71
- parser.add_argument('--content_dir', type=str,
72
- help='Directory path to a batch of content images')
73
- parser.add_argument('--style', type=str,
74
- help='File path to the style image, or multiple style \
75
- images separated by commas if you want to do style \
76
- interpolation or spatial control')
77
- parser.add_argument('--style_dir', type=str,
78
- help='Directory path to a batch of style images')
79
- parser.add_argument('--vgg', type=str, default='pretrained/vgg_normalised.pth')
80
- parser.add_argument('--decoder', type=str, default='pretrained/efdm_decoder_iter_160000.pth.tar')
81
- parser.add_argument('--style_type', type=str, default='adain', help='adain | adamean | adastd | efdm')
82
- parser.add_argument('--test_style_type', type=str, default='', help='adain | adamean | adastd | efdm')
83
- # Additional options
84
- parser.add_argument('--content_size', type=int, default=512,
85
- help='New (minimum) size for the content image, \
86
- keeping the original size if set to 0')
87
- parser.add_argument('--style_size', type=int, default=512,
88
- help='New (minimum) size for the style image, \
89
- keeping the original size if set to 0')
90
- parser.add_argument('--crop', action='store_true',
91
- help='do center crop to create squared image')
92
- parser.add_argument('--save_ext', default='.jpg',
93
- help='The extension name of the output image')
94
- parser.add_argument('--output', type=str, default='output',
95
- help='Directory to save the output image(s)')
96
- parser.add_argument('--photo', action='store_true',
97
- help='apply on the photo style transfer')
98
- # Advanced options
99
- parser.add_argument('--preserve_color', action='store_true',
100
- help='If specified, preserve color of the content image')
101
- parser.add_argument('--alpha', type=float, default=1.0,
102
- help='The weight that controls the degree of \
103
- stylization. Should be between 0 and 1')
104
- parser.add_argument(
105
- '--style_interpolation_weights', type=str, default='',
106
- help='The weight for blending the style of multiple style images')
107
-
108
- args = parser.parse_args()
109
- if not args.test_style_type:
110
- args.test_style_type = args.style_type
111
-
112
- print('Note: the style type: %s and the pre-trained model: %s should be consistent' % (args.style_type, args.decoder))
113
- print('The test style type is:', args.test_style_type)
114
-
115
- do_interpolation = False
116
-
117
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
118
-
119
- output_dir = Path(args.output + '_' + args.style_type + '_' + args.test_style_type)
120
- output_dir.mkdir(exist_ok=True, parents=True)
121
-
122
- # Either --content or --contentDir should be given.
123
- assert (args.content or args.content_dir)
124
- if args.content:
125
- content_paths = [Path(args.content)]
126
- else:
127
- content_dir = Path(args.content_dir)
128
- content_paths = [f for f in content_dir.glob('*')]
129
-
130
- # Either --style or --styleDir should be given.
131
- assert (args.style or args.style_dir)
132
- if args.style:
133
- style_paths = args.style.split(',')
134
- if len(style_paths) == 1:
135
- style_paths = [Path(args.style)]
136
- else:
137
- do_interpolation = True
138
- # assert (args.style_interpolation_weights != ''), \
139
- # 'Please specify interpolation weights'
140
- # weights = [int(i) for i in args.style_interpolation_weights.split(',')]
141
- # interpolation_weights = [w / sum(weights) for w in weights]
142
- else:
143
- style_dir = Path(args.style_dir)
144
- style_paths = [f for f in style_dir.glob('*')]
145
-
146
- decoder = net.decoder
147
- vgg = net.vgg
148
-
149
- decoder.eval()
150
- vgg.eval()
151
-
152
- decoder.load_state_dict(torch.load(args.decoder))
153
- vgg.load_state_dict(torch.load(args.vgg))
154
- vgg = nn.Sequential(*list(vgg.children())[:31])
155
-
156
- vgg.to(device)
157
- decoder.to(device)
158
-
159
- content_tf = test_transform(args.content_size, args.crop)
160
- style_tf = test_transform(args.style_size, args.crop)
161
-
162
- timer = []
163
- for content_path in content_paths:
164
- if do_interpolation:
165
- # one content image, 4 style image
166
- style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
167
- content = content_tf(Image.open(str(content_path))) \
168
- .unsqueeze(0).expand_as(style)
169
- style = style.to(device)
170
- content = content.to(device)
171
- list = []
172
- steps = [1, 0.75, 0.5, 0.25, 0]
173
- for i in steps:
174
- for j in steps:
175
- list.append([i*j, i*(1-j), (1-i)*j, (1-i)*(1-j)])
176
- count = 1
177
- for interpolation_weights in list:
178
- with torch.no_grad():
179
- output = style_transfer(vgg, decoder, content, style,
180
- args.alpha, interpolation_weights, style_type=args.test_style_type)
181
- output = output.cpu()
182
- output_name = output_dir / '{:s}_interpolate_{:s}_{:s}'.format(
183
- content_path.stem, str(count), args.save_ext)
184
- save_image(output, str(output_name))
185
- count+=1
186
-
187
- #### content & style trade-off.
188
- # alpha = [0.0, 0.25, 0.5, 0.75, 1.0]
189
- # for style_path in style_paths:
190
- # content = content_tf(Image.open(str(content_path)))
191
- # style = style_tf(Image.open(str(style_path)))
192
- # if args.preserve_color:
193
- # style = coral(style, content)
194
- # style = style.to(device).unsqueeze(0)
195
- # content = content.to(device).unsqueeze(0)
196
- # ## replace the style image with Gaussian noise
197
- # # style.normal_(0,1)
198
- # # style = torch.rand(style.size()).to(device)
199
- # ### for paired images.
200
- # if args.photo:
201
- # if content_path.stem[2:] == style_path.stem[3:]:
202
- # for sample_alpha in alpha:
203
- # with torch.no_grad():
204
- # output = style_transfer(vgg, decoder, content, style,
205
- # sample_alpha, style_type=args.test_style_type)
206
- # output = output.cpu()
207
- # output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
208
- # content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
209
- # save_image(output, str(output_name))
210
- # else:
211
- # for sample_alpha in alpha:
212
- # with torch.no_grad():
213
- # output = style_transfer(vgg, decoder, content, style,
214
- # sample_alpha, style_type=args.test_style_type)
215
- # output = output.cpu()
216
- # output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
217
- # content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
218
- # save_image(output, str(output_name))
219
- else: # process one content and one style
220
- for style_path in style_paths:
221
- content = content_tf(Image.open(str(content_path)))
222
- style = style_tf(Image.open(str(style_path)))
223
- if args.preserve_color:
224
- style = coral(style, content)
225
- style = style.to(device).unsqueeze(0)
226
- content = content.to(device).unsqueeze(0)
227
- ## replace the style image with Gaussian noise
228
- # style.normal_(0,1)
229
- # style = torch.rand(style.size()).to(device)
230
- ### for paired images.
231
- if args.photo:
232
- if content_path.stem[2:] == style_path.stem[3:]:
233
- with torch.no_grad():
234
- start_time = time.time()
235
- output = style_transfer(vgg, decoder, content, style,
236
- args.alpha, style_type=args.test_style_type)
237
- timer.append(time.time() - start_time)
238
- print(timer)
239
-
240
- output = output.cpu()
241
- output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
242
- content_path.stem, style_path.stem, args.save_ext)
243
- save_image(output, str(output_name))
244
- else:
245
- with torch.no_grad():
246
- output = style_transfer(vgg, decoder, content, style,
247
- args.alpha, style_type=args.test_style_type)
248
- output = output.cpu()
249
- output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
250
- content_path.stem, style_path.stem, args.save_ext)
251
- save_image(output, str(output_name))
252
- print(torch.FloatTensor(timer).mean())