|
from diffusers import DiffusionPipeline |
|
from concurrent.futures import ThreadPoolExecutor |
|
import pandas as pd |
|
import argparse |
|
import torch |
|
import os |
|
|
|
|
|
ALL_CKPTS = [ |
|
"runwayml/stable-diffusion-v1-5", |
|
"segmind/SSD-1B", |
|
"PixArt-alpha/PixArt-XL-2-1024-MS", |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
"stabilityai/sdxl-turbo", |
|
] |
|
SEED = 2024 |
|
|
|
|
|
def load_dataframe(): |
|
dataframe = pd.read_csv( |
|
"https://huggingface.co/datasets/sayakpaul/sample-datasets/raw/main/coco_30k_randomly_sampled_2014_val.csv" |
|
) |
|
return dataframe |
|
|
|
|
|
def load_pipeline(args): |
|
if "runway" in args.pipeline_id: |
|
pipeline = DiffusionPipeline.from_pretrained( |
|
args.pipeline_id, torch_dtype=torch.float16, safety_checker=None |
|
).to("cuda") |
|
else: |
|
pipeline = DiffusionPipeline.from_pretrained(args.pipeline_id, torch_dtype=torch.float16).to("cuda") |
|
pipeline.set_progress_bar_config(disable=True) |
|
return pipeline |
|
|
|
|
|
def generate_images(args, dataframe, pipeline): |
|
all_images = [] |
|
for i in range(0, len(dataframe), args.chunk_size): |
|
if "sdxl-turbo" not in args.pipeline_id: |
|
images = pipeline( |
|
dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(), |
|
num_inference_steps=args.num_inference_steps, |
|
generator=torch.manual_seed(SEED), |
|
).images |
|
else: |
|
images = pipeline( |
|
dataframe.iloc[i : i + args.chunk_size]["caption"].tolist(), |
|
num_inference_steps=args.num_inference_steps, |
|
generator=torch.manual_seed(SEED), |
|
guidance_scale=0.0, |
|
).images |
|
all_images.extend(images) |
|
return all_images |
|
|
|
|
|
def serialize_image(image, path): |
|
image.save(path) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--pipeline_id", default="runwayml/stable-diffusion-v1-5", type=str, choices=ALL_CKPTS) |
|
parser.add_argument("--num_inference_steps", default=30, type=int) |
|
parser.add_argument("--chunk_size", default=2, type=int) |
|
parser.add_argument("--root_img_path", default="sdv15", type=str) |
|
parser.add_argument("--num_workers", type=int, default=4) |
|
args = parser.parse_args() |
|
|
|
dataset = load_dataframe() |
|
pipeline = load_pipeline(args) |
|
images = generate_images(args, dataset, pipeline) |
|
image_paths = [os.path.join(args.root_img_path, f"{i}.jpg") for i in range(len(images))] |
|
|
|
if not os.path.exists(args.root_img_path): |
|
os.makedirs(args.root_img_path) |
|
|
|
with ThreadPoolExecutor(max_workers=args.num_workers) as executor: |
|
executor.map(serialize_image, images, image_paths) |
|
|