AnchorIT_ZS-CIR_BNU / brushedit_app_new.py
Jasmine402's picture
Upload folder using huggingface_hub
e0fc0a8 verified
##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os, random, sys
import numpy as np
import requests
import torch
from pathlib import Path
import pandas as pd
import concurrent.futures
import faiss
import gradio as gr
from pathlib import Path
import os
import json
import logging
from datetime import datetime
import time
from PIL import Image
from torch import nn
from typing import Dict, List, Tuple
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, snapshot_download
from scipy.ndimage import binary_dilation, binary_erosion
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from app.src.vlm_pipeline import (
vlm_response_editing_type,
vlm_response_object_wait_for_edit,
vlm_response_mask,
vlm_response_prompt_after_apply_instruction
)
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
from app.utils.utils import load_grounding_dino_model
from app.src.vlm_template import vlms_template
from app.src.base_model_template import base_models_template
from app.src.aspect_ratio_template import aspect_ratios
from openai import OpenAI
base_openai_url = "https://api.deepseek.com/"
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from transformers import CLIPProcessor, CLIPModel
from app.deepseek.instructions import (
create_apply_editing_messages_deepseek,
create_decomposed_query_messages_deepseek
)
from clip_retrieval.clip_client import ClipClient
#### Description ####
logo = r"""
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
"""
head = r"""
<div style="text-align: center; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;">
<h1 style="font-family: 'Georgia', serif; font-weight: 600; letter-spacing: 0.05em; line-height: 1.2;">
基于扩散模型先验和大语言模型的<br>
<span style="display: inline-block; margin-top: 0.5em;">零样本组合查询图像检索</span>
</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-top: 1em;">
<a href=''><img src="https://img.shields.io/badge/Project_Page-毕业设计_ZS_CIR-green" alt="Project Page"></a>
<a href=''><img src='https://img.shields.io/badge/Paper-Overleaf-blue'></a>
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
</div>
<br>
</div>
"""
descriptions = r"""
<div style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 800px; margin: 0 auto; color: #333; line-height: 1.6;">
<p style="font-size: 1.1rem; margin: 0 0 1.2rem 0; display: flex; align-items: center;">
<span style="font-size: 1.3rem; margin-right: 0.5rem;">🎨</span>
一个无需训练的组合图像检索的交互系统,支持通过文本指令修改参考图像并进行语义检索。
</p>
</div>
"""
instructions = r"""
<div style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 800px; margin: 0 auto; color: #333; line-height: 1.6;">
<ol style="padding-left: 1.5rem; margin: 0 0 1.2rem 0;">
<li style="margin-bottom: 0.5rem;"><strong>上传图像</strong>:点击画布或上传按钮添加参考图像</li>
<li style="margin-bottom: 0.5rem;"><strong>输入指令</strong>:在文本框中描述您想对图像进行的修改</li>
<li style="margin-bottom: 0.5rem;"><strong>生成掩膜</strong>:使用掩膜工具精确控制编辑区域</li>
<li style="margin-bottom: 0.5rem;"><strong>智能增强</strong>:系统会自动生成图像描述,并可进一步优化</li>
<li style="margin-bottom: 0.5rem;"><strong>执行编辑</strong>:点击"图像编辑"按钮生成修改后的图像</li>
<li style="margin-bottom: 0.5rem;"><strong>检索结果</strong>:点击"开始检索"获取相似图像结果</li>
</ol>
</div>
"""
tips = r"""
<div style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 800px; margin: 0 auto; color: #333; line-height: 1.6;">
<h4 style="margin: 1rem 0 0.5rem 0; display: flex; align-items: center;">
<span style="margin-right: 0.5rem;">🖌️</span> 图像编辑功能
</h4>
<ul style="padding-left: 1.5rem; margin: 0 0 1rem 0;">
<li>支持画笔工具创建精确掩膜</li>
<li>提供掩膜放大/缩小、翻转等操作</li>
<li>多参数控制生成效果</li>
</ul>
<h4 style="margin: 1rem 0 0.5rem 0; display: flex; align-items: center;">
<span style="margin-right: 0.5rem;">🧠</span> 智能描述系统
</h4>
<ul style="padding-left: 1.5rem; margin: 0 0 1rem 0;">
<li>自动生成图像描述(BLIP2)</li>
<li>指令增强生成优化提示词</li>
<li>结构化分解复杂描述</li>
</ul>
<h4 style="margin: 1rem 0 0.5rem 0; display: flex; align-items: center;">
<span style="margin-right: 0.5rem;">🔍</span> 高级检索能力
</h4>
<ul style="padding-left: 1.5rem; margin: 0 0 1rem 0;">
<li>零样本学习无需训练</li>
<li>结合视觉-语言模型理解</li>
<li>支持多模态查询(图像+文本)</li>
</ul>
<h4 style="margin: 1rem 0 0.5rem 0; display: flex; align-items: center;">
<span style="margin-right: 0.5rem;">⚙️</span> 技术参数调整
</h4>
<ul style="padding-left: 1.5rem; margin: 0 0 1rem 0;">
<li>可调节控制强度、引导尺度等</li>
<li>支持多种基础模型选择</li>
<li>自定义输出尺寸和比例</li>
</ul>
<h4 style="margin: 1rem 0 0.5rem 0; display: flex; align-items: center;">
<span style="margin-right: 0.5rem;">💡</span> 使用建议
</h4>
<ul style="padding-left: 1.5rem; margin: 0;">
<li>清晰具体的指令会得到更好结果</li>
<li>合理使用掩膜提高编辑精度</li>
<li>尝试不同参数组合优化效果</li>
</ul>
</div>
"""
citation = r"""
"""
# - - - - - examples - - - - - #
EXAMPLES = [
[
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
"add a magic hat on frog head.",
642087011,
"frog",
"frog",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
"replace the background to ancient China.",
648464818,
"chinese_girl",
"chinese_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
"remove the deer.",
648464818,
"angel_christmas",
"angel_christmas",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
"add a wreath on head.",
648464818,
"sunflower_girl",
"sunflower_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
"add a butterfly fairy.",
648464818,
"girl_on_sun",
"girl_on_sun",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
"remove the christmas hat.",
642087011,
"spider_man_rm",
"spider_man_rm",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
"remove the flower.",
642087011,
"anime_flower",
"anime_flower",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
"replace the clothes to a delicated floral skirt.",
648464818,
"chenduling",
"chenduling",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
"make the hedgehog in Italy.",
648464818,
"hedgehog_rp_bg",
"hedgehog_rp_bg",
True,
False,
"GPT4-o (Highly Recommended)"
],
]
INPUT_IMAGE_PATH = {
"frog": "./assets/frog/frog.jpeg",
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
"anime_flower": "./assets/anime_flower/anime_flower.png",
"chenduling": "./assets/chenduling/chengduling.jpg",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
}
MASK_IMAGE_PATH = {
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
MASKED_IMAGE_PATH = {
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
OUTPUT_IMAGE_PATH = {
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
}
VLM_MODEL_NAMES = list(vlms_template.keys())
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
BASE_MODELS = list(base_models_template.keys())
DEFAULT_BASE_MODEL = "realisticVision (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
## init device
try:
if torch.cuda.is_available():
device = "cuda"
elif sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
except:
device = "cpu"
# ## init torch dtype
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
# torch_dtype = torch.bfloat16
# else:
# torch_dtype = torch.float16
# if device == "mps":
# torch_dtype = torch.float16
torch_dtype = torch.float16
# download hf models
BrushEdit_path = "models/"
if not os.path.exists(BrushEdit_path):
BrushEdit_path = snapshot_download(
repo_id="TencentARC/BrushEdit",
local_dir=BrushEdit_path,
token=os.getenv("HF_TOKEN"),
)
## init default VLM
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
else:
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
## init default LLM
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
## init base model
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
# input brushnetX ckpt path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
## init SAM
sam = build_sam(checkpoint=sam_path)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
sam_automask_generator = SamAutomaticMaskGenerator(sam)
## init groundingdino_model
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24", torch_dtype=torch.float16).to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14",torch_dtype=torch.float16).to(device)
# Ordinary function
def crop_and_resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width, original_height = image.size
original_aspect = original_width / original_height
target_aspect = target_width / target_height
# Calculate crop box to maintain aspect ratio
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_height = original_height
left = (original_width - new_width) / 2
top = 0
right = left + new_width
bottom = original_height
else:
# Crop vertically
new_width = original_width
new_height = int(original_width / target_aspect)
left = 0
top = (original_height - new_height) / 2
right = original_width
bottom = top + new_height
# Crop and resize
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
return resized_image
def resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
resized_image = image.resize((target_width, target_height), Image.NEAREST)
return resized_image
def move_mask_func(mask, direction, units):
binary_mask = mask.squeeze()>0
rows, cols = binary_mask.shape
moved_mask = np.zeros_like(binary_mask, dtype=bool)
if direction == 'down':
# move down
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
elif direction == 'up':
# move up
moved_mask[:rows - units, :] = binary_mask[units:, :]
elif direction == 'right':
# move left
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
elif direction == 'left':
# move right
moved_mask[:, :cols - units] = binary_mask[:, units:]
return moved_mask
def random_mask_func(mask, dilation_type='square', dilation_size=20):
# Randomly select the size of dilation
binary_mask = mask.squeeze()>0
if dilation_type == 'square_dilation':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_dilation(binary_mask, structure=structure)
elif dilation_type == 'square_erosion':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_erosion(binary_mask, structure=structure)
elif dilation_type == 'bounding_box':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# create a bounding box
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
elif dilation_type == 'bounding_ellipse':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# calculate the center and axis length of the ellipse
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
a = (max_col - min_col) // 2 # half long axis
b = (max_row - min_row) // 2 # half short axis
# create a bounding ellipse
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[ellipse_mask] = True
else:
ValueError("dilation_type must be 'square' or 'ellipse'")
# use binary dilation
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
return dilated_mask
## Gradio component function
def update_vlm_model(vlm_name):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
if vlm_type == "llava-next":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-next-72b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
elif vlm_type == "qwen2-vl":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_type == "openai":
pass
return "success"
def update_base_model(base_model_name):
global pipe
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
if pipe is not None:
del pipe
torch.cuda.empty_cache()
base_model_path, pipe = base_models_template[base_model_name]
if pipe != "":
pipe.to(device)
else:
if os.path.exists(base_model_path):
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
else:
raise gr.Error(f"The base model {base_model_name} does not exist")
return "success"
def process_random_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_dilation_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_dilation'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_erosion_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_erosion'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def move_mask_left(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_right(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_up(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_down(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def invert_mask(input_image,
original_image,
original_mask,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
if input_mask.max() == 0:
original_mask = 1 - (original_mask>0).astype(np.uint8)
else:
original_mask = 1 - (input_mask>0).astype(np.uint8)
if original_mask is None:
raise gr.Error('Please generate mask first')
original_mask = original_mask.squeeze()
mask_image = Image.fromarray(original_mask*255).convert("RGB")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if original_mask.max() <= 1:
original_mask = (original_mask * 255).astype(np.uint8)
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask, True
def reset_func(input_image,
original_image,
original_mask,
prompt,
blip2_output,
target_prompt,
):
input_image = None
original_image = None
original_mask = None
prompt = ''
mask_gallery = []
masked_gallery = []
result_gallery = []
blip2_output = ''
target_prompt = ''
if torch.cuda.is_available():
torch.cuda.empty_cache()
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, blip2_output, target_prompt, True, False
def update_example(example_type,
prompt,
example_change_times):
input_image = INPUT_IMAGE_PATH[example_type]
image_pil = Image.open(input_image).convert("RGB")
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_image = np.array(image_pil)
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
aspect_ratio = "Custom resolution"
example_change_times += 1
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
def generate_caption(blip2_processor, blip2_model, input_image, device):
image_pil = input_image["background"].convert("RGB")
inputs = blip2_processor(images=image_pil, return_tensors="pt").to(device, torch.float16)
generated_ids = blip2_model.generate(**inputs)
caption = blip2_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return caption
def generate_blip2_description(input_image):
try:
description = generate_caption(blip2_processor, blip2_model, input_image, device)
return description, description
except Exception as e:
return "", f"Caption generation failed: {str(e)}"
def generate_target_prompt(input_image,
original_image,
prompt):
# load example image
if isinstance(original_image, str):
original_image = input_image
image_caption = generate_caption(blip2_processor, blip2_model, input_image, device)
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
llm_model,
original_image,
image_caption,
prompt,
device)
return prompt_after_apply_instruction
def init_img(base,
init_type,
prompt,
aspect_ratio,
example_change_times
):
image_pil = base["background"].convert("RGB")
original_image = np.array(image_pil)
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
raise gr.Error('image aspect ratio cannot be larger than 2.0')
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
else:
if aspect_ratio not in ASPECT_RATIO_LABELS:
aspect_ratio = "Custom resolution"
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
def process_mask(input_image,
original_image,
prompt,
resize_default,
aspect_ratio_name):
if original_image is None:
raise gr.Error('Please upload the input image')
if prompt is None:
raise gr.Error("Please input your instructions, e.g., remove the xxx")
## load mask
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.array(alpha_mask)
# load example image
if isinstance(original_image, str):
original_image = input_image["background"]
image_caption = generate_caption(blip2_processor, blip2_model, input_image, device)
if input_mask.max() == 0:
# category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
category = vlm_response_editing_type(vlm_processor, vlm_model, llm_model, original_image, image_caption, prompt, device)
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
vlm_model,
llm_model,
original_image,
image_caption,
category,
prompt,
device)
# original mask: h,w,1 [0, 255]
original_mask = vlm_response_mask(
vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device).astype(np.uint8)
else:
original_mask = input_mask.astype(np.uint8)
category = None
## resize mask if needed
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
def process(input_image,
original_image,
original_mask,
prompt,
negative_prompt,
control_strength,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
num_samples,
blending,
category,
target_prompt,
resize_default,
aspect_ratio_name,
invert_mask_state):
if original_image is None:
if input_image is None:
raise gr.Error('Please upload the input image')
else:
image_pil = input_image["background"].convert("RGB")
original_image = np.array(image_pil)
if prompt is None or prompt == "":
if target_prompt is None or target_prompt == "":
raise gr.Error("Please input your instructions, e.g., remove the xxx")
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if invert_mask_state:
original_mask = original_mask
else:
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
image_caption = generate_caption(blip2_processor, blip2_model, input_image, device)
# inpainting directly if target_prompt is not None
if category is not None:
pass
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
pass
else:
try:
# category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
category = vlm_response_editing_type(vlm_processor, vlm_model, llm_model, original_image, image_caption, prompt, device)
print(category)
except Exception as e:
raise gr.Error("Time1. Please select the correct VLM model and input the correct API Key first!")
if original_mask is not None:
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
else:
try:
object_wait_for_edit = vlm_response_object_wait_for_edit(
vlm_processor,
vlm_model,
llm_model,
original_image,
image_caption,
category,
prompt,
device)
print(object_wait_for_edit)
original_mask = vlm_response_mask(vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device).astype(np.uint8)
print("Got genarated mask!")
except Exception as e:
raise gr.Error("Time2. Please select the correct VLM model and input the correct API Key first!")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if target_prompt is not None and len(target_prompt) >= 1:
prompt_after_apply_instruction = target_prompt
else:
try:
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
llm_model,
original_image,
image_caption,
prompt,
device)
except Exception as e:
raise gr.Error("Time3. Please select the correct VLM model and input the correct API Key first!")
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
with torch.autocast(device):
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
prompt_after_apply_instruction,
original_mask,
original_image,
generator,
num_inference_steps,
guidance_scale,
control_strength,
negative_prompt,
num_samples,
blending)
original_image = np.array(init_image_np)
masked_image = original_image * (1 - (mask_np>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
return image, [mask_image], [masked_image], prompt, prompt_after_apply_instruction, False
def submit_GPT4o_KEY(GPT4o_KEY):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
try:
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
vlm_processor = ""
response = vlm_model.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."}
]
)
response_str = response.choices[0].message.content
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
except Exception as e:
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
def verify_deepseek_api():
try:
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."}
]
)
response_str = response.choices[0].message.content
return True, "Success. " + response_str
except Exception as e:
return False, "Invalid DeepSeek API Key"
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
try:
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=messages
)
response_str = response.choices[0].message.content
return response_str
except Exception as e:
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
try:
messages = create_decomposed_query_messages_deepseek(integrated_query)
response = llm_model.chat.completions.create(
model="deepseek-chat",
messages=messages
)
response_str = response.choices[0].message.content
return response_str
except Exception as e:
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
def enhance_description(blip2_description, prompt):
try:
if not prompt or not blip2_description:
print("Empty prompt or blip2_description detected")
return "", ""
print(f"Enhancing with prompt: {prompt}")
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip2_description, prompt)
return enhanced_description, enhanced_description
except Exception as e:
print(f"Enhancement failed: {str(e)}")
return "Error occurred", "Error occurred"
def decompose_description(enhanced_description):
try:
if not enhanced_description:
print("Empty enhanced_description detected")
return "", ""
print(f"Decomposing the enhanced description: {enhanced_description}")
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
return decomposed_description, decomposed_description
except Exception as e:
print(f"Decomposition failed: {str(e)}")
return "Error occurred", "Error occurred"
@torch.no_grad()
def mix_and_search(enhanced_text: str, gallery_images: list):
# 获取最新生成的图像元组
latest_item = gallery_images[-1] if gallery_images else None
# 初始化特征列表
features = []
# 图像特征提取
if latest_item and isinstance(latest_item, tuple):
try:
image_path = latest_item[0]
pil_image = Image.open(image_path).convert("RGB")
# 使用 CLIPProcessor 处理图像
image_inputs = clip_processor(
images=pil_image,
return_tensors="pt"
).to(device)
image_features = clip_model.get_image_features(**image_inputs)
features.append(F.normalize(image_features, dim=-1))
except Exception as e:
print(f"图像处理失败: {str(e)}")
# 文本特征提取
if enhanced_text.strip():
text_inputs = clip_processor(
text=enhanced_text,
return_tensors="pt",
padding=True,
truncation=True
).to(device)
text_features = clip_model.get_text_features(**text_inputs)
features.append(F.normalize(text_features, dim=-1))
if not features:
return []
# 特征融合逻辑
if len(features) == 2:
# 图文双特征时加权:40%图像 + 60%文本
mixed = 0.4 * features[0] + 0.6 * features[1]
else:
# 单特征时直接求和(保持原逻辑)
mixed = sum(features)
mixed = F.normalize(mixed, dim=-1)
# 加载Faiss索引和图片路径映射
index_path = "/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index"
input_data_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata")
base_image_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/")
# 按文件名中的数字排序并直接读取parquet文件
parquet_files = sorted(
input_data_dir.glob('*.parquet'),
key=lambda x: int(x.stem.split("_")[-1])
)
# 合并所有parquet数据
dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
df = pd.concat(dfs, ignore_index=True)
image_paths = df["image_path"].tolist()
index = faiss.read_index(index_path)
assert mixed.shape[1] == index.d, "特征维度不匹配"
mixed = mixed.cpu().detach().numpy().astype('float32')
distances, indices = index.search(mixed, 50)
# 获取并验证图片路径
retrieved_images = []
for idx in indices[0]:
if 0 <= idx < len(image_paths):
img_path = base_image_dir / image_paths[idx]
try:
if img_path.exists():
retrieved_images.append(Image.open(img_path).convert("RGB"))
else:
print(f"警告:文件缺失 {img_path}")
except Exception as e:
print(f"图片加载失败: {str(e)}")
return retrieved_images if retrieved_images else ([])
# 问题如下:在cap_file中,每张图片会充当多次reference的值,能够将同一张图片区分开来的是cap_file中的“caption”。我可以如何调整现在的保存逻辑,使得能够区分同一reference不同caption的情况,而不是直接覆盖呢?
# def process_cirr_images():
# if not all([vlm_model, sam_predictor, groundingdino_model]):
# raise RuntimeError("Required models not initialized")
# # Define paths
# dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
# cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
# output_dirs = {
# "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited"),
# "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_mask"),
# "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_masked")
# }
# output_json_path = Path("/home/zt/data/BrushEdit/cirr/image_paint.json")
# descriptions = {}
# # Create output directories
# for dir_path in output_dirs.values():
# dir_path.mkdir(parents=True, exist_ok=True)
# # Load captions
# with open(cap_file, 'r') as f:
# captions = json.load(f)
# for img_path in dev_dir.glob("*.png"):
# base_name = img_path.stem
# caption = next((item["caption"] for item in captions if item.get("reference") == base_name), None)
# if not caption:
# print(f"Warning: No caption for {base_name}")
# continue
# try:
# # 构造空alpha通道(全0)
# rgb_image = Image.open(img_path).convert("RGB")
# empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# # 调用init_img初始化
# base = {"background": image, "layers": [image]}
# init_results = init_img(
# base=base,
# init_type="custom", # 使用自定义初始化
# prompt=caption,
# aspect_ratio="Custom resolution",
# example_change_times=0
# )
# # 获取初始化后的参数
# input_image = init_results[0]
# original_image = init_results[1]
# original_mask = init_results[2]
# # 正确设置process参数
# result_images, mask_images, masked_images, _, target_description, _ = process(
# input_image=input_image,
# original_image=original_image,
# original_mask=original_mask, # 传递初始化后的mask
# prompt=caption,
# negative_prompt="ugly, low quality",
# control_strength=1.0,
# seed=648464818,
# randomize_seed=False,
# guidance_scale=7.5,
# num_inference_steps=50,
# num_samples=1,
# blending=True,
# category=None,
# target_prompt="",
# resize_default=True,
# aspect_ratio_name="Custom resolution",
# invert_mask_state=False
# )
# # Save images
# output_dirs["edited"].mkdir(exist_ok=True)
# result_images[0].save(output_dirs["edited"] / f"{base_name}.png")
# mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.png")
# masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.png")
# # Generate BLIP2 description
# blip2_desc, _ = generate_blip2_description(input_image)
# descriptions[base_name] = {
# "original_caption": caption,
# "blip2_description": blip2_desc,
# "llm_enhanced_caption": target_description
# }
# with open(output_json_path, 'w') as f:
# json.dump(descriptions, f, indent=4) # indent保持可读性
# print(f"Processed {base_name}")
# except Exception as e:
# print(f"Error processing {base_name}: {str(e)}")
# continue
# print("Processing completed!")
# cirr 的 val数据集。目前来看这个函数没有太大问题
# def process_cirr_images():
# if not all([vlm_model, sam_predictor, groundingdino_model]):
# raise RuntimeError("Required models not initialized")
# # Define paths
# # dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
# # cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
# # output_dirs = {
# # "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_edited"),
# # "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_mask"),
# # "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_masked")
# # }
# # output_json_path = Path("/home/zt/data/BrushEdit/cirr/image_paint_pairid.json")
# dev_dir = Path("/home/zt/data/BrushEdit/CIRR/dev")
# cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/val_deepseek_missed_174.json")
# output_dirs = {
# "edited": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_edited"),
# "mask": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_mask"),
# "masked": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_masked")
# }
# output_json_path = Path("/home/zt/data/BrushEdit/CIRR/cirr/image_paint_deepseek_missed_pairid.json")
# descriptions = {}
# # Create output directories
# for dir_path in output_dirs.values():
# dir_path.mkdir(parents=True, exist_ok=True)
# # Load captions
# with open(cap_file, 'r') as f:
# captions = json.load(f)
# for img_path in dev_dir.glob("*.png"):
# base_name = img_path.stem
# # 获取所有匹配的caption条目
# matched_items = [item for item in captions if item.get("reference") == base_name]
# if not matched_items:
# print(f"Warning: No captions for {base_name}")
# continue
# for item in matched_items:
# # 验证必要字段存在
# pairid = item.get("pairid")
# caption = item.get("caption")
# if not all([pairid, caption]):
# print(f"Skipping invalid item for {base_name}: {item}")
# continue
# # 使用pairid构造唯一标识
# processed_base = f"{base_name}_{pairid}"
# try:
# # 构造空alpha通道
# rgb_image = Image.open(img_path).convert("RGB")
# empty_alpha = Image.new("L", rgb_image.size, 0)
# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# # 初始化图像
# base = {"background": image, "layers": [image]}
# init_results = init_img(
# base=base,
# init_type="custom",
# prompt=caption,
# aspect_ratio="Custom resolution",
# example_change_times=0
# )
# # 获取处理参数
# input_image = init_results[0]
# original_image = init_results[1]
# original_mask = init_results[2]
# # 正确设置process参数
# result_images, mask_images, masked_images, _, target_description, _ = process(
# input_image=input_image,
# original_image=original_image,
# original_mask=original_mask, # 传递初始化后的mask
# prompt=caption,
# negative_prompt="ugly, low quality",
# control_strength=1.0,
# seed=648464818,
# randomize_seed=False,
# guidance_scale=7.5,
# num_inference_steps=50,
# num_samples=1,
# blending=True,
# category=None,
# target_prompt="",
# resize_default=True,
# aspect_ratio_name="Custom resolution",
# invert_mask_state=False
# )
# # 保存文件(使用pairid标识)
# result_images[0].save(output_dirs["edited"] / f"{processed_base}.png")
# mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png")
# masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png")
# # 生成描述
# blip2_desc, _ = generate_blip2_description(input_image)
# # 使用pairid作为主键存储元数据
# descriptions[pairid] = {
# "reference": base_name,
# "user_editing_prompt": caption,
# "blip2_description": blip2_desc,
# "llm_enhanced_caption": target_description,
# "processed_files": {
# "edited": f"{processed_base}.png",
# "mask": f"{processed_base}_mask.png",
# "masked": f"{processed_base}_masked.png"
# }
# }
# print(f"Processed {processed_base}")
# except Exception as e:
# print(f"Error processing {processed_base}: {str(e)}")
# continue
# # 保存元数据
# with open(output_json_path, 'w') as f:
# json.dump(descriptions, f, indent=4)
# print("Processing completed!")
# cirr 的 test1数据集。
# def process_cirr_images():
# if not all([vlm_model, sam_predictor, groundingdino_model]):
# raise RuntimeError("Required models not initialized")
# # Define paths
# dev_dir = Path("/home/zt/data/BrushEdit/CIRR/test1")
# cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/cap.rc2.test1.json")
# output_dirs = {
# "edited": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_edited"),
# "mask": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_mask"),
# "masked": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_masked")
# }
# output_json_path = Path("/home/zt/data/BrushEdit/CIRR/test1_image_paint_pairid.json")
# descriptions = {}
# # Create output directories
# for dir_path in output_dirs.values():
# dir_path.mkdir(parents=True, exist_ok=True)
# # Load captions
# with open(cap_file, 'r') as f:
# captions = json.load(f)
# for img_path in dev_dir.glob("*.png"):
# base_name = img_path.stem
# # 获取所有匹配的caption条目
# matched_items = [item for item in captions if item.get("reference") == base_name]
# if not matched_items:
# print(f"Warning: No captions for {base_name}")
# continue
# for item in matched_items:
# # 验证必要字段存在
# pairid = item.get("pairid")
# caption = item.get("caption")
# if not all([pairid, caption]):
# print(f"Skipping invalid item for {base_name}: {item}")
# continue
# # 使用pairid构造唯一标识
# processed_base = f"{base_name}_{pairid}"
# try:
# # 构造空alpha通道
# rgb_image = Image.open(img_path).convert("RGB")
# empty_alpha = Image.new("L", rgb_image.size, 0)
# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# # 初始化图像
# base = {"background": image, "layers": [image]}
# init_results = init_img(
# base=base,
# init_type="custom",
# prompt=caption,
# aspect_ratio="Custom resolution",
# example_change_times=0
# )
# # 获取处理参数
# input_image = init_results[0]
# original_image = init_results[1]
# original_mask = init_results[2]
# # 正确设置process参数
# result_images, mask_images, masked_images, _, target_description, _ = process(
# input_image=input_image,
# original_image=original_image,
# original_mask=original_mask, # 传递初始化后的mask
# prompt=caption,
# negative_prompt="ugly, low quality",
# control_strength=1.0,
# seed=648464818,
# randomize_seed=False,
# guidance_scale=7.5,
# num_inference_steps=50,
# num_samples=1,
# blending=True,
# category=None,
# target_prompt="",
# resize_default=True,
# aspect_ratio_name="Custom resolution",
# invert_mask_state=False
# )
# # 保存文件(使用pairid标识)
# result_images[0].save(output_dirs["edited"] / f"{processed_base}.png")
# mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png")
# masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png")
# # 生成描述
# blip2_desc, _ = generate_blip2_description(input_image)
# # 使用pairid作为主键存储元数据
# descriptions[pairid] = {
# "reference": base_name,
# "user_editing_prompt": caption,
# "blip2_description": blip2_desc,
# "llm_enhanced_caption": target_description,
# "processed_files": {
# "edited": f"{processed_base}.png",
# "mask": f"{processed_base}_mask.png",
# "masked": f"{processed_base}_masked.png"
# }
# }
# print(f"Processed {processed_base}")
# except Exception as e:
# print(f"Error processing {processed_base}: {str(e)}")
# continue
# # 保存元数据
# with open(output_json_path, 'w') as f:
# json.dump(descriptions, f, indent=4)
# print("Processing completed!")
def process_cirr_images():
if not all([vlm_model, sam_predictor, groundingdino_model]):
raise RuntimeError("Required models not initialized")
# Define paths
dev_dir = Path("/home/zt/data/BrushEdit/CIRR/test1")
cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/cap.rc2.test1.json")
output_dirs = {
"edited": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_edited"),
"mask": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_mask"),
"masked": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_masked")
}
output_json_path = Path("/home/zt/data/BrushEdit/CIRR/qw_test1_image_paint_pairid.json")
# Create output directories
for dir_path in output_dirs.values():
dir_path.mkdir(parents=True, exist_ok=True)
# 1. 加载已有处理结果
processed_pairids = set()
if output_json_path.exists():
try:
with open(output_json_path, 'r') as f:
descriptions = json.load(f)
processed_pairids = set(descriptions.keys())
print(f"Loaded {len(processed_pairids)} previously processed pairids")
except Exception as e:
print(f"Error loading existing results: {str(e)}, starting from scratch")
descriptions = {}
else:
descriptions = {}
# 2. 创建临时文件写入器
temp_json_path = output_json_path.with_suffix(".tmp")
# Load captions
with open(cap_file, 'r') as f:
captions = json.load(f)
# 3. 处理进度跟踪
total = 0
processed = 0
for img_path in dev_dir.glob("*.png"):
base_name = img_path.stem
matched_items = [item for item in captions if item.get("reference") == base_name]
if not matched_items:
print(f"Warning: No captions for {base_name}")
continue
for item in matched_items:
total += 1
pairid = str(item.get("pairid")) # 确保字符串类型
caption = item.get("caption")
if not all([pairid, caption]):
print(f"Skipping invalid item for {base_name}: {item}")
continue
# 4. 跳过已处理的pairid
if pairid in processed_pairids:
print(f"Skipping already processed pairid: {pairid}")
continue
processed_base = f"{base_name}_{pairid}"
print(f"Processing {processed_base} ({processed+1}/{total})")
try:
# 构造空alpha通道
rgb_image = Image.open(img_path).convert("RGB")
empty_alpha = Image.new("L", rgb_image.size, 0)
image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# 初始化图像
base = {"background": image, "layers": [image]}
init_results = init_img(
base=base,
init_type="custom",
prompt=caption,
aspect_ratio="Custom resolution",
example_change_times=0
)
# 获取处理参数
input_image = init_results[0]
original_image = init_results[1]
original_mask = init_results[2]
# 正确设置process参数
result_images, mask_images, masked_images, _, target_description, _ = process(
input_image=input_image,
original_image=original_image,
original_mask=original_mask, # 传递初始化后的mask
prompt=caption,
negative_prompt="ugly, low quality",
control_strength=1.0,
seed=648464818,
randomize_seed=False,
guidance_scale=7.5,
num_inference_steps=50,
num_samples=1,
blending=True,
category=None,
target_prompt="",
resize_default=True,
aspect_ratio_name="Custom resolution",
invert_mask_state=False
)
# 保存文件(使用pairid标识)
result_images[0].save(output_dirs["edited"] / f"{processed_base}.png")
mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png")
masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png")
# 生成描述
blip2_desc, _ = generate_blip2_description(input_image)
# 更新描述信息 使用pairid作为主键存储元数据
descriptions[pairid] = {
"reference": base_name,
"user_editing_prompt": caption,
"blip2_description": blip2_desc,
"llm_enhanced_caption": target_description,
"processed_files": {
"edited": f"{processed_base}.png",
"mask": f"{processed_base}_mask.png",
"masked": f"{processed_base}_masked.png"
}
}
# 5. 原子化写入:先写入临时文件,再替换原文件
with open(temp_json_path, 'w') as f:
json.dump(descriptions, f, indent=4)
temp_json_path.replace(output_json_path)
processed +=1
processed_pairids.add(pairid)
print(f"Successfully processed {pairid}")
except Exception as e:
print(f"Error processing {pairid}: {str(e)}")
# 删除可能生成的不完整文件
for ext in ["", "_mask.png", "_masked.png"]:
incomplete_file = output_dirs["edited"] / f"{processed_base}{ext}"
if incomplete_file.exists():
incomplete_file.unlink()
continue
print(f"Processing completed! Total processed: {processed}/{total}")
# circo 的 val数据集。目前来看这个函数没有太大问题
def process_circo_val_images():
if not all([vlm_model, sam_predictor, groundingdino_model]):
raise RuntimeError("Required models not initialized")
# Define paths
# 实际上不单单是dev数据集的图片存储空间,所有的(包括test集合&没有用到的)图片都存在这里
dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/COCO2017_unlabeled/unlabeled2017")
cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/val.json")
output_dirs = {
"edited": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_edited"),
"mask": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_mask"),
"masked": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_masked")
}
output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/image_paint_pairid.json")
descriptions = {}
# Create output directories
for dir_path in output_dirs.values():
dir_path.mkdir(parents=True, exist_ok=True)
# Load captions
with open(cap_file, 'r') as f:
captions = json.load(f)
for img_path in dev_dir.glob("*.jpg"):
# 不包含扩展名
base_name = img_path.stem
# 提取后六位作为参考ID,reference_part是字符串类型
reference_part = base_name[-6:]
# 将JSON中的reference_img_id(原本为int),转换为字符串后比较
matched_items = [item for item in captions if str(item.get("reference_img_id")) == reference_part]
if not matched_items:
print(f"Warning: No captions for {base_name}")
continue
for item in matched_items:
# 验证必要字段存在
pairid = item.get("id")
caption = item.get("relative_caption")
if not all([pairid, caption]):
print(f"Skipping invalid item for {base_name}: {item}")
continue
# 使用pairid构造唯一标识
# 使用f-string进行字符串格式化时,它会自动将非字符串类型的变量转换为字符串类型
processed_base = f"{reference_part}_{pairid}"
try:
# 构造空alpha通道
rgb_image = Image.open(img_path).convert("RGB")
empty_alpha = Image.new("L", rgb_image.size, 0)
image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# 初始化图像
base = {"background": image, "layers": [image]}
init_results = init_img(
base=base,
init_type="custom",
prompt=caption,
aspect_ratio="Custom resolution",
example_change_times=0
)
# 获取处理参数
input_image = init_results[0]
original_image = init_results[1]
original_mask = init_results[2]
# 正确设置process参数
result_images, mask_images, masked_images, _, target_description, _ = process(
input_image=input_image,
original_image=original_image,
original_mask=original_mask, # 传递初始化后的mask
prompt=caption,
negative_prompt="ugly, low quality",
control_strength=1.0,
seed=648464818,
randomize_seed=False,
guidance_scale=7.5,
num_inference_steps=50,
num_samples=1,
blending=True,
category=None,
target_prompt="",
resize_default=True,
aspect_ratio_name="Custom resolution",
invert_mask_state=False
)
# 保存文件(使用pairid标识)
result_images[0].save(output_dirs["edited"] / f"{processed_base}.jpg")
mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.jpg")
masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.jpg")
# 生成描述
blip2_desc, _ = generate_blip2_description(input_image)
# 使用pairid作为主键存储元数据
descriptions[pairid] = {
"reference": int(reference_part),
"user_editing_prompt": caption,
"blip2_description": blip2_desc,
"llm_enhanced_caption": target_description,
"processed_files": {
"edited": f"{processed_base}.jpg",
"mask": f"{processed_base}_mask.jpg",
"masked": f"{processed_base}_masked.jpg"
}
}
print(f"Processed {processed_base}")
except Exception as e:
print(f"Error processing {processed_base}: {str(e)}")
continue
# 保存元数据
with open(output_json_path, 'w') as f:
json.dump(descriptions, f, indent=4)
print("Processing completed!")
# circo 的 test数据集。目前来看这个函数没有太大问题
def process_circo_test_images():
if not all([vlm_model, sam_predictor, groundingdino_model]):
raise RuntimeError("Required models not initialized")
# Define paths
# 实际上不单单是dev数据集的图片存储空间,所有的(包括test集合&没有用到的)图片都存在这里
dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/COCO2017_unlabeled/unlabeled2017")
cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/test.json")
output_dirs = {
"edited": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_edited"),
"mask": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_mask"),
"masked": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_masked")
}
output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/test_image_paint_pairid.json")
descriptions = {}
# Create output directories
for dir_path in output_dirs.values():
dir_path.mkdir(parents=True, exist_ok=True)
# Load captions
with open(cap_file, 'r') as f:
captions = json.load(f)
for img_path in dev_dir.glob("*.jpg"):
# 不包含扩展名
base_name = img_path.stem
# 提取后六位作为参考ID,reference_part是字符串类型
reference_part = base_name[-6:]
# 将JSON中的reference_img_id(原本为int),转换为字符串后比较
matched_items = [item for item in captions if str(item.get("reference_img_id")) == reference_part]
if not matched_items:
print(f"Warning: No captions for {base_name}")
continue
for item in matched_items:
# 验证必要字段存在
pairid = item.get("id")
caption = item.get("relative_caption")
if not all([pairid, caption]):
print(f"Skipping invalid item for {base_name}: {item}")
continue
# 使用pairid构造唯一标识
# 使用f-string进行字符串格式化时,它会自动将非字符串类型的变量转换为字符串类型
processed_base = f"{reference_part}_{pairid}"
try:
# 构造空alpha通道
rgb_image = Image.open(img_path).convert("RGB")
empty_alpha = Image.new("L", rgb_image.size, 0)
image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# 初始化图像
base = {"background": image, "layers": [image]}
init_results = init_img(
base=base,
init_type="custom",
prompt=caption,
aspect_ratio="Custom resolution",
example_change_times=0
)
# 获取处理参数
input_image = init_results[0]
original_image = init_results[1]
original_mask = init_results[2]
# 正确设置process参数
result_images, mask_images, masked_images, _, target_description, _ = process(
input_image=input_image,
original_image=original_image,
original_mask=original_mask, # 传递初始化后的mask
prompt=caption,
negative_prompt="ugly, low quality",
control_strength=1.0,
seed=648464818,
randomize_seed=False,
guidance_scale=7.5,
num_inference_steps=50,
num_samples=1,
blending=True,
category=None,
target_prompt="",
resize_default=True,
aspect_ratio_name="Custom resolution",
invert_mask_state=False
)
# 保存文件(使用pairid标识)
result_images[0].save(output_dirs["edited"] / f"{processed_base}.jpg")
mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.jpg")
masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.jpg")
# 生成描述
blip2_desc, _ = generate_blip2_description(input_image)
# 使用pairid作为主键存储元数据
descriptions[pairid] = {
"reference": int(reference_part),
"user_editing_prompt": caption,
"blip2_description": blip2_desc,
"llm_enhanced_caption": target_description,
"processed_files": {
"edited": f"{processed_base}.jpg",
"mask": f"{processed_base}_mask.jpg",
"masked": f"{processed_base}_masked.jpg"
}
}
print(f"Processed {processed_base}")
except Exception as e:
print(f"Error processing {processed_base}: {str(e)}")
continue
# 保存元数据
with open(output_json_path, 'w') as f:
json.dump(descriptions, f, indent=4)
print("Processing completed!")
# 目前来看这个函数没有太大问题
@torch.no_grad()
def batch_mix_and_search_cirr(
json_path: str = "/home/zt/data/BrushEdit/cirr/image_paint_pairid.json",
image_dir: str = "/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited",
alpha: float = 0.8,
batch_size: int = 32,
output_json_path: str = "retrieval_results_pairid.json"
) -> Dict[str, List[Dict]]:
# 加载索引和元数据
index = faiss.read_index("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index")
metadata = pd.read_parquet("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet")
all_index_ids = metadata["image_path"].tolist()
# 加载并验证输入数据
with open(json_path) as f:
samples = json.load(f)
valid_samples = []
image_dir = Path(image_dir)
for pair_id, sample_info in samples.items():
reference = sample_info["reference"]
img_path = image_dir / f"{reference}.png"
if img_path.exists():
valid_samples.append((
pair_id,
reference,
img_path,
sample_info['llm_enhanced_caption'],
sample_info['user_editing_prompt']
))
# 初始化结果字典
results = {}
total_samples = len(valid_samples)
# 分批次处理
for batch_idx in range(0, total_samples, batch_size):
batch_end = min(batch_idx + batch_size, total_samples)
current_batch = valid_samples[batch_idx:batch_end]
# 批量处理图像(保持原逻辑)
batch_images = [Image.open(s[2]).convert("RGB") for s in current_batch]
image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device)
image_features = clip_model.get_image_features(**image_inputs)
image_features = nn.functional.normalize(image_features, dim=-1)
# 批量处理文本(保持原逻辑)
batch_texts = [s[3] for s in current_batch]
text_inputs = clip_processor(
text=batch_texts,
return_tensors="pt",
padding=True,
truncation=True
).to(device)
text_features = clip_model.get_text_features(**text_inputs)
text_features = nn.functional.normalize(text_features, dim=-1)
# 混合特征
mixed_features = (1 - alpha) * image_features + alpha * text_features
mixed_features = nn.functional.normalize(mixed_features, dim=-1)
# 批量检索
query_features = mixed_features.cpu().numpy().astype("float32")
distances, indices = index.search(query_features, 100)
# 保存当前批次结果
for (pair_id, reference, _, enhanced_cap, editing_prompt), dist_row, idx_row in zip(current_batch, distances, indices):
results[pair_id] = {
"reference": reference,
"llm_enhanced_caption": enhanced_cap,
"user_editing_prompt": editing_prompt,
"retrieved_results": []
}
for distance, idx in zip(dist_row, idx_row):
if 0 <= idx < len(all_index_ids):
raw_id = all_index_ids[idx]
base_name = os.path.basename(raw_id)
file_name = os.path.splitext(base_name)[0]
results[pair_id]["retrieved_results"].append({
"retrieved_id": file_name,
"score": float(distance),
})
# 保存结果到JSON
with open(output_json_path, 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print("Retrieving completed!")
return results
# @torch.no_grad()
# def batch_mix_and_search_cirr(
# json_path: str = "/home/zt/data/BrushEdit/cirr/image_paint_pairid.json",
# image_dir: str = "/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited",
# alpha: float = 0.8,
# batch_size: int = 32,
# output_json_path: str = "retrieval_results_pairid.json"
# ) -> Dict[str, List[Dict]]:
# # 加载索引和元数据
# index = faiss.read_index("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index")
# metadata = pd.read_parquet("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet")
# all_index_ids = metadata["image_path"].tolist()
# # 加载并验证输入数据
# with open(json_path) as f:
# samples = json.load(f)
# valid_samples = []
# image_dir = Path(image_dir)
# for image_id, sample_info in samples.items():
# img_path = image_dir / f"{image_id}.png"
# if img_path.exists():
# valid_samples.append((
# image_id,
# img_path,
# sample_info['llm_enhanced_caption'],
# sample_info['user_editing_prompt']
# ))
# # 初始化结果字典
# results = {}
# total_samples = len(valid_samples)
# # 分批次处理
# for batch_idx in range(0, total_samples, batch_size):
# batch_end = min(batch_idx + batch_size, total_samples)
# current_batch = valid_samples[batch_idx:batch_end]
# # 批量处理图像
# batch_images = [Image.open(s[1]).convert("RGB") for s in current_batch]
# image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device)
# image_features = clip_model.get_image_features(**image_inputs)
# image_features = nn.functional.normalize(image_features, dim=-1)
# # 批量处理文本
# batch_texts = [s[2] for s in current_batch]
# text_inputs = clip_processor(
# text=batch_texts,
# return_tensors="pt",
# padding=True,
# truncation=True
# ).to(device)
# text_features = clip_model.get_text_features(**text_inputs)
# text_features = nn.functional.normalize(text_features, dim=-1)
# # 混合特征
# mixed_features = (1 - alpha) * image_features + alpha * text_features
# mixed_features = nn.functional.normalize(mixed_features, dim=-1)
# # 批量检索
# query_features = mixed_features.cpu().numpy().astype("float32")
# distances, indices = index.search(query_features, 100)
# # 保存当前批次结果
# for (sample_id, _, enhanced_cap, original_cap), dist_row, idx_row in zip(current_batch, distances, indices):
# results[sample_id] = {
# "llm_enhanced_caption": enhanced_cap, # 增强描述
# "original_caption": original_cap, # 原始描述
# "retrieved_results": []
# }
# for distance, idx in zip(dist_row, idx_row):
# if 0 <= idx < len(all_index_ids):
# raw_id = all_index_ids[idx]
# base_name = os.path.basename(raw_id)
# file_name = os.path.splitext(base_name)[0]
# results[sample_id]["retrieved_results"].append({
# "retrieved_id": file_name,
# "score": float(distance),
# })
# # 保存结果到JSON
# with open(output_json_path, 'w') as f:
# json.dump(results, f, indent=2, ensure_ascii=False)
# print("Retrieving completed!")
# return results
# def evaluate_cirr_scores() -> List[Tuple[str, float]]:
# # 设置数据集和检索结果的路径
# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchong.json"
# # 加载数据集
# with open(dataset_path, 'r') as f:
# dataset = json.load(f)
# print(len(dataset))
# # 加载检索结果
# with open(retrieval_results_path, 'r') as f:
# retrieval_results = json.load(f)
# print(len(retrieval_results))
# # 数据结构初始化
# all_target_captions_soft = []
# all_set_member_idx = []
# nn_result = []
# # 构建匹配数据结构
# for sample in dataset:
# all_target_captions_soft.append(sample["target_soft"])
# all_set_member_idx.append(sample["img_set"]["members"])
# query_id = str(sample["reference"])
# retrieved_items = retrieval_results.get(query_id, [])
# nn_result.append([item["retrieved_id"] for item in retrieved_items])
# # 计算召回指标
# out = []
# # Recall@K (全局检索)
# for k in [1, 5, 10, 50]:
# total_score = 0.0
# for i in range(len(dataset)):
# query_id = str(dataset[i]["reference"]) # 获取当前查询的参考ID
# # 过滤掉参考图像本身
# filtered_results = [rid for rid in nn_result[i] if rid != query_id]
# top_k = filtered_results[:k]
# best_score = 0.0
# for target_id, score in all_target_captions_soft[i].items():
# if target_id in top_k:
# best_score = max(best_score, score)
# total_score += best_score
# recall = total_score / len(dataset) * 100
# out.append((f"recall_top{k}_correct_composition", recall))
# # Recall_subset@K (子集检索)
# for k in [1, 2, 3]:
# total_score = 0.0
# for i in range(len(dataset)):
# query_id = str(dataset[i]["reference"])
# # 双重过滤:子集成员 + 排除参考图像
# subset_results = [
# rid for rid in nn_result[i]
# if rid in all_set_member_idx[i] and rid != query_id
# ]
# top_k_subset = subset_results[:k]
# best_score = 0.0
# for target_id, score in all_target_captions_soft[i].items():
# if target_id in top_k_subset:
# best_score = max(best_score, score)
# total_score += best_score
# recall = total_score / len(dataset) * 100
# out.append((f"recall_inset_top{k}_correct_composition", recall))
# # 打印和保存结果
# print("\n" + "="*30 + " Evaluation Results " + "="*30)
# for metric, value in out:
# print(f"{metric:<40}: {value:.4f}")
# output_dir = os.path.dirname(retrieval_results_path)
# output_filename = "evaluation_results_quchong.txt"
# output_path = os.path.join(output_dir, output_filename)
# with open(output_path, 'w') as f:
# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
# for metric, value in out:
# line = f"{metric:<40}: {value:.4f}\n"
# f.write(line)
# print(f"\nResults saved to {output_path}")
# return out
# # 这是因为dataset中的样本的reference字段大量重复(CIRR数据集,每个参考图像对应多个不同的目标描述caption)
# def evaluate_cirr_scores() -> List[Tuple[str, float]]:
# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_noquery_a08.json"
# # 加载数据集
# with open(dataset_path, 'r') as f:
# dataset = json.load(f)
# # 加载检索结果
# with open(retrieval_results_path, 'r') as f:
# retrieval_results = json.load(f)
# # 调试:检查数据一致性
# print(len(dataset))
# print(len(retrieval_results))
# # 过滤有效样本(确保类型严格一致)
# valid_samples = []
# for sample in dataset:
# query_id = str(sample["reference"])
# if query_id in retrieval_results:
# valid_samples.append(sample)
# print(f"Valid samples count: {len(valid_samples)} (must be 2086)")
# # 构建数据结构
# all_target_captions_soft = []
# all_set_member_idx = []
# nn_result = []
# for sample in valid_samples:
# all_target_captions_soft.append(sample["target_soft"])
# all_set_member_idx.append(sample["img_set"]["members"])
# query_id = str(sample["reference"])
# # 直接获取该query_id的检索结果列表
# retrieved_items = retrieval_results[query_id]
# nn_result.append([item["retrieved_id"] for item in retrieved_items])
# out = []
# # Recall@K 计算(使用有效样本数量和过滤后的数据)
# for k in [1, 5, 10, 50]:
# total_score = 0.0
# for i in range(len(valid_samples)):
# query_id = str(valid_samples[i]["reference"])
# filtered_results = [rid for rid in nn_result[i] if rid != query_id]
# top_k = filtered_results[:k]
# best_score = 0.0
# for target_id, score in all_target_captions_soft[i].items():
# if target_id in top_k:
# best_score = max(best_score, score)
# total_score += best_score
# recall = total_score / len(valid_samples) * 100
# print(len(valid_samples))
# out.append((f"recall_top{k}_correct_composition", recall))
# # Recall_subset@K 计算
# for k in [1, 2, 3]:
# total_score = 0.0
# for i in range(len(valid_samples)):
# query_id = str(valid_samples[i]["reference"])
# subset_results = [
# rid for rid in nn_result[i]
# if rid in all_set_member_idx[i] and rid != query_id
# ]
# top_k_subset = subset_results[:k]
# best_score = 0.0
# for target_id, score in all_target_captions_soft[i].items():
# if target_id in top_k_subset:
# best_score = max(best_score, score)
# total_score += best_score
# recall = total_score / len(valid_samples) * 100
# out.append((f"recall_inset_top{k}_correct_composition", recall))
# # 输出结果(保持不变)
# print("\n" + "="*30 + " Evaluation Results " + "="*30)
# for metric, value in out:
# print(f"{metric:<40}: {value:.4f}")
# output_dir = os.path.dirname(retrieval_results_path)
# output_filename = "evaluation_results_valid.txt"
# output_path = os.path.join(output_dir, output_filename)
# with open(output_path, 'w') as f:
# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
# for metric, value in out:
# line = f"{metric:<40}: {value:.4f}\n"
# f.write(line)
# print(f"\nResults saved to {output_path}")
# return out
# # 可以截取前n个计算分数
# def evaluate_cirr_scores() -> List[Tuple[str, float]]:
# # 设置数据集和检索结果的路径
# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchongnew.json"
# # 加载数据集
# with open(dataset_path, 'r') as f:
# dataset = json.load(f)
# print(f"Total samples in dataset: {len(dataset)}")
# # 加载检索结果
# with open(retrieval_results_path, 'r') as f:
# retrieval_results = json.load(f)
# print(f"Total queries in retrieval results: {len(retrieval_results)}")
# # 数据结构初始化
# valid_query_ids = [] # 存储匹配成功的query_id
# all_target_captions_soft = [] # 存储匹配样本的target_soft
# all_set_member_idx = [] # 存储匹配样本的set members
# nn_result = [] # 存储匹配样本的检索结果
# # 构建匹配数据结构(新增双重验证)
# for sample in dataset:
# query_id = str(sample["reference"])
# caption = sample["caption"]
# # 获取对应的检索结果条目
# retrieved_entry = retrieval_results.get(query_id)
# # 双重验证:query_id存在且caption匹配
# if retrieved_entry and retrieved_entry["original_caption"] == caption:
# valid_query_ids.append(query_id)
# all_target_captions_soft.append(sample["target_soft"])
# all_set_member_idx.append(sample["img_set"]["members"])
# # 提取检索结果并转换为id列表
# retrieved_items = retrieved_entry["retrieved_results"]
# nn_result.append([item["retrieved_id"] for item in retrieved_items])
# print(f"Valid matched samples after verification: {len(valid_query_ids)}")
# ##############################################
# # 新增修改:仅取前100个样本
# # 截取前100个有效样本(若不足100则取全部)
# valid_query_ids = valid_query_ids[:300]
# all_target_captions_soft = all_target_captions_soft[:300]
# all_set_member_idx = all_set_member_idx[:300]
# nn_result = nn_result[:300]
# total_samples = len(valid_query_ids)
# ##############################################
# print(f"Evaluating on first {total_samples} samples")
# # 计算召回指标(仅使用有效样本)
# out = []
# total_samples = len(valid_query_ids)
# # Recall@K (全局检索)
# for k in [1, 5, 10, 50]:
# total_score = 0.0
# for i in range(total_samples):
# # 过滤参考图像本身
# filtered_results = [rid for rid in nn_result[i] if rid != valid_query_ids[i]]
# top_k = filtered_results[:k]
# # 计算最佳匹配分数
# best_score = max(
# (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k),
# default=0.0
# )
# total_score += best_score
# recall = (total_score / total_samples) * 100
# out.append((f"recall_top{k}_correct_composition", recall))
# # Recall_subset@K (子集检索)
# for k in [1, 2, 3]:
# total_score = 0.0
# for i in range(total_samples):
# # 双重过滤:子集成员且非参考图像
# subset_results = [
# rid for rid in nn_result[i]
# if rid in all_set_member_idx[i] and rid != valid_query_ids[i]
# ]
# top_k_subset = subset_results[:k]
# # 计算最佳匹配分数
# best_score = max(
# (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k_subset),
# default=0.0
# )
# total_score += best_score
# recall = (total_score / total_samples) * 100
# out.append((f"recall_inset_top{k}_correct_composition", recall))
# # 打印和保存结果
# print("\n" + "="*30 + " Evaluation Results " + "="*30)
# for metric, value in out:
# print(f"{metric:<40}: {value:.4f}")
# output_dir = os.path.dirname(retrieval_results_path)
# output_filename = "evaluation_results_quchongnew.txt"
# output_path = os.path.join(output_dir, output_filename)
# with open(output_path, 'w') as f:
# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
# for metric, value in out:
# line = f"{metric:<40}: {value:.4f}\n"
# f.write(line)
# print(f"\nResults saved to {output_path}")
# return out
# 目前来看这个函数没有太大问题
def evaluate_cirr_scores() -> List[Tuple[str, float]]:
# 设置数据集和检索结果的路径
dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchongnew.json"
# 加载数据集
with open(dataset_path, 'r') as f:
dataset = json.load(f)
print(f"Total samples in dataset: {len(dataset)}")
# 加载检索结果
with open(retrieval_results_path, 'r') as f:
retrieval_results = json.load(f)
print(f"Total queries in retrieval results: {len(retrieval_results)}")
# 数据结构初始化
valid_query_ids = [] # 存储匹配成功的query_id
all_target_captions_soft = [] # 存储匹配样本的target_soft
all_set_member_idx = [] # 存储匹配样本的set members
nn_result = [] # 存储匹配样本的检索结果
# 构建匹配数据结构(新增双重验证)
for sample in dataset:
query_id = str(sample["reference"])
caption = sample["caption"]
# 获取对应的检索结果条目
retrieved_entry = retrieval_results.get(query_id)
# 双重验证:query_id存在且caption匹配
if retrieved_entry and retrieved_entry["original_caption"] == caption:
valid_query_ids.append(query_id)
all_target_captions_soft.append(sample["target_soft"])
all_set_member_idx.append(sample["img_set"]["members"])
# 提取检索结果并转换为id列表
retrieved_items = retrieved_entry["retrieved_results"]
nn_result.append([item["retrieved_id"] for item in retrieved_items])
print(f"Valid matched samples after verification: {len(valid_query_ids)}")
# 计算召回指标(仅使用有效样本)
out = []
total_samples = len(valid_query_ids)
# Recall@K (全局检索)
for k in [1, 5, 10, 50]:
total_score = 0.0
for i in range(total_samples):
# 过滤参考图像本身
filtered_results = [rid for rid in nn_result[i] if rid != valid_query_ids[i]]
top_k = filtered_results[:k]
# 计算最佳匹配分数
best_score = max(
(score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k),
default=0.0
)
total_score += best_score
recall = (total_score / total_samples) * 100
out.append((f"recall_top{k}_correct_composition", recall))
# Recall_subset@K (子集检索)
for k in [1, 2, 3]:
total_score = 0.0
for i in range(total_samples):
# 双重过滤:子集成员且非参考图像
subset_results = [
rid for rid in nn_result[i]
if rid in all_set_member_idx[i] and rid != valid_query_ids[i]
]
top_k_subset = subset_results[:k]
# 计算最佳匹配分数
best_score = max(
(score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k_subset),
default=0.0
)
total_score += best_score
recall = (total_score / total_samples) * 100
out.append((f"recall_inset_top{k}_correct_composition", recall))
# 打印和保存结果
print("\n" + "="*30 + " Evaluation Results " + "="*30)
for metric, value in out:
print(f"{metric:<40}: {value:.4f}")
output_dir = os.path.dirname(retrieval_results_path)
output_filename = "evaluation_results_quchongnew.txt"
output_path = os.path.join(output_dir, output_filename)
with open(output_path, 'w') as f:
f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
for metric, value in out:
line = f"{metric:<40}: {value:.4f}\n"
f.write(line)
print(f"\nResults saved to {output_path}")
return out
if __name__ == "__main__":
# process_circo_val_images()
# process_circo_test_images()
process_cirr_images()
# def process_circo_images():
# if not all([vlm_model, sam_predictor, groundingdino_model]):
# raise RuntimeError("Required models not initialized")
# # Define paths
# dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/img_raw/dev")
# cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/val.json")
# output_dirs = {
# "edited": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_edited"),
# "mask": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_mask"),
# "masked": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_masked")
# }
# output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/image_paint.json")
# descriptions = {}
# # Create output directories
# for dir_path in output_dirs.values():
# dir_path.mkdir(parents=True, exist_ok=True)
# # Load captions
# with open(cap_file, 'r') as f:
# captions = json.load(f)
# for img_path in dev_dir.glob("*.jpg"):
# base_name = img_path.stem
# # 提取后六位作为参考ID
# reference_part = base_name[-6:]
# # 将JSON中的reference_img_id转换为字符串后比较
# caption = next(
# (item["relative_caption"] for item in captions
# if str(item.get("reference_img_id")) == reference_part),
# None
# )
# if not caption:
# print(f"Warning: No caption for {base_name}")
# continue
# try:
# # 构造空alpha通道(全0)
# rgb_image = Image.open(img_path).convert("RGB")
# empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
# # 调用init_img初始化
# base = {"background": image, "layers": [image]}
# init_results = init_img(
# base=base,
# init_type="custom", # 使用自定义初始化
# prompt=caption,
# aspect_ratio="Custom resolution",
# example_change_times=0
# )
# # 获取初始化后的参数
# input_image = init_results[0]
# original_image = init_results[1]
# original_mask = init_results[2]
# # 正确设置process参数
# result_images, mask_images, masked_images, _, target_description, _ = process(
# input_image=input_image,
# original_image=original_image,
# original_mask=original_mask, # 传递初始化后的mask
# prompt=caption,
# negative_prompt="ugly, low quality",
# control_strength=1.0,
# seed=648464818,
# randomize_seed=False,
# guidance_scale=7.5,
# num_inference_steps=50,
# num_samples=1,
# blending=True,
# category=None,
# target_prompt="",
# resize_default=True,
# aspect_ratio_name="Custom resolution",
# invert_mask_state=False
# )
# # Save images
# output_dirs["edited"].mkdir(exist_ok=True)
# result_images[0].save(output_dirs["edited"] / f"{base_name}.jpg")
# mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.jpg")
# masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.jpg")
# # Generate BLIP2 description
# blip2_desc, _ = generate_blip2_description(input_image)
# descriptions[base_name] = {
# "original_caption": caption,
# "blip2_description": blip2_desc,
# "llm_enhanced_caption": target_description
# }
# with open(output_json_path, 'w') as f:
# json.dump(descriptions, f, indent=4) # indent保持可读性
# print(f"Processed {base_name}")
# except Exception as e:
# print(f"Error processing {base_name}: {str(e)}")
# continue
# print("Processing completed!")
# @torch.no_grad()
# def batch_mix_and_search_circo(
# json_path: str = "/home/zt/data/BrushEdit/CIRCO/image_paint.json",
# image_dir: str = "/home/zt/data/BrushEdit/CIRCO/img_paint/circo_edited",
# alpha: float = 0.6,
# batch_size: int = 32,
# output_json_path: str = "circo_retrieval_results.json"
# ) -> Dict[str, List[Dict]]:
# # 加载索引和元数据
# index = faiss.read_index("/home/zt/data/BrushEdit/CIRCO/img_raw/dev/dev_knn.index")
# metadata = pd.read_parquet("/home/zt/data/BrushEdit/CIRCO/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet")
# all_index_ids = metadata["image_path"].tolist()
# # 加载并验证输入数据
# with open(json_path) as f:
# samples = json.load(f)
# valid_samples = []
# image_dir = Path(image_dir)
# for image_id, sample_info in samples.items(): # 关键修改点
# img_path = image_dir / f"{image_id}.jpg" # 直接用字典的key作为image_id
# if img_path.exists():
# valid_samples.append(
# (image_id, img_path, sample_info['llm_enhanced_caption']) # 从value中取caption
# )
# # 初始化结果字典
# results = {}
# total_samples = len(valid_samples)
# # 分批次处理
# for batch_idx in range(0, total_samples, batch_size):
# batch_end = min(batch_idx + batch_size, total_samples)
# current_batch = valid_samples[batch_idx:batch_end]
# # 批量处理图像
# batch_images = [Image.open(s[1]).convert("RGB") for s in current_batch]
# image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device)
# image_features = clip_model.get_image_features(**image_inputs)
# image_features = nn.functional.normalize(image_features, dim=-1)
# # 批量处理文本
# batch_texts = [s[2] for s in current_batch]
# text_inputs = clip_processor(
# text=batch_texts,
# return_tensors="pt",
# padding=True,
# truncation=True
# ).to(device)
# text_features = clip_model.get_text_features(**text_inputs)
# text_features = nn.functional.normalize(text_features, dim=-1)
# # 混合特征
# mixed_features = (1 - alpha) * image_features + alpha * text_features
# mixed_features = nn.functional.normalize(mixed_features, dim=-1)
# # 批量检索
# query_features = mixed_features.cpu().numpy().astype("float32")
# distances, indices = index.search(query_features, 100)
# # 保存当前批次结果
# for (sample_id, _, _), dist_row, idx_row in zip(current_batch, distances, indices):
# results[sample_id] = []
# for distance, idx in zip(dist_row, idx_row):
# if 0 <= idx < len(all_index_ids):
# results[sample_id].append({
# "retrieved_id": all_index_ids[idx],
# "score": float(distance)
# })
# # 保存结果到JSON
# with open(output_json_path, 'w') as f:
# json.dump(results, f, indent=2, ensure_ascii=False)
# print("Retrieving completed!")
# return results
# block = gr.Blocks(
# theme=gr.themes.Soft(
# radius_size=gr.themes.sizes.radius_none,
# text_size=gr.themes.sizes.text_md
# )
# )
# with block as demo:
# with gr.Row():
# with gr.Column():
# gr.HTML(head)
# gr.HTML(descriptions)
# with gr.Accordion(label="🧭 小白也能秒懂的魔法指南:", open=True, elem_id="accordion"):
# with gr.Row(equal_height=True):
# gr.HTML(instructions)
# original_image = gr.State(value=None)
# original_mask = gr.State(value=None)
# category = gr.State(value=None)
# status = gr.State(value=None)
# invert_mask_state = gr.State(value=False)
# example_change_times = gr.State(value=0)
# deepseek_verified = gr.State(value=False)
# blip2_description = gr.State(value="")
# enhanced_description = gr.State(value="")
# decomposed_description = gr.State(value="")
# with gr.Row():
# with gr.Column():
# with gr.Group():
# input_image = gr.ImageEditor(
# label="参考图像",
# type="pil",
# brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
# layers = False,
# interactive=True,
# # height=1024,
# height=408,
# sources=["upload"],
# placeholder="🫧 点击此处或下面的图标上传图像 🫧",
# )
# prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=2)
# with gr.Group():
# with gr.Row():
# mask_button = gr.Button("💎 掩膜生成")
# invert_mask_button = gr.Button("👐 掩膜翻转")
# # random_mask_button = gr.Button("⭕️ 随机掩膜")
# with gr.Row():
# masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
# mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
# with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
# dilation_size = gr.Slider(
# label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
# )
# with gr.Row():
# dilation_mask_button = gr.Button("放大掩膜")
# erosion_mask_button = gr.Button("缩小掩膜")
# moving_pixels = gr.Slider(
# label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
# )
# with gr.Row():
# move_left_button = gr.Button("左移")
# move_right_button = gr.Button("右移")
# with gr.Row():
# move_up_button = gr.Button("上移")
# move_down_button = gr.Button("下移")
# with gr.Column():
# with gr.Row():
# deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1, container=False, type="password")
# verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
# blip2_output = gr.Textbox(label="1. 原图描述(BLIP2生成)", placeholder="🖼️ 上传图片后自动生成图片描述 🖼️", lines=2, interactive=True)
# with gr.Row():
# target_prompt = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击图片编辑同时生成增强描述 or 点击右侧按钮单独生成增强描述 🚀")
# enhance_button = gr.Button("✨ 智能整合")
# with gr.Row():
# decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述 📝")
# decompose_button = gr.Button("🔧 结构分解")
# with gr.Group():
# run_button = gr.Button("💫 图像编辑")
# result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
# with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
# vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
# with gr.Group():
# with gr.Row():
# # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
# GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="", lines=2, container=False, type="password")
# GPT4o_KEY_submit = gr.Button("🔑 验证密钥", scale=0)
# aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
# resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
# base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
# negative_prompt = gr.Textbox(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
# control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
# with gr.Group():
# seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
# randomize_seed = gr.Checkbox(label="随机种子", value=False)
# blending = gr.Checkbox(label="混合模式", value=True)
# num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
# with gr.Group():
# with gr.Row():
# guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
# num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
# # target_prompt = gr.Textbox(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
# init_type = gr.Textbox(label="Init Name", value="", visible=False)
# example_type = gr.Textbox(label="Example Name", value="", visible=False)
# with gr.Row():
# reset_button = gr.Button("Reset")
# retrieve_button = gr.Button("🔍 开始检索")
# with gr.Row():
# retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=660)
# with gr.Row():
# example = gr.Examples(
# label="Quick Example",
# examples=EXAMPLES,
# inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
# examples_per_page=10,
# cache_examples=False,
# visible=False
# )
# with gr.Accordion(label="🎬 隐藏玩法大公开:", open=True, elem_id="accordion"):
# with gr.Row(equal_height=True):
# gr.HTML(tips)
# with gr.Row():
# gr.HTML(citation)
# ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
# ## And we need to solve the conflict between the upload and change example functions.
# input_image.upload(
# init_img,
# [input_image, init_type, prompt, aspect_ratio, example_change_times],
# [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
# )
# example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
# ## vlm and base model dropdown
# vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
# base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
# GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
# ips=[input_image,
# original_image,
# original_mask,
# prompt,
# negative_prompt,
# control_strength,
# seed,
# randomize_seed,
# guidance_scale,
# num_inference_steps,
# num_samples,
# blending,
# category,
# target_prompt,
# resize_default,
# aspect_ratio,
# invert_mask_state]
# ## run brushedit
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
# ## mask func
# mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
# # random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
# dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
# erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
# invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
# ## reset func
# reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, blip2_output, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, blip2_output, target_prompt, resize_default, invert_mask_state])
# input_image.upload(fn=generate_blip2_description, inputs=[input_image], outputs=[blip2_description, blip2_output])
# verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
# # enhance_button.click(fn=enhance_description, inputs=[blip2_output, prompt], outputs=[enhanced_description, enhanced_output])
# # decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
# # retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
# enhance_button.click(fn=enhance_description, inputs=[blip2_output, prompt], outputs=[enhanced_description, target_prompt])
# decompose_button.click(fn=decompose_description, inputs=[target_prompt], outputs=[decomposed_description, decomposed_output])
# retrieve_button.click(fn=mix_and_search, inputs=[target_prompt, result_gallery], outputs=[retrieve_gallery])
# demo.launch(server_name="0.0.0.0", server_port=12345, share=True)