from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler import os import torch import cv2 import numpy as np import pandas as pd import random from tqdm import tqdm import webdataset as wds from PIL import Image from diffusers.utils import load_image, make_image_grid def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): # initialize the dimensions of the image to be resized and # grab the image size dim = None (h, w) = image.shape[:2] # if both the width and height are None, then return the # original image if width is None and height is None: return image # check to see if the width is None if width is None: # calculate the ratio of the height and construct the # dimensions r = height / float(h) dim = (int(w * r), height) # otherwise, the height is None else: # calculate the ratio of the width and construct the # dimensions r = width / float(w) dim = (width, int(h * r)) # resize the image resized = cv2.resize(image, dim, interpolation = inter) # return the resized image return resized # choose the base model here base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" # base_model_path = "runwayml/stable-diffusion-v1-5" # input brushnet ckpt path brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" # choose whether using blended operation blended = False out_dir = "cc3m_synthetic" captions=["A photo of a man",] print("Generating images with captions:") for caption in captions: print("\t", caption) mask_root = "../SemanticSegmentation/mask_skin" cc3m_original_path = "PASTE_HERE" # conditioning scale brushnet_conditioning_scale = 1.0 seed = 12345 brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16) pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False, safety_checker=None, requires_safety_checker=False ) n_steps = 50 high_noise_frac = 0.8 # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) # remove following line if xformers is not installed or when using Torch 2.0. # pipe.enable_xformers_memory_efficient_attention() # memory optimization. pipe.enable_model_cpu_offload() generated = 0 dataset = wds.WebDataset(cc3m_original_path + "/00{000..331}.tar").decode("rgb8").to_tuple("jpg;png", "json") for sample in tqdm(dataset): img = sample[0] name = sample[1]["key"] mask_path = f"{mask_root}/{name}.png" if os.path.exists(mask_path) and not os.path.exists(f"examples/brushnet/{out_dir}/{name}_man.png"): init_image = img mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255) # resize image h,w,_ = init_image.shape if w 0.9) < 50 else False if not is_empty: for caption in captions: generator = torch.Generator("cuda").manual_seed(random.randint(3000000000, 6000000000)) image = pipe( caption, init_image, mask_image, num_inference_steps=10, generator=generator, brushnet_conditioning_scale=brushnet_conditioning_scale ).images[0] if blended: init_image_np = np.asarray(load_image(f"examples/brushnet/{out_dir}/{name}_original.png")) image_np = cv2.resize(np.array(image), (init_image_np.shape[:2]), interpolation=cv2.INTER_AREA) mask_np = np.asarray(load_image(mask_path)) _, mask_np = cv2.threshold(mask_np, 122, 255, cv2.THRESH_BINARY) mask_np = mask_np / 255 image_pasted = init_image_np * (1 - mask_np) + image_np * mask_np image_pasted = image_pasted.astype(image_np.dtype) image = image_resize(np.asarray(image), width=256)[:,:,::-1] cv2.imwrite(f"examples/brushnet/{out_dir}/{name}_man.png", image)