Spaces:
Build error
Build error
File size: 1,602 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from argparse import Namespace
import glob
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import yaml
import sys
current_dir = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, current_dir)
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, utils
from trainer import *
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
# torch.autograd.set_detect_anomaly(True)
Image.MAX_IMAGE_PIXELS = None
opts = Namespace(config='001', pretrained_model_path='pretrained_models/FeatureStyleEncoder/143_enc.pth', stylegan_model_path=f'pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt', arcface_model_path=f'pretrained_models/FeatureStyleEncoder/backbone.pth', parsing_model_path=f'pretrained_models/FeatureStyleEncoder/79999_iter.pth', log_path='./logs/', resume=False, checkpoint='', checkpoint_noiser='', multigpu=False, input_path='./test/', save_path='./')
config = yaml.load(open(f'{current_dir}/configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader)
def get_trainer(device):
# Initialize trainer
trainer = Trainer(config, opts)
trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path)
trainer.to(device)
# state_dict = torch.load(opts.pretrained_model_path)#os.path.join(opts.log_path, opts.config + '/checkpoint.pth'))
trainer.enc.load_state_dict(torch.load(opts.pretrained_model_path))
trainer.enc.eval()
return trainer |