Spaces:
Build error
Build error
from argparse import Namespace | |
import glob | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data as data | |
import yaml | |
import sys | |
current_dir = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.insert(0, current_dir) | |
from PIL import Image | |
from tqdm import tqdm | |
from torchvision import transforms, utils | |
from trainer import * | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = True | |
# torch.autograd.set_detect_anomaly(True) | |
Image.MAX_IMAGE_PIXELS = None | |
opts = Namespace(config='001', pretrained_model_path='pretrained_models/FeatureStyleEncoder/143_enc.pth', stylegan_model_path=f'pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt', arcface_model_path=f'pretrained_models/FeatureStyleEncoder/backbone.pth', parsing_model_path=f'pretrained_models/FeatureStyleEncoder/79999_iter.pth', log_path='./logs/', resume=False, checkpoint='', checkpoint_noiser='', multigpu=False, input_path='./test/', save_path='./') | |
config = yaml.load(open(f'{current_dir}/configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader) | |
def get_trainer(device): | |
# Initialize trainer | |
trainer = Trainer(config, opts) | |
trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path) | |
trainer.to(device) | |
# state_dict = torch.load(opts.pretrained_model_path)#os.path.join(opts.log_path, opts.config + '/checkpoint.pth')) | |
trainer.enc.load_state_dict(torch.load(opts.pretrained_model_path)) | |
trainer.enc.eval() | |
return trainer |