Spaces:
Build error
Build error
import argparse | |
import os | |
import random | |
import sys | |
from pathlib import Path | |
import torch | |
from PIL import Image | |
from joblib import Parallel, delayed | |
from torch.utils.data import DataLoader | |
from torchvision import transforms as T | |
from tqdm.auto import tqdm | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from scripts.rotate_train import Trainer | |
from utils.train import seed_everything | |
from utils.image_utils import list_image_files | |
toTensor = T.ToTensor() | |
net_trainer = Trainer() | |
def load_image(path): | |
return toTensor(Image.open(path)) | |
def load_dataset_images(imgs): | |
tensors_images = Parallel(n_jobs=-1)( | |
delayed(load_image)(os.path.join(args.FFHQ, str(img))) for img in tqdm(imgs)) | |
tensors_dataloader = DataLoader(tensors_images, batch_size=32, pin_memory=False, shuffle=False, drop_last=False) | |
images, key_points, latents = [], [], [] | |
for batch in tqdm(tensors_dataloader): | |
batch = batch.to(net_trainer.device) | |
images_256 = net_trainer.downsample_256(batch).clip(0, 1) | |
images.extend(images_256.cpu()) | |
latents.extend(net_trainer.generate_latents(images_256 * 2 - 1).cpu()) | |
key_points.extend(net_trainer.generate_key_points(batch).cpu()) | |
return images, key_points, latents | |
def main(args): | |
seed_everything(args.seed) | |
images = list_image_files(args.FFHQ) | |
random.shuffle(images) | |
images, key_points, latents = load_dataset_images(images[:args.size]) | |
torch.save({'images': images, 'key_points': key_points, 'latents': latents}, args.output) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Rotate dataset') | |
parser.add_argument('--FFHQ', type=Path) | |
parser.add_argument('--seed', type=int, default=3407) | |
parser.add_argument('--size', type=int, default=10_000) | |
parser.add_argument('--output', type=Path, default='input/rotate_dataset.pkl') | |
args = parser.parse_args() | |
main(args) | |