Spaces:
Runtime error
Runtime error
##!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import os, random, sys | |
import numpy as np | |
import requests | |
import torch | |
from pathlib import Path | |
import pandas as pd | |
import concurrent.futures | |
import faiss | |
import gradio as gr | |
from PIL import Image | |
import torch.nn.functional as F # 新增此行 | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from scipy.ndimage import binary_dilation, binary_erosion | |
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, | |
Qwen2VLForConditionalGeneration, Qwen2VLProcessor) | |
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator | |
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler | |
from diffusers.image_processor import VaeImageProcessor | |
from app.src.vlm_pipeline import ( | |
vlm_response_editing_type, | |
vlm_response_object_wait_for_edit, | |
vlm_response_mask, | |
vlm_response_prompt_after_apply_instruction | |
) | |
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline | |
from app.utils.utils import load_grounding_dino_model | |
from app.src.vlm_template import vlms_template | |
from app.src.base_model_template import base_models_template | |
from app.src.aspect_ratio_template import aspect_ratios | |
from openai import OpenAI | |
base_openai_url = "https://api.deepseek.com/" | |
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc" | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from transformers import CLIPProcessor, CLIPModel | |
from app.deepseek.instructions import ( | |
create_apply_editing_messages_deepseek, | |
create_decomposed_query_messages_deepseek | |
) | |
from clip_retrieval.clip_client import ClipClient | |
#### Description #### | |
logo = r""" | |
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center> | |
""" | |
head = r""" | |
<div style="text-align: center;"> | |
<h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a> | |
<a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a> | |
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a> | |
</div> | |
</br> | |
</div> | |
""" | |
descriptions = r""" | |
Demo for ZS-CIR""" | |
instructions = r""" | |
Demo for ZS-CIR""" | |
tips = r""" | |
Demo for ZS-CIR | |
""" | |
citation = r""" | |
Demo for ZS-CIR""" | |
# - - - - - examples - - - - - # | |
EXAMPLES = [ | |
[ | |
Image.open("./assets/frog/frog.jpeg").convert("RGBA"), | |
"add a magic hat on frog head.", | |
642087011, | |
"frog", | |
"frog", | |
True, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"), | |
"replace the background to ancient China.", | |
648464818, | |
"chinese_girl", | |
"chinese_girl", | |
True, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"), | |
"remove the deer.", | |
648464818, | |
"angel_christmas", | |
"angel_christmas", | |
False, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"), | |
"add a wreath on head.", | |
648464818, | |
"sunflower_girl", | |
"sunflower_girl", | |
True, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"), | |
"add a butterfly fairy.", | |
648464818, | |
"girl_on_sun", | |
"girl_on_sun", | |
True, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"), | |
"remove the christmas hat.", | |
642087011, | |
"spider_man_rm", | |
"spider_man_rm", | |
False, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"), | |
"remove the flower.", | |
642087011, | |
"anime_flower", | |
"anime_flower", | |
False, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"), | |
"replace the clothes to a delicated floral skirt.", | |
648464818, | |
"chenduling", | |
"chenduling", | |
True, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
[ | |
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"), | |
"make the hedgehog in Italy.", | |
648464818, | |
"hedgehog_rp_bg", | |
"hedgehog_rp_bg", | |
True, | |
False, | |
"GPT4-o (Highly Recommended)" | |
], | |
] | |
INPUT_IMAGE_PATH = { | |
"frog": "./assets/frog/frog.jpeg", | |
"chinese_girl": "./assets/chinese_girl/chinese_girl.png", | |
"angel_christmas": "./assets/angel_christmas/angel_christmas.png", | |
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png", | |
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png", | |
"spider_man_rm": "./assets/spider_man_rm/spider_man.png", | |
"anime_flower": "./assets/anime_flower/anime_flower.png", | |
"chenduling": "./assets/chenduling/chengduling.jpg", | |
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png", | |
} | |
MASK_IMAGE_PATH = { | |
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png", | |
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png", | |
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", | |
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", | |
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png", | |
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", | |
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png", | |
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", | |
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", | |
} | |
MASKED_IMAGE_PATH = { | |
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png", | |
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png", | |
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", | |
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", | |
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png", | |
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", | |
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png", | |
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", | |
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", | |
} | |
OUTPUT_IMAGE_PATH = { | |
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png", | |
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png", | |
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png", | |
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png", | |
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png", | |
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png", | |
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png", | |
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png", | |
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png", | |
} | |
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir' | |
# os.makedirs('gradio_temp_dir', exist_ok=True) | |
VLM_MODEL_NAMES = list(vlms_template.keys()) | |
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)" | |
BASE_MODELS = list(base_models_template.keys()) | |
DEFAULT_BASE_MODEL = "realisticVision (Default)" | |
ASPECT_RATIO_LABELS = list(aspect_ratios) | |
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0] | |
## init device | |
try: | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif sys.platform == "darwin" and torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
except: | |
device = "cpu" | |
# ## init torch dtype | |
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
# torch_dtype = torch.bfloat16 | |
# else: | |
# torch_dtype = torch.float16 | |
# if device == "mps": | |
# torch_dtype = torch.float16 | |
torch_dtype = torch.float16 | |
# download hf models | |
BrushEdit_path = "models/" | |
if not os.path.exists(BrushEdit_path): | |
BrushEdit_path = snapshot_download( | |
repo_id="TencentARC/BrushEdit", | |
local_dir=BrushEdit_path, | |
token=os.getenv("HF_TOKEN"), | |
) | |
## init default VLM | |
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME] | |
if vlm_processor != "" and vlm_model != "": | |
vlm_model.to(device) | |
else: | |
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.") | |
## init default LLM | |
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url) | |
## init base model | |
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE") | |
brushnet_path = os.path.join(BrushEdit_path, "brushnetX") | |
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth") | |
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth") | |
# input brushnetX ckpt path | |
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype) | |
pipe = StableDiffusionBrushNetPipeline.from_pretrained( | |
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False | |
) | |
# speed up diffusion process with faster scheduler and memory optimization | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
# remove following line if xformers is not installed or when using Torch 2.0. | |
# pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_model_cpu_offload() | |
## init SAM | |
sam = build_sam(checkpoint=sam_path) | |
sam.to(device=device) | |
sam_predictor = SamPredictor(sam) | |
sam_automask_generator = SamAutomaticMaskGenerator(sam) | |
## init groundingdino_model | |
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py' | |
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device) | |
## Ordinary function | |
def crop_and_resize(image: Image.Image, | |
target_width: int, | |
target_height: int) -> Image.Image: | |
""" | |
Crops and resizes an image while preserving the aspect ratio. | |
Args: | |
image (Image.Image): Input PIL image to be cropped and resized. | |
target_width (int): Target width of the output image. | |
target_height (int): Target height of the output image. | |
Returns: | |
Image.Image: Cropped and resized image. | |
""" | |
# Original dimensions | |
original_width, original_height = image.size | |
original_aspect = original_width / original_height | |
target_aspect = target_width / target_height | |
# Calculate crop box to maintain aspect ratio | |
if original_aspect > target_aspect: | |
# Crop horizontally | |
new_width = int(original_height * target_aspect) | |
new_height = original_height | |
left = (original_width - new_width) / 2 | |
top = 0 | |
right = left + new_width | |
bottom = original_height | |
else: | |
# Crop vertically | |
new_width = original_width | |
new_height = int(original_width / target_aspect) | |
left = 0 | |
top = (original_height - new_height) / 2 | |
right = original_width | |
bottom = top + new_height | |
# Crop and resize | |
cropped_image = image.crop((left, top, right, bottom)) | |
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST) | |
return resized_image | |
## Ordinary function | |
def resize(image: Image.Image, | |
target_width: int, | |
target_height: int) -> Image.Image: | |
""" | |
Crops and resizes an image while preserving the aspect ratio. | |
Args: | |
image (Image.Image): Input PIL image to be cropped and resized. | |
target_width (int): Target width of the output image. | |
target_height (int): Target height of the output image. | |
Returns: | |
Image.Image: Cropped and resized image. | |
""" | |
# Original dimensions | |
resized_image = image.resize((target_width, target_height), Image.NEAREST) | |
return resized_image | |
def move_mask_func(mask, direction, units): | |
binary_mask = mask.squeeze()>0 | |
rows, cols = binary_mask.shape | |
moved_mask = np.zeros_like(binary_mask, dtype=bool) | |
if direction == 'down': | |
# move down | |
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :] | |
elif direction == 'up': | |
# move up | |
moved_mask[:rows - units, :] = binary_mask[units:, :] | |
elif direction == 'right': | |
# move left | |
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units] | |
elif direction == 'left': | |
# move right | |
moved_mask[:, :cols - units] = binary_mask[:, units:] | |
return moved_mask | |
def random_mask_func(mask, dilation_type='square', dilation_size=20): | |
# Randomly select the size of dilation | |
binary_mask = mask.squeeze()>0 | |
if dilation_type == 'square_dilation': | |
structure = np.ones((dilation_size, dilation_size), dtype=bool) | |
dilated_mask = binary_dilation(binary_mask, structure=structure) | |
elif dilation_type == 'square_erosion': | |
structure = np.ones((dilation_size, dilation_size), dtype=bool) | |
dilated_mask = binary_erosion(binary_mask, structure=structure) | |
elif dilation_type == 'bounding_box': | |
# find the most left top and left bottom point | |
rows, cols = np.where(binary_mask) | |
if len(rows) == 0 or len(cols) == 0: | |
return mask # return original mask if no valid points | |
min_row = np.min(rows) | |
max_row = np.max(rows) | |
min_col = np.min(cols) | |
max_col = np.max(cols) | |
# create a bounding box | |
dilated_mask = np.zeros_like(binary_mask, dtype=bool) | |
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True | |
elif dilation_type == 'bounding_ellipse': | |
# find the most left top and left bottom point | |
rows, cols = np.where(binary_mask) | |
if len(rows) == 0 or len(cols) == 0: | |
return mask # return original mask if no valid points | |
min_row = np.min(rows) | |
max_row = np.max(rows) | |
min_col = np.min(cols) | |
max_col = np.max(cols) | |
# calculate the center and axis length of the ellipse | |
center = ((min_col + max_col) // 2, (min_row + max_row) // 2) | |
a = (max_col - min_col) // 2 # half long axis | |
b = (max_row - min_row) // 2 # half short axis | |
# create a bounding ellipse | |
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]] | |
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1 | |
dilated_mask = np.zeros_like(binary_mask, dtype=bool) | |
dilated_mask[ellipse_mask] = True | |
else: | |
ValueError("dilation_type must be 'square' or 'ellipse'") | |
# use binary dilation | |
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255 | |
return dilated_mask | |
## Gradio component function | |
def update_vlm_model(vlm_name): | |
global vlm_model, vlm_processor | |
if vlm_model is not None: | |
del vlm_model | |
torch.cuda.empty_cache() | |
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name] | |
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py | |
if vlm_type == "llava-next": | |
if vlm_processor != "" and vlm_model != "": | |
vlm_model.to(device) | |
return vlm_model_dropdown | |
else: | |
if os.path.exists(vlm_local_path): | |
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path) | |
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") | |
else: | |
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)": | |
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto") | |
elif vlm_name == "llama3-llava-next-8b-hf (Preload)": | |
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf") | |
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto") | |
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)": | |
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf") | |
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto") | |
elif vlm_name == "llava-v1.6-34b-hf (Preload)": | |
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf") | |
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto") | |
elif vlm_name == "llava-next-72b-hf (Preload)": | |
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf") | |
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto") | |
elif vlm_type == "qwen2-vl": | |
if vlm_processor != "" and vlm_model != "": | |
vlm_model.to(device) | |
return vlm_model_dropdown | |
else: | |
if os.path.exists(vlm_local_path): | |
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path) | |
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") | |
else: | |
if vlm_name == "qwen2-vl-2b-instruct (Preload)": | |
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto") | |
elif vlm_name == "qwen2-vl-7b-instruct (Preload)": | |
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto") | |
elif vlm_name == "qwen2-vl-72b-instruct (Preload)": | |
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct") | |
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto") | |
elif vlm_type == "openai": | |
pass | |
return "success" | |
def update_base_model(base_model_name): | |
global pipe | |
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py | |
if pipe is not None: | |
del pipe | |
torch.cuda.empty_cache() | |
base_model_path, pipe = base_models_template[base_model_name] | |
if pipe != "": | |
pipe.to(device) | |
else: | |
if os.path.exists(base_model_path): | |
pipe = StableDiffusionBrushNetPipeline.from_pretrained( | |
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False | |
) | |
# pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_model_cpu_offload() | |
else: | |
raise gr.Error(f"The base model {base_model_name} does not exist") | |
return "success" | |
def process_random_mask(input_image, | |
original_image, | |
original_mask, | |
resize_default, | |
aspect_ratio_name, | |
): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse']) | |
random_mask = random_mask_func(original_mask, dilation_type).squeeze() | |
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") | |
masked_image = original_image * (1 - (random_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) | |
def process_dilation_mask(input_image, | |
original_image, | |
original_mask, | |
resize_default, | |
aspect_ratio_name, | |
dilation_size=20): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
dilation_type = np.random.choice(['square_dilation']) | |
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() | |
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") | |
masked_image = original_image * (1 - (random_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) | |
def process_erosion_mask(input_image, | |
original_image, | |
original_mask, | |
resize_default, | |
aspect_ratio_name, | |
dilation_size=20): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
dilation_type = np.random.choice(['square_erosion']) | |
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() | |
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") | |
masked_image = original_image * (1 - (random_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) | |
def move_mask_left(input_image, | |
original_image, | |
original_mask, | |
moving_pixels, | |
resize_default, | |
aspect_ratio_name): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze() | |
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") | |
masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
if moved_mask.max() <= 1: | |
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) | |
original_mask = moved_mask | |
return [masked_image], [mask_image], original_mask.astype(np.uint8) | |
def move_mask_right(input_image, | |
original_image, | |
original_mask, | |
moving_pixels, | |
resize_default, | |
aspect_ratio_name): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze() | |
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") | |
masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
if moved_mask.max() <= 1: | |
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) | |
original_mask = moved_mask | |
return [masked_image], [mask_image], original_mask.astype(np.uint8) | |
def move_mask_up(input_image, | |
original_image, | |
original_mask, | |
moving_pixels, | |
resize_default, | |
aspect_ratio_name): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze() | |
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") | |
masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
if moved_mask.max() <= 1: | |
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) | |
original_mask = moved_mask | |
return [masked_image], [mask_image], original_mask.astype(np.uint8) | |
def move_mask_down(input_image, | |
original_image, | |
original_mask, | |
moving_pixels, | |
resize_default, | |
aspect_ratio_name): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze() | |
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") | |
masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
if moved_mask.max() <= 1: | |
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) | |
original_mask = moved_mask | |
return [masked_image], [mask_image], original_mask.astype(np.uint8) | |
def invert_mask(input_image, | |
original_image, | |
original_mask, | |
): | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
if input_mask.max() == 0: | |
original_mask = 1 - (original_mask>0).astype(np.uint8) | |
else: | |
original_mask = 1 - (input_mask>0).astype(np.uint8) | |
if original_mask is None: | |
raise gr.Error('Please generate mask first') | |
original_mask = original_mask.squeeze() | |
mask_image = Image.fromarray(original_mask*255).convert("RGB") | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
if original_mask.max() <= 1: | |
original_mask = (original_mask * 255).astype(np.uint8) | |
masked_image = original_image * (1 - (original_mask>0)) | |
masked_image = masked_image.astype(original_image.dtype) | |
masked_image = Image.fromarray(masked_image) | |
return [masked_image], [mask_image], original_mask, True | |
def init_img(base, | |
init_type, | |
prompt, | |
aspect_ratio, | |
example_change_times | |
): | |
image_pil = base["background"].convert("RGB") | |
original_image = np.array(image_pil) | |
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0: | |
raise gr.Error('image aspect ratio cannot be larger than 2.0') | |
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2: | |
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")] | |
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")] | |
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")] | |
width, height = image_pil.size | |
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) | |
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) | |
image_pil = image_pil.resize((width_new, height_new)) | |
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) | |
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) | |
result_gallery[0] = result_gallery[0].resize((width_new, height_new)) | |
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 | |
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times | |
else: | |
if aspect_ratio not in ASPECT_RATIO_LABELS: | |
aspect_ratio = "Custom resolution" | |
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0 | |
def reset_func(input_image, | |
original_image, | |
original_mask, | |
prompt, | |
target_prompt, | |
): | |
input_image = None | |
original_image = None | |
original_mask = None | |
prompt = '' | |
mask_gallery = [] | |
masked_gallery = [] | |
result_gallery = [] | |
target_prompt = '' | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False | |
def update_example(example_type, | |
prompt, | |
example_change_times): | |
input_image = INPUT_IMAGE_PATH[example_type] | |
image_pil = Image.open(input_image).convert("RGB") | |
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")] | |
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")] | |
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")] | |
width, height = image_pil.size | |
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) | |
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) | |
image_pil = image_pil.resize((width_new, height_new)) | |
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) | |
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) | |
result_gallery[0] = result_gallery[0].resize((width_new, height_new)) | |
original_image = np.array(image_pil) | |
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 | |
aspect_ratio = "Custom resolution" | |
example_change_times += 1 | |
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times | |
def generate_target_prompt(input_image, | |
original_image, | |
prompt): | |
# load example image | |
if isinstance(original_image, str): | |
original_image = input_image | |
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( | |
vlm_processor, | |
vlm_model, | |
original_image, | |
prompt, | |
device) | |
return prompt_after_apply_instruction | |
def process_mask(input_image, | |
original_image, | |
prompt, | |
resize_default, | |
aspect_ratio_name): | |
if original_image is None: | |
raise gr.Error('Please upload the input image') | |
if prompt is None: | |
raise gr.Error("Please input your instructions, e.g., remove the xxx") | |
## load mask | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.array(alpha_mask) | |
# load example image | |
if isinstance(original_image, str): | |
original_image = input_image["background"] | |
if input_mask.max() == 0: | |
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) | |
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor, | |
vlm_model, | |
original_image, | |
category, | |
prompt, | |
device) | |
# original mask: h,w,1 [0, 255] | |
original_mask = vlm_response_mask( | |
vlm_processor, | |
vlm_model, | |
category, | |
original_image, | |
prompt, | |
object_wait_for_edit, | |
sam, | |
sam_predictor, | |
sam_automask_generator, | |
groundingdino_model, | |
device).astype(np.uint8) | |
else: | |
original_mask = input_mask.astype(np.uint8) | |
category = None | |
## resize mask if needed | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB") | |
masked_image = original_image * (1 - (original_mask>0)) | |
masked_image = masked_image.astype(np.uint8) | |
masked_image = Image.fromarray(masked_image) | |
return [masked_image], [mask_image], original_mask.astype(np.uint8), category | |
def process(input_image, | |
original_image, | |
original_mask, | |
prompt, | |
negative_prompt, | |
control_strength, | |
seed, | |
randomize_seed, | |
guidance_scale, | |
num_inference_steps, | |
num_samples, | |
blending, | |
category, | |
target_prompt, | |
resize_default, | |
aspect_ratio_name, | |
invert_mask_state): | |
if original_image is None: | |
if input_image is None: | |
raise gr.Error('Please upload the input image') | |
else: | |
image_pil = input_image["background"].convert("RGB") | |
original_image = np.array(image_pil) | |
if prompt is None or prompt == "": | |
if target_prompt is None or target_prompt == "": | |
raise gr.Error("Please input your instructions, e.g., remove the xxx") | |
alpha_mask = input_image["layers"][0].split()[3] | |
input_mask = np.asarray(alpha_mask) | |
output_w, output_h = aspect_ratios[aspect_ratio_name] | |
if output_w == "" or output_h == "": | |
output_h, output_w = original_image.shape[:2] | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
else: | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
pass | |
else: | |
if resize_default: | |
short_side = min(output_w, output_h) | |
scale_ratio = 640 / short_side | |
output_w = int(output_w * scale_ratio) | |
output_h = int(output_h * scale_ratio) | |
gr.Info(f"Output aspect ratio: {output_w}:{output_h}") | |
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) | |
original_image = np.array(original_image) | |
if input_mask is not None: | |
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) | |
input_mask = np.array(input_mask) | |
if original_mask is not None: | |
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) | |
original_mask = np.array(original_mask) | |
if invert_mask_state: | |
original_mask = original_mask | |
else: | |
if input_mask.max() == 0: | |
original_mask = original_mask | |
else: | |
original_mask = input_mask | |
# inpainting directly if target_prompt is not None | |
if category is not None: | |
pass | |
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None: | |
pass | |
else: | |
try: | |
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) | |
except Exception as e: | |
raise gr.Error("Please select the correct VLM model and input the correct API Key first!") | |
if original_mask is not None: | |
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8) | |
else: | |
try: | |
object_wait_for_edit = vlm_response_object_wait_for_edit( | |
vlm_processor, | |
vlm_model, | |
original_image, | |
category, | |
prompt, | |
device) | |
original_mask = vlm_response_mask(vlm_processor, | |
vlm_model, | |
category, | |
original_image, | |
prompt, | |
object_wait_for_edit, | |
sam, | |
sam_predictor, | |
sam_automask_generator, | |
groundingdino_model, | |
device).astype(np.uint8) | |
except Exception as e: | |
raise gr.Error("Please select the correct VLM model and input the correct API Key first!") | |
if original_mask.ndim == 2: | |
original_mask = original_mask[:,:,None] | |
if target_prompt is not None and len(target_prompt) >= 1: | |
prompt_after_apply_instruction = target_prompt | |
else: | |
try: | |
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( | |
vlm_processor, | |
vlm_model, | |
original_image, | |
prompt, | |
device) | |
except Exception as e: | |
raise gr.Error("Please select the correct VLM model and input the correct API Key first!") | |
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed) | |
with torch.autocast(device): | |
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe, | |
prompt_after_apply_instruction, | |
original_mask, | |
original_image, | |
generator, | |
num_inference_steps, | |
guidance_scale, | |
control_strength, | |
negative_prompt, | |
num_samples, | |
blending) | |
original_image = np.array(init_image_np) | |
masked_image = original_image * (1 - (mask_np>0)) | |
masked_image = masked_image.astype(np.uint8) | |
masked_image = Image.fromarray(masked_image) | |
# Save the images (optional) | |
# import uuid | |
# uuid = str(uuid.uuid4()) | |
# image[0].save(f"outputs/image_edit_{uuid}_0.png") | |
# image[1].save(f"outputs/image_edit_{uuid}_1.png") | |
# image[2].save(f"outputs/image_edit_{uuid}_2.png") | |
# image[3].save(f"outputs/image_edit_{uuid}_3.png") | |
# mask_image.save(f"outputs/mask_{uuid}.png") | |
# masked_image.save(f"outputs/masked_image_{uuid}.png") | |
# gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20) | |
return image, [mask_image], [masked_image], prompt, '', False | |
# 新增事件处理函数 | |
def generate_blip_description(input_image): | |
if input_image is None: | |
return "", "Input image cannot be None" | |
try: | |
image_pil = input_image["background"].convert("RGB") | |
except KeyError: | |
return "", "Input image missing 'background' key" | |
except AttributeError as e: | |
return "", f"Invalid image object: {str(e)}" | |
try: | |
description = generate_caption(blip_processor, blip_model, image_pil, device) | |
return description, description # 同时更新state和显示组件 | |
except Exception as e: | |
return "", f"Caption generation failed: {str(e)}" | |
from app.utils.utils import generate_caption | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device) | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device) | |
def submit_GPT4o_KEY(GPT4o_KEY): | |
global vlm_model, vlm_processor | |
if vlm_model is not None: | |
del vlm_model | |
torch.cuda.empty_cache() | |
try: | |
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com") | |
vlm_processor = "" | |
response = vlm_model.chat.completions.create( | |
model="deepseek-chat", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Hello."} | |
] | |
) | |
response_str = response.choices[0].message.content | |
return "Success. " + response_str, "GPT4-o (Highly Recommended)" | |
except Exception as e: | |
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)" | |
def verify_deepseek_api(): | |
try: | |
response = llm_model.chat.completions.create( | |
model="deepseek-chat", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Hello."} | |
] | |
) | |
response_str = response.choices[0].message.content | |
return True, "Success. " + response_str | |
except Exception as e: | |
return False, "Invalid DeepSeek API Key" | |
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt): | |
try: | |
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt) | |
response = llm_model.chat.completions.create( | |
model="deepseek-chat", | |
messages=messages | |
) | |
response_str = response.choices[0].message.content | |
return response_str | |
except Exception as e: | |
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息") | |
def llm_decomposed_prompt_after_apply_instruction(integrated_query): | |
try: | |
messages = create_decomposed_query_messages_deepseek(integrated_query) | |
response = llm_model.chat.completions.create( | |
model="deepseek-chat", | |
messages=messages | |
) | |
response_str = response.choices[0].message.content | |
return response_str | |
except Exception as e: | |
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息") | |
def enhance_description(blip_description, prompt): | |
try: | |
if not prompt or not blip_description: | |
print("Empty prompt or blip_description detected") | |
return "", "" | |
print(f"Enhancing with prompt: {prompt}") | |
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt) | |
return enhanced_description, enhanced_description | |
except Exception as e: | |
print(f"Enhancement failed: {str(e)}") | |
return "Error occurred", "Error occurred" | |
def decompose_description(enhanced_description): | |
try: | |
if not enhanced_description: | |
print("Empty enhanced_description detected") | |
return "", "" | |
print(f"Decomposing the enhanced description: {enhanced_description}") | |
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description) | |
return decomposed_description, decomposed_description | |
except Exception as e: | |
print(f"Decomposition failed: {str(e)}") | |
return "Error occurred", "Error occurred" | |
def mix_and_search(enhanced_text: str, gallery_images: list): | |
# 获取最新生成的图像元组 | |
latest_item = gallery_images[-1] if gallery_images else None | |
# 初始化特征列表 | |
features = [] | |
# 图像特征提取 | |
if latest_item and isinstance(latest_item, tuple): | |
try: | |
image_path = latest_item[0] | |
pil_image = Image.open(image_path).convert("RGB") | |
# 使用 CLIPProcessor 处理图像 | |
image_inputs = clip_processor( | |
images=pil_image, | |
return_tensors="pt" | |
).to(device) | |
image_features = clip_model.get_image_features(**image_inputs) | |
features.append(F.normalize(image_features, dim=-1)) | |
except Exception as e: | |
print(f"图像处理失败: {str(e)}") | |
# 文本特征提取 | |
if enhanced_text.strip(): | |
text_inputs = clip_processor( | |
text=enhanced_text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
).to(device) | |
text_features = clip_model.get_text_features(**text_inputs) | |
features.append(F.normalize(text_features, dim=-1)) | |
if not features: | |
return "## 错误:请先完成图像编辑并生成描述", [] | |
# 特征融合与检索 | |
mixed = sum(features) / len(features) | |
mixed = F.normalize(mixed, dim=-1) | |
# 加载Faiss索引和图片路径映射 | |
index_path = "/home/zt/data/open-images/train/knn.index" | |
input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata") | |
base_image_dir = Path("/home/zt/data/open-images/train/") | |
# 按文件名中的数字排序并直接读取parquet文件 | |
parquet_files = sorted( | |
input_data_dir.glob('*.parquet'), | |
key=lambda x: int(x.stem.split("_")[-1]) | |
) | |
# 合并所有parquet数据 | |
dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取 | |
df = pd.concat(dfs, ignore_index=True) | |
image_paths = df["image_path"].tolist() | |
# 读取Faiss索引 | |
index = faiss.read_index(index_path) | |
assert mixed.shape[1] == index.d, "特征维度不匹配" | |
# 执行检索 | |
mixed = mixed.cpu().detach().numpy().astype('float32') | |
distances, indices = index.search(mixed, 5) | |
# 获取并验证图片路径 | |
retrieved_images = [] | |
for idx in indices[0]: | |
if 0 <= idx < len(image_paths): | |
img_path = base_image_dir / image_paths[idx] | |
try: | |
if img_path.exists(): | |
retrieved_images.append(Image.open(img_path).convert("RGB")) | |
else: | |
print(f"警告:文件缺失 {img_path}") | |
except Exception as e: | |
print(f"图片加载失败: {str(e)}") | |
return "## 检索到以下相似图片:", retrieved_images if retrieved_images else ("## 未找到匹配的图片", []) | |
block = gr.Blocks( | |
theme=gr.themes.Soft( | |
radius_size=gr.themes.sizes.radius_none, | |
text_size=gr.themes.sizes.text_md | |
) | |
) | |
with block as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(head) | |
gr.Markdown(descriptions) | |
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"): | |
with gr.Row(equal_height=True): | |
gr.Markdown(instructions) | |
original_image = gr.State(value=None) | |
original_mask = gr.State(value=None) | |
category = gr.State(value=None) | |
status = gr.State(value=None) | |
invert_mask_state = gr.State(value=False) | |
example_change_times = gr.State(value=0) | |
deepseek_verified = gr.State(value=False) | |
blip_description = gr.State(value="") | |
enhanced_description = gr.State(value="") | |
decomposed_description = gr.State(value="") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.ImageEditor( | |
label="参考图像", | |
type="pil", | |
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"), | |
layers = False, | |
interactive=True, | |
# height=1024, | |
height=512, | |
sources=["upload"], | |
placeholder="🫧 点击此处或下面的图标上传图像 🫧", | |
) | |
prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1) | |
run_button = gr.Button("💫 图像编辑") | |
vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True) | |
with gr.Group(): | |
with gr.Row(): | |
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1) | |
GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1) | |
GPT4o_KEY_submit = gr.Button("🙈 验证") | |
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO) | |
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True) | |
with gr.Row(): | |
mask_button = gr.Button("💎 掩膜生成") | |
random_mask_button = gr.Button("Square/Circle Mask ") | |
# 在分解按钮后添加 | |
with gr.Group(): | |
with gr.Row(): | |
retrieve_button = gr.Button("🔍 开始检索") | |
with gr.Row(): | |
retrieve_output = gr.Markdown(elem_id="accordion") | |
with gr.Row(): | |
retrieve_gallery = gr.Gallery(label="🎊 检索结果",show_label=True, elem_id="gallery", preview=True, height=400) # 新增Gallery组件 | |
with gr.Row(): | |
generate_target_prompt_button = gr.Button("Generate Target Prompt") | |
target_prompt = gr.Text( | |
label="Input Target Prompt", | |
max_lines=5, | |
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", | |
value='', | |
lines=2 | |
) | |
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"): | |
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True) | |
negative_prompt = gr.Text( | |
label="Negative Prompt", | |
max_lines=5, | |
placeholder="Please input your negative prompt", | |
value='ugly, low quality',lines=1 | |
) | |
control_strength = gr.Slider( | |
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01 | |
) | |
with gr.Group(): | |
seed = gr.Slider( | |
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818 | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
blending = gr.Checkbox(label="Blending mode", value=True) | |
num_samples = gr.Slider( | |
label="Num samples", minimum=0, maximum=4, step=1, value=4 | |
) | |
with gr.Group(): | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=1, | |
maximum=12, | |
step=0.1, | |
value=7.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=50, | |
) | |
with gr.Group(visible=True): | |
# BLIP生成的描述 | |
blip_output = gr.Textbox(label="原图描述", placeholder="💭 BLIP生成的图像基础描述 💭", interactive=True, lines=1) | |
# DeepSeek API验证 | |
with gr.Row(): | |
deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1) | |
verify_deepseek = gr.Button("🙈 验证") | |
# 整合后的描述区域 | |
with gr.Row(): | |
enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3) | |
enhance_button = gr.Button("✨ 整合") | |
# 分解后的描述区域 | |
with gr.Row(): | |
decomposed_output = gr.Textbox(label="描述分解", placeholder="💭 DeepSeek生成的分解描述 💭", interactive=True, lines=3) | |
decompose_button = gr.Button("🔧 分解") | |
with gr.Row(): | |
with gr.Tab(elem_classes="feedback", label="Masked Image"): | |
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360) | |
with gr.Tab(elem_classes="feedback", label="Mask"): | |
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360) | |
invert_mask_button = gr.Button("Invert Mask") | |
dilation_size = gr.Slider( | |
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20 | |
) | |
with gr.Row(): | |
dilation_mask_button = gr.Button("Dilation Generated Mask") | |
erosion_mask_button = gr.Button("Erosion Generated Mask") | |
moving_pixels = gr.Slider( | |
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1 | |
) | |
with gr.Row(): | |
move_left_button = gr.Button("Move Left") | |
move_right_button = gr.Button("Move Right") | |
with gr.Row(): | |
move_up_button = gr.Button("Move Up") | |
move_down_button = gr.Button("Move Down") | |
with gr.Tab(elem_classes="feedback", label="Output"): | |
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400) | |
target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False) | |
reset_button = gr.Button("Reset") | |
init_type = gr.Textbox(label="Init Name", value="", visible=False) | |
example_type = gr.Textbox(label="Example Name", value="", visible=False) | |
with gr.Row(): | |
example = gr.Examples( | |
label="Quick Example", | |
examples=EXAMPLES, | |
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown], | |
examples_per_page=10, | |
cache_examples=False, | |
) | |
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"): | |
with gr.Row(equal_height=True): | |
gr.Markdown(tips) | |
with gr.Row(): | |
gr.Markdown(citation) | |
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery. | |
## And we need to solve the conflict between the upload and change example functions. | |
input_image.upload( | |
init_img, | |
[input_image, init_type, prompt, aspect_ratio, example_change_times], | |
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times] | |
) | |
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times]) | |
## vlm and base model dropdown | |
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status]) | |
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status]) | |
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown]) | |
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state]) | |
ips=[input_image, | |
original_image, | |
original_mask, | |
prompt, | |
negative_prompt, | |
control_strength, | |
seed, | |
randomize_seed, | |
guidance_scale, | |
num_inference_steps, | |
num_samples, | |
blending, | |
category, | |
target_prompt, | |
resize_default, | |
aspect_ratio, | |
invert_mask_state] | |
## run brushedit | |
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state]) | |
## mask func | |
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category]) | |
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) | |
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) | |
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) | |
## reset func | |
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state]) | |
# 绑定事件处理 | |
input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output]) | |
verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key]) | |
enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output]) | |
decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output]) | |
# 修改事件绑定 | |
retrieve_button.click( | |
fn=mix_and_search, | |
inputs=[enhanced_output, result_gallery], | |
outputs=[retrieve_output, retrieve_gallery] | |
) | |
demo.launch(server_name="0.0.0.0", server_port=12345, share=True) | |