File size: 2,689 Bytes
d344462 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from diffusers import DiffusionPipeline
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
import argparse
import torch
import os
ALL_CKPTS = [
"runwayml/stable-diffusion-v1-5",
"segmind/SSD-1B",
"PixArt-alpha/PixArt-XL-2-1024-MS",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
SEED = 2024
def load_dataframe():
dataframe = pd.read_csv(
"https://huggingface.co/datasets/sayakpaul/sample-datasets/raw/main/coco_30k_randomly_sampled_2014_val.csv"
)
return dataframe
def load_pipeline(args):
if "runway" in args.pipeline_id:
pipeline = DiffusionPipeline.from_pretrained(
args.pipeline_id, torch_dtype=torch.float16, safety_checker=None
).to("cuda")
else:
pipeline = DiffusionPipeline.from_pretrained(args.pipeline_id, torch_dtype=torch.float16).to("cuda")
pipeline.set_progress_bar_config(disable=True)
return pipeline
def generate_images(args, dataframe, pipeline):
all_images = []
for i in range(0, len(dataframe), args.chunk_size):
if "sdxl-turbo" not in args.pipeline_id:
images = pipeline(
dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(),
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(SEED),
).images
else:
images = pipeline(
dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(),
num_inference_steps=args.num_inference_steps,
generator=torch.manual_seed(SEED),
guidance_scale=0.0,
).images
all_images.extend(images)
return all_images
def serialize_image(image, path):
image.save(path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pipeline_id", default="runwayml/stable-diffusion-v1-5", type=str, choices=ALL_CKPTS)
parser.add_argument("--num_inference_steps", default=30, type=int)
parser.add_argument("--chunk_size", default=2, type=int)
parser.add_argument("--root_img_path", default="sdv15", type=str)
parser.add_argument("--num_workers", type=int, default=4)
args = parser.parse_args()
dataset = load_dataframe()
pipeline = load_pipeline(args)
images = generate_images(args, dataset, pipeline)
image_paths = [os.path.join(args.root_img_path, f"{i}.jpg") for i in range(len(images))]
if not os.path.exists(args.root_img_path):
os.makedirs(args.root_img_path)
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
executor.map(serialize_image, images, image_paths)
|