|
from PIL import Image, ImageEnhance |
|
from diffusers.image_processor import VaeImageProcessor |
|
|
|
import numpy as np |
|
import cv2 |
|
|
|
|
|
|
|
def BrushEdit_Pipeline(pipe, |
|
prompts, |
|
mask_np, |
|
original_image, |
|
generator, |
|
num_inference_steps, |
|
guidance_scale, |
|
control_strength, |
|
negative_prompt, |
|
num_samples, |
|
blending): |
|
if mask_np.ndim != 3: |
|
mask_np = mask_np[:, :, np.newaxis] |
|
|
|
mask_np = mask_np / 255 |
|
height, width = mask_np.shape[0], mask_np.shape[1] |
|
|
|
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) |
|
height_new, width_new = image_processor.get_default_height_width(original_image, height, width) |
|
mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis] |
|
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255 |
|
mask_blurred = mask_blurred[:, :, np.newaxis] |
|
|
|
original_image = cv2.resize(original_image, (width_new, height_new)) |
|
|
|
init_image = original_image * (1 - mask_np) |
|
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") |
|
mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB") |
|
|
|
brushnet_conditioning_scale = float(control_strength) |
|
|
|
images = pipe( |
|
[prompts] * num_samples, |
|
init_image, |
|
mask_image, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
brushnet_conditioning_scale=brushnet_conditioning_scale, |
|
negative_prompt=[negative_prompt]*num_samples, |
|
height=height_new, |
|
width=width_new, |
|
).images |
|
|
|
original_image_pil = Image.fromarray(original_image).convert("RGB") |
|
init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze()) |
|
init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255 |
|
init_image_np = init_image_np.astype(np.uint8) |
|
if blending: |
|
mask_blurred = mask_blurred * 0.5 + 0.5 |
|
image_all = [] |
|
for image_i in images: |
|
image_np = np.array(image_i) |
|
|
|
image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np |
|
image_pasted = image_pasted.astype(np.uint8) |
|
image = Image.fromarray(image_pasted) |
|
image_all.append(image) |
|
else: |
|
image_all = images |
|
|
|
|
|
return image_all, mask_image, mask_np, init_image_np |
|
|
|
|
|
|