Spaces:
Build error
Build error
File size: 1,980 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import argparse
import os
import random
import sys
from pathlib import Path
import torch
from PIL import Image
from joblib import Parallel, delayed
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tqdm.auto import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from scripts.rotate_train import Trainer
from utils.train import seed_everything
from utils.image_utils import list_image_files
toTensor = T.ToTensor()
net_trainer = Trainer()
def load_image(path):
return toTensor(Image.open(path))
@torch.no_grad()
def load_dataset_images(imgs):
tensors_images = Parallel(n_jobs=-1)(
delayed(load_image)(os.path.join(args.FFHQ, str(img))) for img in tqdm(imgs))
tensors_dataloader = DataLoader(tensors_images, batch_size=32, pin_memory=False, shuffle=False, drop_last=False)
images, key_points, latents = [], [], []
for batch in tqdm(tensors_dataloader):
batch = batch.to(net_trainer.device)
images_256 = net_trainer.downsample_256(batch).clip(0, 1)
images.extend(images_256.cpu())
latents.extend(net_trainer.generate_latents(images_256 * 2 - 1).cpu())
key_points.extend(net_trainer.generate_key_points(batch).cpu())
return images, key_points, latents
def main(args):
seed_everything(args.seed)
images = list_image_files(args.FFHQ)
random.shuffle(images)
images, key_points, latents = load_dataset_images(images[:args.size])
torch.save({'images': images, 'key_points': key_points, 'latents': latents}, args.output)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Rotate dataset')
parser.add_argument('--FFHQ', type=Path)
parser.add_argument('--seed', type=int, default=3407)
parser.add_argument('--size', type=int, default=10_000)
parser.add_argument('--output', type=Path, default='input/rotate_dataset.pkl')
args = parser.parse_args()
main(args)
|