|
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler |
|
import os |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
import random |
|
from tqdm import tqdm |
|
|
|
import webdataset as wds |
|
|
|
from PIL import Image |
|
from diffusers.utils import load_image, make_image_grid |
|
|
|
|
|
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): |
|
|
|
|
|
dim = None |
|
(h, w) = image.shape[:2] |
|
|
|
|
|
|
|
if width is None and height is None: |
|
return image |
|
|
|
|
|
if width is None: |
|
|
|
|
|
r = height / float(h) |
|
dim = (int(w * r), height) |
|
|
|
|
|
else: |
|
|
|
|
|
r = width / float(w) |
|
dim = (width, int(h * r)) |
|
|
|
|
|
resized = cv2.resize(image, dim, interpolation = inter) |
|
|
|
|
|
return resized |
|
|
|
|
|
|
|
|
|
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE" |
|
|
|
|
|
|
|
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt" |
|
|
|
|
|
blended = False |
|
|
|
out_dir = "cc3m_synthetic" |
|
captions=["A photo of a man",] |
|
|
|
|
|
print("Generating images with captions:") |
|
for caption in captions: |
|
print("\t", caption) |
|
|
|
|
|
mask_root = "../SemanticSegmentation/mask_skin" |
|
cc3m_original_path = "PASTE_HERE" |
|
|
|
brushnet_conditioning_scale = 1.0 |
|
seed = 12345 |
|
|
|
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16) |
|
pipe = StableDiffusionBrushNetPipeline.from_pretrained( |
|
base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False, safety_checker=None, requires_safety_checker=False |
|
) |
|
|
|
n_steps = 50 |
|
high_noise_frac = 0.8 |
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
generated = 0 |
|
dataset = wds.WebDataset(cc3m_original_path + "/00{000..331}.tar").decode("rgb8").to_tuple("jpg;png", "json") |
|
for sample in tqdm(dataset): |
|
img = sample[0] |
|
name = sample[1]["key"] |
|
mask_path = f"{mask_root}/{name}.png" |
|
|
|
if os.path.exists(mask_path) and not os.path.exists(f"examples/brushnet/{out_dir}/{name}_man.png"): |
|
|
|
init_image = img |
|
mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255) |
|
|
|
|
|
h,w,_ = init_image.shape |
|
if w<h: |
|
scale=768/w |
|
else: |
|
scale=768/h |
|
new_h=int(h*scale) |
|
new_w=int(w*scale) |
|
|
|
init_image=cv2.resize(init_image,(new_w,new_h)) |
|
mask_image=cv2.resize(mask_image,(new_w,new_h))[:,:,np.newaxis] |
|
|
|
init_image = init_image * (1-mask_image) |
|
|
|
init_image = Image.fromarray(init_image.astype(np.uint8)) |
|
mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255) |
|
|
|
|
|
is_empty = True if np.sum(img > 0.9) < 50 else False |
|
|
|
if not is_empty: |
|
for caption in captions: |
|
generator = torch.Generator("cuda").manual_seed(random.randint(3000000000, 6000000000)) |
|
image = pipe( |
|
caption, |
|
init_image, |
|
mask_image, |
|
num_inference_steps=10, |
|
generator=generator, |
|
brushnet_conditioning_scale=brushnet_conditioning_scale |
|
).images[0] |
|
|
|
|
|
if blended: |
|
init_image_np = np.asarray(load_image(f"examples/brushnet/{out_dir}/{name}_original.png")) |
|
image_np = cv2.resize(np.array(image), (init_image_np.shape[:2]), interpolation=cv2.INTER_AREA) |
|
mask_np = np.asarray(load_image(mask_path)) |
|
_, mask_np = cv2.threshold(mask_np, 122, 255, cv2.THRESH_BINARY) |
|
mask_np = mask_np / 255 |
|
image_pasted = init_image_np * (1 - mask_np) + image_np * mask_np |
|
image_pasted = image_pasted.astype(image_np.dtype) |
|
|
|
image = image_resize(np.asarray(image), width=256)[:,:,::-1] |
|
cv2.imwrite(f"examples/brushnet/{out_dir}/{name}_man.png", image) |
|
|