|
import torch |
|
from diffusers import ControlNetModel, AutoencoderKL |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
from tqdm.auto import tqdm |
|
from transformers import pipeline as transformers_pipeline |
|
|
|
from relighting.pipeline import CustomStableDiffusionControlNetInpaintPipeline |
|
from relighting.pipeline_inpaintonly import CustomStableDiffusionInpaintPipeline, CustomStableDiffusionXLInpaintPipeline |
|
from relighting.argument import SAMPLERS, VAE_MODELS, DEPTH_ESTIMATOR, get_control_signal_type |
|
from relighting.image_processor import ( |
|
estimate_scene_depth, |
|
estimate_scene_normal, |
|
merge_normal_map, |
|
fill_depth_circular |
|
) |
|
from relighting.ball_processor import get_ideal_normal_ball, crop_ball |
|
import pickle |
|
|
|
from relighting.pipeline_xl import CustomStableDiffusionXLControlNetInpaintPipeline |
|
|
|
class NoWaterMark: |
|
def apply_watermark(self, *args, **kwargs): |
|
return args[0] |
|
|
|
class ControlSignalGenerator(): |
|
def __init__(self, sd_arch, control_signal_type, device): |
|
self.sd_arch = sd_arch |
|
self.control_signal_type = control_signal_type |
|
self.device = device |
|
|
|
def process_sd_depth(self, input_image, normal_ball=None, mask_ball=None, x=None, y=None, r=None): |
|
if getattr(self, 'depth_estimator', None) is None: |
|
self.depth_estimator = transformers_pipeline("depth-estimation", device=self.device.index) |
|
|
|
control_image = self.depth_estimator(input_image)['depth'] |
|
control_image = np.array(control_image) |
|
control_image = control_image[:, :, None] |
|
control_image = np.concatenate([control_image, control_image, control_image], axis=2) |
|
control_image = Image.fromarray(control_image) |
|
|
|
control_image = fill_depth_circular(control_image, x, y, r) |
|
return control_image |
|
|
|
def process_sdxl_depth(self, input_image, normal_ball=None, mask_ball=None, x=None, y=None, r=None): |
|
if getattr(self, 'depth_estimator', None) is None: |
|
self.depth_estimator = transformers_pipeline("depth-estimation", model=DEPTH_ESTIMATOR, device=self.device.index) |
|
|
|
control_image = estimate_scene_depth(input_image, depth_estimator=self.depth_estimator) |
|
xs = [x] if not isinstance(x, list) else x |
|
ys = [y] if not isinstance(y, list) else y |
|
rs = [r] if not isinstance(r, list) else r |
|
|
|
for x, y, r in zip(xs, ys, rs): |
|
|
|
control_image = fill_depth_circular(control_image, x, y, r) |
|
return control_image |
|
|
|
def process_sd_normal(self, input_image, normal_ball, mask_ball, x, y, r=None, normal_ball_path=None): |
|
if getattr(self, 'depth_estimator', None) is None: |
|
self.depth_estimator = transformers_pipeline("depth-estimation", model=DEPTH_ESTIMATOR, device=self.device.index) |
|
|
|
normal_scene = estimate_scene_normal(input_image, depth_estimator=self.depth_estimator) |
|
normal_image = merge_normal_map(normal_scene, normal_ball, mask_ball, x, y) |
|
normal_image = (normal_image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) |
|
control_image = Image.fromarray(normal_image) |
|
return control_image |
|
|
|
def __call__(self, *args, **kwargs): |
|
process_fn = getattr(self, f"process_{self.sd_arch}_{self.control_signal_type}", None) |
|
if process_fn is None: |
|
raise ValueError |
|
else: |
|
return process_fn(*args, **kwargs) |
|
|
|
|
|
class BallInpainter(): |
|
def __init__(self, pipeline, sd_arch, control_generator, disable_water_mask=True): |
|
self.pipeline = pipeline |
|
self.sd_arch = sd_arch |
|
self.control_generator = control_generator |
|
self.median = {} |
|
if disable_water_mask: |
|
self._disable_water_mask() |
|
|
|
def _disable_water_mask(self): |
|
if hasattr(self.pipeline, "watermark"): |
|
self.pipeline.watermark = NoWaterMark() |
|
print("Disabled watermasking") |
|
|
|
@classmethod |
|
def from_sd(cls, |
|
model, |
|
controlnet=None, |
|
device=0, |
|
sampler="unipc", |
|
torch_dtype=torch.float16, |
|
disable_water_mask=True, |
|
offload=False |
|
): |
|
if controlnet is not None: |
|
control_signal_type = get_control_signal_type(controlnet) |
|
controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch.float16) |
|
pipe = CustomStableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
model, |
|
controlnet=controlnet, |
|
torch_dtype=torch_dtype, |
|
).to(device) |
|
control_generator = ControlSignalGenerator("sd", control_signal_type, device=device) |
|
else: |
|
pipe = CustomStableDiffusionInpaintPipeline.from_pretrained( |
|
model, |
|
torch_dtype=torch_dtype, |
|
).to(device) |
|
control_generator = None |
|
|
|
try: |
|
if torch_dtype==torch.float16 and device != torch.device("cpu"): |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
pass |
|
pipe.set_progress_bar_config(disable=True) |
|
|
|
pipe.scheduler = SAMPLERS[sampler].from_config(pipe.scheduler.config) |
|
|
|
return BallInpainter(pipe, "sd", control_generator, disable_water_mask) |
|
|
|
@classmethod |
|
def from_sdxl(cls, |
|
model, |
|
controlnet=None, |
|
device=0, |
|
sampler="unipc", |
|
torch_dtype=torch.float16, |
|
disable_water_mask=True, |
|
use_fixed_vae=True, |
|
offload=False |
|
): |
|
vae = VAE_MODELS["sdxl"] |
|
vae = AutoencoderKL.from_pretrained(vae, torch_dtype=torch_dtype).to(device) if use_fixed_vae else None |
|
extra_kwargs = {"vae": vae} if vae is not None else {} |
|
|
|
if controlnet is not None: |
|
control_signal_type = get_control_signal_type(controlnet) |
|
controlnet = ControlNetModel.from_pretrained( |
|
controlnet, |
|
variant="fp16" if torch_dtype == torch.float16 else None, |
|
use_safetensors=True, |
|
torch_dtype=torch_dtype, |
|
).to(device) |
|
pipe = CustomStableDiffusionXLControlNetInpaintPipeline.from_pretrained( |
|
model, |
|
controlnet=controlnet, |
|
variant="fp16" if torch_dtype == torch.float16 else None, |
|
use_safetensors=True, |
|
torch_dtype=torch_dtype, |
|
**extra_kwargs, |
|
).to(device) |
|
control_generator = ControlSignalGenerator("sdxl", control_signal_type, device=device) |
|
else: |
|
pipe = CustomStableDiffusionXLInpaintPipeline.from_pretrained( |
|
model, |
|
variant="fp16" if torch_dtype == torch.float16 else None, |
|
use_safetensors=True, |
|
torch_dtype=torch_dtype, |
|
**extra_kwargs, |
|
).to(device) |
|
control_generator = None |
|
|
|
try: |
|
if torch_dtype==torch.float16 and device != torch.device("cpu"): |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
pass |
|
|
|
if offload and device != torch.device("cpu"): |
|
pipe.enable_model_cpu_offload() |
|
pipe.set_progress_bar_config(disable=True) |
|
pipe.scheduler = SAMPLERS[sampler].from_config(pipe.scheduler.config) |
|
|
|
return BallInpainter(pipe, "sdxl", control_generator, disable_water_mask) |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.pipeline(*args, **kwargs) |
|
|
|
def _default_height_width(self, height=None, width=None): |
|
if (height is not None) and (width is not None): |
|
return height, width |
|
if self.sd_arch == "sd": |
|
return (512, 512) |
|
elif self.sd_arch == "sdxl": |
|
return (1024, 1024) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def get_cache_control_image(self): |
|
control_image = getattr(self, "cache_control_image", None) |
|
return control_image |
|
|
|
def prepare_control_signal(self, image, controlnet_conditioning_scale, extra_kwargs): |
|
if self.control_generator is not None: |
|
control_image = self.control_generator(image, **extra_kwargs) |
|
controlnet_kwargs = { |
|
"control_image": control_image, |
|
"controlnet_conditioning_scale": controlnet_conditioning_scale |
|
} |
|
self.cache_control_image = control_image |
|
else: |
|
controlnet_kwargs = {} |
|
|
|
return controlnet_kwargs |
|
|
|
def get_cache_median(self, it): |
|
if it in self.median: return self.median[it] |
|
else: return None |
|
|
|
def reset_median(self): |
|
self.median = {} |
|
print("Reset median") |
|
|
|
def load_median(self, path): |
|
if os.path.exists(path): |
|
with open(path, "rb") as f: |
|
self.median = pickle.load(f) |
|
print(f"Loaded median from {path}") |
|
else: |
|
print(f"Median not found at {path}!") |
|
|
|
def inpaint_iterative( |
|
self, |
|
prompt=None, |
|
negative_prompt="", |
|
num_inference_steps=30, |
|
generator=None, |
|
image=None, |
|
mask_image=None, |
|
height=None, |
|
width=None, |
|
controlnet_conditioning_scale=0.5, |
|
num_images_per_prompt=1, |
|
current_seed=0, |
|
cross_attention_kwargs={}, |
|
strength=0.8, |
|
num_iteration=2, |
|
ball_per_iteration=30, |
|
agg_mode="median", |
|
save_intermediate=True, |
|
cache_dir="./temp_inpaint_iterative", |
|
disable_progress=False, |
|
prompt_embeds=None, |
|
pooled_prompt_embeds=None, |
|
use_cache_median=False, |
|
guidance_scale=5.0, |
|
**extra_kwargs, |
|
): |
|
def computeMedian(ball_images): |
|
all = np.stack(ball_images, axis=0) |
|
median = np.median(all, axis=0) |
|
idx_median = np.argsort(all, axis=0)[all.shape[0]//2] |
|
|
|
|
|
return median, idx_median |
|
|
|
def generate_balls(avg_image, current_strength, ball_per_iteration, current_iteration): |
|
print(f"Inpainting balls for {current_iteration} iteration...") |
|
controlnet_kwargs = self.prepare_control_signal( |
|
image=avg_image, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
extra_kwargs=extra_kwargs, |
|
) |
|
|
|
ball_images = [] |
|
for i in tqdm(range(ball_per_iteration), disable=disable_progress): |
|
seed = current_seed + i |
|
new_generator = torch.Generator().manual_seed(seed) |
|
output_image = self.pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=num_inference_steps, |
|
generator=new_generator, |
|
image=avg_image, |
|
mask_image=mask_image, |
|
height=height, |
|
width=width, |
|
num_images_per_prompt=num_images_per_prompt, |
|
strength=current_strength, |
|
newx=x, |
|
newy=y, |
|
newr=r, |
|
current_seed=seed, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
prompt_embeds=prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
guidance_scale=guidance_scale, |
|
**controlnet_kwargs |
|
).images[0] |
|
|
|
ball_image = crop_ball(output_image, mask_ball_for_crop, x, y, r) |
|
ball_images.append(ball_image) |
|
|
|
if save_intermediate: |
|
os.makedirs(os.path.join(cache_dir, str(current_iteration)), mode=0o777, exist_ok=True) |
|
output_image.save(os.path.join(cache_dir, str(current_iteration), f"raw_{i}.png")) |
|
Image.fromarray(ball_image).save(os.path.join(cache_dir, str(current_iteration), f"ball_{i}.png")) |
|
|
|
os.chmod(os.path.join(cache_dir, str(current_iteration), f"raw_{i}.png"), 0o0777) |
|
os.chmod(os.path.join(cache_dir, str(current_iteration), f"ball_{i}.png"), 0o0777) |
|
|
|
|
|
return ball_images |
|
|
|
if save_intermediate: |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
height, width = self._default_height_width(height, width) |
|
|
|
x = extra_kwargs["x"] |
|
y = extra_kwargs["y"] |
|
r = 256 if "r" not in extra_kwargs else extra_kwargs["r"] |
|
_, mask_ball_for_crop = get_ideal_normal_ball(size=r) |
|
|
|
|
|
avg_image = image |
|
ball_images = generate_balls( |
|
avg_image, |
|
current_strength=1.0, |
|
ball_per_iteration=ball_per_iteration, |
|
current_iteration=0, |
|
) |
|
|
|
|
|
image = np.array(image) |
|
for it in range(1, num_iteration+1): |
|
if use_cache_median and (self.get_cache_median(it) is not None): |
|
print("Use existing median") |
|
all = np.stack(ball_images, axis=0) |
|
idx_median = self.get_cache_median(it) |
|
avg_ball = all[idx_median, |
|
np.arange(idx_median.shape[0])[:, np.newaxis, np.newaxis], |
|
np.arange(idx_median.shape[1])[np.newaxis, :, np.newaxis], |
|
np.arange(idx_median.shape[2])[np.newaxis, np.newaxis, :] |
|
] |
|
else: |
|
avg_ball, idx_median = computeMedian(ball_images) |
|
print("Add new median") |
|
self.median[it] = idx_median |
|
|
|
avg_image = merge_normal_map(image, avg_ball, mask_ball_for_crop, x, y) |
|
avg_image = Image.fromarray(avg_image.astype(np.uint8)) |
|
if save_intermediate: |
|
avg_image.save(os.path.join(cache_dir, f"average_{it}.png")) |
|
|
|
os.chmod(os.path.join(cache_dir, f"average_{it}.png"), 0o0777) |
|
|
|
ball_images = generate_balls( |
|
avg_image, |
|
current_strength=strength, |
|
ball_per_iteration=ball_per_iteration if it < num_iteration else 1, |
|
current_iteration=it, |
|
) |
|
|
|
|
|
best_ball = ball_images[0] |
|
output_image = merge_normal_map(image, best_ball, mask_ball_for_crop, x, y) |
|
return Image.fromarray(output_image.astype(np.uint8)) |
|
|
|
def inpaint( |
|
self, |
|
prompt=None, |
|
negative_prompt=None, |
|
num_inference_steps=30, |
|
generator=None, |
|
image=None, |
|
mask_image=None, |
|
height=None, |
|
width=None, |
|
controlnet_conditioning_scale=0.5, |
|
num_images_per_prompt=1, |
|
strength=1.0, |
|
current_seed=0, |
|
cross_attention_kwargs={}, |
|
prompt_embeds=None, |
|
pooled_prompt_embeds=None, |
|
guidance_scale=5.0, |
|
**extra_kwargs, |
|
): |
|
height, width = self._default_height_width(height, width) |
|
|
|
controlnet_kwargs = self.prepare_control_signal( |
|
image=image, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
extra_kwargs=extra_kwargs, |
|
) |
|
|
|
if generator is None: |
|
generator = torch.Generator().manual_seed(0) |
|
|
|
output_image = self.pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
image=image, |
|
mask_image=mask_image, |
|
height=height, |
|
width=width, |
|
num_images_per_prompt=num_images_per_prompt, |
|
strength=strength, |
|
newx = extra_kwargs["x"], |
|
newy = extra_kwargs["y"], |
|
newr = getattr(extra_kwargs, "r", 256), |
|
current_seed=current_seed, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
prompt_embeds=prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
guidance_scale=guidance_scale, |
|
**controlnet_kwargs |
|
) |
|
|
|
return output_image |