Spaces:
Build error
Build error
import argparse | |
import glob | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data as data | |
import yaml | |
from PIL import Image | |
from tqdm import tqdm | |
from torchvision import transforms, utils | |
from utils.datasets import * | |
from utils.functions import * | |
from trainer import * | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
torch.autograd.set_detect_anomaly(True) | |
Image.MAX_IMAGE_PIXELS = None | |
device = torch.device('cuda') | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', type=str, default='001', help='Path to the config file.') | |
parser.add_argument('--pretrained_model_path', type=str, default='./pretrained_models/143_enc.pth', help='pretrained stylegan2 model') | |
parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model') | |
parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained ArcFace model') | |
parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model') | |
parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') | |
parser.add_argument('--resume', action='store_true', help='resume from checkpoint') | |
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') | |
parser.add_argument('--checkpoint_noiser', type=str, default='', help='checkpoint file path') | |
parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus') | |
parser.add_argument('--input_path', type=str, default='./test/', help='evaluation data file path') | |
parser.add_argument('--save_path', type=str, default='./output/image/', help='output data save path') | |
opts = parser.parse_args() | |
log_dir = os.path.join(opts.log_path, opts.config) + '/' | |
config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader) | |
# Initialize trainer | |
trainer = Trainer(config, opts) | |
trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path) | |
trainer.to(device) | |
state_dict = torch.load(opts.pretrained_model_path)#os.path.join(opts.log_path, opts.config + '/checkpoint.pth')) | |
trainer.enc.load_state_dict(torch.load(opts.pretrained_model_path)) | |
trainer.enc.eval() | |
img_to_tensor = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
# simple inference | |
image_dir = opts.input_path | |
save_dir = opts.save_path | |
os.makedirs(save_dir, exist_ok=True) | |
with torch.no_grad(): | |
img_list = [glob.glob1(image_dir, ext) for ext in ['*jpg','*png']] | |
img_list = [item for sublist in img_list for item in sublist] | |
img_list.sort() | |
for i, img_name in enumerate(img_list): | |
#print(i, img_name) | |
image_A = img_to_tensor(Image.open(image_dir + img_name)).unsqueeze(0).to(device) | |
output = trainer.test(img=image_A, return_latent=True) | |
feature = output.pop() | |
latent = output.pop() | |
#np.save(save_dir + 'latent_code_%d.npy'%i, latent.cpu().numpy()) | |
utils.save_image(clip_img(output[1]), save_dir + img_name) | |
if i > 1000: | |
break | |