##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os, random, sys
import numpy as np
import requests
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download, snapshot_download
from scipy.ndimage import binary_dilation, binary_erosion
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from app.src.vlm_pipeline import (
vlm_response_editing_type,
vlm_response_object_wait_for_edit,
vlm_response_mask,
vlm_response_prompt_after_apply_instruction
)
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
from app.utils.utils import load_grounding_dino_model
from app.src.vlm_template import vlms_template
from app.src.base_model_template import base_models_template
from app.src.aspect_ratio_template import aspect_ratios
from openai import OpenAI
base_openai_url = "https://api.deepseek.com/"
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
from transformers import BlipProcessor, BlipForConditionalGeneration
from app.deepseek.instructions import (
create_apply_editing_messages_deepseek,
create_decomposed_query_messages_deepseek
)
#### Description ####
logo = r"""
"""
head = r"""
基于扩散模型先验和大语言模型的零样本组合查询图像检索
"""
descriptions = r"""
Demo for ZS-CIR"""
instructions = r"""
Demo for ZS-CIR"""
tips = r"""
Demo for ZS-CIR
"""
citation = r"""
Demo for ZS-CIR"""
# - - - - - examples - - - - - #
EXAMPLES = [
[
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
"add a magic hat on frog head.",
642087011,
"frog",
"frog",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
"replace the background to ancient China.",
648464818,
"chinese_girl",
"chinese_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
"remove the deer.",
648464818,
"angel_christmas",
"angel_christmas",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
"add a wreath on head.",
648464818,
"sunflower_girl",
"sunflower_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
"add a butterfly fairy.",
648464818,
"girl_on_sun",
"girl_on_sun",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
"remove the christmas hat.",
642087011,
"spider_man_rm",
"spider_man_rm",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
"remove the flower.",
642087011,
"anime_flower",
"anime_flower",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
"replace the clothes to a delicated floral skirt.",
648464818,
"chenduling",
"chenduling",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
"make the hedgehog in Italy.",
648464818,
"hedgehog_rp_bg",
"hedgehog_rp_bg",
True,
False,
"GPT4-o (Highly Recommended)"
],
]
INPUT_IMAGE_PATH = {
"frog": "./assets/frog/frog.jpeg",
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
"anime_flower": "./assets/anime_flower/anime_flower.png",
"chenduling": "./assets/chenduling/chengduling.jpg",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
}
MASK_IMAGE_PATH = {
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
MASKED_IMAGE_PATH = {
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
OUTPUT_IMAGE_PATH = {
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
}
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
# os.makedirs('gradio_temp_dir', exist_ok=True)
VLM_MODEL_NAMES = list(vlms_template.keys())
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
BASE_MODELS = list(base_models_template.keys())
DEFAULT_BASE_MODEL = "realisticVision (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
## init device
try:
if torch.cuda.is_available():
device = "cuda"
elif sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
except:
device = "cpu"
# ## init torch dtype
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
# torch_dtype = torch.bfloat16
# else:
# torch_dtype = torch.float16
# if device == "mps":
# torch_dtype = torch.float16
torch_dtype = torch.float16
# download hf models
BrushEdit_path = "models/"
if not os.path.exists(BrushEdit_path):
BrushEdit_path = snapshot_download(
repo_id="TencentARC/BrushEdit",
local_dir=BrushEdit_path,
token=os.getenv("HF_TOKEN"),
)
## init default VLM
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
else:
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
## init default LLM
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
## init base model
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
# input brushnetX ckpt path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
## init SAM
sam = build_sam(checkpoint=sam_path)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
sam_automask_generator = SamAutomaticMaskGenerator(sam)
## init groundingdino_model
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
## Ordinary function
def crop_and_resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width, original_height = image.size
original_aspect = original_width / original_height
target_aspect = target_width / target_height
# Calculate crop box to maintain aspect ratio
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_height = original_height
left = (original_width - new_width) / 2
top = 0
right = left + new_width
bottom = original_height
else:
# Crop vertically
new_width = original_width
new_height = int(original_width / target_aspect)
left = 0
top = (original_height - new_height) / 2
right = original_width
bottom = top + new_height
# Crop and resize
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
return resized_image
## Ordinary function
def resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
resized_image = image.resize((target_width, target_height), Image.NEAREST)
return resized_image
def move_mask_func(mask, direction, units):
binary_mask = mask.squeeze()>0
rows, cols = binary_mask.shape
moved_mask = np.zeros_like(binary_mask, dtype=bool)
if direction == 'down':
# move down
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
elif direction == 'up':
# move up
moved_mask[:rows - units, :] = binary_mask[units:, :]
elif direction == 'right':
# move left
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
elif direction == 'left':
# move right
moved_mask[:, :cols - units] = binary_mask[:, units:]
return moved_mask
def random_mask_func(mask, dilation_type='square', dilation_size=20):
# Randomly select the size of dilation
binary_mask = mask.squeeze()>0
if dilation_type == 'square_dilation':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_dilation(binary_mask, structure=structure)
elif dilation_type == 'square_erosion':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_erosion(binary_mask, structure=structure)
elif dilation_type == 'bounding_box':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# create a bounding box
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
elif dilation_type == 'bounding_ellipse':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# calculate the center and axis length of the ellipse
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
a = (max_col - min_col) // 2 # half long axis
b = (max_row - min_row) // 2 # half short axis
# create a bounding ellipse
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[ellipse_mask] = True
else:
ValueError("dilation_type must be 'square' or 'ellipse'")
# use binary dilation
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
return dilated_mask
## Gradio component function
def update_vlm_model(vlm_name):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
if vlm_type == "llava-next":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-next-72b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
elif vlm_type == "qwen2-vl":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_type == "openai":
pass
return "success"
def update_base_model(base_model_name):
global pipe
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
if pipe is not None:
del pipe
torch.cuda.empty_cache()
base_model_path, pipe = base_models_template[base_model_name]
if pipe != "":
pipe.to(device)
else:
if os.path.exists(base_model_path):
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
else:
raise gr.Error(f"The base model {base_model_name} does not exist")
return "success"
def process(input_image,
original_image,
original_mask,
prompt,
negative_prompt,
control_strength,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
num_samples,
blending,
category,
target_prompt,
resize_default,
aspect_ratio_name,
invert_mask_state):
if original_image is None:
if input_image is None:
raise gr.Error('Please upload the input image')
else:
print("input_image的键:", input_image.keys()) # 打印字典键
image_pil = input_image["background"].convert("RGB")
original_image = np.array(image_pil)
if prompt is None or prompt == "":
if target_prompt is None or target_prompt == "":
raise gr.Error("Please input your instructions, e.g., remove the xxx")
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if invert_mask_state:
original_mask = original_mask
else:
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
## inpainting directly if target_prompt is not None
if category is not None:
pass
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
pass
else:
try:
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
if original_mask is not None:
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
else:
try:
object_wait_for_edit = vlm_response_object_wait_for_edit(
vlm_processor,
vlm_model,
original_image,
category,
prompt,
device)
original_mask = vlm_response_mask(vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device).astype(np.uint8)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if target_prompt is not None and len(target_prompt) >= 1:
prompt_after_apply_instruction = target_prompt
else:
try:
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
original_image,
prompt,
device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
with torch.autocast(device):
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
prompt_after_apply_instruction,
original_mask,
original_image,
generator,
num_inference_steps,
guidance_scale,
control_strength,
negative_prompt,
num_samples,
blending)
original_image = np.array(init_image_np)
masked_image = original_image * (1 - (mask_np>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
# Save the images (optional)
# import uuid
# uuid = str(uuid.uuid4())
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
# mask_image.save(f"outputs/mask_{uuid}.png")
# masked_image.save(f"outputs/masked_image_{uuid}.png")
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
return image, [mask_image], [masked_image], prompt, '', False
def process_mask(input_image,
original_image,
prompt,
resize_default,
aspect_ratio_name):
if original_image is None:
raise gr.Error('Please upload the input image')
if prompt is None:
raise gr.Error("Please input your instructions, e.g., remove the xxx")
## load mask
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.array(alpha_mask)
# load example image
if isinstance(original_image, str):
original_image = input_image["background"]
if input_mask.max() == 0:
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
vlm_model,
original_image,
category,
prompt,
device)
# original mask: h,w,1 [0, 255]
original_mask = vlm_response_mask(
vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device).astype(np.uint8)
else:
original_mask = input_mask.astype(np.uint8)
category = None
## resize mask if needed
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
def process_random_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_dilation_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_dilation'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_erosion_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_erosion'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def move_mask_left(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_right(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_up(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_down(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def invert_mask(input_image,
original_image,
original_mask,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
if input_mask.max() == 0:
original_mask = 1 - (original_mask>0).astype(np.uint8)
else:
original_mask = 1 - (input_mask>0).astype(np.uint8)
if original_mask is None:
raise gr.Error('Please generate mask first')
original_mask = original_mask.squeeze()
mask_image = Image.fromarray(original_mask*255).convert("RGB")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if original_mask.max() <= 1:
original_mask = (original_mask * 255).astype(np.uint8)
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask, True
def init_img(base,
init_type,
prompt,
aspect_ratio,
example_change_times
):
image_pil = base["background"].convert("RGB")
original_image = np.array(image_pil)
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
raise gr.Error('image aspect ratio cannot be larger than 2.0')
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
else:
if aspect_ratio not in ASPECT_RATIO_LABELS:
aspect_ratio = "Custom resolution"
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
def reset_func(input_image,
original_image,
original_mask,
prompt,
target_prompt,
):
input_image = None
original_image = None
original_mask = None
prompt = ''
mask_gallery = []
masked_gallery = []
result_gallery = []
target_prompt = ''
if torch.cuda.is_available():
torch.cuda.empty_cache()
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
def update_example(example_type,
prompt,
example_change_times):
input_image = INPUT_IMAGE_PATH[example_type]
image_pil = Image.open(input_image).convert("RGB")
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_image = np.array(image_pil)
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
aspect_ratio = "Custom resolution"
example_change_times += 1
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
def generate_target_prompt(input_image,
original_image,
prompt):
# load example image
if isinstance(original_image, str):
original_image = input_image
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
original_image,
prompt,
device)
return prompt_after_apply_instruction
# 新增事件处理函数
def generate_blip_description(input_image):
if input_image is None:
return "", "Input image cannot be None"
from app.utils.utils import generate_caption
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
try:
image_pil = input_image["background"].convert("RGB")
except KeyError:
return "", "Input image missing 'background' key"
except AttributeError as e:
return "", f"Invalid image object: {str(e)}"
try:
description = generate_caption(blip_processor, blip_model, image_pil, device)
return description, description # 同时更新state和显示组件
except Exception as e:
return "", f"Caption generation failed: {str(e)}"
def submit_GPT4o_KEY(GPT4o_KEY):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
try:
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
vlm_processor = ""
response = vlm_model.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."}
]
)
response_str = response.choices[0].message.content
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
except Exception as e:
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
def verify_deepseek_api():
try:
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."}
]
)
response_str = response.choices[0].message.content
return True, "Success. " + response_str
except Exception as e:
return False, "Invalid DeepSeek API Key"
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
try:
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=messages
)
response_str = response.choices[0].message.content
return response_str
except Exception as e:
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
try:
messages = create_decomposed_query_messages_deepseek(integrated_query)
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=messages
)
response_str = response.choices[0].message.content
return response_str
except Exception as e:
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
def enhance_description(blip_description, prompt):
try:
if not prompt or not blip_description:
print("Empty prompt or blip_description detected")
return "", ""
print(f"Enhancing with prompt: {prompt}")
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
return enhanced_description, enhanced_description
except Exception as e:
print(f"Enhancement failed: {str(e)}")
return "Error occurred", "Error occurred"
def decompose_description(enhanced_description):
try:
if not enhanced_description:
print("Empty enhanced_description detected")
return "", ""
print(f"Decomposing the enhanced description: {enhanced_description}")
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
return decomposed_description, decomposed_description
except Exception as e:
print(f"Decomposition failed: {str(e)}")
return "Error occurred", "Error occurred"
block = gr.Blocks(
theme=gr.themes.Soft(
radius_size=gr.themes.sizes.radius_none,
text_size=gr.themes.sizes.text_md
)
)
with block as demo:
with gr.Row():
with gr.Column():
gr.HTML(head)
gr.Markdown(descriptions)
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
with gr.Row(equal_height=True):
gr.Markdown(instructions)
original_image = gr.State(value=None)
original_mask = gr.State(value=None)
category = gr.State(value=None)
status = gr.State(value=None)
invert_mask_state = gr.State(value=False)
example_change_times = gr.State(value=0)
deepseek_verified = gr.State(value=False)
blip_description = gr.State(value="")
enhanced_description = gr.State(value="")
decomposed_description = gr.State(value="")
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.ImageEditor(
label="参考图像",
type="pil",
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
layers = False,
interactive=True,
# height=1024,
height=512,
sources=["upload"],
placeholder="🫧 点击此处或下面的图标上传图像 🫧",
)
prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
run_button = gr.Button("💫 图像编辑")
vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
with gr.Group():
with gr.Row():
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
GPT4o_KEY_submit = gr.Button("🙈 验证")
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
with gr.Row():
mask_button = gr.Button("💎 掩膜生成")
random_mask_button = gr.Button("Square/Circle Mask ")
with gr.Row():
generate_target_prompt_button = gr.Button("Generate Target Prompt")
target_prompt = gr.Text(
label="Input Target Prompt",
max_lines=5,
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
value='',
lines=2
)
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
negative_prompt = gr.Text(
label="Negative Prompt",
max_lines=5,
placeholder="Please input your negative prompt",
value='ugly, low quality',lines=1
)
control_strength = gr.Slider(
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
)
with gr.Group():
seed = gr.Slider(
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
blending = gr.Checkbox(label="Blending mode", value=True)
num_samples = gr.Slider(
label="Num samples", minimum=0, maximum=4, step=1, value=4
)
with gr.Group():
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=12,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50,
)
with gr.Group(visible=True):
# BLIP生成的描述
blip_output = gr.Textbox(label="原图描述", placeholder="💬 BLIP生成的图像基础描述 💬", interactive=True, lines=3)
# DeepSeek API验证
with gr.Row():
deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
verify_deepseek = gr.Button("🙈 验证")
# 整合后的描述区域
with gr.Row():
enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3)
enhance_button = gr.Button("✨ 整合")
# 分解后的描述区域
with gr.Row():
decomposed_output = gr.Textbox(label="描述分解", placeholder="🔍 DeepSeek生成的分解描述 🔍", interactive=True, lines=3)
decompose_button = gr.Button("🔧 分解")
with gr.Row():
with gr.Tab(elem_classes="feedback", label="Masked Image"):
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
with gr.Tab(elem_classes="feedback", label="Mask"):
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
invert_mask_button = gr.Button("Invert Mask")
dilation_size = gr.Slider(
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
)
with gr.Row():
dilation_mask_button = gr.Button("Dilation Generated Mask")
erosion_mask_button = gr.Button("Erosion Generated Mask")
moving_pixels = gr.Slider(
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
)
with gr.Row():
move_left_button = gr.Button("Move Left")
move_right_button = gr.Button("Move Right")
with gr.Row():
move_up_button = gr.Button("Move Up")
move_down_button = gr.Button("Move Down")
with gr.Tab(elem_classes="feedback", label="Output"):
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
reset_button = gr.Button("Reset")
init_type = gr.Textbox(label="Init Name", value="", visible=False)
example_type = gr.Textbox(label="Example Name", value="", visible=False)
with gr.Row():
example = gr.Examples(
label="Quick Example",
examples=EXAMPLES,
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
examples_per_page=10,
cache_examples=False,
)
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
with gr.Row(equal_height=True):
gr.Markdown(tips)
with gr.Row():
gr.Markdown(citation)
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
## And we need to solve the conflict between the upload and change example functions.
input_image.upload(
init_img,
[input_image, init_type, prompt, aspect_ratio, example_change_times],
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
)
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
## vlm and base model dropdown
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
ips=[input_image,
original_image,
original_mask,
prompt,
negative_prompt,
control_strength,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
num_samples,
blending,
category,
target_prompt,
resize_default,
aspect_ratio,
invert_mask_state]
## run brushedit
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
## mask func
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
## move mask func
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
## prompt func
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
## reset func
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
# 绑定事件处理
input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
enhance_button.click(fn=enhance_description, inputs=[blip_description, prompt], outputs=[enhanced_description, enhanced_output])
decompose_button.click(fn=decompose_description, inputs=[enhanced_description], outputs=[decomposed_description, decomposed_output])
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)