AnchorIT_ZS-CIR_BNU / brushedit_app_new_0404_cirr_blip1.py
Jasmine402's picture
Upload folder using huggingface_hub
e0fc0a8 verified
##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os, random, sys
import numpy as np
import requests
import torch
from pathlib import Path
import pandas as pd
import concurrent.futures
import faiss
import gradio as gr
from pathlib import Path
import os
import json
from PIL import Image
import torch.nn.functional as F # 新增此行
from huggingface_hub import hf_hub_download, snapshot_download
from scipy.ndimage import binary_dilation, binary_erosion
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from app.src.vlm_pipeline import (
vlm_response_editing_type,
vlm_response_object_wait_for_edit,
vlm_response_mask,
vlm_response_prompt_after_apply_instruction
)
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
from app.utils.utils import load_grounding_dino_model
from app.src.vlm_template import vlms_template
from app.src.base_model_template import base_models_template
from app.src.aspect_ratio_template import aspect_ratios
from openai import OpenAI
base_openai_url = "https://api.deepseek.com/"
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import CLIPProcessor, CLIPModel
from app.deepseek.instructions import (
create_apply_editing_messages_deepseek,
create_decomposed_query_messages_deepseek
)
from clip_retrieval.clip_client import ClipClient
#### Description ####
logo = r"""
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
"""
head = r"""
<div style="text-align: center;">
<h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
<a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
</div>
</br>
</div>
"""
descriptions = r"""
Demo for ZS-CIR"""
instructions = r"""
Demo for ZS-CIR"""
tips = r"""
Demo for ZS-CIR
"""
citation = r"""
Demo for ZS-CIR"""
# - - - - - examples - - - - - #
EXAMPLES = [
[
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
"add a magic hat on frog head.",
642087011,
"frog",
"frog",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
"replace the background to ancient China.",
648464818,
"chinese_girl",
"chinese_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
"remove the deer.",
648464818,
"angel_christmas",
"angel_christmas",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
"add a wreath on head.",
648464818,
"sunflower_girl",
"sunflower_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
"add a butterfly fairy.",
648464818,
"girl_on_sun",
"girl_on_sun",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
"remove the christmas hat.",
642087011,
"spider_man_rm",
"spider_man_rm",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
"remove the flower.",
642087011,
"anime_flower",
"anime_flower",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
"replace the clothes to a delicated floral skirt.",
648464818,
"chenduling",
"chenduling",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
"make the hedgehog in Italy.",
648464818,
"hedgehog_rp_bg",
"hedgehog_rp_bg",
True,
False,
"GPT4-o (Highly Recommended)"
],
]
INPUT_IMAGE_PATH = {
"frog": "./assets/frog/frog.jpeg",
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
"anime_flower": "./assets/anime_flower/anime_flower.png",
"chenduling": "./assets/chenduling/chengduling.jpg",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
}
MASK_IMAGE_PATH = {
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
MASKED_IMAGE_PATH = {
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
OUTPUT_IMAGE_PATH = {
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
}
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
# os.makedirs('gradio_temp_dir', exist_ok=True)
VLM_MODEL_NAMES = list(vlms_template.keys())
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
BASE_MODELS = list(base_models_template.keys())
DEFAULT_BASE_MODEL = "realisticVision (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
## init device
try:
if torch.cuda.is_available():
device = "cuda"
elif sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
except:
device = "cpu"
# ## init torch dtype
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
# torch_dtype = torch.bfloat16
# else:
# torch_dtype = torch.float16
# if device == "mps":
# torch_dtype = torch.float16
torch_dtype = torch.float16
# download hf models
BrushEdit_path = "models/"
if not os.path.exists(BrushEdit_path):
BrushEdit_path = snapshot_download(
repo_id="TencentARC/BrushEdit",
local_dir=BrushEdit_path,
token=os.getenv("HF_TOKEN"),
)
## init default VLM
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
else:
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
## init default LLM
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
## init base model
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
# input brushnetX ckpt path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
## init SAM
sam = build_sam(checkpoint=sam_path)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
sam_automask_generator = SamAutomaticMaskGenerator(sam)
## init groundingdino_model
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
## Ordinary function
def crop_and_resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width, original_height = image.size
original_aspect = original_width / original_height
target_aspect = target_width / target_height
# Calculate crop box to maintain aspect ratio
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_height = original_height
left = (original_width - new_width) / 2
top = 0
right = left + new_width
bottom = original_height
else:
# Crop vertically
new_width = original_width
new_height = int(original_width / target_aspect)
left = 0
top = (original_height - new_height) / 2
right = original_width
bottom = top + new_height
# Crop and resize
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
return resized_image
## Ordinary function
def resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
resized_image = image.resize((target_width, target_height), Image.NEAREST)
return resized_image
def move_mask_func(mask, direction, units):
binary_mask = mask.squeeze()>0
rows, cols = binary_mask.shape
moved_mask = np.zeros_like(binary_mask, dtype=bool)
if direction == 'down':
# move down
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
elif direction == 'up':
# move up
moved_mask[:rows - units, :] = binary_mask[units:, :]
elif direction == 'right':
# move left
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
elif direction == 'left':
# move right
moved_mask[:, :cols - units] = binary_mask[:, units:]
return moved_mask
def random_mask_func(mask, dilation_type='square', dilation_size=20):
# Randomly select the size of dilation
binary_mask = mask.squeeze()>0
if dilation_type == 'square_dilation':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_dilation(binary_mask, structure=structure)
elif dilation_type == 'square_erosion':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_erosion(binary_mask, structure=structure)
elif dilation_type == 'bounding_box':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# create a bounding box
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
elif dilation_type == 'bounding_ellipse':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# calculate the center and axis length of the ellipse
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
a = (max_col - min_col) // 2 # half long axis
b = (max_row - min_row) // 2 # half short axis
# create a bounding ellipse
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[ellipse_mask] = True
else:
ValueError("dilation_type must be 'square' or 'ellipse'")
# use binary dilation
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
return dilated_mask
## Gradio component function
def update_vlm_model(vlm_name):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
if vlm_type == "llava-next":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-next-72b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
elif vlm_type == "qwen2-vl":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_type == "openai":
pass
return "success"
def update_base_model(base_model_name):
global pipe
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
if pipe is not None:
del pipe
torch.cuda.empty_cache()
base_model_path, pipe = base_models_template[base_model_name]
if pipe != "":
pipe.to(device)
else:
if os.path.exists(base_model_path):
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
else:
raise gr.Error(f"The base model {base_model_name} does not exist")
return "success"
def process_random_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_dilation_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_dilation'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_erosion_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_erosion'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def move_mask_left(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_right(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_up(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_down(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def invert_mask(input_image,
original_image,
original_mask,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
if input_mask.max() == 0:
original_mask = 1 - (original_mask>0).astype(np.uint8)
else:
original_mask = 1 - (input_mask>0).astype(np.uint8)
if original_mask is None:
raise gr.Error('Please generate mask first')
original_mask = original_mask.squeeze()
mask_image = Image.fromarray(original_mask*255).convert("RGB")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if original_mask.max() <= 1:
original_mask = (original_mask * 255).astype(np.uint8)
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask, True
def reset_func(input_image,
original_image,
original_mask,
prompt,
target_prompt,
):
input_image = None
original_image = None
original_mask = None
prompt = ''
mask_gallery = []
masked_gallery = []
result_gallery = []
target_prompt = ''
if torch.cuda.is_available():
torch.cuda.empty_cache()
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
def update_example(example_type,
prompt,
example_change_times):
input_image = INPUT_IMAGE_PATH[example_type]
image_pil = Image.open(input_image).convert("RGB")
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_image = np.array(image_pil)
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
aspect_ratio = "Custom resolution"
example_change_times += 1
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
def generate_target_prompt(input_image,
original_image,
prompt):
# load example image
if isinstance(original_image, str):
original_image = input_image
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
original_image,
prompt,
device)
return prompt_after_apply_instruction
from app.utils.utils import generate_caption
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
def init_img(base,
init_type,
prompt,
aspect_ratio,
example_change_times
):
image_pil = base["background"].convert("RGB")
original_image = np.array(image_pil)
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
raise gr.Error('image aspect ratio cannot be larger than 2.0')
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
else:
if aspect_ratio not in ASPECT_RATIO_LABELS:
aspect_ratio = "Custom resolution"
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
def process_mask(input_image,
original_image,
prompt,
resize_default,
aspect_ratio_name):
if original_image is None:
raise gr.Error('Please upload the input image')
if prompt is None:
raise gr.Error("Please input your instructions, e.g., remove the xxx")
## load mask
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.array(alpha_mask)
# load example image
if isinstance(original_image, str):
original_image = input_image["background"]
if input_mask.max() == 0:
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
vlm_model,
original_image,
category,
prompt,
device)
# original mask: h,w,1 [0, 255]
original_mask = vlm_response_mask(
vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device).astype(np.uint8)
else:
original_mask = input_mask.astype(np.uint8)
category = None
## resize mask if needed
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
def process(input_image,
original_image,
original_mask,
prompt,
negative_prompt,
control_strength,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
num_samples,
blending,
category,
target_prompt,
resize_default,
aspect_ratio_name,
invert_mask_state):
if original_image is None:
if input_image is None:
raise gr.Error('Please upload the input image')
else:
image_pil = input_image["background"].convert("RGB")
original_image = np.array(image_pil)
if prompt is None or prompt == "":
if target_prompt is None or target_prompt == "":
raise gr.Error("Please input your instructions, e.g., remove the xxx")
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if invert_mask_state:
original_mask = original_mask
else:
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
# inpainting directly if target_prompt is not None
if category is not None:
pass
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
pass
else:
try:
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
if original_mask is not None:
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
else:
try:
object_wait_for_edit = vlm_response_object_wait_for_edit(
vlm_processor,
vlm_model,
original_image,
category,
prompt,
device)
original_mask = vlm_response_mask(vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device).astype(np.uint8)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if target_prompt is not None and len(target_prompt) >= 1:
prompt_after_apply_instruction = target_prompt
else:
try:
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
original_image,
prompt,
device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
with torch.autocast(device):
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
prompt_after_apply_instruction,
original_mask,
original_image,
generator,
num_inference_steps,
guidance_scale,
control_strength,
negative_prompt,
num_samples,
blending)
original_image = np.array(init_image_np)
masked_image = original_image * (1 - (mask_np>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
return image, [mask_image], [masked_image], prompt, '', False
def process_cirr_images():
# 初始化VLM/SAM模型(需补充实际加载代码)
global vlm_model, sam_predictor, groundingdino_model
if not all([vlm_model, sam_predictor, groundingdino_model]):
raise RuntimeError("Required models not initialized")
# Define paths
dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
output_dirs = {
"edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_edited"),
"mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_mask"),
"masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_masked")
}
# Create output directories
for dir_path in output_dirs.values():
dir_path.mkdir(parents=True, exist_ok=True)
# Load captions
with open(cap_file, 'r') as f:
captions = json.load(f)
descriptions = {}
for img_path in dev_dir.glob("*.png"):
base_name = img_path.stem
caption = next((item["caption"] for item in captions if item.get("reference") == base_name), None)
if not caption:
print(f"Warning: No caption for {base_name}")
continue
try:
# 关键修改1:构造空alpha通道(全0)
rgb_image = Image.open(img_path).convert("RGB")
empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# 关键修改2:调用init_img初始化
base = {"background": image, "layers": [image]}
init_results = init_img(
base=base,
init_type="custom", # 使用自定义初始化
prompt=caption,
aspect_ratio="Custom resolution",
example_change_times=0
)
# 获取初始化后的参数
input_image = init_results[0]
original_image = init_results[1]
original_mask = init_results[2]
# 关键修改3:正确设置process参数
process_results = process(
input_image=input_image,
original_image=original_image,
original_mask=original_mask, # 传递初始化后的mask
prompt=caption,
negative_prompt="ugly, low quality",
control_strength=1.0,
seed=648464818,
randomize_seed=False,
guidance_scale=7.5,
num_inference_steps=50,
num_samples=1,
blending=True,
category=None,
target_prompt="",
resize_default=True,
aspect_ratio_name="Custom resolution",
invert_mask_state=False
)
# 结果处理(保持原有逻辑)
result_images, mask_images, masked_images = process_results[:3]
# Save images
output_dirs["edited"].mkdir(exist_ok=True)
result_images[0].save(output_dirs["edited"] / f"{base_name}.png")
mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.png")
masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.png")
# Generate BLIP description
blip_desc, _ = generate_blip_description({"background": image})
descriptions[base_name] = {
"original_caption": caption,
"blip_description": blip_desc
}
print(f"Processed {base_name}")
except Exception as e:
print(f"Error processing {base_name}: {str(e)}")
continue
# Save descriptions
with open("/home/zt/data/BrushEdit/cirr/cirr_description_fix.json", 'w') as f:
json.dump(descriptions, f, indent=4)
print("Processing completed!")
# def process_cirr_images():
# # Define paths
# dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
# cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
# output_dirs = {
# "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_edited"),
# "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_mask"),
# "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_masked")
# }
# # Create output directories if they don't exist
# for dir_path in output_dirs.values():
# dir_path.mkdir(parents=True, exist_ok=True)
# # Load captions from JSON file
# with open(cap_file, 'r') as f:
# captions = json.load(f)
# # Initialize description dictionary
# descriptions = {}
# # Process each PNG image in dev directory
# for img_path in dev_dir.glob("*.png"):
# # Get base name without extension
# base_name = img_path.stem
# # Find matching caption
# caption = None
# for item in captions:
# if item.get("reference") == base_name:
# caption = item.get("caption")
# break
# if caption is None:
# print(f"Warning: No caption found for {base_name}")
# continue
# # Load and convert image to RGB
# try:
# rgb_image = Image.open(img_path).convert("RGB")
# a = Image.new("L", rgb_image.size, 255) # 全不透明alpha通道
# image = Image.merge("RGBA", (*rgb_image.split(), a))
# except Exception as e:
# print(f"Error loading image {img_path}: {e}")
# continue
# # Generate BLIP description
# try:
# blip_desc, _ = generate_blip_description({"background": image})
# except Exception as e:
# print(f"Error generating BLIP description for {base_name}: {e}")
# continue
# # Process image
# try:
# # Prepare input parameters for process function
# input_image = {"background": image, "layers": [image]}
# original_image = np.array(image)
# original_mask = None
# prompt = caption
# negative_prompt = "ugly, low quality"
# control_strength = 1.0
# seed = 648464818
# randomize_seed = False
# guidance_scale = 7.5
# num_inference_steps = 50
# num_samples = 1
# blending = True
# category = None
# target_prompt = ""
# resize_default = True
# aspect_ratio = "Custom resolution"
# invert_mask_state = False
# # Call process function and handle return values properly
# process_results = process(
# input_image,
# original_image,
# original_mask,
# prompt,
# negative_prompt,
# control_strength,
# seed,
# randomize_seed,
# guidance_scale,
# num_inference_steps,
# num_samples,
# blending,
# category,
# target_prompt,
# resize_default,
# aspect_ratio,
# invert_mask_state
# )
# # Extract results safely
# result_images = process_results[0]
# mask_images = process_results[1]
# masked_images = process_results[2]
# # Ensure we have valid images to save
# if not result_images or not mask_images or not masked_images:
# print(f"Warning: No output images generated for {base_name}")
# continue
# # Save processed images
# # Save edited image
# edited_path = output_dirs["edited"] / f"{base_name}.png"
# if isinstance(result_images, (list, tuple)):
# result_images[0].save(edited_path)
# else:
# result_images.save(edited_path)
# # Save mask image
# mask_path = output_dirs["mask"] / f"{base_name}_mask.png"
# if isinstance(mask_images, (list, tuple)):
# mask_images[0].save(mask_path)
# else:
# mask_images.save(mask_path)
# # Save masked image
# masked_path = output_dirs["masked"] / f"{base_name}_masked.png"
# if isinstance(masked_images, (list, tuple)):
# masked_images[0].save(masked_path)
# else:
# masked_images.save(masked_path)
# # Store description
# descriptions[base_name] = {
# "original_caption": caption,
# "blip_description": blip_desc
# }
# print(f"Successfully processed {base_name}")
# except Exception as e:
# print(f"Error processing image {base_name}: {e}")
# continue
# # Save descriptions to JSON file
# with open("/home/zt/data/BrushEdit/cirr/cirr_description_fix.json", 'w') as f:
# json.dump(descriptions, f, indent=4)
# print("Processing completed!")
def generate_blip_description(input_image):
if input_image is None:
return "", "Input image cannot be None"
try:
image_pil = input_image["background"].convert("RGB")
except KeyError:
return "", "Input image missing 'background' key"
except AttributeError as e:
return "", f"Invalid image object: {str(e)}"
try:
description = generate_caption(blip_processor, blip_model, image_pil, device)
return description, description # 同时更新state和显示组件
except Exception as e:
return "", f"Caption generation failed: {str(e)}"
def submit_GPT4o_KEY(GPT4o_KEY):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
try:
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
vlm_processor = ""
response = vlm_model.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."}
]
)
response_str = response.choices[0].message.content
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
except Exception as e:
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
def verify_deepseek_api():
try:
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."}
]
)
response_str = response.choices[0].message.content
return True, "Success. " + response_str
except Exception as e:
return False, "Invalid DeepSeek API Key"
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
try:
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=messages
)
response_str = response.choices[0].message.content
return response_str
except Exception as e:
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
try:
messages = create_decomposed_query_messages_deepseek(integrated_query)
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=messages
)
response_str = response.choices[0].message.content
return response_str
except Exception as e:
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
def enhance_description(blip_description, prompt):
try:
if not prompt or not blip_description:
print("Empty prompt or blip_description detected")
return "", ""
print(f"Enhancing with prompt: {prompt}")
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
return enhanced_description, enhanced_description
except Exception as e:
print(f"Enhancement failed: {str(e)}")
return "Error occurred", "Error occurred"
def decompose_description(enhanced_description):
try:
if not enhanced_description:
print("Empty enhanced_description detected")
return "", ""
print(f"Decomposing the enhanced description: {enhanced_description}")
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
return decomposed_description, decomposed_description
except Exception as e:
print(f"Decomposition failed: {str(e)}")
return "Error occurred", "Error occurred"
@torch.no_grad()
def mix_and_search(enhanced_text: str, gallery_images: list):
# 获取最新生成的图像元组
latest_item = gallery_images[-1] if gallery_images else None
# 初始化特征列表
features = []
# 图像特征提取
if latest_item and isinstance(latest_item, tuple):
try:
image_path = latest_item[0]
pil_image = Image.open(image_path).convert("RGB")
# 使用 CLIPProcessor 处理图像
image_inputs = clip_processor(
images=pil_image,
return_tensors="pt"
).to(device)
image_features = clip_model.get_image_features(**image_inputs)
features.append(F.normalize(image_features, dim=-1))
except Exception as e:
print(f"图像处理失败: {str(e)}")
# 文本特征提取
if enhanced_text.strip():
text_inputs = clip_processor(
text=enhanced_text,
return_tensors="pt",
padding=True,
truncation=True
).to(device)
text_features = clip_model.get_text_features(**text_inputs)
features.append(F.normalize(text_features, dim=-1))
if not features:
return []
# 特征融合与检索
mixed = sum(features) / len(features)
mixed = F.normalize(mixed, dim=-1)
# 加载Faiss索引和图片路径映射
# index_path = "/home/zt/data/open-images/train/knn.index"
# input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
# base_image_dir = Path("/home/zt/data/open-images/train/")
index_path = "/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index"
input_data_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata")
base_image_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/")
# 按文件名中的数字排序并直接读取parquet文件
parquet_files = sorted(
input_data_dir.glob('*.parquet'),
key=lambda x: int(x.stem.split("_")[-1])
)
# 合并所有parquet数据
dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
df = pd.concat(dfs, ignore_index=True)
image_paths = df["image_path"].tolist()
# 读取Faiss索引
index = faiss.read_index(index_path)
assert mixed.shape[1] == index.d, "特征维度不匹配"
# 执行检索
mixed = mixed.cpu().detach().numpy().astype('float32')
distances, indices = index.search(mixed, 50)
# 获取并验证图片路径
retrieved_images = []
for idx in indices[0]:
if 0 <= idx < len(image_paths):
img_path = base_image_dir / image_paths[idx]
try:
if img_path.exists():
retrieved_images.append(Image.open(img_path).convert("RGB"))
else:
print(f"警告:文件缺失 {img_path}")
except Exception as e:
print(f"图片加载失败: {str(e)}")
return retrieved_images if retrieved_images else ([])
if __name__ == "__main__":
process_cirr_images()
# block = gr.Blocks(
# theme=gr.themes.Soft(
# radius_size=gr.themes.sizes.radius_none,
# text_size=gr.themes.sizes.text_md
# )
# )
# with block as demo:
# with gr.Row():
# with gr.Column():
# gr.HTML(head)
# gr.Markdown(descriptions)
# with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
# with gr.Row(equal_height=True):
# gr.Markdown(instructions)
# original_image = gr.State(value=None)
# original_mask = gr.State(value=None)
# category = gr.State(value=None)
# status = gr.State(value=None)
# invert_mask_state = gr.State(value=False)
# example_change_times = gr.State(value=0)
# deepseek_verified = gr.State(value=False)
# blip_description = gr.State(value="")
# enhanced_description = gr.State(value="")
# decomposed_description = gr.State(value="")
# with gr.Row():
# with gr.Column():
# with gr.Group():
# input_image = gr.ImageEditor(
# label="参考图像",
# type="pil",
# brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
# layers = False,
# interactive=True,
# # height=1024,
# height=412,
# sources=["upload"],
# placeholder="🫧 点击此处或下面的图标上传图像 🫧",
# )
# prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
# with gr.Group():
# mask_button = gr.Button("💎 掩膜生成")
# with gr.Row():
# invert_mask_button = gr.Button("👐 掩膜翻转")
# random_mask_button = gr.Button("⭕️ 随机掩膜")
# with gr.Row():
# masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
# mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
# with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
# dilation_size = gr.Slider(
# label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
# )
# with gr.Row():
# dilation_mask_button = gr.Button("放大掩膜")
# erosion_mask_button = gr.Button("缩小掩膜")
# moving_pixels = gr.Slider(
# label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
# )
# with gr.Row():
# move_left_button = gr.Button("左移")
# move_right_button = gr.Button("右移")
# with gr.Row():
# move_up_button = gr.Button("上移")
# move_down_button = gr.Button("下移")
# with gr.Column():
# with gr.Row():
# deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=2, container=False)
# verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
# blip_output = gr.Textbox(label="1. 原图描述(BLIP生成)", placeholder="🖼️ 上传图片后自动生成图片描述 🖼️", lines=2, interactive=True)
# with gr.Row():
# enhanced_output = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击右侧按钮生成增强描述 🚀")
# enhance_button = gr.Button("✨ 智能整合")
# with gr.Row():
# decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述 📝")
# decompose_button = gr.Button("🔧 结构分解")
# with gr.Group():
# run_button = gr.Button("💫 图像编辑")
# result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
# with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
# vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
# with gr.Group():
# with gr.Row():
# # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
# GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
# GPT4o_KEY_submit = gr.Button("🔑 验证密钥")
# aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
# resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
# base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
# negative_prompt = gr.Text(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
# control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
# with gr.Group():
# seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
# randomize_seed = gr.Checkbox(label="随机种子", value=False)
# blending = gr.Checkbox(label="混合模式", value=True)
# num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
# with gr.Group():
# with gr.Row():
# guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
# num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
# target_prompt = gr.Text(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
# init_type = gr.Textbox(label="Init Name", value="", visible=False)
# example_type = gr.Textbox(label="Example Name", value="", visible=False)
# with gr.Row():
# reset_button = gr.Button("Reset")
# retrieve_button = gr.Button("🔍 开始检索")
# with gr.Row():
# retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=660)
# with gr.Row():
# example = gr.Examples(
# label="Quick Example",
# examples=EXAMPLES,
# inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
# examples_per_page=10,
# cache_examples=False,
# )
# with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
# with gr.Row(equal_height=True):
# gr.Markdown(tips)
# with gr.Row():
# gr.Markdown(citation)
# ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
# ## And we need to solve the conflict between the upload and change example functions.
# input_image.upload(
# init_img,
# [input_image, init_type, prompt, aspect_ratio, example_change_times],
# [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
# )
# example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
# ## vlm and base model dropdown
# vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
# base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
# GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
# ips=[input_image,
# original_image,
# original_mask,
# prompt,
# negative_prompt,
# control_strength,
# seed,
# randomize_seed,
# guidance_scale,
# num_inference_steps,
# num_samples,
# blending,
# category,
# target_prompt,
# resize_default,
# aspect_ratio,
# invert_mask_state]
# ## run brushedit
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
# ## mask func
# mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
# random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
# dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
# erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
# invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
# ## reset func
# reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
# input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
# verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
# enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
# decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
# retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
# demo.launch(server_name="0.0.0.0", server_port=12345, share=True)